当前位置: 首页 > news >正文

现代卷积网络实战系列2:训练函数、PyTorch构建LeNet网络

4、训练函数

4.1 调用训练函数

train(epochs, net, train_loader, device, optimizer, test_loader, true_value)

因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的参数,训练函数都需要
这七个参数,是训练一个神经网络所需要的最少参数

4.2 训练函数

训练函数中,所有训练集进行多次迭代,而每次迭代又会将数据分成多个批次进行迭代

def train(epochs, net, train_loader, device, optimizer, test_loader, true_value):for epoch in range(1, epochs + 1):net.train()all_train_loss = []for batch_idx, (data, target) in enumerate(train_loader):data = data.to(device)target = target.to(device)optimizer.zero_grad()output = net(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()cur_train_loss = loss.item()all_train_loss.append(cur_train_loss)train_loss = np.round(np.mean(all_train_loss) * 1000, 2)print('\nepoch step:', epoch)print('training loss: ', train_loss)test(net, test_loader, device, true_value, epoch)print("\nTraining finished")
  1. 定义训练函数
  2. 安装epochs迭代数据
  3. 进入pytorch的训练模式
  4. all_train_loss 存放训练集5万张图片的损失值
  5. 按照batch取数据
  6. 数据进入GPU
  7. 标签进入GPU
  8. 梯度清零
  9. 当前batch进入网络后得到输出
  10. 根据输出得到当前损失
  11. 反向传播
  12. 梯度下降
  13. 获取损失的损失值(PyTorch框架中的数据)
  14. 把当前batch的损失加入all_train_loss数组中,结束batch的迭代
  15. 将5张图片的损失计算出来并且进行求平均,这里乘以1000是因为我觉得计算出的损失太小了,所以乘以1000,方便看损失的变化,保留两位有效数字
  16. 打印当前epoch
  17. 打印损失
  18. 调用测试函数,测试当前训练的网络的性能,结束epoch的迭代
  19. 打印训练完成

5、LeNet

5.1 网络结构

LeNet可以说是首次提出卷积神经网络的模型
主要包含下面的网络层:

  1. 5*5的二维卷积
  2. sigmoid激活函数(这里使用了relu)
  3. 5*5的二维卷积
  4. sigmoid激活函数
  5. 数据一维化
  6. 全连接层
  7. 全连接层
  8. softmax分类器

将网络结构打印出来:

LeNet(
-------(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
-------(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
-------(conv2_drop): Dropout2d(p=0.5, inplace=False)
-------(fc1): Linear(in_features=320, out_features=50, bias=True)
-------(fc2): Linear(in_features=50, out_features=10, bias=True)
)

5.2 PyTorch构建LeNet

class LeNet(nn.Module):def __init__(self, num_classes):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, num_classes)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)

这个时候已经是一个完整的项目了,看看10个epoch训练过程的打印:

D:\conda\envs\pytorch\python.exe A:\0_MNIST\train.py

Reading data…
train_data: (60000, 28, 28) train_label (60000,)
test_data: (10000, 28, 28) test_label (10000,)

Initialize neural network
test loss: 2301.68
test accuracy: 11.3 %

epoch step: 1
training loss: 634.74
test loss: 158.03
test accuracy: 95.29 %

epoch step: 2
training loss: 324.04
test loss: 107.62
test accuracy: 96.55 %

epoch step: 3
training loss: 271.25
test loss: 88.43
test accuracy: 97.04 %

epoch step: 4
training loss: 236.69
test loss: 70.94
test accuracy: 97.61 %

epoch step: 5
training loss: 211.05
test loss: 69.69
test accuracy: 97.72 %

epoch step: 6
training loss: 199.28
test loss: 62.04
test accuracy: 97.98 %

epoch step: 7
training loss: 187.11
test loss: 59.65
test accuracy: 97.98 %

epoch step: 8
training loss: 178.79
test loss: 53.89
test accuracy: 98.2 %

epoch step: 9
training loss: 168.75
test loss: 51.83
test accuracy: 98.43 %

epoch step: 10
training loss: 160.83
test loss: 50.35
test accuracy: 98.4 %

Training finished
进程已结束,退出代码为 0

可以看出基本上只要一个epoch就可以得到很好的训练效果了,后续的epoch中的提升比较小

相关文章:

现代卷积网络实战系列2:训练函数、PyTorch构建LeNet网络

4、训练函数 4.1 调用训练函数 train(epochs, net, train_loader, device, optimizer, test_loader, true_value)因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的…...

rust特性

特性,也叫特质,英文是trait。 trait是一种特殊的类型,用于抽象某些方法。trait类似于其他编程语言中的接口,但又有所不同。 trait定义了一组方法,其他类型可以各自实现这个trait的方法,从而形成多态。 一、…...

TouchGFX之画布控件

TouchGFX的画布控件,在使用相对较小的存储空间的同时保持高性能,可提供平滑、抗锯齿效果良好的几何图形绘制。 TouchGFX 设计器中可用的画布控件: LineCircleShapeLine Progress圆形进度条 存储空间分配和使用​ 为了生成反锯齿效果良好的…...

STM32F103RCT6学习笔记2:串口通信

今日开始快速掌握这款STM32F103RCT6芯片的环境与编程开发,有关基础知识的部分不会多唠,直接实践与运用!文章贴出代码测试工程与测试效果图: 目录 串口通信实验计划: 串口通信配置代码: 测试效果图&#…...

Opencv-图像噪声(均值滤波、高斯滤波、中值滤波)

图像的噪声 图像的平滑 均值滤波 均值滤波代码实现 import cv2 as cv import numpy as np import matplotlib.pyplot as plt from pylab import mplmpl.rcParams[font.sans-serif] [SimHei]img cv.imread("dog.png")#均值滤波cv.blur(img, (5, 5))将对图像img进行…...

MasterAlign相机参数设置-增益调节

相机参数设置-曝光时间调节操作说明 相机参数的设置对于获取清晰、准确的图像至关重要。曝光时间是其中一个关键参数,它直接影响图像的亮度和清晰度。以下是关于曝光时间调节的详细操作步骤,以帮助您轻松进行设置。 步骤一:登录系统 首先&…...

9月22日,每日信息差

今天是2023年09月22日,以下是为您准备的14条信息差 第一、亚马逊将于2024年初在Prime Video中加入广告。Prime Video内容中的广告将于2024年初在美国、英国、德国和加拿大推出,随后晚些时候在法国、意大利、西班牙、墨西哥和澳大利亚推出 第二、中国移…...

Java版本企业工程项目管理系统源码+spring cloud 系统管理+java 系统设置+二次开发

工程项目各模块及其功能点清单 一、系统管理 1、数据字典:实现对数据字典标签的增删改查操作 2、编码管理:实现对系统编码的增删改查操作 3、用户管理:管理和查看用户角色 4、菜单管理:实现对系统菜单的增删改查操…...

Android studio中如何下载sdk

打开 file -> settings 这个页面, 在要下载的 SDK 前面勾上, 然后点 apply 在 platforms 中就可以看到下载好的 SDK: Android SDK目录结构详细介绍可以参考这篇文章: 51CTO博客- Android SDK目录结构...

STM32单片机中国象棋TFT触摸屏小游戏

实践制作DIY- GC0167-中国象棋 一、功能说明: 基于STM32单片机设计-中国象棋 二、功能介绍: 硬件组成:STM32F103RCT6最小系统2.8寸TFT电阻触摸屏24C02存储器1个按键(悔棋) 游戏规则: 1.有悔棋键&…...

【PHP图片托管】CFimagehost搭建私人图床 - 无需数据库支持

文章目录 1.前言2. CFImagehost网站搭建2.1 CFImagehost下载和安装2.2 CFImagehost网页测试2.3 cpolar的安装和注册 3.本地网页发布3.1 Cpolar临时数据隧道3.2 Cpolar稳定隧道(云端设置)3.3.Cpolar稳定隧道(本地设置) 4.公网访问测…...

CCITT 标准的CRC-16检验算法

/******该文件使用查表法计算CCITT 标准的CRC-16检验码,并附测试代码********/ #include #define CRC_INIT 0xffff //CCITT初始CRC为全1 #define GOOD_CRC 0xf0b8 //校验时计算出的固定结果值 /****下表是常用ccitt 16,生成式1021反转成8408后的查询表格****/ u…...

docker启动mysql服务

创建基础文件 mkdir mysql mkdir -p mysql/data获取默认的my.cnf docker run -name mysql -d -p 3306:3306 mysql:latest docker cp mysql:/etc/my.cnf ./vim mysql/my.cnf # For advice on how to change settings please see # http://dev.mysql.com/doc/refman/8.1/en/se…...

Postman应用——Request数据导入导出

文章目录 导入请求数据导出请求数据导出Collection导出Environments 导出所有请求数据导出请求响应数据 Postman可以导入导出Request和Variable变量配置,可以通过文本方式(JOSN文本)或链接方式进行导入导出。 导入请求数据 可以通过JSON文件…...

十四、MySql的用户管理

文章目录 一、用户管理二、用户(一)用户信息(二)创建用户1.语法:2.案例: (三) 删除用户1.语法:2.示例: (四)修改用户密码1.语法&#…...

01.自动化交易综述

算法交易的概念: 利用自动化平台,执行预先设置的一系列规则完成交易行为。 算法交易的优势 1.历史数据评估 2.执行高效 3.无主观情绪输入 4.可度量评价 5.交易频率 算法交易的劣势 1.成本,成本低难以体现收益 2.技巧 算法交易流程 大前…...

基于SpringBoot的网上超市系统的设计与实现

目录 前言 一、技术栈 二、系统功能介绍 管理员功能实现 用户功能实现 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 网络技术和计算机技术发展至今,已经拥有了深厚的理论基础,并在现实中进行了充分运用,尤其是基于计…...

国内首家!阿里云 Elasticsearch 8.9 版本释放 AI 搜索新动能

简介: 阿里云作为国内首家上线 Elasticsearch 8.9版本的厂商,在提供 Elasticsearch Relevance Engine™ (ESRE™) 引擎的基础上,提供增强 AI 的最佳实践与 ES 本身的混合搜索能力,为用户带来了更多创新和探索的可能性。 近年来&a…...

uniapp获取一周日期和星期

UniApp可以使用JavaScript中的Date对象来获取当前日期和星期几。以下是一个示例代码,可以获取当前日期和星期几,并输出在一周内的每天早上和晚上: // 获取当前日期和星期 let date new Date(); let weekdays ["Sunday", "M…...

QT之QListWidget的介绍

QListWidget常用成员函数 1、成员函数介绍2、例子显示图片和按钮的例子 1、成员函数介绍 1)QListWidget(QWidget *parent nullptr) 构造函数,创建一个新的QListWidget对象。 2)void addItem(const QString &label) 在列表末尾添加一个项目,项目标…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

抖音增长新引擎:品融电商,一站式全案代运营领跑者

抖音增长新引擎&#xff1a;品融电商&#xff0c;一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中&#xff0c;品牌如何破浪前行&#xff1f;自建团队成本高、效果难控&#xff1b;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...

基于matlab策略迭代和值迭代法的动态规划

经典的基于策略迭代和值迭代法的动态规划matlab代码&#xff0c;实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...

​​企业大模型服务合规指南:深度解析备案与登记制度​​

伴随AI技术的爆炸式发展&#xff0c;尤其是大模型&#xff08;LLM&#xff09;在各行各业的深度应用和整合&#xff0c;企业利用AI技术提升效率、创新服务的步伐不断加快。无论是像DeepSeek这样的前沿技术提供者&#xff0c;还是积极拥抱AI转型的传统企业&#xff0c;在面向公众…...

规则与人性的天平——由高考迟到事件引发的思考

当那位身着校服的考生在考场关闭1分钟后狂奔而至&#xff0c;他涨红的脸上写满绝望。铁门内秒针划过的弧度&#xff0c;成为改变人生的残酷抛物线。家长声嘶力竭的哀求与考务人员机械的"这是规定"&#xff0c;构成当代中国教育最尖锐的隐喻。 一、刚性规则的必要性 …...

uni-app学习笔记三十五--扩展组件的安装和使用

由于内置组件不能满足日常开发需要&#xff0c;uniapp官方也提供了众多的扩展组件供我们使用。由于不是内置组件&#xff0c;需要安装才能使用。 一、安装扩展插件 安装方法&#xff1a; 1.访问uniapp官方文档组件部分&#xff1a;组件使用的入门教程 | uni-app官网 点击左侧…...

网页端 js 读取发票里的二维码信息(图片和PDF格式)

起因 为了实现在报销流程中&#xff0c;发票不能重用的限制&#xff0c;发票上传后&#xff0c;希望能读出发票号&#xff0c;并记录发票号已用&#xff0c;下次不再可用于报销。 基于上面的需求&#xff0c;研究了OCR 的方式和读PDF的方式&#xff0c;实际是可行的&#xff…...

起重机起升机构的安全装置有哪些?

起重机起升机构的安全装置是保障吊装作业安全的关键部件&#xff0c;主要用于防止超载、失控、断绳等危险情况。以下是常见的安全装置及其功能和原理&#xff1a; 一、超载保护装置&#xff08;核心安全装置&#xff09; 1. 起重量限制器 功能&#xff1a;实时监测起升载荷&a…...

MyBatis-Plus 常用条件构造方法

1.常用条件方法 方法 说明eq等于 ne不等于 <>gt大于 >ge大于等于 >lt小于 <le小于等于 <betweenBETWEEN 值1 AND 值2notBetweenNOT BETWEEN 值1 AND 值2likeLIKE %值%notLikeNOT LIKE %值%likeLeftLIKE %值likeRightLIKE 值%isNull字段 IS NULLisNotNull字段…...

Linux信号保存与处理机制详解

Linux信号的保存与处理涉及多个关键机制&#xff0c;以下是详细的总结&#xff1a; 1. 信号的保存 进程描述符&#xff08;task_struct&#xff09;&#xff1a;每个进程的PCB中包含信号相关信息。 pending信号集&#xff1a;记录已到达但未处理的信号&#xff08;未决信号&a…...