『零基础+1』一文看懂LSTM原理-《动手学深度学习》

P粉084495128
发布: 2025-07-30 11:23:41
原创
740人浏览过
长短期记忆网络(LSTM)为解决隐变量模型的长期信息保存与短期输入缺失问题而设计,含记忆元及输入门、遗忘门、输出门三个门控机制,通过特定计算控制信息留存更新。文中介绍其数学原理、从零开始及简洁实现,提及变体(如带猫眼连接)、与GRU的区别,并展示了训练和预测示例。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

『零基础+1』一文看懂lstm原理-《动手学深度学习》 - php中文网

1 长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) Hochreiter.Schmidhuber.1997。

它有许多与门控循环单元(9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年。

1.1 门控记忆元

可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。

长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。

有些文献认为记忆元是隐状态的一种特殊类型,

它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。

为了控制记忆元,我们需要许多门。

其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。

另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。

我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理,

这种设计的动机与门控循环单元相同,能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。

注:

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网        

Sigmoid 层的输出值在 0 到 1 间,表示每个部分所通过的信息。0 表示「对所有信息关上大门」;1 表示「我家大门常打开」。

一个 LSTM 有三个这样的门,控制 cell 的状态。

  • 门实质上是控制有百分之多少的信息保留下来。门操作由一个 sigmoid 网络层计算得到【0,1】的小数与输入数据流按位乘操作构成。

  • 门的操作是相同的,只是根据不同的设计思想,不同的数据流,叫不同的名字

1.2 输入门、忘记门和输出门

就如在门控循环单元中一样,当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,

它们由三个具有sigmoid激活函数的全连接层处理,以计算输入门、遗忘门和输出门的值。因此,这三个门的值都在(0,1)(0,1)的范围内。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网 『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网        


首先,LSTM 的第一步需要决定我们需要从 cell 中抛弃哪些信息。这个决定是从 sigmoid 中的「遗忘层」来实现的。

它的输入是 ht-1 和 xt,输出为一个 0 到 1 之间的数。Ct−1 就是每个在 cell 中所有在 0 和 1 之间的数值,就像我们刚刚所说的,0 代表全抛弃,1 代表全保留。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网        

下一步,我们需要决定什么样的信息应该被存储起来。这个过程主要分两步。

  • 首先是 sigmoid 层(输入门)决定我们需要更新哪些值;

  • 随后,tanh 层生成了一个新的候选向量 C`,它能够加入状态中。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网        

接下来,我们就可以更新 cell 的状态了。

将旧状态与 ft 相乘,忘记此前我们想要忘记的内容,然后加上 C`。此时遗忘门为ftft

得到的结果便是新的候选值,依照itit进行缩放。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网        

最后,我们需要决定要输出什么。此输出将基于我们处理后的单元状态。

  • 首先,我们会运行一个 sigmoid 层决定 cell 状态输出哪一部分。

  • 随后,我们把 cell 状态通过 tanh 函数,将输出值保持在-1 到 1 间。

  • 之后,我们再乘以 sigmoid 门的输出值,就可以得到结果了。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网        


我们来细化一下长短期记忆网络的数学表达。

假设有hh个隐藏单元,批量大小为nn,输入数为dd。

因此,输入为XtRn×dXt∈Rn×d,

前一时间步的隐状态为Ht1Rn×hHt−1∈Rn×h。

相应地,时间步tt的门被定义如下:

输入门是ItRn×hIt∈Rn×h,

遗忘门是FtRn×hFt∈Rn×h,

输出门是OtRn×hOt∈Rn×h。

它们的计算方法如下:

It=σ(XtWxi+Ht1Whi+bi)It=σ(XtWxi+Ht−1Whi+bi)

Ft=σ(XtWxf+Ht1Whf+bf),Ft=σ(XtWxf+Ht−1Whf+bf),

今天学点啥
今天学点啥

秘塔AI推出的AI学习助手

今天学点啥 258
查看详情 今天学点啥

Ot=σ(XtWxo+Ht1Who+bo)Ot=σ(XtWxo+Ht−1Who+bo)

其中Wxi,Wxf,WxoRd×hWxi,Wxf,Wxo∈Rd×h

Whi,Whf,WhoRh×hWhi,Whf,Who∈Rh×h是权重参数,

bi,bf,boR1×hbi,bf,bo∈R1×h是偏置参数。


我们将其中的一些操作集合命名为不同的记忆元名称

1.3 候选记忆元

由于还没有指定各种门的操作,所以先介绍候选记忆元(candidate memory cell) C~tRn×hC~t∈Rn×h。 它的计算与上面描述的三个门的计算类似, 但是使用tanhtanh函数作为激活函数,函数的值范围为(1,1)(−1,1)。 下面导出在时间步tt处的方程:

C~t=tanh(XtWxc+Ht1Whc+bc),C~t=tanh(XtWxc+Ht−1Whc+bc),

其中WxcRd×hWxc∈Rd×h和 WhcRh×hWhc∈Rh×h是权重参数, bcR1×hbc∈R1×h是偏置参数。


1.4 记忆元

在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。 类似地,在长短期记忆网络中,也有两个门用于这样的目的: 输入门ItIt控制采用多少来自C~tC~t的新数据, 而遗忘门FtFt控制保留多少过去的 记忆元Ct1Rn×hCt−1∈Rn×h的内容。 使用按元素乘法,得出:

Ct=FtCt1+ItC~t.Ct=Ft⊙Ct−1+It⊙C~t.

如果遗忘门始终为11且输入门始终为00, 则过去的记忆元Ct1Ct−1 将随时间被保存并传递到当前时间步。 引入这种设计是为了缓解梯度消失问题, 并更好地捕获序列中的长距离依赖关系。


1.5 隐状态

最后,我们需要定义如何计算隐状态 HtRn×hHt∈Rn×h, 这就是输出门发挥作用的地方。 在长短期记忆网络中,它仅仅是记忆元的tanhtanh的门控版本。 这就确保了HtHt的值始终在区间(1,1)(−1,1)内:

Ht=Ottanh(Ct).          (9.2.4)Ht=Ot⊙tanh(Ct).          (9.2.4)

只要输出门接近11,我们就能够有效地将所有记忆信息传递给预测部分, 而对于输出门接近00,我们只保留记忆元内的所有信息,而不需要更新隐状态。

2 从零开始实现

现在,我们从零开始实现长短期记忆网络。

我们首先加载时光机器数据集。

In [1]
import paddlefrom paddle import nnfrom d2l import paddle as d2limport paddle.nn.functional as Function

batch_size, num_steps = 32, 35train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
登录后复制
   

2.1 初始化模型参数

接下来,我们需要定义和初始化模型参数。

如前所述,超参数num_hiddens定义隐藏单元的数量。

我们按照标准差0.010.01的高斯分布初始化权重,并将偏置项设为00。

In [2]
def get_lstm_params(vocab_size, num_hiddens):
    num_inputs = num_outputs = vocab_size    def normal(shape):
        return paddle.randn(shape=shape)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                paddle.zeros([num_hiddens]))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = paddle.zeros([num_outputs])    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]    for param in params:
        param.stop_gradient = False
    return params
登录后复制
   

2.2 定义模型

在[初始化函数]中,长短期记忆网络的隐状态需要返回一个额外的记忆元,单元的值为0,形状为(批量大小,隐藏单元数)。

因此,我们得到以下的状态初始化。

In [3]
def init_lstm_state(batch_size, num_hiddens):
    return (paddle.zeros([batch_size, num_hiddens]),
            paddle.zeros([batch_size, num_hiddens]))
登录后复制
   

实际模型的定义与我们前面讨论的一样:提供三个门和一个额外的记忆元。

  • 请注意:只有隐状态才会传递到输出层,而记忆元CtCt不直接参与输出计算。
In [4]
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []    for X in inputs:
        I = Function.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = Function.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = Function.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = paddle.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * paddle.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)    return paddle.concat(outputs, axis=0), (H, C)
登录后复制
   

2.3 训练 和 预测

让我们通过实例化8.5节中,引入的RNNModelScratch类来训练一个长短期记忆网络。

此外,我们还加入了额外的模型测试。

In [6]
##  训练vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1.0model = d2l.RNNModelScratch(len(vocab), num_hiddens, device,get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
登录后复制
   
In [10]
##  预测# 自定义 prefix , num_preds 进行预测prefix = 'tr'num_preds = 5net = model
d2l.predict_ch8(prefix, num_preds, net, vocab, device)
登录后复制
       
'treasth'
登录后复制
               

2.4 简洁实现

使用高级API,我们可以直接实例化LSTM模型。

高级API封装了前文介绍的所有配置细节。

这段代码的运行速度要快得多,因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

In [7]
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens, time_major=True)
model = d2l.RNNModel(lstm_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
登录后复制
   

2.5 结构拓展

比较流行的 LSTM 变体就是 Gers & Schmidhuber (2000) 提出的「猫眼连接」(peephole connections)的神经网络,也就是说,门连接层能够接收到 cell 的状态。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网        

上图展示了全加上「猫眼连接」的效果,但实际上论文中并不会加这么多。

另一种变体就是采用一对门,分别叫遗忘门(forget)及输入门(input)。

与分开决定遗忘及输入的内容不同,现在的变体会将这两个流程一同实现。

我们只有在将要输入新信息时才会遗忘,而也只会在忘记信息的同时才会有新的信息输入。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网        

一个比较知名的变体为 GRU(Gated Recurrent),由 Cho, et al. (2014) 提出。他将遗忘门与输入门结合在一起,名为**「更新门」**(update gate),并将 cell 状态与隐藏层状态合并在一起,此外还有一些小的改动。

『零基础+1』一文看懂LSTM原理-《动手学深度学习》 - php中文网        

GRU和LSTM的区别

  • LSTM有三个门,而GRU有两个门
  • 去掉了细胞单元C
  • 输出的时候取消了二阶的非线性函数

这个模型比起标准 LSTM 模型简单一些,因此也变得更加流行了。

当然,这里所列举的只是一管窥豹,还有很多其它的变体,

比如 Yao, et al. (2015) 提出的 Depth Gated RNNs;或是另辟蹊径处理长期依赖问题的 Clockwork RNNs,由 Koutnik, et al. (2014) 提出。

哪个是最好的呢?而这些变化是否真的意义深远?

Greff, et al. (2015) 曾经对比较流行的几种变种做过对比,发现它们基本上都差不多;

Jozefowicz, et al. (2015) 测试了超过一万种 RNN 结构,发现有一些能够在特定任务上超过 LSTMs。

以上就是『零基础+1』一文看懂LSTM原理-《动手学深度学习》的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号