什么是长短时记忆网络?

长短时记忆网络(Long Short-Term Memory,LSTM)是一种特殊的循环神经网络,旨在解决普通RNN的长期依赖问题。
它引入了遗忘门、输入门和输出门来控制细胞状态,以选择性遗忘或保留信息。

主要包含以下要素:

  1. 遗忘门 f:控制上一时刻的细胞状态有多少被遗忘。
  2. 输入门 i:控制当前输入有多少被存入细胞状态。
  3. 细胞状态 C:记录长期信息的细胞记忆。
  4. 输出门 o:控制从细胞状态输出有多少记忆。
  5. 隐藏状态 h:LSTM的输出。

具体公式:

f = σ(Wf * [h(t-1), x(t)])    # 遗忘门
i = σ(Wi * [h(t-1), x(t)])    # 输入门 
C̃ = tanh(Wc * [h(t-1), x(t)]) # 候选细胞状态
C = f * C(t-1) + i * C̃        # 细胞状态  
o = σ(Wo * [h(t-1), x(t)])    # 输出门
h = o * tanh(C)                # 隐藏状态

LSTM代码示例:

python
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTM, self).__init__()

        self.input_size = input_size 
        self.hidden_size = hidden_size 

        self.Wf = nn.Parameter(torch.Tensor(hidden_size, hidden_size + input_size))
        self.Wi = nn.Parameter(torch.Tensor(hidden_size, hidden_size + input_size)) 
        self.Wo = nn.Parameter(torch.Tensor(hidden_size, hidden_size + input_size))
        self.Wc = nn.Parameter(torch.Tensor(hidden_size, hidden_size + input_size))

    def forward(self, x, h, c):
        h = h.squeeze()
        c = c.squeeze()

        i = torch.sigmoid(self.Wi.mm(torch.cat((h, x))))
        f = torch.sigmoid(self.Wf.mm(torch.cat((h, x)))) 
        o = torch.sigmoid(self.Wo.mm(torch.cat((h, x))))
        g = torch.tanh(self.Wc.mm(torch.cat((h, x))))

        c_next = f * c + i * g 
        h_next = o * torch.tanh(c_next)

        return h_next, c_next 

所以,LSTM通过引入门控机制解决了RNN的长期依赖问题,增强了其处理长序列数据的能力。它已成为深度学习和NLP中最为流行和有效的序列建模方法之一。