【模型】Bi-LSTM模型详解
1. 模型架构与计算过程
Bi-LSTM 由两个LSTM层组成,一个是正向LSTM(从前到后处理序列),另一个是反向LSTM(从后到前处理序列)。每个LSTM单元都可以通过门控机制对序列的长期依赖进行建模。

1. 遗忘门

遗忘门决定了前一时刻的单元状态 c t − 1 c_{t-1} ct−1中哪些信息应该被遗忘,哪些应该保留。其计算方式如下:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
其中:
- f t f_t ft 是遗忘门的输出。
- h t − 1 h_{t-1} ht−1 是前一时刻的隐藏状态。
- x t x_t xt 是当前时刻的输入。
- W f W_f Wf 和 b f b_f bf 是遗忘门的权重和偏置。
2. 输入门

输入门决定了当前时刻的输入信息 x t x_t xt 多少应该被存储到单元状态中。它计算一个值 i t i_t it,这个值将与候选单元状态 c t ~ \tilde{c_t} ct~ 一起更新当前的单元状态。
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
其中:
- i t i_t it 是输入门的输出。
- W i W_i Wi 和 b i b_i bi 是输入门的权重和偏置。
3. 更新单元状态

单元状态 c t c_t ct 是LSTM的长期记忆,它根据遗忘门和输入门的输出进行更新:
c t = f t ⋅ c t − 1 + i t ⋅ c t ~ ~c_t = f_t \cdot c_{t-1} + i_t \cdot \tilde{c_t} ct=ft⋅ct−1+it⋅ct~
其中:
- c t − 1 c_{t-1} ct−1 是前一时刻的单元状态。
- f t f_t ft 是遗忘门的输出。
- i t i_t it 是输入门的输出。
- c t ~ \tilde{c_t} ct~ 是当前候选的单元状态。
4. 输出门(Output Gate)

输出门决定了当前时刻的隐藏状态 h t h_t ht(即模型的输出)。它基于当前单元状态 c t c_t ct 和上一时刻的隐藏状态 h t − 1 h_{t-1} ht−1,通过sigmoid激活函数计算输出:
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
然后,将输出门的结果与当前单元状态的tanh激活值相乘,得到当前的隐藏状态:
h t = o t ⋅ tanh ( c t ) h_t = o_t \cdot \tanh(c_t) ht=ot⋅tanh(ct)
其中:
- o t o_t ot 是输出门的输出。
- h t h_t ht 是当前的隐藏状态,也是模型的输出。
总结:
- 遗忘门 f t f_t ft 控制历史信息的遗忘程度。
- 输入门 i t i_t it 控制新信息的加入程度。
- 更新单元状态 c t c_t ct 结合了历史状态和新信息,更新了长期记忆。
- 输出门 o t o_t ot 决定了哪些信息被传递到下一层或作为最终输出。
2. PyTorch实现
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import mean_squared_error
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt# 定义Bi-LSTM模型类
class BiLSTM(nn.Module):def __init__(self, input_size, hidden_size, num_layers, forecast_horizon):super(BiLSTM, self).__init__()self.num_layers = num_layersself.input_size = input_sizeself.hidden_size = hidden_size self.forecast_horizon = forecast_horizon# 定义双向LSTM层self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size,num_layers=self.num_layers, batch_first=True, bidirectional=True)# 定义全连接层self.fc1 = nn.Linear(self.hidden_size * 2, 20) # 由于是双向,hidden_size要乘以2self.fc2 = nn.Linear(20, self.forecast_horizon)# Dropout层,防止过拟合self.dropout = nn.Dropout(0.2)def forward(self, x):# 初始化隐藏状态和细胞状态h_0 = torch.randn(self.num_layers * 2, x.size(0), self.hidden_size).to(device) # 双向,所以乘以2c_0 = torch.randn(self.num_layers * 2, x.size(0), self.hidden_size).to(device)# 通过双向LSTM层进行前向传播out, _ = self.lstm(x, (h_0, c_0))# 只取最后一个时间步的输出(双向LSTM的输出将是[batch_size, time_steps, hidden_size*2])out = F.relu(self.fc1(out[:, -1, :])) # 只取最后一个时间步的输出,经过全连接层1并激活out = self.fc2(out) # 输出层return out# 设置设备,使用GPU(如果可用)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')# 假设你已经准备好了数据,X_train, X_test, y_train, y_test等
# 将数据转换为torch tensors,并转移到设备(GPU/CPU)
X_train_tensor = torch.Tensor(X_train).to(device)
X_test_tensor = torch.Tensor(X_test).to(device)
y_train_tensor = torch.Tensor(y_train).squeeze(-1).to(device) # 确保y_train是正确形状
y_test_tensor = torch.Tensor(y_test).squeeze(-1).to(device)# 创建训练数据和测试数据集
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)# 定义 DataLoader
batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 初始化Bi-LSTM模型
input_size = X_train.shape[2] # 特征数量
hidden_size = 64 # 隐藏层神经元数量
num_layers = 2 # LSTM层数
forecast_horizon = 5 # 预测的目标步数model = BiLSTM(input_size, hidden_size, num_layers, forecast_horizon).to(device)# 定义训练函数
def train_model_with_dataloader(model, train_loader, test_loader, epochs=50, lr=0.001):criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr=lr)train_loss = []val_loss = []for epoch in range(epochs):# 训练阶段model.train()epoch_train_loss = 0for X_batch, y_batch in train_loader:optimizer.zero_grad()# 前向传播output_train = model(X_batch)# 计算损失loss = criterion(output_train, y_batch)loss.backward() # 反向传播optimizer.step() # 更新参数epoch_train_loss += loss.item() # 累计批次损失train_loss.append(epoch_train_loss / len(train_loader)) # 计算平均损失# 验证阶段model.eval()epoch_val_loss = 0with torch.no_grad():for X_batch, y_batch in test_loader:output_val = model(X_batch)loss = criterion(output_val, y_batch)epoch_val_loss += loss.item()val_loss.append(epoch_val_loss / len(test_loader)) # 计算平均验证损失# 打印日志if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss[-1]:.4f}, Validation Loss: {val_loss[-1]:.4f}')# 绘制训练损失和验证损失曲线plt.plot(train_loss, label='Train Loss')plt.plot(val_loss, label='Validation Loss')plt.title('Loss vs Epochs')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.show()# 训练模型
train_model_with_dataloader(model, train_loader, test_loader, epochs=50)# 评估模型性能
def evaluate_model_with_dataloader(model, test_loader):model.eval()y_pred_list = []y_test_list = []with torch.no_grad():for X_batch, y_batch in test_loader:y_pred = model(X_batch)y_pred_list.append(y_pred.cpu().numpy())y_test_list.append(y_batch.cpu().numpy())# 将所有批次结果拼接y_pred_rescaled = np.concatenate(y_pred_list, axis=0)y_test_rescaled = np.concatenate(y_test_list, axis=0)# 计算均方误差mse = mean_squared_error(y_test_rescaled, y_pred_rescaled)print(f'Mean Squared Error: {mse:.4f}')return y_pred_rescaled, y_test_rescaled# 评估模型性能
y_pred_rescaled, y_test_rescaled = evaluate_model_with_dataloader(model, test_loader)# 保存模型
def save_model(model, path='./model_files/bi_lstm_model.pth'):torch.save(model.state_dict(), path)print(f'Model saved to {path}')# 保存训练好的模型
save_model(model)
2.1代码解析
(1)为什么LSTM需要c_0
def forward(self, x):# 初始化隐藏状态和细胞状态h_0 = torch.randn(self.num_layers * 2, x.size(0), self.hidden_size).to(device) # 双向,所以乘以2c_0 = torch.randn(self.num_layers * 2, x.size(0), self.hidden_size).to(device)# 通过双向LSTM层进行前向传播out, _ = self.lstm(x, (h_0, c_0))# 只取最后一个时间步的输出(双向LSTM的输出将是[batch_size, time_steps, hidden_size*2])out = F.relu(self.fc1(out[:, -1, :])) # 只取最后一个时间步的输出,经过全连接层1并激活out = self.fc2(out) # 输出层return out
这是因为 RNN 和 LSTM 在其内部状态(即隐藏状态和细胞状态)管理上有所不同:
- RNN只需要隐藏状态 h t h_t ht来捕捉短期记忆,因此不需要单独的单元状态。
- LSTM 需要隐藏状态 h t h_t ht 和单元状态 c t c_t ct来分别管理短期记忆和长期记忆,所以需要同时初始化并传递这两个状态。
3. 优缺点
优点:
- 双向信息捕捉:通过双向LSTM,模型能够同时捕捉过去和未来的信息,提高了对序列的理解。
- 长序列依赖:LSTM可以有效地解决长序列中的梯度消失和梯度爆炸问题,适用于长时间依赖的任务。
- 更好的预测效果:相比单向LSTM,Bi-LSTM能够在某些任务上提供更好的预测效果,尤其是在需要了解上下文的应用场景中。
缺点:
- 计算复杂度高:双向LSTM需要计算两次LSTM,计算和存储开销较大,尤其是序列很长时。
- 训练时间长:由于参数量增加,模型训练的时间和内存消耗也较大。
- 可能过拟合:在数据量较小或噪声较大的情况下,双向LSTM容易产生过拟合。
4. 模型算法变种
- Attention机制:结合Bi-LSTM与Attention机制,可以进一步增强模型对序列中重要部分的关注能力,提升性能。
- GRU变种:将Bi-LSTM替换为双向GRU(Gated Recurrent Units),GRU相较于LSTM计算量更少,适用于计算资源有限的场景。
- Stacked Bi-LSTM:堆叠多个Bi-LSTM层,进一步提升模型的表现,能够捕捉更复杂的时序依赖关系。
- CRF(条件随机场):结合Bi-LSTM与CRF用于序列标注任务,CRF层能够对标签之间的依赖关系进行建模,进一步提高精度。
5. 模型特点
- 双向信息:通过双向LSTM能够同时捕捉到序列中前向和反向的依赖关系,增强了模型对序列数据的理解能力。
- 序列建模:Bi-LSTM能很好地处理序列数据,尤其是在涉及时间序列或文本等领域。
- 长期依赖捕获:LSTM的设计能够克服传统RNN在处理长序列时的梯度消失问题,适合处理长时间依赖。
6. 应用场景
-
自然语言处理:
- 语音识别:通过Bi-LSTM捕捉语音中的上下文信息,提升识别准确性。
- 机器翻译:在翻译过程中,Bi-LSTM能够同时考虑源语言句子的前后文,提高翻译质量。
- 命名实体识别(NER):通过双向LSTM处理文本,识别出文本中的实体(如人名、地名等)。
- 语义分析:在文本分类、情感分析等任务中,Bi-LSTM可以捕捉更丰富的上下文信息。
-
时间序列预测:
- 财务数据预测:例如,股票价格预测,通过Bi-LSTM能够捕捉时间序列中的长短期依赖。
- 销售预测:对销售数据进行分析,Bi-LSTM可以帮助识别趋势和周期性变化。
-
语音与音频处理:
- 语音情感识别:Bi-LSTM能处理语音信号中的上下文信息,帮助识别说话者的情感状态。
- 音乐生成:生成与输入音频相关的音乐,Bi-LSTM能够理解音频序列的长期依赖性。
相关文章:
【模型】Bi-LSTM模型详解
1. 模型架构与计算过程 Bi-LSTM 由两个LSTM层组成,一个是正向LSTM(从前到后处理序列),另一个是反向LSTM(从后到前处理序列)。每个LSTM单元都可以通过门控机制对序列的长期依赖进行建模。 1. 遗忘门 遗忘…...
directx12 3d开发过程中出现的报错 一
报错:“&”要求左值 “& 要求左值” 这个错误通常是因为你在尝试获取一个临时对象或者右值的地址,而 & 运算符只能用于左值(即可以放在赋值语句左边的表达式,代表一个可以被引用的内存位置)。 可能出现错…...
Ubuntu 24.04 安装 Poetry:Python 依赖管理的终极指南
Ubuntu 24.04 安装 Poetry:Python 依赖管理的终极指南 1. 更新系统包列表2. 安装 Poetry方法 1:使用官方安装脚本方法 2:使用 Pipx 安装 3. 配置环境变量4. 验证安装5. 配置 Poetry(可选)设置虚拟环境位置配置镜像源 6…...
读写锁: ReentrantReadWriteLock
在多线程编程场景中,对共享资源的访问控制极为关键。传统的锁机制在同一时刻只允许一个线程访问共享资源,这在读写操作频繁的场景下,会因为读操作相互不影响数据一致性,而造成不必要的性能损耗。ReentrantReadWriteLock࿰…...
上海路网道路 水系铁路绿色住宅地工业用地面图层shp格式arcgis无偏移坐标2023年
标题和描述中提到的资源是关于2023年上海市地理信息数据的集合,主要包含道路、水系、铁路、绿色住宅区以及工业用地的图层数据,这些数据以Shapefile(shp)格式存储,并且是适用于ArcGIS软件的无偏移坐标系统。这个压缩包…...
爬虫学习笔记之Robots协议相关整理
定义 Robots协议也称作爬虫协议、机器人协议,全名为网络爬虫排除标准,用来告诉爬虫和搜索引擎哪些页面可以爬取、哪些不可以。它通常是一个叫做robots.txt的文本文件,一般放在网站的根目录下。 robots.txt文件的样例 对有所爬虫均生效&#…...
Python小游戏29乒乓球
import pygame import sys # 初始化pygame pygame.init() # 屏幕大小 screen_width 800 screen_height 600 screen pygame.display.set_mode((screen_width, screen_height)) pygame.display.set_caption("打乒乓球") # 颜色定义 WHITE (255, 255, 255) BLACK (…...
220.存在重复元素③
目录 一、题目二、思路三、解法四、收获 一、题目 给你一个整数数组 nums 和两个整数 indexDiff 和 valueDiff 。 找出满足下述条件的下标对 (i, j): i ! j, abs(i - j) < indexDiff abs(nums[i] - nums[j]) < valueDiff 如果存在,返回 true &a…...
使用 Go 语言调用 DeepSeek API:完整指南
引言 DeepSeek 是一个强大的 AI 模型服务平台,本文将详细介绍如何使用 Go 语言调用 DeepSeek API,实现流式输出和对话功能。 Deepseek的api因为被功击已不能用,本文以 DeepSeek:https://cloud.siliconflow.cn/i/vnCCfVaQ 为例子进…...
AJAX笔记原理篇
黑马程序员视频地址: AJAX-Day03-01.XMLHttpRequest_基本使用https://www.bilibili.com/video/BV1MN411y7pw?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes&p33https://www.bilibili.com/video/BV1MN411y7pw?vd_sour…...
ubuntu直接运行arm环境qemu-arm-static
qemu-arm-static 嵌入式开发有时会在ARM设备上使用ubuntu文件系统。开发者常常会面临这样一个问题,想预先交叉编译并安装一些应用程序,但是交叉编译的环境配置以及依赖包的安装十分繁琐,并且容易出错。想直接在目标板上进行编译和安装&#x…...
尝试把clang-tidy集成到AWTK项目
前言 项目经过一段时间的耕耘终于进入了团队开发阶段,期间出现了很多问题,其中一个就是开会讨论团队的代码风格规范,目前项目代码风格比较混乱,有的模块是驼峰,有的模块是匈牙利,后面经过讨论,…...
一文了解性能优化的方法
背景 在应用上线后,用户感知较明显的,除了功能满足需求之外,再者就是程序的性能了。因此,在日常开发中,我们除了满足基本的功能之外,还应该考虑性能因素。关注并可以优化程序性能,也是体现开发能…...
【怎么用系列】短视频戒断——对推荐算法进行干扰
如今推荐算法已经渗透到人们生活的方方面面,尤其是抖音等短视频核心就是推荐算法。 【短视频的危害】 1> 会让人变笨,慢慢让人丧失注意力与专注力 2> 让人丧失阅读长文的能力 3> 让人沉浸在一个又一个快感与嗨点当中。当我们刷短视频时&#x…...
C#中的委托(Delegate)
什么是委托? 首先,我们要知道C#是一种强类型的编程语言,强类型的编程语言的特性,是所有的东西都是特定的类型 委托是一种存储函数的引用类型,就像我们定义的一个 string str 一样,这个 str 变量就是 string 类型. 因为C#中没有函数类型,但是可以定义一个委托类型,把这个函数…...
PostCss
什么是 PostCss 如果把 CSS 单独拎出来看,光是样式本身,就有很多事情要处理。 既然有这么多事情要处理,何不把这些事情集中到一起统一处理呢? PostCss 就是基于这样的理念出现的。 PostCss 类似于一个编译器,可以将…...
Linux 系统上安装 Docker 并进行配置
Docker 是一种开源的应用容器引擎,它允许开发者打包他们的应用以及应用的依赖包到一个可移植的容器中,然后发布到任何流行的 Linux 机器上,也可以实现虚拟化。容器是完全使用沙箱机制,相互之间不会有任何接口(类似 iPh…...
DeepSeek 等 AI 技术能否推动股市的繁荣?
在科技浪潮汹涌澎湃的当下,DeepSeek 等 AI 技术宛如闪耀在天际的耀眼星辰,吸引着全球各界的高度关注。面对这些前沿技术,投资者和市场参与者心中不禁泛起疑问:它们是否能成为推动股市繁荣的强劲动力?这一问题不仅左右着…...
【网络】应用层协议http
文章目录 1. 关于http协议2. 认识URL3. http协议请求与响应格式3.1 请求3.2 响应 3. http的常见方法4. 状态码4.1 常见状态码4.2 重定向 5. Cookie与Session5.1 Cookie5.1.1 认识Cookie5.1.2 设置Cookie5.1.3 Cookie的生命周期 5.2 Session 6. HTTP版本(了解&#x…...
大数据数仓实战项目(离线数仓+实时数仓)2
1.课程目标和课程内容介绍 2.数仓维度建模设计 3.数仓为什么要分层 4.数仓分层思想和作用 下面是阿里的一种分层方式 5.数仓中表的种类和同步策略 6.数仓中表字段介绍以及表关系梳理 订单表itcast_orders 订单明细表 itcast_order_goods 商品信息表 itcast_goods 店铺表 itcast…...
K8S认证|CKS题库+答案| 11. AppArmor
目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...
shell脚本--常见案例
1、自动备份文件或目录 2、批量重命名文件 3、查找并删除指定名称的文件: 4、批量删除文件 5、查找并替换文件内容 6、批量创建文件 7、创建文件夹并移动文件 8、在文件夹中查找文件...
渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet: https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...
《基于Apache Flink的流处理》笔记
思维导图 1-3 章 4-7章 8-11 章 参考资料 源码: https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...
MySQL用户和授权
开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务: test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...
【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统
目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!
简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求,并检查收到的响应。它以以下模式之一…...
Java编程之桥接模式
定义 桥接模式(Bridge Pattern)属于结构型设计模式,它的核心意图是将抽象部分与实现部分分离,使它们可以独立地变化。这种模式通过组合关系来替代继承关系,从而降低了抽象和实现这两个可变维度之间的耦合度。 用例子…...
MySQL 主从同步异常处理
阅读原文:https://www.xiaozaoshu.top/articles/mysql-m-s-update-pk MySQL 做双主,遇到的这个错误: Could not execute Update_rows event on table ... Error_code: 1032是 MySQL 主从复制时的经典错误之一,通常表示ÿ…...
