长短时记忆网络(Long Short-Term Memory,LSTM)是一种特殊的循环神经网络,旨在解决普通RNN的长期依赖问题。
它引入了遗忘门、输入门和输出门来控制细胞状态,以选择性遗忘或保留信息。
主要包含以下要素:
- 遗忘门 f:控制上一时刻的细胞状态有多少被遗忘。
- 输入门 i:控制当前输入有多少被存入细胞状态。
- 细胞状态 C:记录长期信息的细胞记忆。
- 输出门 o:控制从细胞状态输出有多少记忆。
- 隐藏状态 h:LSTM的输出。
具体公式:
f = σ(Wf * [h(t-1), x(t)]) # 遗忘门
i = σ(Wi * [h(t-1), x(t)]) # 输入门
C̃ = tanh(Wc * [h(t-1), x(t)]) # 候选细胞状态
C = f * C(t-1) + i * C̃ # 细胞状态
o = σ(Wo * [h(t-1), x(t)]) # 输出门
h = o * tanh(C) # 隐藏状态
LSTM代码示例:
python
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.Wf = nn.Parameter(torch.Tensor(hidden_size, hidden_size + input_size))
self.Wi = nn.Parameter(torch.Tensor(hidden_size, hidden_size + input_size))
self.Wo = nn.Parameter(torch.Tensor(hidden_size, hidden_size + input_size))
self.Wc = nn.Parameter(torch.Tensor(hidden_size, hidden_size + input_size))
def forward(self, x, h, c):
h = h.squeeze()
c = c.squeeze()
i = torch.sigmoid(self.Wi.mm(torch.cat((h, x))))
f = torch.sigmoid(self.Wf.mm(torch.cat((h, x))))
o = torch.sigmoid(self.Wo.mm(torch.cat((h, x))))
g = torch.tanh(self.Wc.mm(torch.cat((h, x))))
c_next = f * c + i * g
h_next = o * torch.tanh(c_next)
return h_next, c_next
所以,LSTM通过引入门控机制解决了RNN的长期依赖问题,增强了其处理长序列数据的能力。它已成为深度学习和NLP中最为流行和有效的序列建模方法之一。