pytorch之torch.utils.data学习
1、概述
PyTorch 数据加载利用的核心是torch.utils.data.DataLoader类 。它表示在数据集上 Python 可迭代,支持
map-style and iterable-style datasets(地图样式和可迭代样式数据集),
customizing data loading order(自定义数据加载顺序),
automatic batching(自动批处理),
single- and multi-process data loading(单进程和多进程数据加载),
automatic memory pinning(自动内存固定)。
这些选项由 a 的构造函数参数配置 DataLoader,其签名为:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
以下部分详细描述了这些选项的效果和用法。
2、Dataset Types(数据类型)
DataLoader 构造函数最重要的参数是dataset,它指示要从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:
map-style datasets,(地图式数据集),
iterable-style datasets.(可迭代式数据集。)
map-style datasets 地图式数据集
映射式数据集是一种实现__getitem__()和 len()协议的数据集,并表示从(可能是非整数)索引/键到数据样本的映射。
例如,这样的数据集,当使用 访问时dataset[idx],可以从磁盘上的文件夹中读取第 idx-th个图像及其相应的标签。
请参阅Dataset了解更多详情。
iterable-style datasets可迭代式数据集
CLASStorch.utils.data.IterableDataset(*args, **kwds)
可迭代样式数据集是实现协议__iter__()的IterableDataset 子类的实例,并表示数据样本上的一个迭代迭代。这种类型的数据集特别适合随机读取成本昂贵甚至不可能的情况,以及批量大小取决于获取的数据的情况。
例如,这样的数据集在调用时iter(dataset)可以返回从数据库、远程服务器甚至实时生成的日志读取的数据流。
所有表示数据样本可迭代对象的数据集都应该继承它。当数据来自流时,这种形式的数据集特别有用。
请参阅IterableDataset了解更多详情。
Example 1: splitting workload across all workers in iter():
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)
# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])
# Mult-process loading with two worker processes
# Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
# With even more workers
print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
Example 2: splitting workload across all workers using worker_init_fn:
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
return iter(range(self.start, self.end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)
# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
# Directly doing multi-process loading yields duplicate data
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
# Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset # the dataset copy in this worker process
overall_start = dataset.start
overall_end = dataset.end
# configure the dataset to only process the split workload
per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
worker_id = worker_info.id
dataset.start = overall_start + worker_id * per_worker
dataset.end = min(dataset.start + per_worker, overall_end)
# Mult-process loading with the custom `worker_init_fn`
# Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
# With even more workers
print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
笔记
当使用多进程数据加载的IterableDataset时。在每个工作进程上复制相同的数据集对象,因此必须对副本进行不同的配置,以避免重复数据。请参阅IterableDataset文档了解如何实现这一点。
3、数据加载顺序和Sampler
对于可迭代样式的数据集,数据加载顺序完全由用户定义的可迭代控制。这允许更容易地实现块读取和动态批量大小(例如,通过每次生成批量样本)。
本节的其余部分涉及地图样式数据集的情况 。torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的序列。它们表示数据集索引上的可迭代对象。例如,在随机梯度下降 (SGD) 的常见情况下,a Sampler可以随机排列一系列索引并一次生成每个索引,或者为小批量 SGD 生成少量索引。
A sequential or shuffled sampler将根据DataLoader 的shuffle参数自动构。或者,用户可以使用sampler参数来指定一个自定义Sampler对象,该对象每次都会生成下一个要获取的索引/键。
自定义Sampler一次生成批次索引列表,可以作为batch_sampler参数传递。自动批处理也可以通过batch_size和drop_last 参数启用。有关这方面的更多详细信息,请参阅 下一节。
注意
sampler和batch_sampler都不兼容可迭代风格的数据集,因为这样的数据集没有概念。
4、加载批量和非批量数据
DataLoader支持自动将单独获取的数据样本通过参数 batch_size、drop_last、batch_sampler和 collate_fn(具有默认功能)整理为批次
自动批处理(默认)
这是最常见的情况,对应于获取小批量数据并将它们整理成批量样本,即包含一个维度为批量维度(通常是第一个维度)的张量。
当batch_size(默认1)不是None时,数据加载器将生成批量样本而不是单个样本。batch_size和 drop_last参数用于指定数据加载器如何获取批量的数据集键。对于地图样式数据集,用户也可以指定batch_sampler,这一次会生成一个键列表。
在使用来自sampler的索引获取样本列表之后,使用作为collate_fn参数传递的函数将样本列表整理成批。
在这种情况下,从地图样式的数据集加载大致相当于:
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
从可迭代风格的数据集加载大致相当于:
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
禁用自动批处理
在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者简单地加载单个样本。例如,直接加载批处理数据(例如,从数据库中批量读取或读取连续的内存块)可能更便宜,或者批处理大小依赖于数据,或者程序被设计为处理单个样本。在这些场景下,最好不要使用自动批处理(其中collate_fn用于整理样本),而是让数据加载器直接返回数据集object的每个成员。
当batch_size和batch_sampler都为None时(batch_sampler的默认值已经为None),自动批处理被禁用。从数据集中获得的每个样本都使用作为collate_fn参数传递的函数进行处理。
当自动批处理被禁用时,默认的collate_fn只是将NumPy数组转换为PyTorch张量,并保持其他所有内容不变。
在这种情况下,从地图样式的数据集加载大致相当于:
for index in sampler:
yield collate_fn(dataset[index])
从可迭代风格的数据集加载大致相当于:
for data in iter(dataset):
yield collate_fn(data)
Working with collate_fn
例如,如果每个数据样本由一个3通道图像和一个整数类标签组成,也就是说,数据集的每个元素返回一个元组(image, class_index),默认collate_fn将这样的元组列表整理成一个批处理图像张量和一个批处理类标签张量的元组。特别是,默认的collate_fn具有以下属性:
5、Single- and Multi-process Data Loading
默认情况下,DataLoader使用单进程数据加载。
torch.utils.data.get_worker_info()返回工作进程中的各种有用信息(包括工作进程id,数据集副本,初始种子等),并在主进程中返回None。用户可以在数据集代码和/或worker_init_fn中使用这个函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。例如,这在对数据集进行分片时特别有用。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!