深度学习笔记:不同的反向传播迭代方法
1 随机梯度下降法SGD
随机梯度下降法每次迭代取梯度下降最大的方向更新。这一方法实现简单,但是在很多函数中,梯度下降的方向不一定指向函数最低点,这使得梯度下降呈现“之”字形,其效率较低
class SGD:"""随机梯度下降法(Stochastic Gradient Descent)"""def __init__(self, lr=0.01):self.lr = lrdef update(self, params, grads):for key in params.keys():params[key] -= self.lr * grads[key]
2 Momentum
momentum即动量。该方法设置变量v代表梯度下降的速度,其中dL/dW(梯度值)代表改变速度的“受力”,而α则作为“阻力”,限制v变化。该方法进行梯度下降可以类比一个小球在三维平面上滚动。
在下面的示例中,可以看到虽然迭代方向还是呈“之”字形,但是在x方向,虽然梯度较小,但是由于受力始终在一个方向,速度逐渐加快。在y方向,虽然梯度大,但上下受力相反,使得y方向不会有很大偏移
class Momentum:"""Momentum SGD"""def __init__(self, lr=0.01, momentum=0.9):self.lr = lrself.momentum = momentumself.v = Nonedef update(self, params, grads):if self.v is None:self.v = {}for key, val in params.items(): self.v[key] = np.zeros_like(val)for key in params.keys():self.v[key] = self.momentum*self.v[key] - self.lr*grads[key] params[key] += self.v[key]
在程序里一开始v设为None,在第一次调用update时会将v更新为和各权重形状一样的0矩阵
3 AdaGrad
AdaGrad的思路是根据上一轮迭代的变化量动态调整每一个权重的学习率。一个权重在迭代中变化量越大,其在下一轮中学习率就会减少更多。
在公式中,我们用h记录过去所有梯度的平方和(⊙代表矩阵元素相乘),在更新权重时之前变化较大的权重值变化量会变小。
由于h是不断累加的平方和,如果学习一直持续下去,W更新率会不断趋于0,要改善这一问题可以参考RMSProp,该方法会对较早更新的梯度逐渐“遗忘”,而更多反应新更新的状态
AdaGrad
class AdaGrad:"""AdaGrad"""def __init__(self, lr=0.01):self.lr = lrself.h = Nonedef update(self, params, grads):if self.h is None:self.h = {}for key, val in params.items():self.h[key] = np.zeros_like(val)for key in params.keys():self.h[key] += grads[key] * grads[key]params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)
在这里注意我们在h的每个元素中加上了微小的1e-7,这是为了防止h中有元素为0时,作为除数会报错。
RMSProp
class RMSprop:"""RMSprop"""def __init__(self, lr=0.01, decay_rate = 0.99):self.lr = lrself.decay_rate = decay_rateself.h = Nonedef update(self, params, grads):if self.h is None:self.h = {}for key, val in params.items():self.h[key] = np.zeros_like(val)for key in params.keys():self.h[key] *= self.decay_rateself.h[key] += (1 - self.decay_rate) * grads[key] * grads[key]params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)
RMSProp的方法和AdaGrad类似,除了每一轮迭代时会将h乘上一个decay_rate(大小在0-1)以减小之前梯度对h的影响
如图,一开始由于y方向梯度变化大,所以更新快,但因此y方向上学习率也减小较快,使得网络在后期逐渐沿x方向更新
Adam
Adam类似于momentum和AdaGrad两种方法的结合,其具体原理较为复杂,可以找原论文http://arxiv.org/abs/1412.6980v8
class Adam:"""Adam (http://arxiv.org/abs/1412.6980v8)"""def __init__(self, lr=0.001, beta1=0.9, beta2=0.999):self.lr = lrself.beta1 = beta1self.beta2 = beta2self.iter = 0self.m = Noneself.v = Nonedef update(self, params, grads):if self.m is None:self.m, self.v = {}, {}for key, val in params.items():self.m[key] = np.zeros_like(val)self.v[key] = np.zeros_like(val)self.iter += 1lr_t = self.lr * np.sqrt(1.0 - self.beta2**self.iter) / (1.0 - self.beta1**self.iter) for key in params.keys():#self.m[key] = self.beta1*self.m[key] + (1-self.beta1)*grads[key]#self.v[key] = self.beta2*self.v[key] + (1-self.beta2)*(grads[key]**2)self.m[key] += (1 - self.beta1) * (grads[key] - self.m[key])self.v[key] += (1 - self.beta2) * (grads[key]**2 - self.v[key])params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key]) + 1e-7)#unbias_m += (1 - self.beta1) * (grads[key] - self.m[key]) # correct bias#unbisa_b += (1 - self.beta2) * (grads[key]*grads[key] - self.v[key]) # correct bias#params[key] += self.lr * unbias_m / (np.sqrt(unbisa_b) + 1e-7)
利用mnist数据集对几种训练方式进行比较:
在该测试程序中,我们使用5层神经网络,每层神经元个数100。利用SGD, momentum, AdaGrad, Adam, RMSProp分别进行2000次迭代,并比较最终各网络的总损失
# coding: utf-8
import os
import sys
sys.path.append("D:\AI learning source code") # 为了导入父目录的文件而进行的设定
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.util import smooth_curve
from common.multi_layer_net import MultiLayerNet
from common.optimizer import *# 0:读入MNIST数据==========
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)train_size = x_train.shape[0]
batch_size = 128
max_iterations = 2000# 1:进行实验的设置==========
optimizers = {}
optimizers['SGD'] = SGD()
optimizers['Momentum'] = Momentum()
optimizers['AdaGrad'] = AdaGrad()
optimizers['Adam'] = Adam()
optimizers['RMSprop'] = RMSprop()networks = {}
train_loss = {}
for key in optimizers.keys():networks[key] = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100],output_size=10)train_loss[key] = [] # 2:开始训练==========
for i in range(max_iterations):batch_mask = np.random.choice(train_size, batch_size)x_batch = x_train[batch_mask]t_batch = t_train[batch_mask]for key in optimizers.keys():grads = networks[key].gradient(x_batch, t_batch)optimizers[key].update(networks[key].params, grads)loss = networks[key].loss(x_batch, t_batch)train_loss[key].append(loss)if i % 100 == 0:print( "===========" + "iteration:" + str(i) + "===========")for key in optimizers.keys():loss = networks[key].loss(x_batch, t_batch)print(key + ":" + str(loss))# 3.绘制图形==========
markers = {"SGD": "o", "Momentum": "x", "AdaGrad": "s", "Adam": "D", "RMSprop": "v"}
x = np.arange(max_iterations)
for key in optimizers.keys():plt.plot(x, smooth_curve(train_loss[key]), marker=markers[key], markevery=100, label=key)
plt.xlabel("iterations")
plt.ylabel("loss")
plt.ylim(0, 1)
plt.legend()
plt.show()
实验结果如下
相关文章:
深度学习笔记:不同的反向传播迭代方法
1 随机梯度下降法SGD 随机梯度下降法每次迭代取梯度下降最大的方向更新。这一方法实现简单,但是在很多函数中,梯度下降的方向不一定指向函数最低点,这使得梯度下降呈现“之”字形,其效率较低 class SGD:"""随机…...
ElasticSearch 学习笔记总结(三)
文章目录一、ES 相关名词 专业介绍二、ES 系统架构三、ES 创建分片副本 和 elasticsearch-head插件四、ES 故障转移五、ES 应对故障六、ES 路由计算 和 分片控制七、ES集群 数据写流程八、ES集群 数据读流程九、ES集群 更新流程 和 批量操作十、ES 相关重要 概念 和 名词十一、…...
深入理解border以及应用
深入border属性以及应用👏👏 border这个属性在开发过程中很常用,常常用它来作为边界的。但是大家真的了解border吗?以及它的形状是什么样子的。 我们先来看这样一段代码:👏 <!--* Author: syk 185901…...
如何复现论文?什么是论文复现?
参考资料: 学习篇—顶会Paper复现方法 - 知乎 如何读论文?复现代码?_复现代码是什么意思 - CSDN 我是如何复现我人生的第一篇论文的 - 知乎 在我看来,论文复现应该有一个大前提和分为两个层次。 大前提是你要清楚地懂得自己要…...
22.2.28打卡 Codeforces Round #851 (Div. 2) A~C
A题 One and Two 题面翻译 题目描述 给你一个数列 a1,a2,…,ana_1, a_2, \ldots, a_na1,a2,…,an . 数列中的每一个数的值要么是 111 要么是 222 . 找到一个最小的正整数 kkk,使之满足: 1≤k≤n−11 \leq k \leq n-11≤k≤n−1 , anda1⋅a2⋅……...
Learining C++ No.12【vector】
引言: 北京时间:2023/2/27/11:42,高数考试还在进行中,我充分意识到了学校的不高级,因为题目真的没什么意思,虽然挺平易近人,但是……,考试期间时间比较放松,所以不能耽误…...
【数电基础】——逻辑代数运算
目录 1.概念 1.基本逻辑概念 2.基本逻辑电路(与或非) 逻辑与运算 与门电路: 逻辑或运算 或门电路: 逻辑非运算(逻辑反) 非门电路编辑 3.复合逻辑电路(运算) 与非逻辑…...
【Redis】什么是缓存与数据库双写不一致?怎么解决?
1. 热点缓存重建 我们以热点缓存 key 重建来一步步引出什么是缓存与数据库双写不一致,及其解决办法。 1.1 什么是热点缓存重建 在实际开发中,开发人员使用 “缓存 过期时间” 的策略来实现加速数据读写和内存使用率,这种策略能满足大多数…...
互联网衰退期,测试工程师35岁之路怎么走...
国内的互联网行业发展较快,所以造成了技术研发类员工工作强度比较大,同时技术的快速更新又需要员工不断的学习新的技术。因此淘汰率也比较高,超过35岁的基层研发类员工,往往因为家庭原因、身体原因,比较难以跟得上工作…...
动态规划(以背包问题为例)
1) 要求达到的目标为装入的背包的总价值最大,并且重量不超出2) 要求装入的物品不能重复动态规划(Dynamic Programming)算法的核心思想是:将大问题划分为小问题进行解决,从而一步步获取最优解的处理算法。动态规划算法与分治算法类似ÿ…...
Java异常
异常的体系结构 在java的Throwable下有Error和Exception两个子类 Error(错误):程序运行中出现了严重的问题,非代码性错误,无法处理,常见的有虚拟机运行错误和内存溢出等Exception(异常):是由于代码本身造成的问题,可以进行处理,异常一个可以分为运行时异常和编译时异常 运行…...
别克GL8改装完工,一起来看看效果
①豪华商务头等舱 别克GL8作为商务车,不管是家用还是商务接待,原车内饰都太掉档次了,所以车主要求全部换掉。>>织布座椅换成航空座椅 主副驾:改装纳帕皮 中排:改装水晶宝座豪华版航空座椅,带通风、加…...
mac 中 shell 一些知识
mac 设置环境变量首先得看你所使用的 shell shell 是一个命令行解释器,顾名思义就是机器外面的一层壳,用于人机交互,只要是人与电脑之间交互的接口,就可以称为 shell。表现为其作用是用户输入一条命令,shell 就立即解…...
CentOS 配置FTP(开启VSFTPD服务)
网上已经有很多关于VSFTPD的配置,但有两个通病,要么就是原理介绍太多,要么就是不完整,操作下来又要查询多篇文章才能用。 我这里不讲原理,只记录操作,尽可能通过复制下面的操作可以实现FTP读写功能。方便自…...
Http的请求方法
Http的请求方法对应的数据传输能力把Http请求分为Url类请求和Body类请求 1.Url类请求包括但不限于GET、HEAD、OPTIONS、TRACE 等请求方法 2.Body类请求包括但不限于POST、PUSH、PATCH、DELETE 等请求方法。 3.原因:get请求没有请求体(好像也可以…...
Python字典-- 内附蓝桥题:统计数字
字典 ~~不定时更新🎃,上次更新:2023/02/28 🗡常用函数(方法) 1. dic.get(key) --> 判断字典 dic 是否有 key,有返回其对应的值,没有返回 None 举个栗子🌰 dic …...
文本处理工具
Grep工具的基本使用grep作用:grep是行过滤工具;用于根据关键字进行行过滤提示:通过alias命令设置grep别名,搜索参数时带颜色显示alias grepgrep colorauto 命令语法格式:grep [选项] 参数 文件名grep命令选项ÿ…...
C++STL详解(三)——vector的介绍和使用
文章目录vector的介绍vector的使用vector的定义方式vector的空间增长问题reserve和resizevector的迭代器使用begin 和endrbegin和rendinsert 和erasefind函数元素访问vector迭代器失效问题1:inserse插入扩容时空间销毁造成野指针问题2:erase删除或者inse…...
GEBCO海洋数据下载
一、数据集简介 GEBCO(General Bathymetric chart of the Oceans)旨在为世界海洋提供最权威的、可公开获取的测深数据集。 目前的网格化测深数据集,即GEBCO_2022网格,是一个全球海洋和陆地的地形模型,在15角秒间隔的…...
【C++容器】vector、map、hash_map、unordered_map四大容器的性能分析【2023.02.28】
摘要 vector是标准容器对数组的封装,是一段连续的线性的内存。map底层是二叉排序树。hash_map是C11之前的无序map,unordered_map底层是hash表,涉及桶算法。现对各个容器的查询与”插入“性能做对比分析,方便后期选择。 测试方案…...
ACM-蓝桥杯训练第一周
🚀write in front🚀 📝个人主页:认真写博客的夏目浅石.CSDN 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝 📣系列专栏:ACM周训练题目合集.CSDN 💬总结:…...
python基础—字符串操作
(1)字符串: Python内置了一系列的数据类型,其中最主要的内置类型是数值类型、文本序列(字符串)类型、序列(列表、元组和range)类型、集合类型、映射(字典)类型…...
【Spring】通过JdbcTemplate实现CRUD操作
个人简介:Java领域新星创作者;阿里云技术博主、星级博主、专家博主;正在Java学习的路上摸爬滚打,记录学习的过程~ 个人主页:.29.的博客 学习社区:进去逛一逛~ 通过JdbcTemplate实现 增删查改一、添加相关依…...
实战|掌握Linux内存监视:free命令详解与使用技巧
文章目录前言一. free命令介绍二. 语法格式及常用选项三. 参考案例3.1 查看free相关的信息3.2 以MB的形式显示内存的使用情况3.3 以总和的形式显示内存的使用情况3.4 周期性的查询内存的使用情况3.5 以更人性化的形式来查看内存的结果输出四. free在脚本中的应用总结前言 大家…...
嵌入式入门必看!调试工具安装——基于 AM64x核心板
本章节内容是为评估板串口安装USB转串口驱动程序。驱动适用于CH340、CH341等USB转串口芯片。 USB转串口驱动安装 适用安装环境:Windows 7 64bit、Windows 10 64bit。 本文测试板卡为创龙科技SOM-TL64x核心板,它是一款基于TI Sitara系列AM64x双核ARM Cortex-A53 + 单/四核Cort…...
JAVA开发(java类加载过程)
1、java语言的平台无关性。 因为java语言可以跑在java虚拟机上,所以只要能装java虚拟机的地方就能跑java程序。java语言以后缀名 .java为文件扩展名。通过java编译器javac编译成字节码文件.class 。java字节码文件通过java虚拟机解析运行。所以java语言可以说是编译…...
【vulhub漏洞复现】Thinkphp 2.x 任意代码执行
一、漏洞详情影响版本 thinkphp 2.x但是由于thinkphp 3.0版本在Lite模式下没有修复该漏洞,所以也存在该漏洞漏洞原因:e 和 /e模式匹配路由:e 配合函数preg_replace()使用, 可以把匹配来的字符串当作正则表达式执行; /e 可执行模式,…...
LeetCode 1145. 二叉树着色游戏 -- 简单搜索
二叉树着色游戏 提示 中等 199 相关企业 有两位极客玩家参与了一场「二叉树着色」的游戏。游戏中,给出二叉树的根节点 root,树上总共有 n 个节点,且 n 为奇数,其中每个节点上的值从 1 到 n 各不相同。 最开始时: 「一…...
HyperGBM的三种Early Stopping方式
本文作者:杨健,九章云极 DataCanvas 主任架构师 很多机器学习框架如都提供了Early Stopping策略,主要用来防止模型过拟合。和模型训练提前停止的目标不同,AutoML的Early Stopping策略更多考虑的是算力消耗和模型质量的平衡。 通…...
心系区域发展,高德用一体化出行服务平台“聚”力区域未来
交通,是城市的血脉。通过对人、资源、产业的连接,交通建设往往是城市和区域经济发展的前提。不过,在度过了“要想富,先修路”的初级建设阶段后,交通产业内部也出现了挑战,诸如城市秩序、发展成本、用户使用…...
php网站怎么做seo/市场调查报告模板及范文
什么是 HomeLists ? HomeLists 是一款自托管耗材统计软件,能通过提醒等帮助您跟踪家庭消耗品。 安装 在群晖上以 Docker 方式安装。 在注册表中搜索 homelists ,选择第一个 aceberg/homelists,版本选择 latest。 本文写作时&…...
网上虚拟银行注册网站/淘宝产品关键词排名查询
对于刚接触linux系统的学员来说,确实是一件比较困难的事情,造成这种局面主要原因之一是windows的设计考虑到用户的体验效果,提供了更好的用户操作效果。以至于用户接触的最多的系统,所以刚接触linux的时候会感觉很不适应ÿ…...
网站如何优化关键词排名/seo排名培训
783. 二叉搜索树节点最小距离 难度简单106 给定一个二叉搜索树的根节点 root,返回树中任意两节点的差的最小值。 示例: 输入: root [4,2,6,1,3,null,null] 输出: 1 解释: 注意,root是树节点对象(TreeNode object),而不是数组。给…...
web前端毕业论文/黑锋网seo
来源 | www.iteye.com/blog/josh-persistence-2161848现实企业级Java应用开发、维护中,有时候我们会碰到下面这些问题:OutOfMemoryError,内存不足内存泄露线程死锁锁争用(Lock Contention)Java进程消耗CPU过高......这…...
c 转网站开发/2022年今天新闻联播
作者 | 丁彦军责编 | 仲培艺近日,有位粉丝向我请教,在爬取某网站时,网页的源代码出现了中文乱码问题,本文就将与大家一起总结下关于网络爬虫的乱码处理。注意,这里不仅是中文乱码,还包括一些如日文、韩文 、…...
bootstrap 网站源代码/网站快速收录入口
全部分页查询获取计数根据用户名查询根据主键批量修改新增全部分页查询 //sql select * from tb_user limit #{pageSize},#{pageNumber}//java PageHelper.startPage(pageSize,PageNumber); userMapper.selectByExample(null);Service public class TbUserDubboServiceImpl im…...