pytorch创建自己的Dataset加载数据集

文章目录

  • 创建一个类并继承torch.utils.data.dataset.Datase类
  • 创建__getitem__方法
  • 加载数据集

创建一个类并继承torch.utils.data.dataset.Datase类

class MyDataset(Dataset):
    '''
     data_path: 数据集路径
     img_size: 图片大小
     train_lines: 图片名数组
    '''
    def __init__(self,data_path,img_size,train_lines):
        super(MyDataset, self).__init__()

        self.data_path = data_path
        self.img_size = img_size
        self.train_lines = train_lines
        self.length = len(train_lines)

创建__getitem__方法

class MyDataset(Dataset):
    '''
     data_path: 数据集路径
     img_size: 图片大小
     train_lines: 图片名数组
    '''
    def __init__(self,data_path,img_size,train_lines):
        super(MyDataset, self).__init__()

        self.data_path = data_path
        self.img_size = img_size
        self.train_lines = train_lines
        
    def __getitem__(self, index):
        annotation_line = self.train_lines[index]
        name = annotation_line.split()[0]       # 获取图片名

        image = Image.open(os.path.join(os.path.join(self.data_path,"dem"),name+".tif"))
        label = Image.open(os.path.join(os.path.join(self.data_path, "label"), name + ".png"))

        image = np.array(image)
        label = np.array(label)

        image = cv2.resize(image,(self.img_size,self.img_size))
        label = cv2.resize(label,(self.img_size,self.img_size))

        # image = image[np.newaxis,:]     

        print("images size: {}, label size: {}".format(image.shape,label.shape))

        return image,label

加载数据集

如果不知道如何将文件夹中所有图片名称写入TXT中可以参考:python读取文件夹中的所有图片并将图片名逐行写入txt中:https://blog.csdn.net/weixin_43598687/article/details/125666776?spm=1001.2014.3001.5501

dataset_path = r"E:/workspace/PyCharmProject/dem_feature/dem/512"

# 打开数据集的txt, 逐行读取图片名
with open(os.path.join(dataset_path, "dem/train.txt"), "r") as f:
    train_lines = f.readlines()

with open(os.path.join(dataset_path, "dem/val.txt"), "r") as f:
    val_lines = f.readlines()

train_dataset = MyDataset(dataset_path, img_size=512,train_lines=train_lines)

train_dataloader = DataLoader(train_dataset,batch_size=8,shuffle=False)

for iteration,data in enumerate(train_dataloader):
    imgs,labels = data
    
    print(imgs,labels)

你可能感兴趣的