Pytorch如何加载自己的数据集(使用DataLoader读取Dataset)

2022-12-14 22:46:28
目录
1.Pytorch加载数据集会用到官方整理好的数据集2.Dataset3.DataLoader4.查看数据5.总结

1.Pytorch加载数据集会用到官方整理好的数据集

很多时候我们需要加载自己的数据集,这时候我们需要使用Dataset和DataLoader

    Dataset:是被封装进DataLoader里,实现该方法封装自己的数据和标签。DataLoader:被封装入DataLoaderIter里,实现该方法达到数据的划分。

    2.Dataset

    阅读源码后,我们可以指导,继承该方法必须实现两个方法:

      _getitem_()_len_()

      因此,在实现过程中我们测试如下:

      import torch
      import numpy as np
      
      # 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
      class GetLoader(torch.utils.data.Dataset):
      	# 初始化函数,得到数据
          def __init__(self, data_root, data_label):
              self.data = data_root
              self.label = data_label
          # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
          def __getitem__(self, index):
              data = self.data[index]
              labels = self.label[index]
              return data, labels
          # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
          def __len__(self):
              return len(self.data)
      
      # 随机生成数据,大小为10 * 20列
      source_data = np.random.rand(10, 20)
      # 随机生成标签,大小为10 * 1列
      source_label = np.random.randint(0,2,(10, 1))
      # 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
      torch_data = GetLoader(source_data, source_label)
      

      3.DataLoader

      提供对Dataset的操作,操作如下:

      torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
      

      参数含义如下:

        dataset:加载torch.utils.data.Dataset对象数据batch_size:每个batch的大小shuffle:是否对数据进行打乱drop_last:是否对无法整除的最后一个datasize进行丢弃num_workers:表示加载的时候子进程数

        因此,在实现过程中我们测试如下(紧跟上述用例):

        from torch.utils.data import DataLoader
        
        # 读取数据
        datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)
        

        此时,我们的数据已经加载完毕了,只需要在训练过程中使用即可。

        4.查看数据

        我们可以通过迭代器(enumerate)进行输出数据,测试如下:

        for i, data in enumerate(datas):
        	# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
            print("第 {} 个Batch \n{}".format(i, data))
        

        输出结果如下图:

        结果说明:由于数据的是10个,batchsize大小为6,且drop_last=False,因此第一个大小为6,第二个为4。

        每一个batch中包含data和对应的labels。

        当我们想取出data和对应的labels时候,只需要用下表就可以啦,测试如下:

        # 表示输出数据
        print(data[0])
        # 表示输出标签
        print(data[1])
        

        结果如图:

        5.总结

        以上为个人经验,希望能给大家一个参考,也希望大家多多支持易采站长站。