【即插即用完整代码】CVPR 2024部分单头注意力SHSA,分类、检测和分割SOTA!
文章末尾,扫码添加公众号,领取
完整版即插即用模块代码!
适用于所有的CV二维任务:图像分割、超分辨率、目标检测、图像识别、低光增强、遥感检测等
摘要(Abstract)
-
背景与动机:
-
近年来,高效的视觉Transformer(ViT)在资源受限的设备上表现出色,具有低延迟和良好的性能。传统的高效ViT模型通常使用4×4的补丁嵌入和4阶段结构,并在微观层面上采用复杂的多头注意力机制。
-
然而,这些模型在宏观和微观设计层面上仍存在计算冗余,尤其是在处理高分辨率图像时,全局注意力模块的计算复杂度与图像大小成二次关系。
-
-
方法介绍:
-
本文提出了一种单头视觉Transformer(SHViT),旨在以内存高效的方式解决所有设计层面的计算冗余。
-
通过使用更大的步长补丁化茎(patchify stem),减少了内存访问成本,并在早期阶段利用具有减少空间冗余的token表示,从而在保持竞争力性能的同时提高了效率。
-
进一步分析表明,早期阶段的注意力层可以用卷积替代,而后期阶段的多头注意力机制存在计算冗余。因此,引入了单头注意力模块,通过并行结合全局和局部信息来提高准确性。
-
-
实验结果:
-
在ImageNet-1k数据集上,SHViT-S4在GPU、CPU和iPhone 12移动设备上的速度分别比MobileViTv2 ×1.0快3.3倍、8.1倍和2.4倍,同时准确度提高了1.3%。
-
在使用Mask R-CNN头进行目标检测和实例分割的MS COCO任务中,SHViT在GPU和移动设备上的性能与FastViT-SA12相当,但背骨延迟分别降低了3.8倍和2.0倍。
-
引言(Introduction)
-
视觉Transformer的优势与挑战:
-
视觉Transformer(ViT)在各种计算机视觉任务中表现出色,能够有效地建模长距离依赖关系,并且随着训练数据和模型参数的增加而扩展。
-
然而,ViT缺乏归纳偏置,需要更多的训练数据,并且全局注意力模块的计算复杂度与图像大小成二次关系,导致计算效率低下。
-
-
现有研究的局限性:
-
以往的研究通常将ViT与CNN结合,或引入成本高效的注意力变体。然而,这些方法主要关注如何聚合token,而不是如何构建token。
-
尽管在多头自注意力(MHSA)方面取得了一些进展,但宏观和微观设计中的冗余仍未得到充分理解和解决。
-
-
研究目标与贡献:
-
本文旨在通过系统分析设计层面的冗余,提出内存高效的设计原则来解决这些问题。
-
引入单头视觉Transformer(SHViT),在多种设备上实现了良好的准确度和速度权衡。
-
通过广泛的实验验证了SHViT的高效性和有效性。
-
方法论(Method)
2.1 宏观设计中的冗余分析
-
补丁嵌入尺寸的影响:
-
传统的高效模型通常使用4×4的补丁嵌入。本文通过实验发现,使用更大的步长补丁化茎(如16×16)可以在早期阶段减少空间冗余,并且不会显著降低性能。
-
例如,使用16×16补丁化茎的模型在GPU和CPU上的速度分别提高了3.0倍和2.8倍,尽管其性能略有下降。
-
-
宏观设计的优势:
-
通过减少token数量,可以显著降低内存访问成本。
-
由于步长设计的激进性,当分辨率增加时,吞吐量的下降幅度较小,从而有效提高了性能。
-
2.2 微观设计中的冗余分析
-
多头注意力机制的冗余:
-
多头注意力机制在计算上需求较高,但研究表明许多注意力头并不是必不可少的。
-
通过注意力图可视化、头相似性分析和头消融研究,发现后期阶段的多头机制存在显著的冗余。
-
-
单头自注意力(SHSA)的提出:
-
提出了一种新的单头自注意力模块,仅对输入通道的一部分应用自注意力,而其他通道保持不变。
-
这种设计不仅消除了多头机制的计算冗余,还通过处理部分通道降低了内存访问成本。
-
2.3 单头视觉Transformer(SHViT)
-
架构概述:
-
SHViT的输入图像首先经过四个3×3的步进卷积层,以提取更好的局部表示。
-
然后,tokens通过三个阶段的SHViT块进行层次化表示提取。每个SHViT块包括深度卷积层、单头自注意力层和前馈网络。
-
为了在不损失信息的情况下减少tokens,使用了高效的下采样层。
-
-
设计细节:
-
在第一阶段不使用单头自注意力层,以提高效率。
-
通过全局平均池化和全连接层输出预测结果。
-
使用层归一化和批量归一化来优化模型速度。
-
实验细节配置(Experiments)
3.1 实施细节
-
数据集与任务:
-
在ImageNet-1K数据集上进行图像分类任务,包含1.28M训练图像和50K验证图像,涵盖1000个类别。
-
-
训练设置:
-
使用AdamW优化器从头开始训练模型,训练300个周期,学习率为10^-3,总批量大小为2048。
-
使用余弦学习率调度器,并在前5个周期进行线性预热。
-
对于384×384和512×512分辨率,进行30个周期的微调,学习率为0.004,权重衰减为10^-8。
-
-
硬件平台与性能评估:
-
在Nvidia A100 GPU上测量GPU吞吐量,批量大小为256。
-
在Intel Xeon Gold 5218R CPU @ 2.10GHz处理器上评估CPU和CPUONNX的运行时间,批量大小为16(使用单线程)。
-
在iPhone 12上测量移动延迟,iOS版本为16.5,批量大小为1。
-
将模型导出为ONNX格式进行CPUONNX评估。
-
在COCO数据集上使用RetinaNet和Mask R-CNN头进行目标检测和实例分割任务的验证。
-
视频代码展示:
主要代码展示:
class GroupNorm(torch.nn.GroupNorm):
"""
Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1):
super().__init__()
self.add_module('c', torch.nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', torch.nn.BatchNorm2d(b))
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
device=c.weight.device)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class BN_Linear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
trunc_normal_(self.l.weight, std=std)
if bias:
torch.nn.init.constant_(self.l.bias, 0)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps)**0.5
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class SHSA(torch.nn.Module):
"""Single-Head Self-Attention"""
def __init__(self, dim, qk_dim, pdim):
super().__init__()
self.scale = qk_dim ** -0.5
self.qk_dim = qk_dim
self.dim = dim
self.pdim = pdim
self.pre_norm = GroupNorm(pdim)
self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim)
self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
dim, dim, bn_weight_init = 0))
def forward(self, x):
B, C, H, W = x.shape
x1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim = 1)
x1 = self.pre_norm(x1)
qkv = self.qkv(x1)
q, k, v = qkv.split([self.qk_dim, self.qk_dim, self.pdim], dim = 1)
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
attn = (q.transpose(-2, -1) @ k) * self.scale
attn = attn.softmax(dim = -1)
x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W)
x = self.proj(torch.cat([x1, x2], dim = 1))
return x
if __name__ == '__main__':
x = torch.randn(1,64,32,32)
shsa = SHSA(dim=64, qk_dim=64, pdim=64)
print(shsa)
output = shsa(x)
print(f"Input shape: {x.shape}")
print(f"output shape: {output.shape}")
运行结果展示:
因篇幅问题不能全部显示,请点此查看更多更全内容