如何基于PyTorch框架自定义数据集类获取数据
在PyTorch框架中,可以通过自定义数据集类来加载和处理数据
要自定义数据集类,需要继承 PyTorch提供的 torch.utils.data.Dataset
类,并实现两个主要方法:__len__
和 __getitem__
下面是一个示例,展示如何基于PyTorch框架来自定义数据集类以获取数据:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
# 在这里对数据进行预处理、转换等操作
# 返回一个样本(通常是一个字典)
return item
# 创建数据集实例
data = [...] # 数据列表,包含训练样本
dataset = CustomDataset(data)
# 创建数据加载器
batch_size = 32
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 遍历数据加载器获取数据批次
for batch in dataloader:
# 处理每个批次的数据
inputs = batch['input']
labels = batch['label']
# 在这里进行模型训练、推理等操作
在此示例中,定义了一个名为 CustomDataset
的自定义数据集类,该类继承自torch.utils.data. Dataset
__init__
方法是构造函数,传入数据列表 data 并将其保存为类的属性 self.data
__len__
方法返回数据集的长度,即样本数量
__getitem__
方法通过索引获取单个样本
然后,创建了一个数据集实例 dataset
,并使用 torch.utils.data.DataLoader
创建了一个数据加载器 dataloader
通过遍历数据加载器可以获取每个批次 输入数据inputs 以及 标签数据labels,进行模型训练、推理等操作
注意:根据具体的应用需求,可以在__getitem__
方法中对数据进行预处理、转换等操作,并将处理后的样本作为字典或其他形式返回, 这样,在训练过程中可以方便地获取输入数据和标签数据 ,并进行相应的操作
下面再来看一个例子,该例通过在 __getitem__方法中对数据进行预处理,并最终返回一个包含图片数据、对应的标签数据以及图像文件名的字典
class BipedDataset(Dataset): # 定义了一个名为BipedDataset的类,它继承自PyTorch的Dataset类,用于自定义数据集
'''
用于构建一个自定义数据集,可以在训练神经网络时使用
它提供了加载图像、预处理数据等功能,以便用于深度学习模型的训练
'''
def __init__(self,
data_root,
img_height,
img_width,
mean_bgr, # 图像的均值(以BGR通道顺序表示)
train_mode='train', # 训练模式,可以是 'train' 或 'test' 之一,默认为 'train
crop_img=False,
arg=None
):
'''
这是类的构造函数,用于初始化对象的属性
它接受许多参数,包括数据根目录 data_root、图像高度 img_height、图像宽度 img_width、均值 mean_bgr、训练模式 train_mode 等
'''
self.data_root = data_root
self.train_mode = train_mode
self.img_height = img_height
self.img_width = img_width
self.mean_bgr = mean_bgr
self.crop_img = crop_img
self.arg = arg
self.data_index = self._build_index()
def _build_index(self): # 用于构建数据索引
data_root = os.path.abspath(self.data_root)
sample_indices = [] # 用于存储图像和标签的文件路径对
# 构建图像和标签的文件路径,其中 images_path 和 labels_path 分别指向数据集中图像和标签的存储路径
# 使用两层循环遍历图像目录中的所有文件,构建图像和标签的文件路径,并将其添加到 sample_indices 列表中
images_path = os.path.join(data_root,'edges\\imgs',self.train_mode)
labels_path = os.path.join(data_root,'edges\\labels',self.train_mode)
for file_name_ext in os.listdir(images_path):
file_name = os.path.splitext(file_name_ext)[0]
sample_indices.append(
( os.path.join(images_path, file_name + '.tif'),
os.path.join(labels_path, file_name + '.tif'), )
)
return sample_indices # 返回构建好的图像和标签的文件路径对列表
def __len__(self): # 返回数据集的长度,即样本的数量
return len(self.data_index)
def __getitem__(self, idx): # 用于获取指定索引处的数据样本,它接受一个索引 idx 作为参数
# get data sample
'''
首先,根据索引获取图像路径和标签路径
然后,使用OpenCV加载图像和标签
接下来,调用self.transform方法进行数据变换
最后,返回一个包含图像、对应标签以及图像文件名的字典
'''
image_path, label_path = self.data_index[idx]
# load data
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
image, label = self.transform(img=image, gt=label) # transform方法:用于对图像和标签进行预处理
img_name = os.path.basename(image_path)
file_name = os.path.splitext(img_name)[0] + ".png"
return dict(images=image, labels=label, file_names=file_name)
def transform(self, img, gt):
# 将标签转换为浮点型数组,并将其归一化到 [0, 1] 的范围内
gt = np.array(gt, dtype=np.float32)
if len(gt.shape) == 3:
gt = gt[:, :, 0]
gt /= 255.
# 将图像转换为浮点型数组,并减去均值 self.mean_bgr
img = np.array(img, dtype=np.float32)
img -= self.mean_bgr
i_h, i_w, _ = img.shape # 获取图像的高度、宽度和通道数
# 根据设定的裁剪大小 crop_size 对图像进行裁剪或缩放
crop_size = self.img_height if self.img_height == self.img_width else None
# 对于裁剪过程,它会在图像中随机选择一个位置来裁剪
if i_w > crop_size and i_h > crop_size:
i = random.randint(0, i_h - crop_size)
j = random.randint(0, i_w - crop_size)
img = img[i:i + crop_size, j:j + crop_size]
gt = gt[i:i + crop_size, j:j + crop_size]
else: # 如果图像的尺寸小于 crop_size,则会使用双线性插值进行缩放
# New addidings
img = cv2.resize(img, dsize=(crop_size, crop_size))
gt = cv2.resize(gt, dsize=(crop_size, crop_size))
# 对标签gt进行一些额外的处理,然后将图像img和标签gt转换为PyTorch的张量形式
gt[gt > 0.1] += 0.2
gt = np.clip(gt, 0., 1.)
img = img.transpose((2, 0, 1))
img = torch.from_numpy(img.copy()).float()
gt = torch.from_numpy(np.array([gt])).float()
return img, gt
在此处就定义完成了一个数据集类 BipedDataset
如何使用自定义的 BipedDataset 类来对数据进行加载呢?下面以加载验证集数据为例来进行说明
首先,对这个类进行实例化得到实例化后的数据集对象 dataset_val
dataset_val = BipedDataset(args.input_dir,
img_width =args.img_width,
img_height =args.img_height,
mean_bgr =args.mean_pixel_values,
train_mode ='test',
arg =args
)
其次,将该对象传入DataLoader中创建验证集数据加载器 dataloader_val
dataloader_val = DataLoader(dataset_val,
batch_size=1,
shuffle=False,
num_workers=args.workers)
然后,将数据集加载器 dataloader_val 作为参数传入进行验证过程的函数 validate_one_epoch 中
val_precision,val_recall,val_IoU = validate_one_epoch(epoch,
dataloader_val,
model,
device,
img_test_dir,
arg=args)
def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None):
precision = 0.0
recall = 0.0
IoU = 0.0
model.eval()
with torch.no_grad():
for _, sample_batched in enumerate(dataloader):
images = sample_batched['images'].to(device)
labels = sample_batched['labels'].to(device)
file_names = sample_batched['file_names']
preds = model(images)
labels = normalize_image(labels)
preds = normalize_image(preds)
precision += calculate_precision(preds, labels)
recall += calculate_recall(preds, labels)
IoU += calculate_iou(preds, labels)
save_image_batch_to_disk(preds, output_dir, file_names,arg=arg)
precision = precision / len(dataloader)
recall = recall / len(dataloader)
IoU = IoU / len(dataloader)
print(time.ctime(), '[Val_Epoch]: {0} Precision:{1} Recall:{2} IoU:{3} '.format(epoch, precision, recall, IoU))
print(f"第{epoch}次迭代的验证精确度为{precision},验证召回率为{recall},验证交并比为{IoU}")
return precision, recall, IoU
最后,我们可以看到将 dataloader_val验证集数据加载器 传入 函数validate_one_epoch 中,通过遍历 dataloader 中的数据,可以通过 自定义类BipedDataset 返回的包含三个元素的字典来获取图像数据、对应的标签数据以及图像文件名,如下图所示
images = sample_batched['images'].to(device)
labels = sample_batched['labels'].to(device)
file_names = sample_batched['file_names']
综上所述, 就是关于如何基于PyTorch深度学习框架自定义数据集来获取数据的详细步骤了,如果你觉得有用,麻烦点赞关注一下哈,谢谢!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!