pytorch实现的通道注意力机制SENet的代码

通道注意力机制SENet

pytorch实现的通道注意力机制SENet的代码_第1张图片

Diagram of a Squeeze-and-Excitation building block

实现代码如下

import torch
import torch.nn as nn

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

if __name__ == "__main__":
    t = torch.ones((32, 128, 26, 26))
    se = SELayer(channel=128, reduction=16)
    out = se(t)
    print(out.shape)

你可能感兴趣的