现代卷积网络实战系列2:训练函数、PyTorch构建LeNet网络
4、训练函数
4.1 调用训练函数
train(epochs, net, train_loader, device, optimizer, test_loader, true_value)
因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的参数,训练函数都需要
这七个参数,是训练一个神经网络所需要的最少参数
4.2 训练函数
训练函数中,所有训练集进行多次迭代,而每次迭代又会将数据分成多个批次进行迭代
def train(epochs, net, train_loader, device, optimizer, test_loader, true_value):for epoch in range(1, epochs + 1):net.train()all_train_loss = []for batch_idx, (data, target) in enumerate(train_loader):data = data.to(device)target = target.to(device)optimizer.zero_grad()output = net(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()cur_train_loss = loss.item()all_train_loss.append(cur_train_loss)train_loss = np.round(np.mean(all_train_loss) * 1000, 2)print('\nepoch step:', epoch)print('training loss: ', train_loss)test(net, test_loader, device, true_value, epoch)print("\nTraining finished")
- 定义训练函数
- 安装epochs迭代数据
- 进入pytorch的训练模式
- all_train_loss 存放训练集5万张图片的损失值
- 按照batch取数据
- 数据进入GPU
- 标签进入GPU
- 梯度清零
- 当前batch进入网络后得到输出
- 根据输出得到当前损失
- 反向传播
- 梯度下降
- 获取损失的损失值(PyTorch框架中的数据)
- 把当前batch的损失加入all_train_loss数组中,结束batch的迭代
- 将5张图片的损失计算出来并且进行求平均,这里乘以1000是因为我觉得计算出的损失太小了,所以乘以1000,方便看损失的变化,保留两位有效数字
- 打印当前epoch
- 打印损失
- 调用测试函数,测试当前训练的网络的性能,结束epoch的迭代
- 打印训练完成
5、LeNet
5.1 网络结构
LeNet可以说是首次提出卷积神经网络的模型
主要包含下面的网络层:
- 5*5的二维卷积
- sigmoid激活函数(这里使用了relu)
- 5*5的二维卷积
- sigmoid激活函数
- 数据一维化
- 全连接层
- 全连接层
- softmax分类器
将网络结构打印出来:
LeNet(
-------(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
-------(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
-------(conv2_drop): Dropout2d(p=0.5, inplace=False)
-------(fc1): Linear(in_features=320, out_features=50, bias=True)
-------(fc2): Linear(in_features=50, out_features=10, bias=True)
)
5.2 PyTorch构建LeNet
class LeNet(nn.Module):def __init__(self, num_classes):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, num_classes)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)
这个时候已经是一个完整的项目了,看看10个epoch训练过程的打印:
D:\conda\envs\pytorch\python.exe A:\0_MNIST\train.py
Reading data…
train_data: (60000, 28, 28) train_label (60000,)
test_data: (10000, 28, 28) test_label (10000,)
Initialize neural network
test loss: 2301.68
test accuracy: 11.3 %
epoch step: 1
training loss: 634.74
test loss: 158.03
test accuracy: 95.29 %
epoch step: 2
training loss: 324.04
test loss: 107.62
test accuracy: 96.55 %
epoch step: 3
training loss: 271.25
test loss: 88.43
test accuracy: 97.04 %
epoch step: 4
training loss: 236.69
test loss: 70.94
test accuracy: 97.61 %
epoch step: 5
training loss: 211.05
test loss: 69.69
test accuracy: 97.72 %
epoch step: 6
training loss: 199.28
test loss: 62.04
test accuracy: 97.98 %
epoch step: 7
training loss: 187.11
test loss: 59.65
test accuracy: 97.98 %
epoch step: 8
training loss: 178.79
test loss: 53.89
test accuracy: 98.2 %
epoch step: 9
training loss: 168.75
test loss: 51.83
test accuracy: 98.43 %
epoch step: 10
training loss: 160.83
test loss: 50.35
test accuracy: 98.4 %
Training finished
进程已结束,退出代码为 0
可以看出基本上只要一个epoch就可以得到很好的训练效果了,后续的epoch中的提升比较小
相关文章:
现代卷积网络实战系列2:训练函数、PyTorch构建LeNet网络
4、训练函数 4.1 调用训练函数 train(epochs, net, train_loader, device, optimizer, test_loader, true_value)因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的…...
rust特性
特性,也叫特质,英文是trait。 trait是一种特殊的类型,用于抽象某些方法。trait类似于其他编程语言中的接口,但又有所不同。 trait定义了一组方法,其他类型可以各自实现这个trait的方法,从而形成多态。 一、…...
TouchGFX之画布控件
TouchGFX的画布控件,在使用相对较小的存储空间的同时保持高性能,可提供平滑、抗锯齿效果良好的几何图形绘制。 TouchGFX 设计器中可用的画布控件: LineCircleShapeLine Progress圆形进度条 存储空间分配和使用 为了生成反锯齿效果良好的…...
STM32F103RCT6学习笔记2:串口通信
今日开始快速掌握这款STM32F103RCT6芯片的环境与编程开发,有关基础知识的部分不会多唠,直接实践与运用!文章贴出代码测试工程与测试效果图: 目录 串口通信实验计划: 串口通信配置代码: 测试效果图&#…...
Opencv-图像噪声(均值滤波、高斯滤波、中值滤波)
图像的噪声 图像的平滑 均值滤波 均值滤波代码实现 import cv2 as cv import numpy as np import matplotlib.pyplot as plt from pylab import mplmpl.rcParams[font.sans-serif] [SimHei]img cv.imread("dog.png")#均值滤波cv.blur(img, (5, 5))将对图像img进行…...
MasterAlign相机参数设置-增益调节
相机参数设置-曝光时间调节操作说明 相机参数的设置对于获取清晰、准确的图像至关重要。曝光时间是其中一个关键参数,它直接影响图像的亮度和清晰度。以下是关于曝光时间调节的详细操作步骤,以帮助您轻松进行设置。 步骤一:登录系统 首先&…...
9月22日,每日信息差
今天是2023年09月22日,以下是为您准备的14条信息差 第一、亚马逊将于2024年初在Prime Video中加入广告。Prime Video内容中的广告将于2024年初在美国、英国、德国和加拿大推出,随后晚些时候在法国、意大利、西班牙、墨西哥和澳大利亚推出 第二、中国移…...
Java版本企业工程项目管理系统源码+spring cloud 系统管理+java 系统设置+二次开发
工程项目各模块及其功能点清单 一、系统管理 1、数据字典:实现对数据字典标签的增删改查操作 2、编码管理:实现对系统编码的增删改查操作 3、用户管理:管理和查看用户角色 4、菜单管理:实现对系统菜单的增删改查操…...
Android studio中如何下载sdk
打开 file -> settings 这个页面, 在要下载的 SDK 前面勾上, 然后点 apply 在 platforms 中就可以看到下载好的 SDK: Android SDK目录结构详细介绍可以参考这篇文章: 51CTO博客- Android SDK目录结构...
STM32单片机中国象棋TFT触摸屏小游戏
实践制作DIY- GC0167-中国象棋 一、功能说明: 基于STM32单片机设计-中国象棋 二、功能介绍: 硬件组成:STM32F103RCT6最小系统2.8寸TFT电阻触摸屏24C02存储器1个按键(悔棋) 游戏规则: 1.有悔棋键&…...
【PHP图片托管】CFimagehost搭建私人图床 - 无需数据库支持
文章目录 1.前言2. CFImagehost网站搭建2.1 CFImagehost下载和安装2.2 CFImagehost网页测试2.3 cpolar的安装和注册 3.本地网页发布3.1 Cpolar临时数据隧道3.2 Cpolar稳定隧道(云端设置)3.3.Cpolar稳定隧道(本地设置) 4.公网访问测…...
CCITT 标准的CRC-16检验算法
/******该文件使用查表法计算CCITT 标准的CRC-16检验码,并附测试代码********/ #include #define CRC_INIT 0xffff //CCITT初始CRC为全1 #define GOOD_CRC 0xf0b8 //校验时计算出的固定结果值 /****下表是常用ccitt 16,生成式1021反转成8408后的查询表格****/ u…...
docker启动mysql服务
创建基础文件 mkdir mysql mkdir -p mysql/data获取默认的my.cnf docker run -name mysql -d -p 3306:3306 mysql:latest docker cp mysql:/etc/my.cnf ./vim mysql/my.cnf # For advice on how to change settings please see # http://dev.mysql.com/doc/refman/8.1/en/se…...
Postman应用——Request数据导入导出
文章目录 导入请求数据导出请求数据导出Collection导出Environments 导出所有请求数据导出请求响应数据 Postman可以导入导出Request和Variable变量配置,可以通过文本方式(JOSN文本)或链接方式进行导入导出。 导入请求数据 可以通过JSON文件…...
十四、MySql的用户管理
文章目录 一、用户管理二、用户(一)用户信息(二)创建用户1.语法:2.案例: (三) 删除用户1.语法:2.示例: (四)修改用户密码1.语法&#…...
01.自动化交易综述
算法交易的概念: 利用自动化平台,执行预先设置的一系列规则完成交易行为。 算法交易的优势 1.历史数据评估 2.执行高效 3.无主观情绪输入 4.可度量评价 5.交易频率 算法交易的劣势 1.成本,成本低难以体现收益 2.技巧 算法交易流程 大前…...
基于SpringBoot的网上超市系统的设计与实现
目录 前言 一、技术栈 二、系统功能介绍 管理员功能实现 用户功能实现 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 网络技术和计算机技术发展至今,已经拥有了深厚的理论基础,并在现实中进行了充分运用,尤其是基于计…...
国内首家!阿里云 Elasticsearch 8.9 版本释放 AI 搜索新动能
简介: 阿里云作为国内首家上线 Elasticsearch 8.9版本的厂商,在提供 Elasticsearch Relevance Engine™ (ESRE™) 引擎的基础上,提供增强 AI 的最佳实践与 ES 本身的混合搜索能力,为用户带来了更多创新和探索的可能性。 近年来&a…...
uniapp获取一周日期和星期
UniApp可以使用JavaScript中的Date对象来获取当前日期和星期几。以下是一个示例代码,可以获取当前日期和星期几,并输出在一周内的每天早上和晚上: // 获取当前日期和星期 let date new Date(); let weekdays ["Sunday", "M…...
QT之QListWidget的介绍
QListWidget常用成员函数 1、成员函数介绍2、例子显示图片和按钮的例子 1、成员函数介绍 1)QListWidget(QWidget *parent nullptr) 构造函数,创建一个新的QListWidget对象。 2)void addItem(const QString &label) 在列表末尾添加一个项目,项目标…...
电脑插入多块移动硬盘后经常出现卡顿和蓝屏
当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
OD 算法题 B卷【正整数到Excel编号之间的转换】
文章目录 正整数到Excel编号之间的转换 正整数到Excel编号之间的转换 excel的列编号是这样的:a b c … z aa ab ac… az ba bb bc…yz za zb zc …zz aaa aab aac…; 分别代表以下的编号1 2 3 … 26 27 28 29… 52 53 54 55… 676 677 678 679 … 702 703 704 705;…...
企业大模型服务合规指南:深度解析备案与登记制度
伴随AI技术的爆炸式发展,尤其是大模型(LLM)在各行各业的深度应用和整合,企业利用AI技术提升效率、创新服务的步伐不断加快。无论是像DeepSeek这样的前沿技术提供者,还是积极拥抱AI转型的传统企业,在面向公众…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现指南针功能
指南针功能是许多位置服务应用的基础功能之一。下面我将详细介绍如何在HarmonyOS 5中使用DevEco Studio实现指南针功能。 1. 开发环境准备 确保已安装DevEco Studio 3.1或更高版本确保项目使用的是HarmonyOS 5.0 SDK在项目的module.json5中配置必要的权限 2. 权限配置 在mo…...
Vue 3 + WebSocket 实战:公司通知实时推送功能详解
📢 Vue 3 WebSocket 实战:公司通知实时推送功能详解 📌 收藏 点赞 关注,项目中要用到推送功能时就不怕找不到了! 实时通知是企业系统中常见的功能,比如:管理员发布通知后,所有用户…...
【题解-洛谷】P10480 可达性统计
题目:P10480 可达性统计 题目描述 给定一张 N N N 个点 M M M 条边的有向无环图,分别统计从每个点出发能够到达的点的数量。 输入格式 第一行两个整数 N , M N,M N,M,接下来 M M M 行每行两个整数 x , y x,y x,y,表示从 …...
AT模式下的全局锁冲突如何解决?
一、全局锁冲突解决方案 1. 业务层重试机制(推荐方案) Service public class OrderService {GlobalTransactionalRetryable(maxAttempts 3, backoff Backoff(delay 100))public void createOrder(OrderDTO order) {// 库存扣减(自动加全…...
性能优化中,多面体模型基本原理
1)多面体编译技术是一种基于多面体模型的程序分析和优化技术,它将程序 中的语句实例、访问关系、依赖关系和调度等信息映射到多维空间中的几何对 象,通过对这些几何对象进行几何操作和线性代数计算来进行程序的分析和优 化。 其中࿰…...
从数据报表到决策大脑:AI重构电商决策链条
在传统电商运营中,决策链条往往止步于“数据报表层”:BI工具整合历史数据,生成滞后一周甚至更久的销售分析,运营团队凭经验预判需求。当爆款突然断货、促销库存积压时,企业才惊觉标准化BI的决策时差正成为增长瓶颈。 一…...
