什么是残差网络?

残差网络(Residual Network)是一种广泛使用的深度神经网络结构。

它的基本思想是:
通过在相邻层之间添加“残差连接”,实现信息的直接传递,避免信息在经过多层处理后消失或爆炸。
一个简单的残差块,当输入x经过两层网络计算得到F(x)时,如果直接将x添加到F(x)中,那么最终的输出就是x + F(x),也就是输入x和理论上F(x)应该逼近的目标值之和。
这实际上使得网络只需要学习输入x与输出y之间的差值(残差)F(x),而不是直接学习 x 到 y 的映射关系。这减轻了深层网络的参数优化难度,有助于解决梯度消失和爆炸问题。

代码示例:

python
import torch
import torch.nn as nn

# 残差块 
class ResidualBlock(nn.Module): 
    def __init__(self, in_channels, out_channels, stride=1, downsample=False):  
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        if downsample:  
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            ) 
        else:
            self.downsample = None   

    def forward(self, x):
        residual = x 
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample:
            residual = self.downsample(x)

        out += residual 
        out = self.relu(out)
        return out

可以看到,残差块通过residual连接实现了输入的直接传递,输出是输入与block内两层网络输出的和。这使得网络只需要学习输入和输出的差值,简化了优化难度。

残差网络通过堆叠多个残差块,实现了信息的直接传播,解决了深层网络训练中的梯度消失问题,使得网络可以继续加深,达到上百层的深度。这是残差网络可以训练超深度模型的关键。

所以,残差网络通过残差学习和信息直接传播的思想,成功地训练了深度达上百层的模型,解决了深层网络训练的瓶颈,大大提高了模型的表征能力。理解残差网络的原理与结构,可以帮助我们构建更深更强的神经网络模型。