ccc-pytorch-LSTM(8)
文章目录
- 一、LSTM简介
- 二、LSTM中的核心结构
- 三、如何解决RNN中的梯度消失/爆炸问题
- 四、情感分类实战(google colab)
一、LSTM简介
LSTM(long short-term memory)长短期记忆网络,RNN的改进,克服了RNN中“记忆低下”的问题。通过“门”结构实现信息的添加和移除,通过记忆元将序列处理过程中的相关信息一直传递下去,经典结构如下:
二、LSTM中的核心结构
记忆元(memory cell)-长期记忆:
就像一个cell一样,信息通过这条只有少量线性交互的线传递。传递过程中有3种“门”结构可以告诉它该学习或者保存哪些信息
三个门结构-短期记忆
遗忘门:用来决定当前状态哪些信息被移除
输入门:决定放入哪些信息到细胞状态
输出门:决定哪些信息用于输出
细节注意:
- 新的细胞状态只需要遗忘门和输入门就可以更新,公式为:Ct=ft∗Ct−1+it∗Ct~C_t=f_t*C_{t-1}+i_t* \tilde{C_t}Ct=ft∗Ct−1+it∗Ct~(注意所有的∗*∗都表示Hadamard 乘积)
- 只有隐状态h_t会传递到输出层,记忆元完全属于内部信息,不可手动修改
三、如何解决RNN中的梯度消失/爆炸问题
解决是指很大程度上缓解,不是让它彻底消失。先解释RNN为什么会有这些问题:
∂Lt∂U=∑k=0t∂Lt∂Ot∂Ot∂St(∏j=k+1t∂Sj∂Sj−1)∂Sk∂U∂Lt∂W=∑k=0t∂Lt∂Ot∂Ot∂St(∏j=k+1t∂Sj∂Sj−1)∂Sk∂W\begin{aligned} &\frac{\partial L_t}{\partial U}= \sum_{k=0}^{t}\frac{\partial L_t}{\partial O_t}\frac{\partial O_t}{\partial S_t}(\prod_{j=k+1}^{t}\frac{\partial S_j}{\partial S_{j-1}})\frac{\partial S_k}{\partial U}\\&\frac{\partial L_t}{\partial W}= \sum_{k=0}^{t}\frac{\partial L_t}{\partial O_t}\frac{\partial O_t}{\partial S_t}(\prod_{j=k+1}^{t}\frac{\partial S_j}{\partial S_{j-1}})\frac{\partial S_k}{\partial W} \end{aligned} ∂U∂Lt=k=0∑t∂Ot∂Lt∂St∂Ot(j=k+1∏t∂Sj−1∂Sj)∂U∂Sk∂W∂Lt=k=0∑t∂Ot∂Lt∂St∂Ot(j=k+1∏t∂Sj−1∂Sj)∂W∂Sk(具体过程可以看这里)
上面是训练过程任意时刻更新W、U需要用到的求偏导的结果。实际使用会加上激活函数,通常为tanh、sigmoid等
tanh和其导数图像如下
sigmoid和其导数如下
这些激活函数的导数都比1要小,又因为∏j=k+1t∂Sj∂Sj−1=∏j=k+1ttanh′(Ws)\prod_{j=k+1}^{t}\frac{\partial S_j}{\partial S_{j-1}}=\prod_{j=k+1}^{t}tanh'(W_s)∏j=k+1t∂Sj−1∂Sj=∏j=k+1ttanh′(Ws),所以当WsW_sWs过小过大就会分别造成梯度消失和爆炸的问题,特别是过小。
LSTM如何缓解
由链式法则和三个门的公式可以得到:
∂Ct∂Ct−1=∂Ct∂ft∂ft∂ht−1∂ht−1∂Ct−1+∂Ct∂it∂it∂ht−1∂ht−1∂Ct−1+∂Ct∂Ct~∂Ct~∂ht−1∂ht−1∂Ct−1+∂Ct∂Ct−1=Ct−1σ′(⋅)Wf∗ot−1tanh′(Ct−1)+Ct~σ′(⋅)Wi∗ot−1tanh′(Ct−1)+ittanh′(⋅)Wc∗ot−1tanh′(Ct−1)+ft\begin{aligned} &\frac{\partial C_t}{\partial C_{t-1}}\\&=\frac{\partial C_t}{\partial f_t}\frac{\partial f_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_t}{\partial i_t}\frac{\partial i_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_t}{\partial \tilde{C_t}}\frac{\partial \tilde{C_t}}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_t}{\partial C_{t-1}}\\ &=C_{t-1}\sigma '(\cdot)W_f*o_{t-1}tanh'(C_{t-1})+\tilde{C_t}\sigma '(\cdot)W_i*o_{t-1}tanh'(C_{t-1})\\&+i_ttanh'(\cdot)W_c*o_{t-1}tanh'(C_{t-1})+f_t \end{aligned}∂Ct−1∂Ct=∂ft∂Ct∂ht−1∂ft∂Ct−1∂ht−1+∂it∂Ct∂ht−1∂it∂Ct−1∂ht−1+∂Ct~∂Ct∂ht−1∂Ct~∂Ct−1∂ht−1+∂Ct−1∂Ct=Ct−1σ′(⋅)Wf∗ot−1tanh′(Ct−1)+Ct~σ′(⋅)Wi∗ot−1tanh′(Ct−1)+ittanh′(⋅)Wc∗ot−1tanh′(Ct−1)+ft
- 由相乘变成了相加,不容易叠加
- sigmoid函数使单元间传递结果非常接近0或者1,使模型变成非线性,并且可以在学习过程中内部调整
四、情感分类实战(google colab)
环境和库:
!pip install torch
!pip install torchtext
!python -m spacy download en# K80 gpu for 12 hours
import torch
from torch import nn, optim
from torchtext import data, datasetsprint('GPU:', torch.cuda.is_available())torch.manual_seed(123)
加载数据集:
TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)print(train_data.examples[15].text)
print(train_data.examples[15].label)
网络结构:
class RNN(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim):""""""super(RNN, self).__init__()# [0-10001] => [100]self.embedding = nn.Embedding(vocab_size, embedding_dim)# [100] => [256]self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2, bidirectional=True, dropout=0.5)# [256*2] => [1]self.fc = nn.Linear(hidden_dim*2, 1)self.dropout = nn.Dropout(0.5)def forward(self, x):"""x: [seq_len, b] vs [b, 3, 28, 28]"""# [seq, b, 1] => [seq, b, 100]embedding = self.dropout(self.embedding(x))# output: [seq, b, hid_dim*2]# hidden/h: [num_layers*2, b, hid_dim]# cell/c: [num_layers*2, b, hid_di]output, (hidden, cell) = self.rnn(embedding)# [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)# [b, hid_dim*2] => [b, 1]hidden = self.dropout(hidden)out = self.fc(hidden)return out
Embedding
rnn = RNN(len(TEXT.vocab), 100, 256)pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
criteon = nn.BCEWithLogitsLoss().to(device)
rnn.to(device)
训练并测试
import numpy as npdef binary_acc(preds, y):"""get accuracy"""preds = torch.round(torch.sigmoid(preds))correct = torch.eq(preds, y).float()acc = correct.sum() / len(correct)return accdef train(rnn, iterator, optimizer, criteon):avg_acc = []rnn.train()for i, batch in enumerate(iterator):# [seq, b] => [b, 1] => [b]pred = rnn(batch.text).squeeze(1)# loss = criteon(pred, batch.label)acc = binary_acc(pred, batch.label).item()avg_acc.append(acc)optimizer.zero_grad()loss.backward()optimizer.step()if i%10 == 0:print(i, acc)avg_acc = np.array(avg_acc).mean()print('avg acc:', avg_acc)def eval(rnn, iterator, criteon):avg_acc = []rnn.eval()with torch.no_grad():for batch in iterator:# [b, 1] => [b]pred = rnn(batch.text).squeeze(1)#loss = criteon(pred, batch.label)acc = binary_acc(pred, batch.label).item()avg_acc.append(acc)avg_acc = np.array(avg_acc).mean()print('>>test:', avg_acc)for epoch in range(10):eval(rnn, test_iterator, criteon)train(rnn, train_iterator, optimizer, criteon)
最后得到的准确率结果如下:
完整colab链接:lstm
完整代码:
# -*- coding: utf-8 -*-
"""lstmAutomatically generated by Colaboratory.Original file is located athttps://colab.research.google.com/drive/1GX0Rqur8T45MSYhLU9MYWAbycfLH4-Fu
"""!pip install torch
!pip install torchtext
!python -m spacy download en# K80 gpu for 12 hours
import torch
from torch import nn, optim
from torchtext import data, datasetsprint('GPU:', torch.cuda.is_available())torch.manual_seed(123)TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)print('len of train data:', len(train_data))
print('len of test data:', len(test_data))print(train_data.examples[15].text)
print(train_data.examples[15].label)# word2vec, glove
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
LABEL.build_vocab(train_data)batchsz = 30
device = torch.device('cuda')
train_iterator, test_iterator = data.BucketIterator.splits((train_data, test_data),batch_size = batchsz,device=device
)class RNN(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim):""""""super(RNN, self).__init__()# [0-10001] => [100]self.embedding = nn.Embedding(vocab_size, embedding_dim)# [100] => [256]self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2, bidirectional=True, dropout=0.5)# [256*2] => [1]self.fc = nn.Linear(hidden_dim*2, 1)self.dropout = nn.Dropout(0.5)def forward(self, x):"""x: [seq_len, b] vs [b, 3, 28, 28]"""# [seq, b, 1] => [seq, b, 100]embedding = self.dropout(self.embedding(x))# output: [seq, b, hid_dim*2]# hidden/h: [num_layers*2, b, hid_dim]# cell/c: [num_layers*2, b, hid_di]output, (hidden, cell) = self.rnn(embedding)# [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)# [b, hid_dim*2] => [b, 1]hidden = self.dropout(hidden)out = self.fc(hidden)return outrnn = RNN(len(TEXT.vocab), 100, 256)pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
criteon = nn.BCEWithLogitsLoss().to(device)
rnn.to(device)import numpy as npdef binary_acc(preds, y):"""get accuracy"""preds = torch.round(torch.sigmoid(preds))correct = torch.eq(preds, y).float()acc = correct.sum() / len(correct)return accdef train(rnn, iterator, optimizer, criteon):avg_acc = []rnn.train()for i, batch in enumerate(iterator):# [seq, b] => [b, 1] => [b]pred = rnn(batch.text).squeeze(1)# loss = criteon(pred, batch.label)acc = binary_acc(pred, batch.label).item()avg_acc.append(acc)optimizer.zero_grad()loss.backward()optimizer.step()if i%10 == 0:print(i, acc)avg_acc = np.array(avg_acc).mean()print('avg acc:', avg_acc)def eval(rnn, iterator, criteon):avg_acc = []rnn.eval()with torch.no_grad():for batch in iterator:# [b, 1] => [b]pred = rnn(batch.text).squeeze(1)#loss = criteon(pred, batch.label)acc = binary_acc(pred, batch.label).item()avg_acc.append(acc)avg_acc = np.array(avg_acc).mean()print('>>test:', avg_acc)for epoch in range(10):eval(rnn, test_iterator, criteon)train(rnn, train_iterator, optimizer, criteon)
相关文章:
ccc-pytorch-LSTM(8)
文章目录一、LSTM简介二、LSTM中的核心结构三、如何解决RNN中的梯度消失/爆炸问题四、情感分类实战(google colab)一、LSTM简介 LSTM(long short-term memory)长短期记忆网络,RNN的改进,克服了RNN中“记忆…...
教育小程序开发解决方案
如今无论是国家还是家庭对于教育的重视性也越来越高,都希望自己的孩子能够赢在起跑线上,但是因为工作的缘故许多家长并没有过多的精力去辅导孩子学习,再加上许多家长对于教育也并没有经验与技巧。而这些都充分体现了正确教育的重要性。 那么一…...
动态规划之股票问题大总结
参考资料:代码随想录 (programmercarl.com)一、只能买卖一次题目链接:121. 买卖股票的最佳时机 - 力扣(LeetCode)算法思想:设置两种状态:0表示已持有股票,1表示未持有股票1.dp[i][0]表示第i天已持有股票时&…...
我来跟你讲vue进阶
一、组件(重点) 组件(Component)是 Vue.js 最强大的功能之一。 组件可以扩展 HTML 元素,封装可重用的代码。 组件系统让我们可以用独立可复用的小组件来构建大型应用,几乎任意类型的应用的界面都可以抽象…...
#847(Div3)E. Vlad and a Pair of Numbers
原题链接: E. Vlad and a Pair of Numbers 题意: 题目有公式 a⊕b(ab)/2xa ⊕ b (a b) / 2 xa⊕b(ab)/2x, 给你的是 xxx,让输出一组满足题目要求的 a,ba,ba,b,没有就输出−1-1…...
怎么把pdf转换成图片?这个方法你值得拥有
想要高效率的工作,除了需要大家合理安排时间之外,一些能够辅助高效工作的工具也是必不可少的。就拿要把一份pdf文件转换成若干图片来说,如果不知道方法,找不到合适的转换工具,那么想要完成这一任务,势必要花…...
go语言使用append向二维数组添加一维数组
var ans [][]int ans append(ans, append([]int(nil), nums...))(正确写法)需要注意的是,为了避免对原切片造成影响,代码在将当前排列追加到结果数组 ans 时,使用了 append(ans, append([]int(nil), nums…)) 的方式…...
YOLOv5训练大规模的遥感实例分割数据集 iSAID从切图到数据集制作及训练
最近想训练遥感实例分割,纵观博客发现较少相关 iSAID数据集的切分及数据集转换内容,思来想去应该在繁忙之中抽出时间写个详细的教程。 iSAID数据集下载 iSAID数据集链接 下载上述数据集。 百度网盘中的train和val中包含了实例和语义分割标签。 上述…...
js学习5(函数)
目录 定义函数 函数的特性 使用函数模拟类 模拟私有属性和方法 闭包 函数特性利用 箭头函数 定义函数 function func1(name) { console.log(name); } func2 function (name) { console.log(name); } func3 function func0(name) { console.log(name); } co…...
用Qt画一个仪表盘
关于Qt Qt是一个跨平台的C图形用户界面应用程序框架,通过使用Qt,可以快速开发出跨平台的多平台应用程序,包括Windows、Mac OS X、Linux和其他Unix系统。Qt提供了强大的图形操作界面(GUI)程序开发和移植的能力…...
linux 端口查询命令
任何知识都是用进废退,有段时间没摸linux,这大脑里的知识点仿佛全部消失了,就无语。 索性,再写一篇记录,加强一下记忆,下次需要就看自己的资料好了。lsof命令Linux端口查询命令可以通过lsof实现:…...
C语言函数: 字符串函数及模拟实现strtok()、strstr()、strerror()
C语言函数: 字符串函数及模拟实现strtok()、strstr()、strerror() strstr()函数: 作用:字符串查找。在一串字符串中,查找另一串字符串是否存在。 形参: str2在str1中寻找。返回值是char*的指针 原理:如果在str1中找到了str2&…...
【学习笔记】人工智能哲学研究:《心智、语言和机器》
关于人工智能哲学,我曾在这篇文章里 【脑洞大开】从哲学角度看人工智能:介绍徐英瑾的《心智、语言和机器》 做过介绍。图片来源:http://product.dangdang.com/29419969.html在我完成了一些人工智能相关的工作以后,我再来分享《心智…...
设计模式之门面模式(外观模式)
目录 1.模式定义 2.应用场景 2.1 电源总开关例子 2.2 股民炒股场景 编辑 3. 实例如下 4. 门面模式的优缺点 传送门: 项目中用到的责任链模式 给对象讲工厂模式,必须易懂易会 策略模式,工作中你用上了吗? 1.模式定…...
MySQL - 多表查询
目录1. 多表查询示例2. 多表查询分类2.1 等/非等值连接2.1.1 等值连接2.1.2非等值连接2.2 自然/非自然连接2.3 内/外连接2.3.1 内连接2.3.2 外连接3.UNION的使用3.1 合并查询结果3.1.1 UNION操作符3.1.2 UNION ALL操作符4. 7种JOIN操作5. join 多张表多表查询,也称为…...
自定义报表是什么?
自定义报表是指根据用户的需求和要求,自行设计和生成的报表。自定义报表可以根据用户的具体需求,选择需要的数据和指标,进行灵活的排列和组合,生成符合用户要求的报表。自定义报表可以帮助用户更好地了解业务情况,发现…...
windows安装docker-小白用【避坑】【伸手党福利】
目录实操开启 Hyper-V 和容器特性下载docker安装dockercmd中,使用命令测试是否成功报错解决办法:下载linux模拟器wsl:双击打开docker重新打开cmd,输入命令,成功显示sever和clinet实操 开启 Hyper-V 和容器特性 控制面…...
环形链表相关的练习
目录 一、相交链表 二、环形链表 三、环形链表 || 一、相交链表 给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点,返回 null 。 图示两个链表在节点 c1 开始相交: 题目数据…...
C++ 提示对话框
头文件 #include<iostream>#include<cstdio> using namespace std; 函数格式 MessageBox( HWND hWnd, LPCTSTR lpText, LPCTSTR lpCaption, UINT uType) 参数 hWnd :此参数代表消息框拥有的窗口。如果为NULL,则消息框没有拥有窗口。 lp…...
SprintBoot打包及profile文件配置
打成Jar包 需要添加打包组件将项目中的资源、配置、依赖包打到一个jar包中,可以使用maven的package;运行: java -jar xxx(jar包名) 操作步骤 第一步: 引入Spring Boot打包插件 <!--打包的插件--> <build><!--修改jar的名字--><fi…...
java面试-java集合
说说你如何选用集合? 需要键值对选用 map 接口下的集合,需要排序用 TreeMap, 不需要排序用 HashMap 不需要键值对仅存放元素则选择 Collection 下实现的接口,保证元素唯一使用 Set, 不需要则选用 List Collection 和 Collections 有什么区别…...
Node.js简介
客户端访问网页时向服务器端发送请求要访问服务器中的页面,服务器收到请求后向数据库中进行搜索,搜索到相关数据然后返回结果给客户端显示; 这个过程就类似于:客人(客户端)去饭馆(服务端&#…...
每天学一点之Lambda表达式
Lambda表达式 思想导入: 函数式编程思想: 在数学中,函数就是有输入量、输出量的一套计算方案,也就是“拿什么东西做什么事情”。编程中的函数,也有类似的概念,你调用我的时候,给我实参为形参赋…...
Raft分布式共识算法学习笔记
1. Raft算法 Raft算法属于Multi-Paxos算法,它是在Multi-Paxos思想的基础上,做了一些简化和限制,比如增加了日志必须是连续的,只支持领导者、跟随者和候选人三种状态,在理解和算法实现上都相对容易许多 从本质上说&am…...
中介者模式
介绍 Java中介者模式(Mediator Pattern)是一种行为设计模式,它可以降低多个对象之间的耦合性,通过一个中介者对象来协调这些对象的交互. 在中介者模式中,多个对象之间的交互不是直接进行的,而是通过一个中介者对象来进行的.这个中介者对象封装了对象之间的交互逻辑,每个对象只…...
Kaggle赛题解析:Google手语识别
文章目录一、比赛前言信息二、比赛背景三、比赛任务四、评价指标五、数据描述六、解题思路一、比赛前言信息 比赛名称:Google - Isolated Sign Language Recognition 中文名称:帮助用户从PopSign游戏学习美国手语 比赛链接:https://www.ka…...
什么是ChatGPT?
目录前言一、什么是GPT?二、什么是ChatGPT?三、ChatGPT应用场景四、ChatGPT未来展望五、OpenAI介绍前言 3月3号,早上6:30就有人发消息给我,来问我有关GPT API的事件。 那是因为3月2号,OpenAI 发布了ChatGPT 3.5的开放…...
深入理解Zookeeper的ZAB协议
ZAB是什么ZAB(Zookeeper Atomic Broadcast):Zookeeper原子广播ZAB是为了保证Zookeeper数据一致性而产生的算法(指的是Zookeeper集群模式)。它不仅能解决正常情况下的数据一致性问题,还可以保证主节点发生宕…...
opencv-图像几何处理
缩放 缩放只是调整图像的大小。为此,opencv提供了一个cv2.resize()函数,可以手动指定图像大小,也可以指定缩放因子。你可以使用任意一种方法调整图像的大小: import cv2 from matplotlib import pyplot as pltlogo cv2.imread(…...
[前端笔记030]vue之hello、数据绑定、MVVM、数据代理、事件处理、计算属性和监视属性
前言 本笔记参考视频,尚硅谷:BV1Zy4y1K7SH p1 -p25官网文档完善,本文只做笔记使用,官网下载vue的开发版和生产版或者使用CDN,并去谷歌商店下载开发插件 简介 组件化模式,提高代码复用率,更好维护声明式编…...
有专门做食品的网站吗/友情链接交换平台免费
协方差(covariance): 该公式可以有如下理解:如果有X,Y两个变量,每个时刻的“X值与其均值之差”乘以“Y值与其均值之差”得到一个乘积,再对这每时刻的乘积求和并求出均值(其实是求“期望”,但就不引申太多新概念了&…...
中国建设服务信息官网/网站优化有哪些类型
据悉华为即将在今天下午发布2019年的业绩,2019年前三季度它取得了超过两成的增长,但是随着美国因素的影响加大,2019年四季度它的业绩或将出现下滑。华为公布的业绩显示,2019年上半年营收同比增长23.2%,2019年前三季度营…...
北京 建设官方网站/百度关键词快排
Cocos-2d中,涉及到4种坐标系: GL坐标系Cocos2D以OpenglES为图形库,所以它使用OpenglES坐标系。GL坐标系原点在屏幕左下角,x轴向右,y轴向上。 屏幕坐标系苹果的Quarze2D使用的是不同的坐标系统,原点在屏幕左…...
酒店网站建设公司/海外推广营销系统
Hive的几种常见的数据导入方式这里介绍四种:(1)、从本地文件系统中导入数据到Hive表;(2)、从HDFS上导入数据到Hive表;(3)、从别的表中查询出相应的数据并导入到Hive表中&…...
b2b建设网站公司/小程序设计
在网络编程中,出于节约带宽或者编码的需要,通常需要以原生方式处理long和int,而不是转换为string。public class ByteOrderUtils {public static byte[] int2byte(int res) {byte[] targets new byte[4];targets[3] (byte) (res & 0xff…...
网站建设规划图/网络推广引流有哪些渠道
一、个人见解(直播难与易) 直播难:个人认为要想把直播从零开始做出来,绝对是牛逼中的牛逼,大牛中的大牛,因为直播中运用到的技术难点非常之多,视频/音频处理,图形处理,视…...