PyTorch DataLoader 自定义 Sampler 迭代问题解决

DDD
发布: 2025-10-18 08:01:20
原创
702人浏览过

pytorch dataloader 自定义 sampler 迭代问题解决

本文针对 PyTorch 中使用自定义 Sampler 时,DataLoader 只能迭代一个 epoch 的问题进行了分析和解决。通过修改 Sampler 的 `__next__` 方法,在抛出 `StopIteration` 异常时重置索引,使得 DataLoader 可以在多个 epoch 中正常迭代。文章提供了一个完整的代码示例,演示了如何实现一个可以根据不同 batch size 采样数据的自定义 Sampler,并确保其在训练循环中正常工作。

在使用 PyTorch 进行深度学习模型训练时,DataLoader 是一个非常重要的工具,它负责数据的加载和预处理。DataLoader 可以与 Sampler 结合使用,以控制数据的采样方式。然而,当使用自定义的 Sampler 时,可能会遇到 DataLoader 只能迭代一个 epoch 的问题。这通常是由于 Sampler 在一个 epoch 结束后没有正确地重置其内部状态导致的。

问题分析

当 DataLoader 迭代 Sampler 时,它会不断调用 Sampler 的 __next__ 方法来获取下一个 batch 的索引。当 Sampler 完成一次完整的数据集遍历后,它应该抛出一个 StopIteration 异常来通知 DataLoader 停止迭代。然而,如果 Sampler 在抛出 StopIteration 异常后没有重置其内部索引,那么在下一个 epoch 开始时,Sampler 仍然处于完成状态,导致 DataLoader 无法继续迭代。

解决方案

解决这个问题的方法是在 Sampler 的 __next__ 方法中,当检测到数据集已经遍历完毕并准备抛出 StopIteration 异常时,同时重置 Sampler 的内部索引。

下面是一个示例,展示了如何修改一个自定义的 Sampler 来解决这个问题。假设我们有一个 VariableBatchSampler,它可以根据预定义的 batch_sizes 列表来生成不同大小的 batch。

import torch
import numpy as np
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, TensorDataset



class VariableBatchSampler(Sampler):
    def __init__(self, dataset_len: int, batch_sizes: list):
        self.dataset_len = dataset_len
        self.batch_sizes = batch_sizes
        self.batch_idx = 0
        self.start_idx = 0
        self.end_idx = self.batch_sizes[self.batch_idx]

    def __iter__(self):
        return self

    def __next__(self):
        if self.start_idx >= self.dataset_len:
            self.batch_idx = 0
            self.start_idx = 0
            self.end_idx = self.batch_sizes[self.batch_idx]
            raise StopIteration

        batch_indices = list(range(self.start_idx, self.end_idx))
        self.start_idx = self.end_idx
        self.batch_idx += 1

        try:
            self.end_idx += self.batch_sizes[self.batch_idx]
        except IndexError:
            self.end_idx = self.dataset_len

        return batch_indices
登录后复制

在这个 VariableBatchSampler 中,我们在 __next__ 方法中添加了以下代码:

AI建筑知识问答
AI建筑知识问答

用人工智能ChatGPT帮你解答所有建筑问题

AI建筑知识问答 22
查看详情 AI建筑知识问答
if self.start_idx >= self.dataset_len:
    self.batch_idx = 0
    self.start_idx = 0
    self.end_idx = self.batch_sizes[self.batch_idx]
    raise StopIteration
登录后复制

这段代码在 self.start_idx 大于或等于 self.dataset_len 时执行,这意味着我们已经遍历了整个数据集。此时,我们将 self.batch_idx、self.start_idx 和 self.end_idx 重置为初始值,以便在下一个 epoch 中重新开始迭代。

完整示例

下面是一个完整的示例,展示了如何使用修改后的 VariableBatchSampler 和 DataLoader 进行多 epoch 训练。

import torch
import numpy as np
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, TensorDataset



class VariableBatchSampler(Sampler):
    def __init__(self, dataset_len: int, batch_sizes: list):
        self.dataset_len = dataset_len
        self.batch_sizes = batch_sizes
        self.batch_idx = 0
        self.start_idx = 0
        self.end_idx = self.batch_sizes[self.batch_idx]

    def __iter__(self):
        return self

    def __next__(self):
        if self.start_idx >= self.dataset_len:
            self.batch_idx = 0
            self.start_idx = 0
            self.end_idx = self.batch_sizes[self.batch_idx]
            raise StopIteration

        batch_indices = list(range(self.start_idx, self.end_idx))
        self.start_idx = self.end_idx
        self.batch_idx += 1

        try:
            self.end_idx += self.batch_sizes[self.batch_idx]
        except IndexError:
            self.end_idx = self.dataset_len

        return batch_indices


x_train = torch.randn(23)
y_train = torch.randint(0, 2, (23,))

batch_sizes = [4, 10, 7, 2]
train_dataset = TensorDataset(x_train, y_train)
sampler = VariableBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
dataloader_train = DataLoader(train_dataset, sampler=sampler)

max_epoch = 4
for epoch in np.arange(1, max_epoch):
    print("Epoch: ", epoch)
    for x_batch, y_batch in dataloader_train:
         print(x_batch.shape)
登录后复制

这段代码会输出每个 epoch 中每个 batch 的形状,证明 DataLoader 可以在多个 epoch 中正常迭代。

总结

当使用自定义的 Sampler 时,确保在 __next__ 方法中正确地重置内部索引,以便 DataLoader 可以在多个 epoch 中正常迭代。 否则,DataLoader 在第一个epoch后会停止工作。 通过本文提供的示例,您可以更好地理解如何实现一个自定义的 Sampler,并解决 DataLoader 迭代问题。

以上就是PyTorch DataLoader 自定义 Sampler 迭代问题解决的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号