您的当前位置:首页正文

【即插即用完整代码】CVPR 2024部分单头注意力SHSA,分类、检测和分割SOTA!

来源:华佗健康网

文章末尾,扫码添加公众号,领取

完整版即插即用模块代码!

适用于所有的CV二维任务:图像分割、超分辨率、目标检测、图像识别、低光增强、遥感检测等

摘要(Abstract)

  • 背景与动机

    • 近年来,高效的视觉Transformer(ViT)在资源受限的设备上表现出色,具有低延迟和良好的性能。传统的高效ViT模型通常使用4×4的补丁嵌入和4阶段结构,并在微观层面上采用复杂的多头注意力机制。

    • 然而,这些模型在宏观和微观设计层面上仍存在计算冗余,尤其是在处理高分辨率图像时,全局注意力模块的计算复杂度与图像大小成二次关系。

  • 方法介绍

    • 本文提出了一种单头视觉Transformer(SHViT),旨在以内存高效的方式解决所有设计层面的计算冗余。

    • 通过使用更大的步长补丁化茎(patchify stem),减少了内存访问成本,并在早期阶段利用具有减少空间冗余的token表示,从而在保持竞争力性能的同时提高了效率。

    • 进一步分析表明,早期阶段的注意力层可以用卷积替代,而后期阶段的多头注意力机制存在计算冗余。因此,引入了单头注意力模块,通过并行结合全局和局部信息来提高准确性。

  • 实验结果

    • 在ImageNet-1k数据集上,SHViT-S4在GPU、CPU和iPhone 12移动设备上的速度分别比MobileViTv2 ×1.0快3.3倍、8.1倍和2.4倍,同时准确度提高了1.3%。

    • 在使用Mask R-CNN头进行目标检测和实例分割的MS COCO任务中,SHViT在GPU和移动设备上的性能与FastViT-SA12相当,但背骨延迟分别降低了3.8倍和2.0倍。

引言(Introduction)

  • 视觉Transformer的优势与挑战

    • 视觉Transformer(ViT)在各种计算机视觉任务中表现出色,能够有效地建模长距离依赖关系,并且随着训练数据和模型参数的增加而扩展。

    • 然而,ViT缺乏归纳偏置,需要更多的训练数据,并且全局注意力模块的计算复杂度与图像大小成二次关系,导致计算效率低下。

  • 现有研究的局限性

    • 以往的研究通常将ViT与CNN结合,或引入成本高效的注意力变体。然而,这些方法主要关注如何聚合token,而不是如何构建token。

    • 尽管在多头自注意力(MHSA)方面取得了一些进展,但宏观和微观设计中的冗余仍未得到充分理解和解决。

  • 研究目标与贡献

    • 本文旨在通过系统分析设计层面的冗余,提出内存高效的设计原则来解决这些问题。

    • 引入单头视觉Transformer(SHViT),在多种设备上实现了良好的准确度和速度权衡。

    • 通过广泛的实验验证了SHViT的高效性和有效性。

方法论(Method)

2.1 宏观设计中的冗余分析
  • 补丁嵌入尺寸的影响

    • 传统的高效模型通常使用4×4的补丁嵌入。本文通过实验发现,使用更大的步长补丁化茎(如16×16)可以在早期阶段减少空间冗余,并且不会显著降低性能。

    • 例如,使用16×16补丁化茎的模型在GPU和CPU上的速度分别提高了3.0倍和2.8倍,尽管其性能略有下降。

  • 宏观设计的优势

    • 通过减少token数量,可以显著降低内存访问成本。

    • 由于步长设计的激进性,当分辨率增加时,吞吐量的下降幅度较小,从而有效提高了性能。

2.2 微观设计中的冗余分析
  • 多头注意力机制的冗余

    • 多头注意力机制在计算上需求较高,但研究表明许多注意力头并不是必不可少的。

    • 通过注意力图可视化、头相似性分析和头消融研究,发现后期阶段的多头机制存在显著的冗余。

  • 单头自注意力(SHSA)的提出

    • 提出了一种新的单头自注意力模块,仅对输入通道的一部分应用自注意力,而其他通道保持不变。

    • 这种设计不仅消除了多头机制的计算冗余,还通过处理部分通道降低了内存访问成本。

2.3 单头视觉Transformer(SHViT)
  • 架构概述

    • SHViT的输入图像首先经过四个3×3的步进卷积层,以提取更好的局部表示。

    • 然后,tokens通过三个阶段的SHViT块进行层次化表示提取。每个SHViT块包括深度卷积层、单头自注意力层和前馈网络。

    • 为了在不损失信息的情况下减少tokens,使用了高效的下采样层。

  • 设计细节

    • 在第一阶段不使用单头自注意力层,以提高效率。

    • 通过全局平均池化和全连接层输出预测结果。

    • 使用层归一化和批量归一化来优化模型速度。

实验细节配置(Experiments)

3.1 实施细节
  • 数据集与任务

    • 在ImageNet-1K数据集上进行图像分类任务,包含1.28M训练图像和50K验证图像,涵盖1000个类别。

  • 训练设置

    • 使用AdamW优化器从头开始训练模型,训练300个周期,学习率为10^-3,总批量大小为2048。

    • 使用余弦学习率调度器,并在前5个周期进行线性预热。

    • 对于384×384和512×512分辨率,进行30个周期的微调,学习率为0.004,权重衰减为10^-8。

  • 硬件平台与性能评估

    • 在Nvidia A100 GPU上测量GPU吞吐量,批量大小为256。

    • 在Intel Xeon Gold 5218R CPU @ 2.10GHz处理器上评估CPU和CPUONNX的运行时间,批量大小为16(使用单线程)。

    • 在iPhone 12上测量移动延迟,iOS版本为16.5,批量大小为1。

    • 将模型导出为ONNX格式进行CPUONNX评估。

    • 在COCO数据集上使用RetinaNet和Mask R-CNN头进行目标检测和实例分割任务的验证。

视频代码展示:

主要代码展示:

class GroupNorm(torch.nn.GroupNorm):
    """
    Group Normalization with 1 group.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, **kwargs):
        super().__init__(1, num_channels, **kwargs)
class Conv2d_BN(torch.nn.Sequential):
    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
                 groups=1, bn_weight_init=1):
        super().__init__()
        self.add_module('c', torch.nn.Conv2d(
            a, b, ks, stride, pad, dilation, groups, bias=False))
        self.add_module('bn', torch.nn.BatchNorm2d(b))
        torch.nn.init.constant_(self.bn.weight, bn_weight_init)
        torch.nn.init.constant_(self.bn.bias, 0)
    @torch.no_grad()
    def fuse(self):
        c, bn = self._modules.values()
        
        m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
            0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
            device=c.weight.device)
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m
class BN_Linear(torch.nn.Sequential):
    def __init__(self, a, b, bias=True, std=0.02):
        super().__init__()
        self.add_module('bn', torch.nn.BatchNorm1d(a))
        self.add_module('l', torch.nn.Linear(a, b, bias=bias))
        trunc_normal_(self.l.weight, std=std)
        if bias:
            torch.nn.init.constant_(self.l.bias, 0)
    @torch.no_grad()
    def fuse(self):
        bn, l = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        b = bn.bias - self.bn.running_mean * \
            self.bn.weight / (bn.running_var + bn.eps)**0.5
        
            b = (l.weight @ b[:, None]).view(-1) + self.l.bias
        m = torch.nn.Linear(w.size(1), w.size(0))
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m
    
class SHSA(torch.nn.Module):
    """Single-Head Self-Attention"""
    def __init__(self, dim, qk_dim, pdim):
        super().__init__()
        self.scale = qk_dim ** -0.5
        self.qk_dim = qk_dim
        self.dim = dim
        self.pdim = pdim
        self.pre_norm = GroupNorm(pdim)
        self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim)
        self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
            dim, dim, bn_weight_init = 0))
        
    def forward(self, x):
        B, C, H, W = x.shape
        x1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim = 1)
        x1 = self.pre_norm(x1)
        qkv = self.qkv(x1)
        q, k, v = qkv.split([self.qk_dim, self.qk_dim, self.pdim], dim = 1)
        q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
        
        attn = (q.transpose(-2, -1) @ k) * self.scale
        attn = attn.softmax(dim = -1)
        x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W)
        x = self.proj(torch.cat([x1, x2], dim = 1))
        return x
if __name__ == '__main__':
    x = torch.randn(1,64,32,32)
    shsa = SHSA(dim=64, qk_dim=64, pdim=64)
    print(shsa)
    output = shsa(x)
    print(f"Input shape: {x.shape}")
    print(f"output shape: {output.shape}")

运行结果展示: 

因篇幅问题不能全部显示,请点此查看更多更全内容