深度学习理论基础(三)封装数据集及手写数字识别
目录
- 前期准备
- 一、制作数据集
- 1. excel表格数据
- 2. 代码
- 二、手写数字识别
- 1. 下载数据集
- 2. 搭建模型
- 3. 训练网络
- 4. 测试网络
- 5. 保存训练模型
- 6. 导入已经训练好的模型文件
- 7. 完整代码
前期准备
必须使用 3 个 PyTorch 内置的实用工具(utils):
⚫ DataSet 用于封装数据集;
⚫ DataLoader 用于加载数据不同的批次;
⚫ random_split 用于划分训练集与测试集。
一、制作数据集
在封装我们的数据集时,必须继承实用工具(utils)中的 DataSet 的类,这个过程需要重写__init__和__getitem__、__len__三个方法,分别是为了加载数据集、获取数据索引、获取数据总量。我们通过代码读取excel表格里面的数据作为数据集。
1. excel表格数据

2. 代码
为了简单演示,我们将表格的第0列作为输入特征,第1列作为输出特征。
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import matplotlib.pyplot as plt# 制作数据集
class MyData(Dataset): """继承 Dataset 类"""def __init__(self, filepath):super().__init__()df = pd.read_excel(filepath).values """ 读取excel数据"""arr = df.astype(np.int32) """转为 int32 类型数组"""ts = torch.tensor(arr) """数组转为张量"""ts = ts.to('cuda') """把训练集搬到 cuda 上"""self.X = ts[:, :1] """获取第0列的所有行做为输入特征"""self.Y = ts[:, 1:2] """获取第1列的所有行为输出特征"""self.len = ts.shape[0] """样本的总数""" def __getitem__(self, index):return self.X[index], self.Y[index]def __len__(self):return self.lenif __name__ == '__main__': """获取数据集"""Data = MyData('label.xlsx')print(Data.X[0]) """输出为:tensor([1020741172], device='cuda:0', dtype=torch.int32)"""print(Data.Y[0]) """输出为:tensor([1], device='cuda:0', dtype=torch.int32) """print(Data.__len__()) """输出为:233 """"""划分训练集与测试集"""train_size = int(len(Data) * 0.7) # 训练集的样本数量test_size = len(Data) - train_size # 测试集的样本数量train_Data, test_Data = random_split(Data, [train_size, test_size])"""批次加载器"""""" 第一个参数:表示要加载的数据集,即之前划分好的 train_Data或test_Data 。"""""" 第二个参数:表示在每个 epoch(训练周期)开始之前是否重新洗牌数据。在训练过程中,通常会将数据进行洗牌,以确保模型能够学习到更加泛化的特征。而测试数据不需要重新洗牌,因为测试集仅用于评估模型的性能,不涉及模型参数的更新"""""" 第三个参数:表示每个批次中的样本数量为 32。也就是说,每次迭代加载器时,它会从训练数据集中加载128个样本。"""train_loader = DataLoader(train_Data, shuffle=True, batch_size=128)test_loader = DataLoader(test_Data, shuffle=False, batch_size=64)"""打印第一个批次的输入与输出特征"""for inputs, targets in train_loader:print(inputs)print(targets)
二、手写数字识别
1. 下载数据集
在下载数据集之前,要设定转换参数:transform,该参数里解决两个问题:
⚫ ToTensor:将图像数据转为张量,且调整三个维度的顺序为 (C-W-H);C表示通道数,二维灰度图像的通道数为 1,三维 RGB 彩图的通道数为 3。
⚫ Normalize:将神经网络的输入数据转化为标准正态分布,训练更好;根据统计计算,MNIST 训练集所有像素的均值是 0.1307、标准差是 0.3081
"""数据转换为tensor数据"""
transform_data = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.1307, 0.3081)
])"""下载训练集与测试集"""
train_Data = datasets.MNIST(root = 'E:/Desktop/Document/4. Python/例程代码/dataset/mnist/', """下载路径"""train = True, """训练集"""download = True, """如果该路径没有该数据集,就下载"""transform = transform_data """数据集转换参数"""
)
test_Data = datasets.MNIST(root = 'E:/Desktop/Document/4. Python/例程代码/dataset/mnist_test/', """下载路径"""train = False, """非训练集,也就是测试集"""download = True, """如果该路径没有该数据集,就下载"""transform = transform_data """数据集转换参数"""
)"""批次加载器"""
train_loader = DataLoader(train_Data, shuffle=True, batch_size=64)
test_loader = DataLoader(test_Data, shuffle=False, batch_size=64)

2. 搭建模型
class DNN(nn.Module):def __init__(self):''' 搭建神经网络各层 '''super(DNN,self).__init__()self.net = nn.Sequential( # 按顺序搭建各层nn.Flatten(), # 把图像铺平成一维nn.Linear(784, 512), nn.ReLU(), # 第 1 层:全连接层nn.Linear(512, 256), nn.ReLU(), # 第 2 层:全连接层nn.Linear(256, 128), nn.ReLU(), # 第 3 层:全连接层nn.Linear(128, 64), nn.ReLU(), # 第 4 层:全连接层nn.Linear(64, 10) # 第 5 层:全连接层)def forward(self, x):''' 前向传播 '''y = self.net(x) # x 即输入数据return y # y 即输出数据
3. 训练网络
"""实例化模型"""
model = DNN().to('cuda:0') def train_net():"""1.损失函数的选择"""loss_fn = nn.CrossEntropyLoss() # 自带 softmax 激活函数"""2.优化算法的选择"""learning_rate = 0.01 # 设置学习率optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate,momentum=0.5 # momentum(动量),它使梯度下降算法有了力与惯性)"""3.训练"""epochs = 5losses = [] """记录损失函数变化的列表"""for epoch in range(epochs):for (x, y) in train_loader: """从批次加载器中获取小批次的x与y"""x, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) #将样本放入实例化的模型中,这里自动调用forward方法。loss = loss_fn(Pred, y) # 计算损失函数losses.append(loss.item()) # 记录损失函数的变化optimizer.zero_grad() # 清理上一轮滞留的梯度loss.backward() # 一次反向传播optimizer.step() # 优化内部参数"""4.画损失图"""Fig = plt.figure()plt.plot(range(len(losses)), losses)plt.show()
损失图如下:

4. 测试网络
测试网络不需要回传梯度。
"""实例化模型"""
model = DNN().to('cuda:0') def test_net():correct = 0total = 0with torch.no_grad(): #该局部关闭梯度计算功能for (x, y) in test_loader: #从批次加载器中获取小批次的x与yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model (x) #将样本放入实例化的模型中,这里自动调用forward方法。_, predicted = torch.max(Pred.data, dim=1)correct += torch.sum((predicted == y))total += y.size(0)print(f'测试集精准度: {100 * correct / total} %')

5. 保存训练模型
在保存模型前,必须要先进行训练网络去获取和优化模型参数。
if __name__ == '__main__':model = DNN().to('cuda:0') train_net()torch.save(model,'old_model.pth')
6. 导入已经训练好的模型文件
导入训练好的模型文件,我们就不需要再进行训练网络,直接使用测试网络来测试即可。
new_model使用了原有模型文件,我们就需要在测试网络的前向传播中的模型修改为 new_model去进行测试。如下:
""" 假设我们之前保存好的模型文件为:'old_model.pth' """def test_net():correct = 0total = 0with torch.no_grad(): #该局部关闭梯度计算功能for (x, y) in test_loader: #从批次加载器中获取小批次的x与yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = new_model (x) #将样本放入实例化的模型中,这里自动调用forward方法。_, predicted = torch.max(Pred.data, dim=1)correct += torch.sum((predicted == y))total += y.size(0)print(f'测试集精准度: {100 * correct / total} %')if __name__ == '__main__':new_model = torch.load('old_model.pth')test_net()
7. 完整代码
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt"""------------1.下载数据集----------"""
"""数据转换为tensor数据"""
transform_data = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.1307, 0.3081)
])"""下载训练集与测试集"""
train_Data = datasets.MNIST(root = 'E:/Desktop/Document/4. Python/例程代码/dataset/mnist/', """下载路径"""train = True, """训练集"""download = True, """如果该路径没有该数据集,就下载"""transform = transform_data """数据集转换参数"""
)
test_Data = datasets.MNIST(root = 'E:/Desktop/Document/4. Python/例程代码/dataset/mnist_test/', """下载路径"""train = False, """非训练集,也就是测试集"""download = True, """如果该路径没有该数据集,就下载"""transform = transform_data """数据集转换参数"""
)"""批次加载器"""
train_loader = DataLoader(train_Data, shuffle=True, batch_size=64)
test_loader = DataLoader(test_Data, shuffle=False, batch_size=64)"""---------------2.定义模型------------"""
class DNN(nn.Module):def __init__(self):''' 搭建神经网络各层 '''super(DNN,self).__init__()self.net = nn.Sequential( # 按顺序搭建各层nn.Flatten(), # 把图像铺平成一维nn.Linear(784, 512), nn.ReLU(), # 第 1 层:全连接层nn.Linear(512, 256), nn.ReLU(), # 第 2 层:全连接层nn.Linear(256, 128), nn.ReLU(), # 第 3 层:全连接层nn.Linear(128, 64), nn.ReLU(), # 第 4 层:全连接层nn.Linear(64, 10) # 第 5 层:全连接层)def forward(self, x):''' 前向传播 '''y = self.net(x) # x 即输入数据return y # y 即输出数据"""-------------3.训练网络-----------"""
def train_net():# 损失函数的选择loss_fn = nn.CrossEntropyLoss() # 自带 softmax 激活函数# 优化算法的选择learning_rate = 0.01 # 设置学习率optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate,momentum=0.5)epochs = 5losses = [] # 记录损失函数变化的列表for epoch in range(epochs):for (x, y) in train_loader: # 获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) # 一次前向传播(小批量)loss = loss_fn(Pred, y) # 计算损失函数losses.append(loss.item()) # 记录损失函数的变化optimizer.zero_grad() # 清理上一轮滞留的梯度loss.backward() # 一次反向传播optimizer.step() # 优化内部参数"""Fig = plt.figure()""""""plt.plot(range(len(losses)), losses)""""""plt.show()""""""--------------------4.测试网络-----------"""
def test_net():correct = 0total = 0with torch.no_grad(): #该局部关闭梯度计算功能for (x, y) in test_loader: #获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = new_model(x) #一次前向传播(小批量)_, predicted = torch.max(Pred.data, dim=1)correct += torch.sum((predicted == y))total += y.size(0)print(f'测试集精准度: {100 * correct / total} %')if __name__ == '__main__':""" ------- 5.保存模型文件------"""""" model = DNN().to('cuda:0') """""" train_net() """""" torch.save(model,'old_model.pth') """""" ------- 6.加载模型文件 ----- """new_model = torch.load('old_model.pth')test_net()
相关文章:
深度学习理论基础(三)封装数据集及手写数字识别
目录 前期准备一、制作数据集1. excel表格数据2. 代码 二、手写数字识别1. 下载数据集2. 搭建模型3. 训练网络4. 测试网络5. 保存训练模型6. 导入已经训练好的模型文件7. 完整代码 前期准备 必须使用 3 个 PyTorch 内置的实用工具(utils): ⚫…...
vue3+eachrts饼图轮流切换显示高亮数据
<template><div class"charts-box"><div class"charts-instance" ref"chartRef"></div>// 自定义legend 样式<div class"charts-note"><span v-for"(items, index) in data.dataList" cla…...
UTONMOS:AI+Web3+元宇宙数字化“三位一体”将触发经济新爆点
人工智能、元宇宙、Web3,被称为数字化的“三位一体”,如何看待这三大技术所扮演的角色? 3月24日,2024全球开发者先锋大会“数字化的三位一体——人工智能、元宇宙、Web3.0”论坛在上海漕河泾开发区举行,首次提出&…...
开始焦虑了
大家好,我是洋子,25届的暑期实习自从3月份开始招聘有一段时间了,最近接到了几个25届同学的咨询,其中一个学妹印象比较深刻,她的情况如下 个人情况 学历是双非本,计算机专业,学习方向是Java&…...
数据结构和算法:十大排序
排序算法 排序算法用于对一组数据按照特定顺序进行排列。排序算法有着广泛的应用,因为有序数据通常能够被更高效地查找、分析和处理。 排序算法中的数据类型可以是整数、浮点数、字符或字符串等。排序的判断规则可根据需求设定,如数字大小、字符 ASCII…...
LLaMA-Factory微调(sft)ChatGLM3-6B保姆教程
LLaMA-Factory微调(sft)ChatGLM3-6B保姆教程 准备 1、下载 下载LLaMA-Factory下载ChatGLM3-6B下载ChatGLM3windows下载CUDA ToolKit 12.1 (本人是在windows进行训练的,显卡GTX 1660 Ti) CUDA安装完毕后,…...
Web安全-浏览器安全策略及跨站脚本攻击与请求伪造漏洞原理
Web安全-浏览器安全策略及跨站脚本攻击与请求伪造漏洞原理 Web服务组件分层概念 静态层 :web前端框架:Bootstrap,jQuery,HTML5框架等,主要存在跨站脚本攻击脚本层:web应用,web开发框架,web服务…...
蓝桥杯B组C++省赛——飞机降落(DFS)
题目连接:https://www.lanqiao.cn/problems/3511/learning/ 思路:由于数据范围很小,所有选择用DFS枚举所有飞机的所有的降落顺序,看哪个顺序可以让所有飞机顺利降落,有的话就算成功方案,输出了“YES”。 …...
Java 中的 Map集合
文章目录 添加和修改元素获取元素检查元素删除元素获取所有键 / 值 / 键值对大小 在 Java 中,Map 接口是 Java 集合框架的一部分,它存储键值对(key-value pairs)。Map 接口有许多常用的方法,用于添加、删除、获取元素&…...
基于springboot大学生兼职平台管理系统(完整源码+数据库)
一、项目简介 本项目是一套基于springboot大学生兼职平台管理系统 包含:项目源码、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试,eclipse或者idea 确保可以运行! 该系统功能完善、界面美观、操作简单、功…...
C#学生信息管理系统
一、引言 学生信息管理系统是现代学校管理的重要组成部分,它能够有效地管理学生的基本信息、课程信息、成绩信息等,提高学校管理的效率和质量。本文将介绍如何使用SQL Server数据库和C#语言在.NET平台上开发一个学生信息管理系统的课程设计项目。 二、项…...
双机 Cartogtapher 建图文件配置
双机cartogtapher建图 最近在做硕士毕设的最后一个实验,其中涉及到多机建图,经过调研最终采用cartographer建图算法,其中配置多机建图的文件有些麻烦,特此博客以记录 非常感谢我的同门 ”叶少“ 山上的稻草人-CSDN博客的帮助&am…...
VMware提示 该虚拟机似乎正在使用中,如何解决?
VMware提示 该虚拟机似乎正在使用中,如何解决? 问题描述解决方法1.找到安装VMware的文件目录2.在VMware目录下.lck后缀的文件夹删除或重命名3.运行VMware 问题描述 该虚拟机似乎正在使用中。 如果该虚拟机未在使用,请按“获取所有权(T)”按钮获取它的所…...
阿里云短信服务业务
一、了解阿里云用户权限操作 1.注册账号、实名认证; 2.使用AccessKey 步骤一 点击头像,权限安全的AccessKey 步骤二 设置子用户AccessKey 步骤三 添加用户组和用户 步骤四 添加用户组记得绑定短信服务权限 步骤五 添加用户记得勾选openApi访问 添加…...
ElasticSearch的DSL查询
ElasticSearch的DSL查询 准备工作 创建测试方法,初始化测试结构。 import org.apache.http.HttpHost; import org.apache.lucene.search.TotalHits; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRespo…...
每天定时杀spark进程
##编写shell脚本 #!/bin/bash arr(“zhangsan” “lisi” “wangwu”) for i in “${arr[]}” do processps -ef|grep ${i}| grep -v "grep"| awk {print $2} kill -9 ${process} done ##每日定时杀手动启动的进程 0 19 * * * cd /kill_process && sh kil…...
win10 安装kubectl,配置config连接k8s集群
安装kubectl 按照官方文档安装:https://kubernetes.io/docs/tasks/tools/install-kubectl-windows/ curl安装 (1)下载curl安装压缩包: curl for Windows (2)配置环境变量: 用户变量: Path变…...
Calico IPIP和BGP TOR的数据包走向
IPIP Mesh全网互联 文字描述 APOD eth0 10.7.75.132 -----> APOD 网关 -----> A宿主机 cali76174826315网卡 -----> Atunl0 10.7.75.128 封装 ----> Aeth0 10.120.181.20 -----> 通过网关 10.120.181.254 -----> 下一跳 BNODE eth0 10.120.179.8 解封装 --…...
静态成员主要用于提供与类本身相关的功能或数据,有什么应用场景
静态成员(包括静态方法和静态属性)在JavaScript中常用于多种应用场景,它们为类提供了与类本身直接相关而不是与实例相关的功能或数据。以下是一些常见的应用场景: 工厂方法 静态方法可以作为工厂方法,用于创建类的实…...
在线考试|基于Springboot的在线考试管理系统设计与实现(源码+数据库+文档)
在线考试管理系统目录 目录 基于Springboot的在线考试管理系统设计与实现 一、前言 二、系统设计 三、系统功能设计 1、前台: 2、后台 管理员功能 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取: 博主…...
SpringBoot-17-MyBatis动态SQL标签之常用标签
文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...
XML Group端口详解
在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...
C++实现分布式网络通信框架RPC(3)--rpc调用端
目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中,我们已经大致实现了rpc服务端的各项功能代…...
css实现圆环展示百分比,根据值动态展示所占比例
代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...
深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法
深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...
vscode(仍待补充)
写于2025 6.9 主包将加入vscode这个更权威的圈子 vscode的基本使用 侧边栏 vscode还能连接ssh? debug时使用的launch文件 1.task.json {"tasks": [{"type": "cppbuild","label": "C/C: gcc.exe 生成活动文件"…...
条件运算符
C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...
深入理解JavaScript设计模式之单例模式
目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式(Singleton Pattern&#…...
第25节 Node.js 断言测试
Node.js的assert模块主要用于编写程序的单元测试时使用,通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试,通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...
【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具
第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...
