什么是ResNet?有几种类型?

ResNet(Residual Network)是一种深度残差神经网络,由微软研究院提出。它的关键在于使用残差结构,解决在训练深度神经网络时出现的梯度消失问题。

ResNet的基本块是残差块,主要有两种类型:

  1. ResNet V1:
x = ConvLayer(x) 
x = ConvLayer(x)
x = x + shortcut   # shortcut直连x
x = Activation(x) 
  1. ResNet V2:
x = ConvLayer(x)
x = ConvLayer(x)
shortcut = ConvLayer(x)  # shortcut也做卷积变换
x = x + shortcut
x = Activation(x)

ResNet通过在主路径上加残差块,实现深层网络结构。主要有以下优点:

  1. 残差连接可以跨过多层网络直接传播信号,有效地解决梯度消失问题,利于网络训练。
  2. 残差块内包含若干个卷积层,但输出和输入尺度保持一致,便于残差相加。
  3. 残差网络极深,152层的ResNet在ImageNet上取得 state-of-the-art的结果。网络越深代表提取的特征越抽象和语义化。
  4. 残差连接引入的 shortcut路径使得网络中低层特征也可以直接到达最终输出,这种多阶特征结合的方式可以产生更强的表示能力。
  5. 残差块使得网络每一部分的功能非常清晰,有利于训练和理解。

ResNet代码示例:

python
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

所以,ResNet通过残差 Learning 的方式成功地训练了152层的超深度神经网络,大大推动了CNN在计算机视觉的进展。它已成为图像分类和目标检测等视觉任务的基础网络结构之一。