YOLOV5 添加注意力机制,以添加ECAttention为例(系列)
来源:华佗健康网
总述:在yolov5的网络结构中添加相关的注意力机制已经不是什么新鲜的事儿了,各种新奇的注意力机制近年来也是层出不穷,当然,选择什么类型的注意力机制、采取什么分布规范的数据集以及将注意力机制添加在网络的什么位置都会或多或少地对结果如P\R\mAP等产生好的、坏的影响,这里的确是一千个读者有一千个哈姆雷特,接下来从代码的角度讲详细的介绍如何在主干网络中添加ECA注意机制。
1、在common.py函数里面添加ECA代码。
2、在yolo.py中注册ECA。
3、修改对应的yaml文件。
步骤:
1、在common.py函数里面添加ECA代码。
class ECA(nn.Module):
"""Constructs a ECA module.
Args:
channel: Number of channels of the input feature map
k_size: Adaptive selection of kernel size
"""
def __init__(self, c1,c2, k_size=3):
super(ECA, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# feature descriptor on the global spatial information
y = self.avg_pool(x)
# print(y.shape,y.squeeze(-1).shape,y.squeeze(-1).transpose(-1, -2).shape)
# Two different branches of ECA module
# 50*C*1*1
#50*C*1
#50*1*C
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
y = self.sigmoid(y)
return x * y.expand_as(x)
2、在yolo.py中注册ECA。对yolo.py详细修改见下面代码。
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
C3, C3TR]:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3,eca_layer]:
args.insert(2, n) # number of repeats
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum([ch[x] for x in f])
elif m is Detect:
args.append([ch[x] for x in f])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
elif m is Contract:
c2 = ch[f] * args[0] ** 2
elif m is Expand:
c2 = ch[f] // args[0] ** 2
elif m is eca_layer:
channel=args[0]
channel=make_divisible(channel*gw,8)if channel != no else channel
args=[channel]
else:
c2 = ch[f]
3、修改对应的yaml文件。
# parameters
nc: 2 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, C3, [1024, False]], # 9
[-1,1,ECA,[1024]], #ECA
]
# YOLOv5 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 15], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 11], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[18, 21, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
自学心得:正如开头所述,添加什么、添加何处、采用什么数据集等等一切取决于你的实际工作是什么。最后祝大家多发SCI!!!
欢迎讨论,有问题直接评论区!看到我会第一时间解答,同时也欢迎各位积极讨论!!!
因篇幅问题不能全部显示,请点此查看更多更全内容