当前位置: 首页 > news >正文

现代卷积网络实战系列2:训练函数、PyTorch构建LeNet网络

4、训练函数

4.1 调用训练函数

train(epochs, net, train_loader, device, optimizer, test_loader, true_value)

因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的参数,训练函数都需要
这七个参数,是训练一个神经网络所需要的最少参数

4.2 训练函数

训练函数中,所有训练集进行多次迭代,而每次迭代又会将数据分成多个批次进行迭代

def train(epochs, net, train_loader, device, optimizer, test_loader, true_value):for epoch in range(1, epochs + 1):net.train()all_train_loss = []for batch_idx, (data, target) in enumerate(train_loader):data = data.to(device)target = target.to(device)optimizer.zero_grad()output = net(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()cur_train_loss = loss.item()all_train_loss.append(cur_train_loss)train_loss = np.round(np.mean(all_train_loss) * 1000, 2)print('\nepoch step:', epoch)print('training loss: ', train_loss)test(net, test_loader, device, true_value, epoch)print("\nTraining finished")
  1. 定义训练函数
  2. 安装epochs迭代数据
  3. 进入pytorch的训练模式
  4. all_train_loss 存放训练集5万张图片的损失值
  5. 按照batch取数据
  6. 数据进入GPU
  7. 标签进入GPU
  8. 梯度清零
  9. 当前batch进入网络后得到输出
  10. 根据输出得到当前损失
  11. 反向传播
  12. 梯度下降
  13. 获取损失的损失值(PyTorch框架中的数据)
  14. 把当前batch的损失加入all_train_loss数组中,结束batch的迭代
  15. 将5张图片的损失计算出来并且进行求平均,这里乘以1000是因为我觉得计算出的损失太小了,所以乘以1000,方便看损失的变化,保留两位有效数字
  16. 打印当前epoch
  17. 打印损失
  18. 调用测试函数,测试当前训练的网络的性能,结束epoch的迭代
  19. 打印训练完成

5、LeNet

5.1 网络结构

LeNet可以说是首次提出卷积神经网络的模型
主要包含下面的网络层:

  1. 5*5的二维卷积
  2. sigmoid激活函数(这里使用了relu)
  3. 5*5的二维卷积
  4. sigmoid激活函数
  5. 数据一维化
  6. 全连接层
  7. 全连接层
  8. softmax分类器

将网络结构打印出来:

LeNet(
-------(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
-------(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
-------(conv2_drop): Dropout2d(p=0.5, inplace=False)
-------(fc1): Linear(in_features=320, out_features=50, bias=True)
-------(fc2): Linear(in_features=50, out_features=10, bias=True)
)

5.2 PyTorch构建LeNet

class LeNet(nn.Module):def __init__(self, num_classes):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, num_classes)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)

这个时候已经是一个完整的项目了,看看10个epoch训练过程的打印:

D:\conda\envs\pytorch\python.exe A:\0_MNIST\train.py

Reading data…
train_data: (60000, 28, 28) train_label (60000,)
test_data: (10000, 28, 28) test_label (10000,)

Initialize neural network
test loss: 2301.68
test accuracy: 11.3 %

epoch step: 1
training loss: 634.74
test loss: 158.03
test accuracy: 95.29 %

epoch step: 2
training loss: 324.04
test loss: 107.62
test accuracy: 96.55 %

epoch step: 3
training loss: 271.25
test loss: 88.43
test accuracy: 97.04 %

epoch step: 4
training loss: 236.69
test loss: 70.94
test accuracy: 97.61 %

epoch step: 5
training loss: 211.05
test loss: 69.69
test accuracy: 97.72 %

epoch step: 6
training loss: 199.28
test loss: 62.04
test accuracy: 97.98 %

epoch step: 7
training loss: 187.11
test loss: 59.65
test accuracy: 97.98 %

epoch step: 8
training loss: 178.79
test loss: 53.89
test accuracy: 98.2 %

epoch step: 9
training loss: 168.75
test loss: 51.83
test accuracy: 98.43 %

epoch step: 10
training loss: 160.83
test loss: 50.35
test accuracy: 98.4 %

Training finished
进程已结束,退出代码为 0

可以看出基本上只要一个epoch就可以得到很好的训练效果了,后续的epoch中的提升比较小

相关文章:

现代卷积网络实战系列2:训练函数、PyTorch构建LeNet网络

4、训练函数 4.1 调用训练函数 train(epochs, net, train_loader, device, optimizer, test_loader, true_value)因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的…...

rust特性

特性,也叫特质,英文是trait。 trait是一种特殊的类型,用于抽象某些方法。trait类似于其他编程语言中的接口,但又有所不同。 trait定义了一组方法,其他类型可以各自实现这个trait的方法,从而形成多态。 一、…...

TouchGFX之画布控件

TouchGFX的画布控件,在使用相对较小的存储空间的同时保持高性能,可提供平滑、抗锯齿效果良好的几何图形绘制。 TouchGFX 设计器中可用的画布控件: LineCircleShapeLine Progress圆形进度条 存储空间分配和使用​ 为了生成反锯齿效果良好的…...

STM32F103RCT6学习笔记2:串口通信

今日开始快速掌握这款STM32F103RCT6芯片的环境与编程开发,有关基础知识的部分不会多唠,直接实践与运用!文章贴出代码测试工程与测试效果图: 目录 串口通信实验计划: 串口通信配置代码: 测试效果图&#…...

Opencv-图像噪声(均值滤波、高斯滤波、中值滤波)

图像的噪声 图像的平滑 均值滤波 均值滤波代码实现 import cv2 as cv import numpy as np import matplotlib.pyplot as plt from pylab import mplmpl.rcParams[font.sans-serif] [SimHei]img cv.imread("dog.png")#均值滤波cv.blur(img, (5, 5))将对图像img进行…...

MasterAlign相机参数设置-增益调节

相机参数设置-曝光时间调节操作说明 相机参数的设置对于获取清晰、准确的图像至关重要。曝光时间是其中一个关键参数,它直接影响图像的亮度和清晰度。以下是关于曝光时间调节的详细操作步骤,以帮助您轻松进行设置。 步骤一:登录系统 首先&…...

9月22日,每日信息差

今天是2023年09月22日,以下是为您准备的14条信息差 第一、亚马逊将于2024年初在Prime Video中加入广告。Prime Video内容中的广告将于2024年初在美国、英国、德国和加拿大推出,随后晚些时候在法国、意大利、西班牙、墨西哥和澳大利亚推出 第二、中国移…...

Java版本企业工程项目管理系统源码+spring cloud 系统管理+java 系统设置+二次开发

工程项目各模块及其功能点清单 一、系统管理 1、数据字典:实现对数据字典标签的增删改查操作 2、编码管理:实现对系统编码的增删改查操作 3、用户管理:管理和查看用户角色 4、菜单管理:实现对系统菜单的增删改查操…...

Android studio中如何下载sdk

打开 file -> settings 这个页面, 在要下载的 SDK 前面勾上, 然后点 apply 在 platforms 中就可以看到下载好的 SDK: Android SDK目录结构详细介绍可以参考这篇文章: 51CTO博客- Android SDK目录结构...

STM32单片机中国象棋TFT触摸屏小游戏

实践制作DIY- GC0167-中国象棋 一、功能说明: 基于STM32单片机设计-中国象棋 二、功能介绍: 硬件组成:STM32F103RCT6最小系统2.8寸TFT电阻触摸屏24C02存储器1个按键(悔棋) 游戏规则: 1.有悔棋键&…...

【PHP图片托管】CFimagehost搭建私人图床 - 无需数据库支持

文章目录 1.前言2. CFImagehost网站搭建2.1 CFImagehost下载和安装2.2 CFImagehost网页测试2.3 cpolar的安装和注册 3.本地网页发布3.1 Cpolar临时数据隧道3.2 Cpolar稳定隧道(云端设置)3.3.Cpolar稳定隧道(本地设置) 4.公网访问测…...

CCITT 标准的CRC-16检验算法

/******该文件使用查表法计算CCITT 标准的CRC-16检验码,并附测试代码********/ #include #define CRC_INIT 0xffff //CCITT初始CRC为全1 #define GOOD_CRC 0xf0b8 //校验时计算出的固定结果值 /****下表是常用ccitt 16,生成式1021反转成8408后的查询表格****/ u…...

docker启动mysql服务

创建基础文件 mkdir mysql mkdir -p mysql/data获取默认的my.cnf docker run -name mysql -d -p 3306:3306 mysql:latest docker cp mysql:/etc/my.cnf ./vim mysql/my.cnf # For advice on how to change settings please see # http://dev.mysql.com/doc/refman/8.1/en/se…...

Postman应用——Request数据导入导出

文章目录 导入请求数据导出请求数据导出Collection导出Environments 导出所有请求数据导出请求响应数据 Postman可以导入导出Request和Variable变量配置,可以通过文本方式(JOSN文本)或链接方式进行导入导出。 导入请求数据 可以通过JSON文件…...

十四、MySql的用户管理

文章目录 一、用户管理二、用户(一)用户信息(二)创建用户1.语法:2.案例: (三) 删除用户1.语法:2.示例: (四)修改用户密码1.语法&#…...

01.自动化交易综述

算法交易的概念: 利用自动化平台,执行预先设置的一系列规则完成交易行为。 算法交易的优势 1.历史数据评估 2.执行高效 3.无主观情绪输入 4.可度量评价 5.交易频率 算法交易的劣势 1.成本,成本低难以体现收益 2.技巧 算法交易流程 大前…...

基于SpringBoot的网上超市系统的设计与实现

目录 前言 一、技术栈 二、系统功能介绍 管理员功能实现 用户功能实现 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 网络技术和计算机技术发展至今,已经拥有了深厚的理论基础,并在现实中进行了充分运用,尤其是基于计…...

国内首家!阿里云 Elasticsearch 8.9 版本释放 AI 搜索新动能

简介: 阿里云作为国内首家上线 Elasticsearch 8.9版本的厂商,在提供 Elasticsearch Relevance Engine™ (ESRE™) 引擎的基础上,提供增强 AI 的最佳实践与 ES 本身的混合搜索能力,为用户带来了更多创新和探索的可能性。 近年来&a…...

uniapp获取一周日期和星期

UniApp可以使用JavaScript中的Date对象来获取当前日期和星期几。以下是一个示例代码,可以获取当前日期和星期几,并输出在一周内的每天早上和晚上: // 获取当前日期和星期 let date new Date(); let weekdays ["Sunday", "M…...

QT之QListWidget的介绍

QListWidget常用成员函数 1、成员函数介绍2、例子显示图片和按钮的例子 1、成员函数介绍 1)QListWidget(QWidget *parent nullptr) 构造函数,创建一个新的QListWidget对象。 2)void addItem(const QString &label) 在列表末尾添加一个项目,项目标…...

数据结构--排序(1)

文章目录 排序概念直接插入排序希尔排序冒泡排序堆排序选择排序验证不同排序的运行时间 排序概念 排序指的是通过某一特征关键字(如信息量大小,首字母等)来对一连串的数据进行重新排列的操作,实现递增或者递减的数据排序。 稳定…...

【AI视野·今日NLP 自然语言处理论文速览 第三十七期】Thu, 21 Sep 2023

AI视野今日CS.NLP 自然语言处理论文速览 Thu, 21 Sep 2023 Totally 57 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Chain-of-Verification Reduces Hallucination in Large Language Models Authors Shehzaad Dhuliawala, Mojt…...

高防服务器防护效果怎么样?

对于很多拥有在线业务的公司,数据是非常重要,如果遭到网络攻击会导致很严重的后果,所以很多公司选择高防服务器,那么高防服务器防护效果是怎么样的呢?今天就让小编带大家看一看吧! 弹性带宽。高防服务器一…...

tomcat架构概览

https://blog.csdn.net/ldw201510803006/article/details/119880100 前言 Tomcat 要实现 2 个核心功能: 处理 Socket 连接,负责网络字节流与 Request 和 Response 对象的转化。加载和管理 Servlet,以及具体处理 Request 请求。 因此 Tomc…...

海康的资料

系列文章目录 文章目录 系列文章目录前言一、海康二、使用步骤1.引入库2.读入数据 总结 前言 提示:这里可以添加本文要记录的大概内容: 例如:随着人工智能的不断发展,机器学习这门技术也越来越重要,很多人都开启了学…...

【ELFK】之消息队列kafka

本章结构: 1、为什么要使用消息队列MQ 2、使用消息队列的好处 3、消息队列的两种模式 4、对Kafka的概述 5、Kafka的特性 6、Kafka的系统架构 7、部署Kafka Kafka 定义 Kafka 是一个分布式的基于发布/订阅模式的消息队列(MQ,Message Qu…...

Qt核心:元对象系统、属性系统、对象树、信号槽

一、元对象系统 1、Qt 的元对象系统提供的功能有:对象间通信的信号和槽机制、运行时类型信息和动态属性系统等。 2、元对象系统是 Qt 对原有的 C进行的一些扩展,主要是为实现信号和槽机制而引入的, 信号和槽机制是 Qt 的核心特征。 3、要使…...

【若依框架2】前后端分离版本添加功能页

在VSCode的src/views下新建个文件平example,在example下创建test文件夹&#xff0c;在test里创建index.vue文件 <template> <h1>Hello world</h1> </template><script> export default {name: "index" } </script><style s…...

Unity Bolt模块间通信

使用Bolt无代码设计开发的时候&#xff0c;我们不能简单的认为只需要一个FlowMachine就可以完成所有流程的开发。我们需要不同的模块进行拆分&#xff0c;以便更好的管理和协作。这就需要不同模块之间的通信处理。经过研究与使用&#xff0c;将常用的通信方式总结如下&#xff…...

please choose a certificate and try again.(-5)报错怎么解决

the server you want to connect to requests identification,please choose a certificate and try again.(-5)...

电子网站有哪些/网络营销推广公司名称

前言Redis提供了5种数据类型&#xff1a;String(字符串)、Hash(哈希)、List(列表)、Set(集合)、Zset(有序集合)&#xff0c;理解每种数据类型的特点对于redis的开发和运维非常重要。Redis中的list是我们经常使用到的一种数据类型&#xff0c;根据使用方式的不同&#xff0c;可以…...

盗取dede系统做的网站模板/百度指数使用方法

获取样式用.attr(class,className); 追加样式用&#xff1a;.addClass(className),增加样式&#xff0c;原有的样式不会消失&#xff0c;而是再添加一个样式 移除样式&#xff1a;.removeClass(className)&#xff0c;删除有的样式 切换样式&#xff1a;.toggleClass(className…...

怎么创建音乐网站/推广关键词怎么设置

最短路径分析属于ArcGIS的网络分析范畴。而ArcGIS的网络分析分为两类&#xff0c;分别是基于几何网络和网络数据集的网络分析。它们都可以实现最短路径功能。下面先介绍基于几何网络的最短路径分析的实现。以后会陆续介绍基于网络数据集的最短路径分析以及这两种方法的区别。 几…...

做游戏网站用什么系统做/百度模拟点击软件判刑了

thinkphp3.2.3(5以下)的addAll返回值问题thinkphp3.2.3(5以下)的addAll返回值问题[var1]我们都知道mysql支持一次插入多条数据&#xff0c;如下&#xff1a;以用户表user为例&#xff0c;表结构自增主键id、账号username、密码password。insert into user(username,password) v…...

做柱状图 饼状图的网站/软文营销的本质

上一篇文章中已经配置缓存服务器&#xff0c;这里说说主DNS服务器的配置&#xff1a;涉及相关知识&#xff1a;资源记录&#xff1a;SOA&#xff1a;资源起始记录&#xff0c;放在配置文件的第一条A记录&#xff1a;域名指向IP地址AAAA记录&#xff1a;域名指向IPV6地址PTR记录…...

徐州做网站的哪个好/公司网站设计模板

阅读文本大概需要 3 分钟。清楚地认识自己最近几天朋友圈被程序员因同事没写代码注释而惨遭枪击&#xff0c;刷屏了&#xff0c;好几个技术圈的大号都转发了&#xff0c;并且都是 10w 的阅读量。那篇英文原文根本没有提及动机是因为没写代码注释&#xff0c;全是作者自己的猜想…...