当前位置: 首页 > news >正文

手写数字识别之优化算法:观察Loss下降的情况判断合理的学习率

目录

手写数字识别之优化算法:观察Loss下降的情况判断合理的学习率

前提条件

设置学习率

学习率的主流优化算法


手写数字识别之优化算法:观察Loss下降的情况判断合理的学习率

我们明确了分类任务的损失函数(优化目标)的相关概念和实现方法,本节我们依旧横向展开"横纵式"教学法,如 图1 所示,本节主要探讨在手写数字识别任务中,使得损失达到最小的参数取值的实现方法。

图1:“横纵式”教学法 — 优化算法



前提条件

在优化算法之前,需要进行数据处理、设计神经网络结构,代码与上一节保持一致,如下所示。

# 加载相关库
import os
import random
import paddle
from paddle.nn import Conv2D, MaxPool2D, Linear
import numpy as np
from PIL import Image
import gzip
import json# 定义数据集读取器
def load_data(mode='train'):# 读取数据文件datafile = './work/mnist.json.gz'print('loading mnist dataset from {} ......'.format(datafile))data = json.load(gzip.open(datafile))# 读取数据集中的训练集,验证集和测试集train_set, val_set, eval_set = data# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLSIMG_ROWS = 28IMG_COLS = 28# 根据输入mode参数决定使用训练集,验证集还是测试if mode == 'train':imgs = train_set[0]labels = train_set[1]elif mode == 'valid':imgs = val_set[0]labels = val_set[1]elif mode == 'eval':imgs = eval_set[0]labels = eval_set[1]# 获得所有图像的数量imgs_length = len(imgs)# 验证图像数量和标签数量是否一致assert len(imgs) == len(labels), \"length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(labels))index_list = list(range(imgs_length))# 读入数据时用到的batchsizeBATCHSIZE = 100# 定义数据生成器def data_generator():# 训练模式下,打乱训练数据if mode == 'train':random.shuffle(index_list)imgs_list = []labels_list = []# 按照索引读取数据for i in index_list:# 读取图像和标签,转换其尺寸和类型img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')label = np.reshape(labels[i], [1]).astype('int64')imgs_list.append(img) labels_list.append(label)# 如果当前数据缓存达到了batch size,就返回一个批次数据if len(imgs_list) == BATCHSIZE:yield np.array(imgs_list), np.array(labels_list)# 清空数据缓存列表imgs_list = []labels_list = []# 如果剩余数据的数目小于BATCHSIZE,# 则剩余数据一起构成一个大小为len(imgs_list)的mini-batchif len(imgs_list) > 0:yield np.array(imgs_list), np.array(labels_list)return data_generator# 定义模型结构
import paddle.nn.functional as F
# 多层卷积神经网络实现
class MNIST(paddle.nn.Layer):def __init__(self):super(MNIST, self).__init__()# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)# 定义池化层,池化核的大小kernel_size为2,池化步长为2self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)# 定义池化层,池化核的大小kernel_size为2,池化步长为2self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)# 定义一层全连接层,输出维度是10self.fc = Linear(in_features=980, out_features=10)# 定义网络前向计算过程,卷积后紧接着使用池化层,最后使用全连接层计算最终输出# 卷积层激活函数使用Relu,全连接层激活函数使用softmaxdef forward(self, inputs):x = self.conv1(inputs)x = F.relu(x)x = self.max_pool1(x)x = self.conv2(x)x = F.relu(x)x = self.max_pool2(x)x = paddle.reshape(x, [x.shape[0], -1])x = self.fc(x)return x

设置学习率

在深度学习神经网络模型中,通常使用标准的随机梯度下降算法更新参数,学习率代表参数更新幅度的大小,即步长。当学习率最优时,模型的有效容量最大,最终能达到的效果最好。学习率和深度学习任务类型有关,合适的学习率往往需要大量的实验和调参经验。探索学习率最优值时需要注意如下两点:

  • 学习率不是越小越好。学习率越小,损失函数的变化速度越慢,意味着我们需要花费更长的时间进行收敛,如 图2 左图所示。
  • 学习率不是越大越好。只根据总样本集中的一个批次计算梯度,抽样误差会导致计算出的梯度不是全局最优的方向,且存在波动。在接近最优解时,过大的学习率会导致参数在最优解附近震荡,损失难以收敛,如 图2 右图所示。


图2: 不同学习率(步长过大/过小)的示意图
 

在训练前,我们往往不清楚一个特定问题设置成怎样的学习率是合理的,因此在训练时可以尝试调小或调大,通过观察Loss下降的情况判断合理的学习率,设置学习率的代码如下所示。

#仅优化算法的设置有所差别
def train(model):model.train()#调用加载数据的函数train_loader = load_data('train')#设置不同初始学习率opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())# opt = paddle.optimizer.SGD(learning_rate=0.0001, parameters=model.parameters())# opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())EPOCH_NUM = 10for epoch_id in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):#准备数据images, labels = dataimages = paddle.to_tensor(images)labels = paddle.to_tensor(labels)#前向计算的过程predicts = model(images)#计算损失,取一个批次样本损失的平均值loss = F.cross_entropy(predicts, labels)avg_loss = paddle.mean(loss)#每训练了100批次的数据,打印下当前Loss的情况if batch_id % 200 == 0:print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))#后向传播,更新参数的过程avg_loss.backward()# 最小化loss,更新参数opt.step()# 清除梯度opt.clear_grad()#保存模型参数paddle.save(model.state_dict(), 'mnist.pdparams')#创建模型    
model = MNIST()
#启动训练过程
train(model)

学习率的主流优化算法

学习率是优化器的一个参数,调整学习率看似是一件非常麻烦的事情,需要不断的调整步长,观察训练时间和Loss的变化。经过研究员的不断的实验,当前已经形成了四种比较成熟的优化算法:SGD、Momentum、AdaGrad和Adam,效果如 图3 所示。


图3: 不同学习率算法效果示意图
 

  • SGD: 随机梯度下降算法,每次训练少量数据,抽样偏差导致的参数收敛过程中震荡。

  • Momentum: 引入物理“动量”的概念,累积速度,减少震荡,使参数更新的方向更稳定。

每个批次的数据含有抽样误差,导致梯度更新的方向波动较大。如果我们引入物理动量的概念,给梯度下降的过程加入一定的“惯性”累积,就可以减少更新路径上的震荡,即每次更新的梯度由“历史多次梯度的累积方向”和“当次梯度”加权相加得到。历史多次梯度的累积方向往往是从全局视角更正确的方向,这与“惯性”的物理概念很像,也是为何其起名为“Momentum”的原因。类似不同品牌和材质的篮球有一定的重量差别,街头篮球队中的投手(擅长中远距离投篮)喜欢稍重篮球的比例较高。一个很重要的原因是,重的篮球惯性大,更不容易受到手势的小幅变形或风吹的影响。

  • AdaGrad: 根据不同参数距离最优解的远近,动态调整学习率。学习率逐渐下降,依据各参数变化大小调整学习率。

通过调整学习率的实验可以发现:当某个参数的现值距离最优解较远时(表现为梯度的绝对值较大),我们期望参数更新的步长大一些,以便更快收敛到最优解。当某个参数的现值距离最优解较近时(表现为梯度的绝对值较小),我们期望参数的更新步长小一些,以便更精细的逼近最优解。类似于打高尔夫球,专业运动员第一杆开球时,通常会大力打一个远球,让球尽量落在洞口附近。当第二杆面对离洞口较近的球时,他会更轻柔而细致的推杆,避免将球打飞。与此类似,参数更新的步长应该随着优化过程逐渐减少,减少的程度与当前梯度的大小有关。根据这个思想编写的优化算法称为“AdaGrad”,Ada是Adaptive的缩写,表示“适应环境而变化”的意思。RMSProp是在AdaGrad基础上的改进,学习率随着梯度变化而适应,解决AdaGrad学习率急剧下降的问题。

  • Adam: 由于动量和自适应学习率两个优化思路是正交的,因此可以将两个思路结合起来,这就是当前广泛应用的算法。

说明:

每种优化算法均有更多的参数设置。理论最合理的未必在具体案例中最有效,所以模型调参是很有必要的,最优的模型配置往往是在一定“理论”和“经验”的指导下实验出来的。


我们可以尝试选择不同的优化算法训练模型,观察训练时间和损失变化的情况,代码实现如下。

#仅优化算法的设置有所差别
def train(model):model.train()#调用加载数据的函数train_loader = load_data('train')#四种优化算法的设置方案,可以逐一尝试效果opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())# opt = paddle.optimizer.Momentum(learning_rate=0.01, momentum=0.9, parameters=model.parameters())# opt = paddle.optimizer.Adagrad(learning_rate=0.01, parameters=model.parameters())# opt = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())EPOCH_NUM = 3for epoch_id in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):#准备数据images, labels = dataimages = paddle.to_tensor(images)labels = paddle.to_tensor(labels)#前向计算的过程predicts = model(images)#计算损失,取一个批次样本损失的平均值loss = F.cross_entropy(predicts, labels)avg_loss = paddle.mean(loss)#每训练了100批次的数据,打印下当前Loss的情况if batch_id % 200 == 0:print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))#后向传播,更新参数的过程avg_loss.backward()# 最小化loss,更新参数opt.step()# 清除梯度opt.clear_grad()#保存模型参数paddle.save(model.state_dict(), 'mnist.pdparams')#创建模型    
model = MNIST()
#启动训练过程
train(model)

相关文章:

手写数字识别之优化算法:观察Loss下降的情况判断合理的学习率

目录 手写数字识别之优化算法:观察Loss下降的情况判断合理的学习率 前提条件 设置学习率 学习率的主流优化算法 手写数字识别之优化算法:观察Loss下降的情况判断合理的学习率 我们明确了分类任务的损失函数(优化目标)的相关概念和实现方法&#xff…...

软件工程(二十) 系统运行与软件维护

1、系统转换计划 1.1、遗留系统的演化策略 时至今日,你想去开发一个系统,想完全不涉及到已有的系统,基本是不可能的事情。但是对于已有系统我们有一个策略。 比如我们是淘汰掉已有系统,还是继承已有系统,或者集成已有系统,或者改造遗留的系统呢,都是不同的策略。 技术…...

蓝蓝设计ui设计公司作品--泛亚高科-光伏电站控制系统界面设计

泛亚高科(北京)科技有限公司(以下简称“泛亚高科”),一个以实时监控、高精度数值计算为基础的科技公司, 自成立以来,组成了以博士、硕士为核心的技术团队,整合了华北电力大学等高校资源,凭借在电…...

软考高级系统架构设计师系列论文七十:论信息系统的安全体系

软考高级系统架构设计师系列论文七十:论信息系统的安全体系 一、信息系统相关知识点二、摘要三、正文四、总结一、信息系统相关知识点 软考高级信息系统项目管理师系列之四十三:信息系统安全管理...

​Softing dataFEED OPC Suite——助力数字孪生技术发展

一 行业概览 数字孪生技术是充分利用物理模型、传感器更新、运行历史等数据,集成多学科、多物理量、多尺度、多概率的仿真过程,在虚拟空间中完成映射,从而反映相对应的实体装备的全生命周期过程。数字孪生技术已经应用在众多领域&#xff1a…...

LLaMA中ROPE位置编码实现源码解析

1、Attention中q,经下式,生成新的q。m为句长length,d为embedding_dim/head θ i 1 1000 0 2 i d \theta_i\frac{1}{10000^\frac{2i}{d}} θi​10000d2i​1​ 2、LLaMA中RoPE源码 import torchdef precompute_freqs_cis(dim: int, end: i…...

在c++ 20下使用微软的proxy库替代传统的virtual动态多态

传统的virtual动态多态&#xff0c;经常会有下面这样的使用需求&#xff1a; #include <iostream> #include <vector>// 声明一个包含virtual虚函数的基类 struct shape {virtual ~shape() {}virtual void draw() 0; };// 派生&#xff0c;实现virtual虚函数 str…...

Spring MVC:@RequestMapping

Spring MVC RequestMapping属性 RequestMapping RequestMapping&#xff0c; 是 Spring Web 应用程序中最常用的注解之一&#xff0c;主要用于映射 HTTP 请求 URL 与处理请求的处理器 Controller 方法上。使用 RequestMapping 注解可以方便地定义处理器 Controller 的方法来处…...

【vue3+ts项目】配置eslint校验代码工具,eslint+prettier+stylelint

1、运行好后自动打开浏览器 package.json中 vite后面加上 --open 2、安装eslint npm i eslint -D3、运行 eslint --init 之后&#xff0c;回答一些问题&#xff0c; 自动创建 .eslintrc 配置文件。 npx eslint --init回答问题如下&#xff1a; 使用eslint仅检查语法&…...

PHP之ZipArchive打包压缩文件

1、Linux 安装 nginx 安装zlib库 2、使用&#xff0c;目前我这边的需求是。 1、材料图片、单据图片&#xff0c;分别压缩打包到“材料.zip”和“单据.zip”。 2、“材料.zip”和“单据.zip”在压缩打包到“订单.zip” 3、支持批量导出多个订单的图片信息所有订单的压缩文件&…...

面试之快速学习C++14

文章参考&#xff1a;https://zhuanlan.zhihu.com/p/588826142?utm_id0 最近学了一会感慨到找工作好难&#xff0c;上周面试了一家医疗公司&#xff0c;准备攒攒经验但是不去&#xff0c;结果三天了没消息&#xff0c;感觉一面都没过… 本来自傲看不上&#xff0c;结果人家也…...

【算法专题突破】双指针 - 快乐数(3)

目录 1. 题目解析 2. 算法原理 3. 代码编写 写在最后&#xff1a; 1. 题目解析 题目链接&#xff1a;202. 快乐数 - 力扣&#xff08;Leetcode&#xff09; 这道题的题目也很容易理解&#xff0c; 看一下题目给的示例就能很容易明白&#xff0c; 但是要注意一个点&#…...

【javaweb】学习日记Day4 - Maven 依赖管理 Web入门

目录 一、Maven入门 - 管理和构建java项目的工具 1、IDEA如何构建Maven项目 2、Maven 坐标 &#xff08;1&#xff09;定义 &#xff08;2&#xff09;主要组成 3、IDEA如何导入和删除项目 二、Maven - 依赖管理 1、依赖配置 2、依赖传递 &#xff08;1&#xff09;查…...

C++信息学奥赛1144:单词翻转

#include <iostream> #include <string> using namespace std; int main() {string str;// 输入一行字符串getline(cin, str);string arr;for (int i 0; i < str.length(); i){if (str[i] ! ){arr str[i]; // 将非空格字符添加到临时存储的字符串中}else{for…...

qt检查文件夹是否有写权限

Qt 使用如下函数能够判断路径或者文件是否可写&#xff1a; bool QFileInfo::isWritable() const 对于win10系统实测&#xff0c;结果不准确。继续排查&#xff0c;官方文档描述&#xff1a;a&#xff09;如果未启用 NTFS 权限检查&#xff0c;Windows 上的结果将仅反映文件是…...

LSF 安装目录,快速参考 LSF 命令、守护程序、配置文件、日志文件和重要集群配置参数

样本 UNIX 和 Linux 安装目录 守护程序错误日志文件 守护程序错误日志文件存储在 LSF_LOGDIR 在 lsf.conf 文件中定义的目录中。 LSF 基本系统守护程序日志文件LSF 批处理系统守护程序日志文件pim.log.host_namembatchd.log.host_namembatchd.log.host_namesbatchd.log.host_…...

在Mybatis中写动态sql这些标签:if、where、set、trim、foreach、choose的作用是什么,怎么用?

在 MyBatis 中&#xff0c;您可以使用动态 SQL 标签来构建灵活的 SQL 查询&#xff0c;以根据不同的条件生成不同的查询语句。以下是这些标签的作用和用法&#xff1a; 1. **<if> 标签&#xff1a;** 用于根据某个条件动态地包含或排除 SQL 片段&#xff0c;test:可以写…...

7 Python的模块和包

概述 在上一节&#xff0c;我们介绍了Python的异常处理&#xff0c;包括&#xff1a;异常、异常处理、抛出异常、用户自定义异常等内容。在这一节中&#xff0c;我们将介绍Python的模块和包。Python的模块&#xff08;Module&#xff09;和包&#xff08;Package&#xff09;是…...

【JavaWeb 篇】使用Servlet、JdbcTemplate和Durid连接池实现用户登录功能与测试

在现代Web应用程序开发中&#xff0c;用户登录功能是基础中的基础。它为用户提供了安全访问系统的途径。本篇博客将引导您通过使用Servlet、Spring框架的JdbcTemplate以及Durid连接池&#xff0c;来构建一个完整的用户登录功能。我们将详细展示每个部分的代码&#xff0c;并解释…...

【Unity3D赛车游戏】【六】如何在Unity中为汽车添加发动机和手动挡变速?

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a;Uni…...

【Go 基础篇】切片:Go语言中的灵活数据结构

在Go语言中&#xff0c;切片&#xff08;Slice&#xff09;是一种强大且灵活的数据结构&#xff0c;用于管理和操作一系列元素。与数组相比&#xff0c;切片的大小可以动态调整&#xff0c;这使得它成为处理动态数据集合的理想选择。本文将围绕Go语言中切片的引入&#xff0c;介…...

龙芯2K1000LA移植交叉编译环境以及QT

嵌入式大赛结束了&#xff0c;根据这次比赛中记的凌乱的笔记&#xff0c;整理了一份龙芯2K1000LA的环境搭建过程&#xff0c;可能笔记缺少了一部分步骤或者错误&#xff0c;但是大致步骤可以当作参考。 一、交叉编译工具链 下载连接&#xff1a;龙芯 GNU 编译工具链 | 龙芯开…...

javaee spring依赖注入之spel方式

spring依赖注入之spel方式 <dependency><groupId>org.springframework</groupId><artifactId>spring-expression</artifactId><version>4.3.18.RELEASE</version></dependency>package com.test.pojo;import java.util.List; …...

【Java集合学习1】ArrayList集合学习及集合概述分析

JavaArrayList集合学习及集合学习概述 一、Java集合概述 Java 集合&#xff0c; 也叫作容器&#xff0c;主要是由两大接口派生而来&#xff1a;一个是 Collection接口&#xff0c;主要用于存放单一元素&#xff1b;另一个是 Map 接口&#xff0c;主要用于存放键值对。对于Col…...

TouchGFX之调试

DebugPrinter类是一种在显示屏上打印调试消息的简单方法&#xff0c;无需向屏幕添加控件。 在使用DebugPrinter之前&#xff0c;需要分配一个实例并将其传递给Application类&#xff0c;且DebugPrinter实例必须兼容所使用的LCD类。 该表列出了DebugPrinter类名称&#xff1a; …...

C# winform加载yolov8模型测试(附例程)

第一步&#xff1a;在NuGet中下载Yolov8.Net 第二步&#xff1a;引用 using Yolov8Net; 第三步&#xff1a;加载模型 private IPredictor yolov8 YoloV8Predictor.Create("D:\\0MyWork\\Learn\\vs2022\\yolov_onnx\\best.onnx", mylabel); 第四步&#xff1a;图…...

浙大陈越何钦铭数据结构07-图6 旅游规划

题目: 有了一张自驾旅游路线图&#xff0c;你会知道城市间的高速公路长度、以及该公路要收取的过路费。现在需要你写一个程序&#xff0c;帮助前来咨询的游客找一条出发地和目的地之间的最短路径。如果有若干条路径都是最短的&#xff0c;那么需要输出最便宜的一条路径。 输入…...

VUE笔记(七)项目登录

1、安装elementui 在终端执行 vue add element 注册组件 如果要使用哪个组件&#xff0c;大家需要在plugins/element.js中注册该组件 import Vue from vue import { Button } from element-ui Vue.use(Button) 在页面组件中使用 <el-button type"primary"&…...

大语言模型之六- LLM之企业私有化部署

数据安全是每个公司不得不慎重对待的&#xff0c;为了提高生产力&#xff0c;降本增效又不得不接受新技术带来的工具&#xff0c;私有化部署对于公司还是非常有吸引力的。大语言模型这一工具结合公司的数据可以大大提高公司生产率。 私有化LLM需要处理的问题 企业内私有化LLM…...

Python3 列表

Python3 列表 序列是 Python 中最基本的数据结构。 序列中的每个值都有对应的位置值&#xff0c;称之为索引&#xff0c;第一个索引是 0&#xff0c;第二个索引是 1&#xff0c;依此类推。 Python 有 6 个序列的内置类型&#xff0c;但最常见的是列表和元组。 列表都可以进…...

猎场第几集做的网站推广/推广方案怎么做

1. 找样本文章好辛苦啊&#xff0c;都没有批量下载&#xff0c;要一篇一篇下载&#xff0c;找到一个630多篇英语小说的网站&#xff0c;现在还有160多篇没下载。但没有办法了&#xff0c;要研究必须先要有样本数据。 2. 终于解出了a^logb(n) <> n^logb(a)的转换方式。 转…...

开江建设局网站/seo综合查询怎么关闭

htons(), ntohl(), ntohs()&#xff0c;htons() 函数&#xff1a; 转载自&#xff1a;https://blog.csdn.net/myyllove/article/details/83380209 atoi()和itoa()函数 转载自&#xff1a;https://www.cnblogs.com/ralap7/p/9171613.html...

做门面商铺比较好的网站/百度统计官网

最近项目中通过Kubernetes部署Prometheus完成可视化大屏数据采集&#xff0c;特此记录便于日后查阅。 一、Prometheus部署 1、deploy.yaml apiVersion: apps/v1 kind: Deployment metadata:labels:name: prometheus-deploymentname: prometheusnamespace: monitoring spec:re…...

网站超链接怎么做 word文档/百度推广代理公司哪家好

1 介绍 1.1 实现流程 让左边排好序让右边排好序合并后整体排好序 1.2 特点 时间复杂度O(nlogn)空间复杂度O(n)稳定 2 实现 2.1 递归 public class MergeSort {public static void mergeSort(int[] arr) {if (arr null || arr.length < 2) {return;}process(arr, 0,…...

做网站高流量赚广告费/北京seo优化方案

概述 什么是分库分表 数据数量是不可控的&#xff0c;随着时间和业务发展&#xff0c;造成表里面数据越来越多&#xff0c;如果再去对数据库表CURD操作时&#xff0c;就会有性能问题。 解决方案 为了解决由于数据量过大而造成数据库性能降低问题&#xff0c;主要有下面两种…...

互联网网站 权限/上海最专业的seo公司

百度 紫光 大疆 爱奇艺 科大讯飞 cvte 蔚来 大华 乐鑫 联发科 20道选择&#xff0c;3道编程 注&#xff1a;以下为个人认为笔试中较难的题目和涉及的知识点 (1)KMP算法&#xff0c;哈夫曼编码&#xff1f; (2)sed指令 (3)二叉排序树 (4)双亲表示法 (5)平均有效内存访问时间…...