import torch
import torch.nn as nn
[docs]class SELayer(nn.Module):
"""Squeeze-and-excitation networks"""
def __init__(self, in_channels, se_channels):
super(SELayer, self).__init__()
self.in_channels = in_channels
self.se_channels = se_channels
self.encoder_decoder = nn.Sequential(
nn.Linear(in_channels, se_channels),
nn.ELU(),
nn.Linear(se_channels, in_channels),
nn.Sigmoid(),
)
# self.reset_parameters()
[docs] def forward(self, x):
""""""
# Aggregate input representation
x_global = torch.mean(x, dim=0)
# Compute reweighting vector s
s = self.encoder_decoder(x_global)
return x * s