【霹雳吧啦】手把手带你入门语义分割の番外4:FCN 源码讲解(PyTorch)—— 关于 my_dataset.py 代码讲解

2023-12-26 11:34:14

目录

前言

Preparation

一、VOCSegmentation 类

1、__init__ 函数

2、__getitem__ 函数

3、collate_fn 函数

二、cat_list 函数


前言

文章性质:学习笔记 📖

视频教程:FCN源码解析(Pytorch)- 3 自定义读取数据集

主要内容:根据 视频教程 中提供的 FCN 源代码(PyTorch),对 my_dataset.py 文件进行具体讲解。

Preparation

FCN 源码:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_segmentation/fcn

【补充】除此之外,需要大家去 PASCAL VOC 官网下载数据集,下载后得到 VOCdevkit 文件夹:

?

一、VOCSegmentation 类

?VOCSegmentation 类继承自 data.Dataset 父类。

1、__init__ 函数

这个类使用 __init__ 初始化数据集对象,传入?VOCdevkit 根目录路径 voc_root、数据集年份 year、数据预处理操作 transforms、训练文件列表的文件名 txt_name 等参数,year 默认为 2012 ,txt_name 默认为 train.txt 。

关于 txt_name 取值的具体说明:

  • 当 txt_name 取值为 train.txt 时,读取的是 训练集 数据,可根据 train.txt 中保存的图片名称去寻找对应的图片数据和标签数据。
  • 当 txt_name 取值为 val.txt 时,读取的是 验证集 数据,可根据 val.txt 中保存的图片名称去寻找对应的图片数据和标签数据。

?

【代码解析】对 VOCSegmentation 类代码的具体解析(结合上图):

  1. ?我们这里的 VOC 数据集只支持 2007 和 2012 的,如果年份不是 2007 和 2012 则引发 AssertionError 异常。
  2. ?通过拼接 voc_root 、VOCdevkit 和 VOC2012 得到 root 路径,然后判断路径是否存在。
  3. ?通过拼接?root 路径和固定的目录名称得到 图片目录 image_dir 分割标签目录 mask_dir
  4. ?通过拼接 root 路径、固定的目录名称和 txt_name 得到 train.txt 文件的目录或者 val.txt 文件的目录,然后判断路径是和否存在。
  5. ?遍历 txt_path 路径指定的 txt 文件,读取所有的非空行,并通过 strip 方法将行首行尾的空格去掉。
  6. ?构建出 file_names ,内含训练集或者验证集的图片名称。
  7. ?构建出图片文件的路径 os.path.join(image_dir,?x + ".jpg")?和标签文件的路径?os.path.join(mask_dir,?x + ".png") 。

?

2、__getitem__ 函数

这个类使用 __getitem__ 方法实现根据索引获取数据的功能:

?

【代码解析】根据索引 index 打开对应的图片文件(转?RGB 格式)和标签文件,并进行一系列预处理操作,然后返回处理后的图片和标签。

3、collate_fn 函数

这个类使用 collate_fn 方法在加载数据时,将一个批次的数据进行整理和处理:

?

【代码解析】对 collate_fn 函数代码的具体解析(结合上图):

  1. ?将图片?images 和标签 targets 分别打包成两个列表。
  2. ?使用 cat_list 函数将图片列表 images 进行拼接,得到一个批次的图像数据,若图像尺寸不足最大尺寸,则用 0 进行填充。
  3. ?使用 cat_list 函数将标签列表 targets 进行拼接,得到一个批次的标签数据,若标签尺寸不足最大尺寸,则用 255 进行填充。

【补充】针对上面的第一步,我们可以通过断点调试的方式查看打包前后的区别:

?

?

二、cat_list 函数

可使用?cat_list 方法将一个批次的图像数据进行拼接,相关代码截图如下:

?

Step1 计算这个 batch 图像数据中的通道数 channel 、高度 h 、宽度 w 的最大值:

  1. ?通过遍历 images 图像列表获取图像数据。
  2. ?使用 zip(*[img.shape for img in images]) 将所有图像的对应维度取出并打包成元组。
  3. ?使用 max 函数求得各维度的最大值,得到 max_size 元组。

Step2 不同大小的 images 图片 打包成 一个 Tensor ,然后输入到网络当中进行运算

  1. ?构建批次图像数据的形状,即 batch_shape ,经断点调试可知为 [4, 3, 480, 480],这四个维度分别是图像数量、通道数、高度和宽度。
  2. ?创建与 batch_shape 形状相同的新张量 batched_imgs ,并使用 fill_ 方法将其元素填充为指定的 fill_value 。
  3. ?使用 pad_img[..., :img.shape[-2], :img.shape[-1].copy_(img)] 将原始图像 images 的内容复制到对应位置的批次图像 batched_imgs 上。

【补充】我们可以通过断点调试的方式查看 cat_list 函数的返回结果:

?

文章来源:https://blog.csdn.net/nanzhou520/article/details/135107213
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。