
在计算机视觉领域,尤其是在视频理解任务中,利用预训练模型进行微调是一种高效且常用的策略。pytorch video库中的i3d(inflated 3d convnet)模型因其在kinetics等大型视频数据集上的出色表现而广受欢迎。然而,当我们需要将这些模型应用于具有不同类别数量的自定义数据集时,核心挑战在于如何正确地修改模型的输出层,使其与新任务的类别数匹配。本教程将详细阐述这一过程,并解决在修改模型时可能遇到的attributeerror问题。
首先,我们需要从PyTorch Hub加载预训练的I3D模型。facebookresearch/pytorchvideo提供了方便的接口来加载这些模型。
import torch
import torch.nn as nn
from pytorchvideo.models import i3d_r50
# 加载在Kinetics 400上预训练的I3D模型
model = torch.hub.load("facebookresearch/pytorchvideo", i3d_r50, pretrained=True)
print("原始模型结构示例:")
print(model)通过print(model),我们可以看到模型的详细结构。对于I3D模型,其分类头通常位于模型深层的一个特定模块中。
在进行微调时,关键是找到并修改模型的最终分类层。对于PyTorch Video的I3D模型,其分类头通常是一个ResNetBasicHead模块,其中包含一个名为proj的Linear层,负责最终的分类输出。
通过打印模型结构,我们可以观察到类似以下的部分:
(blocks): Sequential(
...
(6): ResNetBasicHead(
(pool): AvgPool3d(...)
(dropout): Dropout(...)
(proj): Linear(in_features=2048, out_features=400, bias=True) # 原始分类层
(output_pool): AdaptiveAvgPool3d(...)
)
)从上述结构可以看出,ResNetBasicHead是blocks模块的第7个子模块(索引为6),而proj层是ResNetBasicHead内部的分类层。
为什么直接访问 model.ResNetBasicHead 会出错?
用户在尝试 model.ResNetBasicHead.proj = ... 时会遇到 AttributeError: 'Net' object has no attribute 'ResNetBasicHead'。这是因为 ResNetBasicHead 并不是 model 对象的一个直接属性。它被封装在 model 的 blocks 属性中,而 blocks 又是一个 Sequential 容器,其子模块通过索引或名称来访问。因此,正确的访问路径应该是 model.blocks[6] 来获取 ResNetBasicHead 模块。
现在我们已经了解了如何定位分类层,接下来介绍两种修改模型输出层的方法。假设我们的自定义数据集有 num_classes = 4 个输出类别。
这是最常见且直接的微调方法。我们获取原始 proj 层的输入特征维度,然后创建一个新的 Linear 层来替换它,新层的输出特征维度设置为自定义的类别数。
num_classes = 4
# 正确访问并替换分类层
# 获取原始proj层的输入特征维度
in_features = model.blocks[6].proj.in_features
# 创建一个新的Linear层
new_proj_layer = nn.Linear(in_features, num_classes)
# 替换原始的proj层
model.blocks[6].proj = new_proj_layer
print("\n替换分类层后的模型结构示例:")
print(model.blocks[6])替换后的 ResNetBasicHead 将会是:
(6): ResNetBasicHead( (pool): AvgPool3d(kernel_size=(4, 7, 7), stride=(1, 1, 1), padding=(0, 0, 0)) (dropout): Dropout(p=0.5, inplace=False) (proj): Linear(in_features=2048, out_features=4, bias=True) # 输出类别已修改为4 (output_pool): AdaptiveAvgPool3d(output_size=1) )
这种方法确保了模型输出的维度与自定义数据集的类别数完全匹配,是进行分类任务微调的标准做法。
除了替换原有层,我们也可以选择在模型现有结构的基础上追加新的分类层。这在某些特定场景下可能有用,例如当你想保留原有预训练的分类头作为特征提取的一部分,并在其后添加一个新的分类器。
A. 在 blocks 模块末尾追加新的线性层
这种方法会在模型的 blocks 模块的末尾添加一个全新的线性层,它将接收 ResNetBasicHead 模块(在 proj 层之前的特征)的输出作为输入。
num_classes = 4
# 获取ResNetBasicHead的输入特征维度(即其proj层的输入特征维度)
# 这里假设新的线性层直接接收ResNetBasicHead的中间特征输出
in_features_for_new_layer = model.blocks[6].proj.in_features
new_linear_layer = nn.Linear(in_features_for_new_layer, num_classes)
# 将新的线性层追加到model.blocks模块的末尾
model.blocks.add_module("custom_linear_classifier", new_linear_layer)
print("\n追加新的分类层到model.blocks后的模型结构示例:")
print(model.blocks)此时,模型结构会变为:
(blocks): Sequential(
...
(6): ResNetBasicHead(
(pool): AvgPool3d(...)
(dropout): Dropout(...)
(proj): Linear(in_features=2048, out_features=400, bias=True) # 原始分类层依然存在
(output_pool): AdaptiveAvgPool3d(...)
)
(custom_linear_classifier): Linear(in_features=2048, out_features=4, bias=True) # 新增的分类层
)B. 在 ResNetBasicHead 模块内部追加新的线性层
此方法在 ResNetBasicHead 模块内部添加一个线性层。这意味着 ResNetBasicHead 将包含两个线性层 (proj 和新添加的 linear)。这通常不用于简单的类别数修改,但可能用于更复杂的架构设计。
num_classes = 4
# 获取原始proj层的输入特征维度
in_features_for_new_layer_in_head = model.blocks[6].proj.in_features
new_linear_layer_in_head = nn.Linear(in_features_for_new_layer_in_head, num_classes)
# 将新的线性层追加到ResNetBasicHead模块内部
model.blocks[6].add_module("custom_linear_in_head", new_linear_layer_in_head)
print("\n追加新的分类层到ResNetBasicHead内部后的模型结构示例:")
print(model.blocks[6])此时,ResNetBasicHead 结构会变为:
(6): ResNetBasicHead( (pool): AvgPool3d(kernel_size=(4, 7, 7), stride=(1, 1, 1), padding=(0, 0, 0)) (dropout): Dropout(p=0.5, inplace=False) (proj): Linear(in_features=2048, out_features=400, bias=True) # 原始分类层依然存在 (output_pool): AdaptiveAvgPool3d(output_size=1) (custom_linear_in_head): Linear(in_features=2048, out_features=4, bias=True) # 新增的层 )
请注意,在方法二的两种追加方式中,原始的 proj 层仍然存在。这意味着在模型前向传播时,您需要明确如何使用这些输出。对于大多数简单的分类任务,直接替换 proj 层(方法一)是更清晰和推荐的做法。
正确地修改预训练模型的输出层是进行迁移学习和微调的关键一步。通过本教程,我们学习了如何加载PyTorch I3D模型,分析其结构,并以两种主要方式(替换或追加)修改其分类头,以适应自定义数据集的类别数量。在大多数情况下,直接替换 proj 层(方法一)是实现分类任务微调最直接有效的方法。理解模型结构和PyTorch的模块访问机制,是成功进行模型定制的基础。
以上就是PyTorch I3D模型在自定义数据集上的微调指南的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号