手写数字识别实战
全部代码:
import matplotlib.pyplot
import torch
from torch import nn # nn是完成神经网络相关的一些工作
from torch.nn import functional as F # functional是常用的一些函数
from torch import optim # 优化的工具包import torchvision
from matplotlib import pyplot as plt
from utils import plot_images, plot_curve, one_hot# step1 : load dataset
# 指定了每次梯度更新时用于训练模型的数据样本数量
batch_size = 512 # 一次处理的图片数量 我们的处理的图片是28×28像素
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), # 把numpy格式转换为tensortorchvision.transforms.Normalize( # 图像像素分布在0-1,所以要-0.1307,除以标准差0.3801,使得数据能够在0附近均匀分布(0.1307,), (0.3081,))])),# 1.torchvision.transforms.Normalize 是 PyTorch 中的一个非常有用的图像预处理转换(transform),# 它主要用于将图像数据标准化到特定的均值(mean)和标准差(std)上。这个转换通常用于训练深度学习模型之前,# 特别是卷积神经网络(CNN)模型,因为标准化有助于模型更快地收敛并提高模型的性能。# 2.这里,(0.1307,) 和 (0.3081,) 分别指定了用于标准化的均值和标准差。注意,虽然这两个元组只包含一个元素,# 但它们实际上是为每个通道(channel)指定的。在这个特定的例子中,由于这些值是针对MNIST数据集的,而MNIST数据集是灰度图像,所以只有一个通道。# 3.虽然这里直接给出了均值(0.1307)和标准差(0.3081),但在实际应用中,这些值通常是通过计算整个训练数据集的像素值的统计量来获得的。# 对于MNIST这样的灰度数据集,计算整个数据集的像素均值和标准差,然后用于所有图像的标准化。batch_size=batch_size, shuffle=True) # batch_size一次行处理多少张图片,shuffle意味着加载时要做一个随机的打散test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader))
# x代表当前批次(batch)中的输入数据,即图像数据。对于MNIST数据集来说,x的形状通常是[batch_size, 1, 28, 28](如果数据没有被转换为灰度图并归一化,
# 则可能是[batch_size, 3, 28, 28],但MNIST是灰度图,所以通道数为1)。这里的batch_size是你在创建DataLoader时指定的每个批次中的样本数。# y代表当前批次中每个输入数据对应的标签(label),即每个图像对应的数字(0-9之间的整数)。y的形状通常是[batch_size],表示每个样本的类别标签。
print(x.shape, y.shape, x.min(), x.max())
plot_images(x, y, 'image sample')matplotlib.pyplot.show()# step2 : bulid a network
class Net(nn.Module):def __init__(self):super(Net,self).__init__()#wx+bself.fc1 = nn.Linear(28*28,256) #,28*28是x的维度,256一般根据经验随机决定,大维变成小维self.fc2 = nn.Linear(256,64) #第二层的输入与上一层的输出相同self.fc3 = nn.Linear(64,10) #10分类,此处不是根据经验#计算过程def forward(self,x):# x: [b,1,28,28]# h1 =relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3,最后一层看情况添加激活函数x = self.fc3(x) # 激活函数加不加取决于你的任务return x# step3 : 训练。 训练的逻辑是:每一次求导,然后再去更新
# net.parameters()返回[w1,b1,w2,b2,w3,b3],这就是我们要优化的; lr是学习步长 ;momentum帮助更好的优化
net = Net()
# 使用SGD优化器,学习率为0.01,动量为0.9
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
# 把loss保存起来
train_loss = []# epoch 是整个数据集的训练轮数。在这个例子中,数据集将被遍历3次。
# batch_idx 是当前批次的索引。
for epoch in range(3):for batch_idx, (x,y) in enumerate(train_loader):# x: [b,1,28,28] y : [512]# 将图像数据从[b,1,28,28]打平成[b,feature],size(0)是batch,因为网络期望的输入是一个一维的特征向量。x = x.view(x.size(0),28*28)# [b,10]# one_hot是一个自定义函数,用于将类别标签转换为one-hot编码out = net(x)# [b,10],真实的yy_onehot= one_hot(y)# loss=mse(out,y_onehot),求其均方差loss = F.mse_loss(out,y_onehot)#清零梯度optimizer.zero_grad()#计算梯度loss.backward()# 更新梯度:w‘ = w-lr*gradoptimizer.step()#进行梯度下降的可视化,把数据记录下来train_loss.append(loss.item())# 每隔10个批次,打印当前轮次、批次索引和损失值,以便于监控训练过程。if batch_idx % 10 == 0:print(epoch,batch_idx,loss.item())# 将训练损失绘制成曲线图
plot_curve(train_loss)
#we can get optimal [w1,b1,w2,b2,w3,b3]# step4 : 测试test
total_correct = 0
# 打印loss
for x, y in test_loader:x = x.view(x.size(0), 28 * 28)out = net(x) # 得到网络的输出# out: [b, 10] => pred: [b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item() # item()取数值 当前batch正确的个数total_correct += correct
total_num = len(test_loader.dataset) # 总的测试的数量
acc = total_correct / total_num # 准确率
print('test acc:', acc)
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28))
pred = out.argmax(dim=1)
plot_images(x, pred, 'test')
# 后期可进行的工作:
# def net()中增加网络层数
# def forward()中最后一层可以用softmax()
# loss:F.mse_loss()改成交叉熵函数
utils工具包:
# 四个步骤:load data; bulid model; train; test
import torch
from matplotlib import pyplot as pltdef plot_curve(data): # 绘制loss下降的曲线图fig = plt.figure()plt.plot(range(len(data)), data, color='blue')plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_images(img, label, name): # 画图片(因为这里涉及到一个图片的识别),这个地方可以方便地看到图片的识别结果fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')plt.title("{}:{}".format(name, label[i].item()))plt.xticks([])plt.yticks([])plt.showdef one_hot(label, depth=10): # 需要通过scatter()完成one_hot编码out = torch.zeros(label.size(0), depth)idx = torch.LongTensor(label).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return out
相关文章:

手写数字识别实战
全部代码: import matplotlib.pyplot import torch from torch import nn # nn是完成神经网络相关的一些工作 from torch.nn import functional as F # functional是常用的一些函数 from torch import optim # 优化的工具包import torchvision from matplotlib …...

二叉树遍历
二叉树的遍历是二叉树操作中的一个基本且重要的概念,它指的是按照一定的规则访问二叉树中的每个节点,并且每个节点仅被访问一次。常见的二叉树遍历方式有四种:前序遍历(Pre-order Traversal)、中序遍历(In-…...

uni app 调用前置摄像头
uniapp开发app并没有相关Api调用前置摄像头。只能使用5app的api 调用前置摄像头拍照 plus.camera.getCamera(index) 获取需要操作的摄像头对象,如果要进行拍照或摄像操作,需先通过此方法获取摄像头对象 index指定要获取摄像头的索引值,1表…...

哈工大李治军老师OS课程笔记(4)——内存管理
一 内存使用与分段(实验六) 内存是如何用起来的? 内存使用:将程序放在内存中,PC指向开始地址 重定位:修改程序中的地址(是相对地址) 什么时候完成重定位? 编译时加基址…...

代码随想录算法训练营第43天:动态规划part10:子序列问题
300.最长递增子序列 力扣题目链接(opens new window) 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2…...

传智教育引通义灵码进课堂,为技术人才教育学习提效
7 月 17 日,阿里云与传智教育在阿里巴巴云谷园区签署合作协议,双方将基于阿里云智能编程助手通义灵码在课程共建、品牌合作及产教融合等多个领域展开合作,共同推进 AI 教育及相关业务的发展,致力于培养适应未来社会需求的高素质技…...

企业信息化建设搞得好了叫系统工程,搞不好叫面子工程
2024-06-13 09:26贝格前端工场...

程序员如何平衡日常编码工作与提升式学习?
在快速变化的编程领域中,平衡日常编码工作与个人成长确实是一个重要且富有挑战性的议题。以下是我对这一问题的看法和建议: 1. 认识到平衡的重要性 首先,理解两者之间的平衡并非零和游戏,而是相辅相成的。高效的编码工作能够为个…...

Linux---文件系统和日志分析
文章目录 文件系统和日志分析inode和block概述inode包含文件的元信息用stat命令可以查看某个文件的inode信息Linux系统文件三个主要的时间属性 目录文件的结构用户通过文件名打开文件时,系统内部的过程查看inode号码的方法硬盘分区后的结构访问文件的简单流程inode的…...

MySQL 体系架构
文章目录 一. MySQL 分支与变种1. Drizzle2. MariaDB3. Percona Server 二. MySQL的替代1. Postgre SQL2. SQLite 三. MySQL 体系架构1.连接层2 Server层(SQL处理层)3. 存储引擎层1)MySQL官方存储引擎概要2)第三方引擎3࿰…...

跨站脚本攻击漏洞
1.JavaScript JavaScript 是一种脚本,一门编程语言,它可以在网页上实现复杂的功能,网页展现给你的不再是简单的静态信息,而是实时的内容更新,交互式的地图,2D/3D动画,滚动播放的视频等等。 &a…...

RabbitMQ入门与进阶
RabbitMQ入门与进阶 基础篇1. 为什么需要消息队列?2. 什么是消息队列?3. RabbitMQ体系结构介绍4. RabbitMQ安装5. HelloWorld6. RabbitMQ经典用法(工作模式)7. Work Queues8. Publish/Subscribe9. Routing10. Topics 进阶篇1. RabbitMQ整合SpringBoot2. 消息可靠性投递故障情…...

Unity新输入系统 之 InputActions(输入配置文件)
本文仅作笔记学习和分享,不用做任何商业用途 本文包括但不限于unity官方手册,unity唐老狮等教程知识,如有不足还请斧正 首先你应该了解新输入系统的基本单位Unity新输入系统 之 InputAction(输入配置文件最基本的单位࿰…...

Linux运维篇-误删/bin,/sbin目录怎么修复系统
这里写自定义目录标题 前言实例挂载镜像,重启系统进入救援模式拷贝镜像系统中的/bin和/sbin目录到原系统重启系统 总结 前言 当你看到这篇文章的时候,你的系统可能已经无法登录,或者正在处于登录状态但是不能执行任何常规的命令,…...

构建高效外贸电商系统的技术探索与源码开发
在当今全球化的经济浪潮中,外贸电商作为连接国内外市场的桥梁,其重要性日益凸显。一个高效、稳定、功能全面的外贸电商系统,不仅能够助力企业突破地域限制,拓宽销售渠道,还能提升客户体验,增强品牌竞争力。…...

Java设计模式:中介者模式详解与最佳实践
Java设计模式:中介者模式详解与最佳实践 1. 引言 在软件开发过程中,特别是复杂系统的构建中,模块间的交互往往成为影响代码质量的重要因素。当模块之间耦合度过高时,系统的维护、扩展和理解成本都会显著增加。为了降低模块之间的…...

Matlab绘制像素风字母颜色及透明度随机变化动画
本文是使用 Matlab 绘制像素风字母颜色及透明度随机变化动画的教程 实现效果 实现代码 如果需要更改为其他字母组合,在下面代码的基础上简单修改就可以使用。 步骤:(1) 定义字母形状;(2) 给出字母组合顺序;(3) 重新运行程序&#…...

C:每日一题:二分查找
1、知识介绍: 1.1 概念: 二分查找是一种在有序数组中查找某一特定元素的搜索算法 1.2 基本思想: 每次将待查找的范围缩小一半,通过比较中间元素与目标元素的大小,来决定是在左半部分还是右半部分继续查找。 举个生…...

python Django中使用ORM进行分组统计并降序排列
python Django中使用ORM进行分组统计并降序排列 # 使用supplier和Count进行分组统计,其中supplier为MyModel的一个字段 supplier_counts MyModel.objects.values(supplier).annotate(countCount(supplier)).order_by(-count) # 输出统计结果 for supplier_count in supplier_…...

QT C++ 编写modbus 总结
[开源库的使用]libModbus编译及使用_libmodbus库-CSDN博客 libmodbus的下载与编译_modbus库文件下载-CSDN博客 【QT5】解决 QT 界面中文显示乱码问题_qt5输出中文乱码解决方法-CSDN博客 Qt:解决qt修改完ui文件起不到作用_qt ui文件修改后不生效-CSDN博客...

基于SpringBoot的网络海鲜市场系统的设计与实现
TOC springboot219基于SpringBoot的网络海鲜市场系统的设计与实现 绪论 1.1 选题背景 当人们发现随着生产规模的不断扩大,人为计算方面才是一个巨大的短板,所以发明了各种计算设备,从结绳记事,到算筹,以及算盘&…...

c#相关基础知识
c#参数4种种别 值参:像Java的正常数据的传输 ref:对参数的指向是参数本身的地址,而不是数据的副本,所以可以对数据进行直接操作 out: 绑定控件,控件传输值赋值给类中的内部类 待定...

注意力机制 — 它是什么以及它是如何工作的
一、说明 注意力机制是深度学习领域的一个突破。它们帮助模型专注于数据的重要部分,并提高语言处理和计算机视觉等任务的理解和性能。这篇文章将深入探讨深度学习中注意力的基础知识,并展示其背后的主要思想。 二、注意力机制回顾 在我们谈论注意力之前&…...

学习嵌入式第二十六天
进程线程 1.进程的概念 2.进程 和 程序 硬盘中程序 ,加载到内存中,运行起来,就是进程 创建线程 pthread_create posix thread create 线程执行 ---体现在线程执行函数 (回调函数) 线程退出 ---pthread_exit() …...

speech语音audio音频
在信号处理和语言技术领域,speech 和 audio 是两个相关但不同的概念。它们有各自的定义和应用场景。以下是对这两个术语的详细解释: 1. Speech(语音) Speech 主要指的是人类说话时产生的声音。它是人类语言交流的一种主要形式&a…...

最常用的正则表达式规则和语法
正则表达式(Regular Expression,简称 regex)是一种用于匹配字符串的强大工具。它使用特定的语法规则来定义字符串模式,可以用来搜索、替换、验证字符串等。以下是一些常用的正则表达式规则和语法: 1. 基本字符匹配 . :匹配任意单个字符(除了换行符)。 示例:a.c 可以匹…...

Datawhale X 魔搭 AI夏令营第四期-魔搭生图task1学习笔记
根据教程提供的链接,进入相应文章了解魔搭生图的主要工作是通过对大量图片的训练,生成自己的模型,然后使用不同的正向、反向提示词使模型输出对应的图片 1.官方跑baseline教程链接:Task 1 从零入门AI生图原理&实践 2.简单列举一下赛事的…...

WPF中XAML相对路径表示方法
在WPF XAML中,相对路径是一种非常实用的方式来引用资源文件,如图像、样式表和其他XAML文件。相对路径可以帮助您构建更加灵活和可移植的应用程序,因为它允许资源文件的位置相对于XAML文件的位置进行定位。 相对路径的表示方法 在XAML中&…...

操作系统内存管理技术详解
操作系统内存管理技术详解:第一部分 引言 操作系统作为计算机系统的核心组件,负责管理硬件资源、提供用户接口和运行应用程序。在操作系统的众多功能中,内存管理无疑是最为关键的技术之一。本文将深入探讨操作系统内存管理的背后技术&…...

python之numpy(2 创建矩阵)
numpy创建矩阵 前面提到,numpy主要是针对数组和矩阵的操作。下面我们分别创建数组和矩阵。 import numpy as np x0np.array([1,2,3,4]) x1np.array([[1,2,3,4],[1,2,3,4]]) print(x0,x1,sep\n) 在numpy中,使用array创建数组和矩阵。其中,创…...