卷积神经网络|制作自己的Dataset
在编写代码训练神经网络之前,导入数据是必不可少的。PyTorch提供了许多预加载的数据集(如FashionMNIST),这些数据集 子类并实现特定于特定数据的函数。?
它们可用于对模型进行原型设计和基准测试,加载这些数据集是十分简单的。好吧,那如何加载自己制作的数据集呢?
简单来讲,自定义数据集类必须实现三个函数:__init__、__len__和__getitem__。下面代码就实现了一个Dataset
import osimport torchfrom torch.utils.data import Datasetfrom torchvision import transformsfrom PIL import Imageimport numpy as np?class MyDataset(Dataset):def __init__(self, path_file,transform=None,label_transform=None):self.path_file=path_fileself.imgs=[name for name in os.listdir(path_file)]#获取path_file路径下所有文件名self.transform = transformself.label_transform = label_transform?def __len__(self):return len(self.imgs)?def __getitem__(self, idx):#get the imageimg_path = os.path.join(self.path_file,self.imgs[idx])#获得图片完整路径image=Image.open(img_path)image=image.resize((28,28))#修改图片为默认大小image = np.array(image)image=torch.from_numpy(image)#将numpy数组转换为张量image=image.permute(2,0,1)#将H,W,C转换为C,H,W?if self.transform:image = self.transform(image)?#get the labelstr1=self.imgs[idx].split('.')label=torch.tensor(eval(str1[1]))?if self.label_transform:label=self.label_transform(label)?return image, label
注:上述代码从路径path_file读取文件,准确来讲应该是我们准备的训练图片,格式如下:? ? ?
? ? ? ??? ? ? ? ?cat1.0.jpg
? ? ? ? ? ? ? ? ? cat2.0.jpg
? ? ? ? ? ? ? ? ? ...
? ? ? ? ? ? ? ? ? dog1.1.jpg
? ? ? ? ? ? ? ? ? dog2.1.jpg
? ? ? ? ? ? ? ? ? ...
图片名重要含义:类别(0,1等)
而cat1,dog1这些并不重要,因为0,1,已经反映了图片的类别,这里仅仅是一个习惯,同样jpg也是如此。
实际上,在我们准备图片时,图片名往往不是这样,但直接写个简单的文件处理程序便很容易转变为上述格式。
之所以这样命名,就是为容易获得图片和对应的类别,也就是实现自己的Dataset。当然,其它还有许多方法,但核心就是加载自己的数据时获得图片和对应的类别。
再次看一下实现自己的Dataset的架构:
class CustomImageDataset(Dataset):def __init__(self, path_file, transform=None, target_transform=None):.........?def __len__(self):return len(...)???def __getitem__(self, idx):.........if self.transform:image = self.transform(image)if self.label_transform:label = self.label_transform(label)????????return?image,?label
在训练模型时,我们通常希望 在“小批量”中传递样本,在每个时期重新洗牌数据以减少模型过度拟合,并使用 Python 的 加快数据检索速度。
DataLoader是一个迭代对象,它在一个简单的 API 中为我们抽象了这种复杂性。下面我们将Dataset带入DataLoader.
path="E:\\3-10\\dogandcats\\train"#图片所在目录training_data=MyDataset(path)train_dataloader?=?torch.utils.data.DataLoader(training_data,?batch_size=2,?shuffle=True)
让我们run一下:
>>> trainimg,label=next(iter(train_dataloader))>>> trainimg.size()torch.Size([2, 3, 28, 28])>>> label.size()torch.Size([2])
结果符合预期,与在使用pytorch预加载的数据集格式一样!

?
点点点,赞和在看都在这儿!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!