游戏AI的创造思路-技术基础-深度学习(3)
继续填坑,本篇介绍深度学习中的长短期记忆网络~~~~
目录
3.3. 长短期记忆网络(LSTM)
3.3.1. 什么是长短期记忆网络
3.3.2. 形成过程与运行原理
3.3.2.1. 细胞状态与门结构
3.3.2.2. 遗忘门
3.3.2.3. 输入门
3.3.2.4. 细胞状态更新
3.3.2.5. 输出门
3.3.2.6. 以上各步骤的示例代码
3.3.3. 优缺点
3.3.4. 存在的问题及解决方法
3.3.5. 示例代码
3.3. 长短期记忆网络(LSTM)
3.3.1. 什么是长短期记忆网络
长短期记忆网络(LSTM,Long Short-Term Memory)算法是一种特殊的循环神经网络(RNN),它旨在解决传统RNN在处理长序列数据时遇到的梯度消失和梯度爆炸问题,从而更有效地学习序列中的长期依赖关系。
- 为了最小化训练误差,通常使用梯度下降法,如应用时序性倒传递算法,来依据错误修改每次的权重。此外,LSTM有多种变体,其中一个重要的版本是门控循环单元(GRU)。
- LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件。其表现通常比时间递归神经网络及隐马尔科夫模型(HMM)更好。例如,在不分段连续手写识别上,LSTM模型曾赢得过ICDAR手写识别比赛冠军。此外,LSTM还广泛应用于自主语音识别,并在2013年使用TIMIT自然演讲数据库达成了17.7%的错误率纪录。
- LSTM的成功在很大程度上促进了深度学习和人工智能领域的发展。尽管近年来出现了新的模型结构,如基于注意力机制的Transformer,但LSTM仍然是许多序列建模任务的可靠选择。随着时间的推移,LSTM被广泛应用于自然语言处理、语音识别、文本生成、视频分析等多个领域
3.3.2. 形成过程与运行原理
LSTM通过引入“门”结构和“细胞状态”来更好地捕捉序列中的长期依赖关系。(通过借鉴脑神经学的知识来组建序列中的长期依赖关系)
3.3.2.1. 细胞状态与门结构
LSTM的核心是细胞状态,它像一条传送带,在整个链上运行,只有一些小的线性操作作用其上,信息在上面流传保持不变会很容易。LSTM通过精心设计的门结构来去除或增加信息到细胞状态,这些门结构包括遗忘门、输入门和输出门。
3.3.2.2. 遗忘门
决定从细胞状态中丢弃什么信息。它查看当前的输入和前一个时间步的隐藏状态,并为细胞状态中的每个数字输出一个在0到1之间的数字,1表示“完全保留”,0表示“完全舍弃”。
遗忘门决定了从上一个时间步的细胞状态中丢弃哪些信息。其计算公式为:
其中,表示输入门在时刻的值,是时刻 ( t ) 的输入,是前一个时刻的隐藏状态,和 是对应的权重矩阵,而是偏置项。函数表示sigmoid激活函数。
3.3.2.3. 输入门
决定什么新信息将被存储在细胞状态中。这包括两部分,一部分是输入门决定我们将更新哪些部分,另一部分是tanh层创建一个新的候选值向量,这个向量可能会被添加到细胞状态中。
类似地,表示遗忘门在时刻的值,其他符号的含义与输入门公式中的相同,只是权重和偏置项是针对遗忘门的。
3.3.2.4. 细胞状态更新
首先,旧细胞状态与遗忘门相乘,丢弃掉需要丢弃的信息。然后,将输入门的输出与tanh层的输出相乘,得出新的候选细胞状态。最后,将这两个值相加,形成新的细胞状态。
- 旧细胞状态与遗忘门相乘:
这里,表示经过遗忘门处理后的旧细胞状态,是前一个时刻的细胞状态, 是遗忘门在时刻的输出,而表示逐元素相乘(Hadamard乘积)。这一步的目的是丢弃掉不需要的信息。
- 计算新的候选细胞状态:
其中,是新的候选细胞状态,是时刻 的输入, 是前一个时刻的隐藏状态, 和 是对应的权重矩阵,是偏置项。函数 是双曲正切激活函数,它将输入值压缩到 ( -1 ) 到 ( 1 ) 的范围内。
- 将候选细胞状态与输入门相乘:
这里,是输入门在时刻的输出,表示逐元素相乘。这一步的目的是根据输入门的选择来决定哪些新的信息被加入到细胞状态中。
- 更新细胞状态:
最终,新的细胞状态是经过遗忘门处理后的旧细胞状态 与经过输入门处理后的新候选细胞状态 之和。这一步完成了细胞状态的更新,使得LSTM能够记住长期依赖关系。
3.3.2.5. 输出门
基于细胞状态来决定输出什么。首先,运行一个sigmoid层来确定细胞状态的哪个部分将输出,然后将细胞状态通过tanh进行处理(得到一个在-1到1之间的值),并将其与sigmoid门的输出相乘,最终得到输出。
在这里,是输出门在时刻的值,其他参数和符号的意义与前面公式中的一致,但针对输出门。
3.3.2.6. 以上各步骤的示例代码
Python代码示例
import numpy as np def sigmoid(x): return 1 / (1 + np.exp(-x)) def tanh(x): return np.tanh(x) # LSTM Cell 参数初始化
input_size = 10
hidden_size = 20 Wf = np.random.randn(hidden_size, hidden_size + input_size) # 遗忘门权重
Wi = np.random.randn(hidden_size, hidden_size + input_size) # 输入门权重
Wc = np.random.randn(hidden_size, hidden_size + input_size) # 候选细胞状态权重
Wo = np.random.randn(hidden_size, hidden_size + input_size) # 输出门权重 # LSTM Cell 前向传播
def lstm_cell_forward(xt, ht_prev, ct_prev, Wf, Wi, Wc, Wo): # 拼接前一个隐藏状态和当前输入 concat = np.concatenate((ht_prev, xt), axis=0) # 计算遗忘门 ft = sigmoid(np.dot(Wf, concat)) # 计算输入门 it = sigmoid(np.dot(Wi, concat)) # 计算候选细胞状态 cct = tanh(np.dot(Wc, concat)) # 细胞状态更新 ct = ft * ct_prev + it * cct # 计算输出门 ot = sigmoid(np.dot(Wo, concat)) # 计算隐藏状态 ht = ot * tanh(ct) return ht, ct # 示例使用
xt = np.random.randn(input_size) # 当前输入
ht_prev = np.zeros(hidden_size) # 前一个隐藏状态
ct_prev = np.zeros(hidden_size) # 前一个细胞状态 ht, ct = lstm_cell_forward(xt, ht_prev, ct_prev, Wf, Wi, Wc, Wo)
C++代码示例
#include <Eigen/Dense>
#include <cmath> using namespace Eigen; // 激活函数
double sigmoid(double x) { return 1.0 / (1.0 + std::exp(-x));
} double tanh(double x) { return std::tanh(x);
} // LSTM单元前向传播
void LSTMCellForward(const VectorXd& xt, const VectorXd& ht_prev, const VectorXd& ct_prev, const MatrixXd& Wf, const MatrixXd& Wi, const MatrixXd& Wc, const MatrixXd& Wo, VectorXd& ht, VectorXd& ct) { int input_size = xt.size(); int hidden_size = ht_prev.size(); VectorXd concat(input_size + hidden_size); concat << ht_prev, xt; // 计算遗忘门 VectorXd ft = concat.unaryExpr([](double elem) { return sigmoid(elem); }) * Wf.transpose(); // 计算输入门 VectorXd it = concat.unaryExpr([](double elem) { return sigmoid(elem); }) * Wi.transpose(); // 计算候选细胞状态 VectorXd cct = concat.unaryExpr([](double elem) { return tanh(elem); }) * Wc.transpose(); // 细胞状态更新 ct = ft.array() * ct_prev.array() + it.array() * cct.array(); // 计算输出门 VectorXd ot = concat.unaryExpr([](double elem) { return sigmoid(elem); }) * Wo.transpose(); // 计算隐藏状态 ht = ot.array() * ct.array().unaryExpr([](double elem) { return tanh(elem); });
} int main() { int input_size = 10; int hidden_size = 20; MatrixXd Wf = MatrixXd::Random(hidden_size, hidden_size + input_size); // 遗忘门权重 MatrixXd Wi = MatrixXd::Random(hidden_size, hidden_size + input_size); // 输入门权重 MatrixXd Wc = MatrixXd::Random(hidden_size, hidden_size + input_size); // 候选细胞状态权重 MatrixXd Wo = MatrixXd::Random(hidden_size, hidden_size + input_size); // 输出门权重 VectorXd xt = VectorXd::Random(input_size); // 当前输入 VectorXd ht_prev = VectorXd::Zero(hidden_size); // 前一个隐藏状态 VectorXd ct_prev = VectorXd::Zero(hidden_size); // 前一个细胞状态 VectorXd ht(hidden_size), ct(hidden_size); LSTMCellForward(xt, ht_prev, ct_prev, Wf, Wi, Wc, Wo, ht, ct); // Do something with ht and ct... return 0;
}
这些代码是简化示例,实际应用中LSTM的实现会更加复杂,包括多个时间步的迭代、批处理支持、梯度计算和权重更新等。
在生产环境中,建议使用成熟的深度学习框架如TensorFlow或PyTorch来实现LSTM哦。
3.3.3. 优缺点
优点:
- 能够有效地解决传统RNN中的梯度消失和梯度爆炸问题。
- 能够更好地捕捉序列中的长期依赖关系。
- 在处理长序列数据时具有优势。
缺点:
- LSTM模型相对复杂,计算成本较高。
- 对于输入序列长度较长时,可能会出现过拟合现象,导致泛化能力下降。
3.3.4. 存在的问题及解决方法
过拟合问题:可以通过正则化、dropout等技术来减轻过拟合现象。
无法有效捕捉时间上下文关系:可以引入双向LSTM(Bidirectional LSTM)结构来提高对于时间上下文之间关系的建模能力。
对输入数据序列顺序敏感:在实际应用中,可以通过数据增强、序列颠倒等方法来减轻模型对输入数据序列顺序的敏感性。
3.3.5. 示例代码
Python代码
由于篇幅限制,这里提供一个简化的Python示例,使用PyTorch库实现LSTM:
import torch
import torch.nn as nn # 定义一个简单的LSTM模型
class SimpleLSTM(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleLSTM, self).__init__() self.hidden_size = hidden_size self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x, hidden): lstm_out, hidden = self.lstm(x, hidden) output = self.fc(lstm_out[:, -1, :]) # 取最后一个时间步的输出进行分类 return output, hidden def init_hidden(self, batch_size): return (torch.zeros(1, batch_size, self.hidden_size), torch.zeros(1, batch_size, self.hidden_size)) # 模型参数
input_size = 10
hidden_size = 20
output_size = 2
batch_size = 1
sequence_length = 5 # 创建模型实例
model = SimpleLSTM(input_size, hidden_size, output_size) # 创建虚拟输入数据和初始隐藏状态
x = torch.randn(batch_size, sequence_length, input_size)
hidden = model.init_hidden(batch_size) # 前向传播
output, hidden = model(x, hidden)
print(output)
C++代码
在C++中使用LSTM,我们通常会借助PyTorch的C++ API,也称为LibTorch。以下是一个简单的示例:
#include <torch/script.h> // 包含TorchScript的头文件
#include <iostream> int main() { // 加载一个预先训练好的LSTM模型(这里假设你已经有一个用PyTorch训练的模型并导出了TorchScript) torch::jit::script::Module module; try { module = torch::jit::load("lstm_model.pt"); // 加载模型 } catch (const c10::Error& e) { std::cerr << "模型加载错误\n"; return -1; } // 创建一个输入张量,假设输入大小为[1, 5, 10](batch_size, sequence_length, input_size) torch::Tensor input = torch::randn({1, 5, 10}); // 执行模型前向传播 std::vector<torch::jit::IValue> inputs; inputs.push_back(input); torch::Tensor output = module.forward(inputs).toTensor(); std::cout << output << std::endl; return 0;
}
请注意,C++ 示例中的模型需要是预先训练好并导出为TorchScript的模型。TorchScript是PyTorch的一个子集,允许模型在没有Python运行时的环境中执行。
在C++中直接使用LSTM而不依赖预先训练的模型会更复杂,因为你需要手动实现LSTM的所有细节。这通常不是推荐的做法,除非你有特定的性能要求或需要深度定制LSTM的行为。
在大多数情况下,使用PyTorch等高级库会更加方便和高效。
相关文章:
游戏AI的创造思路-技术基础-深度学习(3)
继续填坑,本篇介绍深度学习中的长短期记忆网络~~~~ 目录 3.3. 长短期记忆网络(LSTM) 3.3.1. 什么是长短期记忆网络 3.3.2. 形成过程与运行原理 3.3.2.1. 细胞状态与门结构 3.3.2.2. 遗忘门 3.3.2.3. 输入门 3.3.2.4. 细胞状态更新 3.…...
贪心算法练习题(2024/6/24)
1K 次取反后最大化的数组和 给你一个整数数组 nums 和一个整数 k ,按以下方法修改该数组: 选择某个下标 i 并将 nums[i] 替换为 -nums[i] 。 重复这个过程恰好 k 次。可以多次选择同一个下标 i 。 以这种方式修改数组后,返回数组 可能的最…...
大厂程序员上班猝死成常态?
大家好,我是瑶琴呀,拥有一头黑长直秀发的女程序员。 近日,连续看到大厂程序员猝死、低血糖晕倒的新闻,同为程序员感到很难受。互联网加班成常态这是既定事实,尤其在这个内卷严重、经济不景气的环境中,加班…...
深度学习 —— 1.单一神经元
深度学习初级课程 1.单一神经元2.深度神经网络3.随机梯度下降法4.过拟合和欠拟合5.剪枝、批量标准化6.二分类 前言 本套课程仍为 kaggle 课程《Intro to Deep Learning》,仍按之前《机器学习》系列课程模式进行。前一系列《Keras入门教程》内容,与本系列…...
Android 12.0 通知发送过程源码分析-Framework
以下NotificationManagerService简称 NMS 1. 通知的发送: NotificationManager.notify(int id, Notification notification) 开始. 源码路径: /frameworks/base/core/java/android/app/NotificationManager.java/***发布通知以显示在状态栏中。 如果通知带有* 相同的 ID 已被…...
提防远程攻击:了解正向 Shell 和反向 Shell 确保服务器安全
前言 在当今网络安全形势日益复杂的环境中,了解正向 Shell 和反向 Shell 的工作原理和使用场景,对于保护你的服务器免受远程攻击至关重要。本文不仅深入解析这两种常见的远程控制技术,还将提供有效的防护建议,帮助你提升服务器的…...
RabbitMQ中CorrelationData 与DeliveryTag的区别
在RabbitMQ中,CorrelationData是一个用于封装业务ID信息的类,它主要在消息确认机制中发挥作用。以下是关于CorrelationData在RabbitMQ中的详细作用: 封装业务ID信息: 当发送消息时,可以将业务ID信息封装在Correlation…...
数据恢复篇:如何在Android上恢复删除的短信
如果您不小心删除了Android设备上的短信并想要检索它们,则可以尝试以下方法: 如何在Android上恢复删除的短信 检查您的备份: 如果您之前备份了Android设备,则可以从备份中恢复已删除的短信。检查您设备的内部存储空间或 Google 云…...
花了大几万的踩坑经验!宠物空气净化器哪个牌子好:希喂、小米、有哈PK
我的闺蜜最近向我大吐苦水,自从家里养了猫之后,她发现家里的空气质量大不如前。宠物的浮毛和排泄物的气味在空气中飘散,让她非常怀念以前没有养猫时家里清新的呼吸环境。她觉得这些漂浮的毛发和异味大大降低了居家的舒适度。 还引起了身体上…...
查普曼大学团队使用惯性动捕系统制作动画短片
道奇电影和媒体艺术学院是查普曼大学的知名学院,同时也是美国首屈一指的电影学院之一,拥有一流电影制作工作室。 最近,道奇学院的一个学生制作团队接手了一个项目,该项目要求使用真人动作、视觉效果以及真人演员和CG角色之间的互动…...
vue 代理
一、常用的发送一个ajax请求: 1、xhr new XMLHttpRequest(),真正开发中不常用 2、jq,jq主要功能是获取dom,周边才是请求接口 3、axios(大名鼎鼎的) axios.get("url").then(response>{},error>{} )4、…...
[leetcode]24-game
. - 力扣(LeetCode) class Solution { public:static constexpr int TARGET 24;static constexpr double EPSILON 1e-6;static constexpr int ADD 0, MULTIPLY 1, SUBTRACT 2, DIVIDE 3;bool judgePoint24(vector<int> &nums) {vector&l…...
网络爬虫的原理
网络爬虫的原理 网络爬虫,作为信息检索和数据分析的重要工具,其原理的核心在于模拟人类浏览网页的行为,通过自动化的方式从互联网上收集所需的数据。在了解了网络爬虫的基本原理后,我们可以进一步探讨其在实际应用中的工作机制以…...
游戏AI的创造思路-技术基础-机器学习(2)
本篇存在大量的公式,数学不好的孩子们要开始恶补数学了,尤其是统计学和回归方程类的内容。 小伙伴们量力而行~~~~~ 游戏呢,其实最早就是数学家、元祖程序员编写的数学游戏,一脉相承传承至今,囊括了更多的设计师、美术…...
【深度学习】记录为什么没有调用GPU
排查CLIP为什么评测推理没有调用GPU,主要是这个代码:https://github.com/OFA-Sys/Chinese-CLIP/blob/master/cn_clip/eval/extract_features.py 第一次认为:因为model并没有to.cuda()。 但是又发现,model.cuda(args.gpu) # 已经加…...
vite 创建vue3项目 集成 ESLint、Prettier、Sass等
在网上找了一大堆vue3脚手架的东西,无非就是vite或者vue-cli,在vue2时代,vue-cli用的人挺多的,也很好用,然而vue3大多是和vite搭配搭建的,而且个人感觉vite这个脚手架并没有那么的好用,搭建项目时只能做两个…...
计算机系统基础知识(上)
目录 计算机系统的概述 计算机的硬件 处理器 存储器 总线 接口 外部设备 计算机的软件 操作系统 数据库 文件系统 计算机系统的概述 如图所示计算机系统分为软件和硬件:硬件包括:输入输出设备、存储器,处理器 软件则包括系统软件和…...
[深度学习]循环神经网络RNN
RNN(Recurrent Neural Network,即循环神经网络)是一类用于处理序列数据的神经网络,广泛应用于自然语言处理(NLP)、时间序列预测、语音识别等领域。与传统的前馈神经网络不同,RNN具有循环结构&am…...
【C++:list】
list概念 list是一个带头的双向循环链表,双向循环链表的特色:每一个节点拥有两 个指针进行维护,俩指针分别为prev和next,prev指该节点的前一个节点,next为该节点的后一个节点 list的底层实现中为什么对迭代器单独写一个结构体进行…...
解锁 Apple M1/M2 上的深度学习力量:安装 TensorFlow 完全指南
前言 随着 Apple M1 和 M2 芯片的问世,苹果重新定义了笔记本电脑和台式机的性能标准。这些强大的芯片不仅适用于日常任务,还能处理复杂的机器学习和深度学习工作负载。本文将详细介绍如何在 Apple M1 或 M2 芯片上安装和配置 TensorFlow,助你…...
Apache Iceberg:现代数据湖存储格式的未来
Apache Iceberg 是一个开源的表格式,用于在分布式数据湖中管理大规模数据集。它由 Netflix 开发,并捐赠给 Apache 基金会。Iceberg 的设计目标是解决传统数据湖存储格式(如 Apache Hive 和 Apache Parquet)在大规模数据管理中的一…...
【离散数学·图论】(复习)
一、基本概念 1.一些基本术语: 2.点u,v邻接(或相邻): 边e称为关联顶点u和v,or e连接u和v; 3.G(V,E)中,顶点v所有邻居的集合:N(v), 成为v的邻域。 4.度 : deg(v) 5.悬挂点:度为1的…...
【ONLYOFFICE震撼8.1】ONLYOFFICE8.1版本桌面编辑器测评
随着远程工作的普及和数字化办公的发展,越来越多的人开始寻找一款具有强大功能和便捷使用的办公软件。在这个时候,ONLYOFFICE 8.1应运而生,成为了许多用户的新选择。ONLYOFFICE 8.1是一种办公套件软件,它提供了文档处理、电子表格…...
Shell 脚本编程保姆级教程(上)
一、运行第一个 Shell 脚本 1.1 Shell 脚本 Shell 脚本(shell script),是一种为 shell 编写的脚本程序。 业界所说的 shell 通常都是指 shell 脚本,但读者朋友要知道,shell 和 shell script 是两个不同的概念。 由…...
凸优化相关文章汇总
深度学习/机器学习入门基础数学知识整理(三):凸优化,Hessian,牛顿法_深度学习和凸优化-CSDN博客 深度学习/机器学习入门基础数学知识整理(四):拟牛顿法、BFGS、L-BFGS、DFP、共轭梯…...
Java鲜花下单预约系统源码小程序源码
让美好触手可及 🌸一、开启鲜花新篇章 在繁忙的都市生活中,我们总是渴望那一抹清新与美好。鲜花,作为大自然的馈赠,总能给我们带来无尽的惊喜与愉悦。但你是否曾因为工作繁忙、时间紧张而错过了亲自挑选鲜花的机会?今…...
网络变压器和RJ45接线的方法
网络变压器在以太网硬件电路设计中扮演着重要的角色,它主要用于信号电平耦合、隔离外部干扰、实现阻抗匹配以及增加传输距离。而RJ45接口则是以太网连接的标准化接口,它提供了与网络电缆的连接点。 网络变压器与RJ45的接线方法通常遵循以下步骤…...
Matlab/simulink三段式电流保护
电流1段仿真波形如下所示 电流2段仿真波形如下所示 电流3段仿真波形如下所示...
OOXML入门学习
进入-飞入 <par> <!-- 这是一个并行动画序列的开始。"par"代表并行,意味着在这个标签内的所有动画将同时开始。 --><cTn id"5" presetID"2" presetClass"entr" presetSubtype"4" fill"hold&…...
k8s集群node节点加入失败
出现这种情况: [preflight] FYI: You can look at this config file with kubectl -n kube-system get cm kubeadm-config -o yaml [kubelet-start] Writing kubelet configuration to file "/var/lib/kubelet/config.yaml" [kubelet-start] Writing kub…...
微信可以做网站吗/全国最新的疫情数据
Excel:常见的错误信息以及解决方法(转)在Excel中建立了一张工作表,往往希望所有数据都是正确的。但是,基本上这是不可能的!而偏偏计算机是个“较真”的家伙,如果你不改正错误,它会就此罢工,不再…...
怎么在服务器中安装WordPress/南京seo优化推广
我们在做接口测试时,除了常见的http接口,还有一种比较多见,就是socket接口,今天讲解下怎么用Python进行websocket接口测试。现在大多数用的都是websocket,那我们就先来安装一下websocket的安装包。pip install websock…...
包头网站建设易通/游戏推广论坛
一、函数 1.函数初始:函数就是封装一个功能 2.函数名,函数体,关键字,函数的返回值 def 关键字,定义一个函数 my_len 函数名书写规则和变量一样 def 与函数名中间一个空格 函数名()&…...
上海网站建设公司哪家好?/广告优化
源码介绍请注意:该源码来源网友分享,素材虎不提供技术支持,没有技术能力的小白勿拍。(如需安装服务费用另算)这套TPfang房产今天看到了就拿出来测试无奈里面没教程只能自己摸索,搞来搞去吧页面显示出来了确找不到后台地址…...
网站上传/企业营销策划书模板
一、先百度搜索 node 找 Download | Node.js 二、下载 Node 根据你自己的操作系统下载 三、安装 node (注意:如果安装失败,使用管理员身份打开安装) 四、安装成功之后 使用 WIN R 输入 node 测试 五、安装 cnpm 替换 npm …...
惠州响应式网站建设公司/最新热搜新闻事件
本文将介绍了最先进的深度学习优化方法,帮助神经网络训练得更快,表现得更好。有很多个不同形式的优化器,这里我们只找最基础、最常用、最有效和最新的来介绍。 优化器 首先,让我们定义优化。当我们训练我们的模型以使其表现更好…...