
本文针对 PyTorch 中使用自定义 Sampler 时,DataLoader 只能迭代一个 epoch 的问题进行了分析和解决。通过修改 Sampler 的 `__next__` 方法,在抛出 `StopIteration` 异常时重置索引,使得 DataLoader 可以在多个 epoch 中正常迭代。文章提供了一个完整的代码示例,演示了如何实现一个可以根据不同 batch size 采样数据的自定义 Sampler,并确保其在训练循环中正常工作。
在使用 PyTorch 进行深度学习模型训练时,DataLoader 是一个非常重要的工具,它负责数据的加载和预处理。DataLoader 可以与 Sampler 结合使用,以控制数据的采样方式。然而,当使用自定义的 Sampler 时,可能会遇到 DataLoader 只能迭代一个 epoch 的问题。这通常是由于 Sampler 在一个 epoch 结束后没有正确地重置其内部状态导致的。
当 DataLoader 迭代 Sampler 时,它会不断调用 Sampler 的 __next__ 方法来获取下一个 batch 的索引。当 Sampler 完成一次完整的数据集遍历后,它应该抛出一个 StopIteration 异常来通知 DataLoader 停止迭代。然而,如果 Sampler 在抛出 StopIteration 异常后没有重置其内部索引,那么在下一个 epoch 开始时,Sampler 仍然处于完成状态,导致 DataLoader 无法继续迭代。
解决这个问题的方法是在 Sampler 的 __next__ 方法中,当检测到数据集已经遍历完毕并准备抛出 StopIteration 异常时,同时重置 Sampler 的内部索引。
下面是一个示例,展示了如何修改一个自定义的 Sampler 来解决这个问题。假设我们有一个 VariableBatchSampler,它可以根据预定义的 batch_sizes 列表来生成不同大小的 batch。
import torch
import numpy as np
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, TensorDataset
class VariableBatchSampler(Sampler):
def __init__(self, dataset_len: int, batch_sizes: list):
self.dataset_len = dataset_len
self.batch_sizes = batch_sizes
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
def __iter__(self):
return self
def __next__(self):
if self.start_idx >= self.dataset_len:
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
raise StopIteration
batch_indices = list(range(self.start_idx, self.end_idx))
self.start_idx = self.end_idx
self.batch_idx += 1
try:
self.end_idx += self.batch_sizes[self.batch_idx]
except IndexError:
self.end_idx = self.dataset_len
return batch_indices在这个 VariableBatchSampler 中,我们在 __next__ 方法中添加了以下代码:
if self.start_idx >= self.dataset_len:
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
raise StopIteration这段代码在 self.start_idx 大于或等于 self.dataset_len 时执行,这意味着我们已经遍历了整个数据集。此时,我们将 self.batch_idx、self.start_idx 和 self.end_idx 重置为初始值,以便在下一个 epoch 中重新开始迭代。
下面是一个完整的示例,展示了如何使用修改后的 VariableBatchSampler 和 DataLoader 进行多 epoch 训练。
import torch
import numpy as np
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, TensorDataset
class VariableBatchSampler(Sampler):
def __init__(self, dataset_len: int, batch_sizes: list):
self.dataset_len = dataset_len
self.batch_sizes = batch_sizes
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
def __iter__(self):
return self
def __next__(self):
if self.start_idx >= self.dataset_len:
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
raise StopIteration
batch_indices = list(range(self.start_idx, self.end_idx))
self.start_idx = self.end_idx
self.batch_idx += 1
try:
self.end_idx += self.batch_sizes[self.batch_idx]
except IndexError:
self.end_idx = self.dataset_len
return batch_indices
x_train = torch.randn(23)
y_train = torch.randint(0, 2, (23,))
batch_sizes = [4, 10, 7, 2]
train_dataset = TensorDataset(x_train, y_train)
sampler = VariableBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
dataloader_train = DataLoader(train_dataset, sampler=sampler)
max_epoch = 4
for epoch in np.arange(1, max_epoch):
print("Epoch: ", epoch)
for x_batch, y_batch in dataloader_train:
print(x_batch.shape)这段代码会输出每个 epoch 中每个 batch 的形状,证明 DataLoader 可以在多个 epoch 中正常迭代。
当使用自定义的 Sampler 时,确保在 __next__ 方法中正确地重置内部索引,以便 DataLoader 可以在多个 epoch 中正常迭代。 否则,DataLoader 在第一个epoch后会停止工作。 通过本文提供的示例,您可以更好地理解如何实现一个自定义的 Sampler,并解决 DataLoader 迭代问题。
以上就是PyTorch DataLoader 自定义 Sampler 迭代问题解决的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号