本文复现程明明、颜水成团队的MLP相关论文,提出引入h、w、c三维信息编码机制及加权融合方式的模型。该模型无需空域卷积、注意力及额外da尺度训练数据,性能与CNN、ViT相当。文中展示了模型组网、定义、结构可视化等内容,还进行了Cifar10验证性能测试,指出类MLP方法有较大改进空间。
☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

Hi guy,我们又见面了,这次来复现一篇 MLP 相关的论文
本文是程明明、颜水成团队在MLP上新的探索,引入h、w、c三维信息编码机制,提出加权融合方式

性能如下,具有和CNN、ViT模型相当的竞争力

import paddleimport paddle.nn as nnimport paddle.nn.functional as F trunc_normal_ = nn.initializer.TruncatedNormal(std=.02) zeros_ = nn.initializer.Constant(value=0.) ones_ = nn.initializer.Constant(value=1.)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations def convert_to_list(value, n, name, dtype=np.int):
def drop_path(x, drop_prob = 0., training = False):
if drop_prob == 0. or not training: return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = paddle.to_tensor(keep_prob) + paddle.rand(shape)
random_tensor = paddle.floor(random_tensor)
output = x.divide(keep_prob) * random_tensor return outputclass DropPath(nn.Layer):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob def forward(self, x):
return drop_path(x, self.drop_prob, self.training)class Identity(nn.Layer):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input
class Mlp(nn.Layer):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop) def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x) return xclass WeightedPermuteMLP(nn.Layer):
def __init__(self, dim, segment_dim=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.segment_dim = segment_dim
self.mlp_c = nn.Linear(dim, dim, bias_attr=qkv_bias)
self.mlp_h = nn.Linear(dim, dim, bias_attr=qkv_bias)
self.mlp_w = nn.Linear(dim, dim, bias_attr=qkv_bias)
self.reweight = Mlp(dim, dim // 4, dim *3)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop) def forward(self, x):
B, H, W, C = x.shape
S = C // self.segment_dim
h = x.reshape([B, H, W, self.segment_dim, S]).transpose([0, 3, 2, 1, 4]).reshape([B, self.segment_dim, W, H*S])
h = self.mlp_h(h).reshape([B, self.segment_dim, W, H, S]).transpose([0, 3, 2, 1, 4]).reshape([B, H, W, C])
w = x.reshape([B, H, W, self.segment_dim, S]).transpose([0, 1, 3, 2, 4]).reshape([B, H, self.segment_dim, W*S])
w = self.mlp_w(w).reshape([B, H, self.segment_dim, W, S]).transpose([0, 1, 3, 2, 4]).reshape([B, H, W, C])
c = self.mlp_c(x)
a = (h + w + c).transpose([0, 3, 1, 2]).flatten(2).mean(2)
a = self.reweight(a).reshape([B, C, 3]).transpose([2, 0, 1])
a = F.softmax(a, axis=0).unsqueeze(2).unsqueeze(2)
x = h * a[0] + w * a[1] + c * a[2]
x = self.proj(x)
x = self.proj_drop(x) return xclass PermutatorBlock(nn.Layer):
def __init__(self, dim, segment_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn = WeightedPermuteMLP):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = mlp_fn(dim, segment_dim=segment_dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
self.skip_lam = skip_lam def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam
x = x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam return xclass PatchEmbed(nn.Layer):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x):
x = self.proj(x) # B, C, H, W
return xclass Downsample(nn.Layer):
""" Image to Patch Embedding
"""
def __init__(self, in_embed_dim, out_embed_dim, patch_size):
super().__init__()
self.proj = nn.Conv2D(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x):
x = x.transpose([0, 3, 1, 2])
x = self.proj(x) # B, C, H, W
x = x.transpose([0, 2, 3, 1]) return xdef basic_blocks(dim, index, layers, segment_dim, mlp_ratio=3., qkv_bias=False, qk_scale=None, \
attn_drop=0, drop_path_rate=0., skip_lam=1.0, mlp_fn = WeightedPermuteMLP, **kwargs):
blocks = [] for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(PermutatorBlock(dim, segment_dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\
attn_drop=attn_drop, drop_path=block_dpr, skip_lam=skip_lam, mlp_fn = mlp_fn))
blocks = nn.Sequential(*blocks) return blocksclass VisionPermutator(nn.Layer):
""" Vision Permutator
"""
def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dims=None, transitions=None, segment_dim=None, mlp_ratios=None, skip_lam=1.0,
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=nn.LayerNorm,mlp_fn = WeightedPermuteMLP):
super().__init__()
self.num_classes = num_classes
self.patch_embed = PatchEmbed(img_size = img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
network = [] for i in range(len(layers)):
stage = basic_blocks(embed_dims[i], i, layers, segment_dim[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
qk_scale=qk_scale, attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate, norm_layer=norm_layer, skip_lam=skip_lam,
mlp_fn = mlp_fn)
network.append(stage) if i >= len(layers) - 1: break
if transitions[i] or embed_dims[i] != embed_dims[i+1]:
patch_size = 2 if transitions[i] else 1
network.append(Downsample(embed_dims[i], embed_dims[i+1], patch_size))
self.network = nn.LayerList(network)
self.norm = norm_layer(embed_dims[-1]) # Classifier head
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else Identity()
self.apply(self._init_weights) def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias) elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight) def forward_embeddings(self, x):
x = self.patch_embed(x) # B,C,H,W-> B,H,W,C
x = x.transpose([0, 2, 3, 1]) return x def forward_tokens(self,x):
for idx, block in enumerate(self.network):
x = block(x)
B, H, W, C = x.shape
x = x.reshape([B, -1, C]) return x def forward(self, x):
x = self.forward_embeddings(x) # B, H, W, C -> B, N, C
x = self.forward_tokens(x)
x = self.norm(x) return self.head(x.mean(1))def vip_s14(**kwargs):
layers = [4, 3, 8, 3]
transitions = [False, False, False, False]
segment_dim = [16, 16, 16, 16]
mlp_ratios = [3, 3, 3, 3]
embed_dims = [384, 384, 384, 384]
model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=14, transitions=transitions,
segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs) return modeldef vip_s7(**kwargs):
layers = [4, 3, 8, 3]
transitions = [True, False, False, False]
segment_dim = [32, 16, 16, 16]
mlp_ratios = [3, 3, 3, 3]
embed_dims = [192, 384, 384, 384]
model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs) return modeldef vip_m7(**kwargs):
layers = [4, 3, 14, 3]
transitions = [False, True, False, False]
segment_dim = [32, 32, 16, 16]
mlp_ratios = [3, 3, 3, 3]
embed_dims = [256, 256, 512, 512]
model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs) return modeldef vip_l7(**kwargs):
layers = [8, 8, 16, 4]
transitions = [True, False, False, False]
segment_dim = [32, 16, 16, 16]
mlp_ratios = [3, 3, 3, 3]
embed_dims = [256, 512, 512, 512]
model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs) return modelpaddle.Model(vip_s7()).summary((1,3,224,224))
---------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
=================================================================================
Conv2D-1 [[1, 3, 224, 224]] [1, 192, 32, 32] 28,416
PatchEmbed-1 [[1, 3, 224, 224]] [1, 192, 32, 32] 0
LayerNorm-1 [[1, 32, 32, 192]] [1, 32, 32, 192] 384
Linear-2 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-3 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-1 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-4 [[1, 192]] [1, 48] 9,264
GELU-1 [[1, 48]] [1, 48] 0
Dropout-1 [[1, 576]] [1, 576] 0
Linear-5 [[1, 48]] [1, 576] 28,224
Mlp-1 [[1, 192]] [1, 576] 0
Linear-6 [[1, 32, 32, 192]] [1, 32, 32, 192] 37,056
Dropout-2 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
WeightedPermuteMLP-1 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
Identity-1 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
LayerNorm-2 [[1, 32, 32, 192]] [1, 32, 32, 192] 384
Linear-7 [[1, 32, 32, 192]] [1, 32, 32, 576] 111,168
GELU-2 [[1, 32, 32, 576]] [1, 32, 32, 576] 0
Dropout-3 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
Linear-8 [[1, 32, 32, 576]] [1, 32, 32, 192] 110,784
Mlp-2 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
PermutatorBlock-1 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
LayerNorm-3 [[1, 32, 32, 192]] [1, 32, 32, 192] 384
Linear-10 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-11 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-9 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-12 [[1, 192]] [1, 48] 9,264
GELU-3 [[1, 48]] [1, 48] 0
Dropout-4 [[1, 576]] [1, 576] 0
Linear-13 [[1, 48]] [1, 576] 28,224
Mlp-3 [[1, 192]] [1, 576] 0
Linear-14 [[1, 32, 32, 192]] [1, 32, 32, 192] 37,056
Dropout-5 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
WeightedPermuteMLP-2 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
Identity-2 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
LayerNorm-4 [[1, 32, 32, 192]] [1, 32, 32, 192] 384
Linear-15 [[1, 32, 32, 192]] [1, 32, 32, 576] 111,168
GELU-4 [[1, 32, 32, 576]] [1, 32, 32, 576] 0
Dropout-6 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
Linear-16 [[1, 32, 32, 576]] [1, 32, 32, 192] 110,784
Mlp-4 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
PermutatorBlock-2 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
LayerNorm-5 [[1, 32, 32, 192]] [1, 32, 32, 192] 384
Linear-18 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-19 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-17 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-20 [[1, 192]] [1, 48] 9,264
GELU-5 [[1, 48]] [1, 48] 0
Dropout-7 [[1, 576]] [1, 576] 0
Linear-21 [[1, 48]] [1, 576] 28,224
Mlp-5 [[1, 192]] [1, 576] 0
Linear-22 [[1, 32, 32, 192]] [1, 32, 32, 192] 37,056
Dropout-8 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
WeightedPermuteMLP-3 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
Identity-3 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
LayerNorm-6 [[1, 32, 32, 192]] [1, 32, 32, 192] 384
Linear-23 [[1, 32, 32, 192]] [1, 32, 32, 576] 111,168
GELU-6 [[1, 32, 32, 576]] [1, 32, 32, 576] 0
Dropout-9 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
Linear-24 [[1, 32, 32, 576]] [1, 32, 32, 192] 110,784
Mlp-6 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
PermutatorBlock-3 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
LayerNorm-7 [[1, 32, 32, 192]] [1, 32, 32, 192] 384
Linear-26 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-27 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-25 [[1, 32, 32, 192]] [1, 32, 32, 192] 36,864
Linear-28 [[1, 192]] [1, 48] 9,264
GELU-7 [[1, 48]] [1, 48] 0
Dropout-10 [[1, 576]] [1, 576] 0
Linear-29 [[1, 48]] [1, 576] 28,224
Mlp-7 [[1, 192]] [1, 576] 0
Linear-30 [[1, 32, 32, 192]] [1, 32, 32, 192] 37,056
Dropout-11 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
WeightedPermuteMLP-4 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
Identity-4 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
LayerNorm-8 [[1, 32, 32, 192]] [1, 32, 32, 192] 384
Linear-31 [[1, 32, 32, 192]] [1, 32, 32, 576] 111,168
GELU-8 [[1, 32, 32, 576]] [1, 32, 32, 576] 0
Dropout-12 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
Linear-32 [[1, 32, 32, 576]] [1, 32, 32, 192] 110,784
Mlp-8 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
PermutatorBlock-4 [[1, 32, 32, 192]] [1, 32, 32, 192] 0
Conv2D-2 [[1, 192, 32, 32]] [1, 384, 16, 16] 295,296
Downsample-1 [[1, 32, 32, 192]] [1, 16, 16, 384] 0
LayerNorm-9 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-34 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-35 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-33 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-36 [[1, 384]] [1, 96] 36,960
GELU-9 [[1, 96]] [1, 96] 0
Dropout-13 [[1, 1152]] [1, 1152] 0
Linear-37 [[1, 96]] [1, 1152] 111,744
Mlp-9 [[1, 384]] [1, 1152] 0
Linear-38 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-14 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-5 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-5 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-10 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-39 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-10 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-15 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-40 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-10 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-5 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-11 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-42 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-43 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-41 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-44 [[1, 384]] [1, 96] 36,960
GELU-11 [[1, 96]] [1, 96] 0
Dropout-16 [[1, 1152]] [1, 1152] 0
Linear-45 [[1, 96]] [1, 1152] 111,744
Mlp-11 [[1, 384]] [1, 1152] 0
Linear-46 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-17 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-6 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-6 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-12 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-47 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-12 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-18 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-48 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-12 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-6 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-13 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-50 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-51 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-49 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-52 [[1, 384]] [1, 96] 36,960
GELU-13 [[1, 96]] [1, 96] 0
Dropout-19 [[1, 1152]] [1, 1152] 0
Linear-53 [[1, 96]] [1, 1152] 111,744
Mlp-13 [[1, 384]] [1, 1152] 0
Linear-54 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-20 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-7 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-7 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-14 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-55 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-14 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-21 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-56 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-14 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-7 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-15 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-58 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-59 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-57 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-60 [[1, 384]] [1, 96] 36,960
GELU-15 [[1, 96]] [1, 96] 0
Dropout-22 [[1, 1152]] [1, 1152] 0
Linear-61 [[1, 96]] [1, 1152] 111,744
Mlp-15 [[1, 384]] [1, 1152] 0
Linear-62 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-23 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-8 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-8 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-16 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-63 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-16 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-24 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-64 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-16 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-8 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-17 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-66 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-67 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-65 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-68 [[1, 384]] [1, 96] 36,960
GELU-17 [[1, 96]] [1, 96] 0
Dropout-25 [[1, 1152]] [1, 1152] 0
Linear-69 [[1, 96]] [1, 1152] 111,744
Mlp-17 [[1, 384]] [1, 1152] 0
Linear-70 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-26 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-9 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-9 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-18 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-71 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-18 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-27 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-72 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-18 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-9 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-19 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-74 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-75 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-73 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-76 [[1, 384]] [1, 96] 36,960
GELU-19 [[1, 96]] [1, 96] 0
Dropout-28 [[1, 1152]] [1, 1152] 0
Linear-77 [[1, 96]] [1, 1152] 111,744
Mlp-19 [[1, 384]] [1, 1152] 0
Linear-78 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-29 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-10 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-10 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-20 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-79 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-20 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-30 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-80 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-20 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-10 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-21 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-82 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-83 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-81 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-84 [[1, 384]] [1, 96] 36,960
GELU-21 [[1, 96]] [1, 96] 0
Dropout-31 [[1, 1152]] [1, 1152] 0
Linear-85 [[1, 96]] [1, 1152] 111,744
Mlp-21 [[1, 384]] [1, 1152] 0
Linear-86 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-32 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-11 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-11 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-22 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-87 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-22 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-33 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-88 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-22 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-11 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-23 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-90 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-91 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-89 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-92 [[1, 384]] [1, 96] 36,960
GELU-23 [[1, 96]] [1, 96] 0
Dropout-34 [[1, 1152]] [1, 1152] 0
Linear-93 [[1, 96]] [1, 1152] 111,744
Mlp-23 [[1, 384]] [1, 1152] 0
Linear-94 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-35 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-12 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-12 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-24 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-95 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-24 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-36 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-96 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-24 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-12 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-25 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-98 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-99 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-97 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-100 [[1, 384]] [1, 96] 36,960
GELU-25 [[1, 96]] [1, 96] 0
Dropout-37 [[1, 1152]] [1, 1152] 0
Linear-101 [[1, 96]] [1, 1152] 111,744
Mlp-25 [[1, 384]] [1, 1152] 0
Linear-102 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-38 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-13 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-13 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-26 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-103 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-26 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-39 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-104 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-26 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-13 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-27 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-106 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-107 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-105 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-108 [[1, 384]] [1, 96] 36,960
GELU-27 [[1, 96]] [1, 96] 0
Dropout-40 [[1, 1152]] [1, 1152] 0
Linear-109 [[1, 96]] [1, 1152] 111,744
Mlp-27 [[1, 384]] [1, 1152] 0
Linear-110 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-41 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-14 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-14 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-28 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-111 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-28 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-42 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-112 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-28 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-14 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-29 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-114 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-115 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-113 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-116 [[1, 384]] [1, 96] 36,960
GELU-29 [[1, 96]] [1, 96] 0
Dropout-43 [[1, 1152]] [1, 1152] 0
Linear-117 [[1, 96]] [1, 1152] 111,744
Mlp-29 [[1, 384]] [1, 1152] 0
Linear-118 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-44 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-15 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-15 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-30 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-119 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-30 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-45 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-120 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-30 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-15 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-31 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-122 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-123 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-121 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-124 [[1, 384]] [1, 96] 36,960
GELU-31 [[1, 96]] [1, 96] 0
Dropout-46 [[1, 1152]] [1, 1152] 0
Linear-125 [[1, 96]] [1, 1152] 111,744
Mlp-31 [[1, 384]] [1, 1152] 0
Linear-126 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-47 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-16 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-16 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-32 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-127 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-32 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-48 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-128 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-32 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-16 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-33 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-130 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-131 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-129 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-132 [[1, 384]] [1, 96] 36,960
GELU-33 [[1, 96]] [1, 96] 0
Dropout-49 [[1, 1152]] [1, 1152] 0
Linear-133 [[1, 96]] [1, 1152] 111,744
Mlp-33 [[1, 384]] [1, 1152] 0
Linear-134 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-50 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-17 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-17 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-34 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-135 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-34 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-51 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-136 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-34 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-17 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-35 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-138 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-139 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-137 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,456
Linear-140 [[1, 384]] [1, 96] 36,960
GELU-35 [[1, 96]] [1, 96] 0
Dropout-52 [[1, 1152]] [1, 1152] 0
Linear-141 [[1, 96]] [1, 1152] 111,744
Mlp-35 [[1, 384]] [1, 1152] 0
Linear-142 [[1, 16, 16, 384]] [1, 16, 16, 384] 147,840
Dropout-53 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
WeightedPermuteMLP-18 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Identity-18 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-36 [[1, 16, 16, 384]] [1, 16, 16, 384] 768
Linear-143 [[1, 16, 16, 384]] [1, 16, 16, 1152] 443,520
GELU-36 [[1, 16, 16, 1152]] [1, 16, 16, 1152] 0
Dropout-54 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
Linear-144 [[1, 16, 16, 1152]] [1, 16, 16, 384] 442,752
Mlp-36 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
PermutatorBlock-18 [[1, 16, 16, 384]] [1, 16, 16, 384] 0
LayerNorm-37 [[1, 256, 384]] [1, 256, 384] 768
Linear-145 [[1, 384]] [1, 1000] 385,000
=================================================================================
Total params: 25,114,984
Trainable params: 25,114,984
Non-trainable params: 0
---------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 319.20
Params size (MB): 95.81
Estimated Total Size (MB): 415.58
---------------------------------------------------------------------------------{'total_params': 25114984, 'trainable_params': 25114984}| Model | # Param | Top-1 Acc. | Top-5 Acc. |
|---|---|---|---|
| vip s7 | 25M | 0.814 | 0.958 |
| vip m7 | 55M | 0.827 | 0.961 |
# vip s7vip_s = vip_s7()
vip_s.set_state_dict(paddle.load('/home/aistudio/data/data96765/vip_s7.pdparams'))# vip m7vip_m = vip_m7()
vip_m.set_state_dict(paddle.load('/home/aistudio/data/data96765/vip_m7.pdparams'))采用Cifar10数据集,无过多的数据增强
import paddle.vision.transforms as Tfrom paddle.vision.datasets import Cifar10
paddle.set_device('gpu')#数据准备transform = T.Compose([
T.Resize(size=(224,224)),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],data_format='HWC'),
T.ToTensor()
])
train_dataset = Cifar10(mode='train', transform=transform)
val_dataset = Cifar10(mode='test', transform=transform)vip_m = vip_m7(num_classes=10)
vip_m.set_state_dict(paddle.load('/home/aistudio/data/data96765/vip_m7.pdparams'))
model = paddle.Model(vip_m)由于时间篇幅只训练5轮,感兴趣的同学可以继续训练
model.prepare(optimizer=paddle.optimizer.AdamW(learning_rate=0.0001, parameters=model.parameters()),
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
visualdl=paddle.callbacks.VisualDL(log_dir='visual_log') # 开启训练可视化model.fit(
train_data=train_dataset,
eval_data=val_dataset,
batch_size=32,
epochs=5,
verbose=1,
callbacks=[visualdl]
)
以上就是ViP:类MLP架构又一狂欢的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号