【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集
课程链接
1.读取图像分类数据集
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
#读取数据集
trans=transforms.ToTensor()
mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)
print('训练数据集:',len(mnist_train),'测试数据集:',len(mnist_test))
print('训练数据集图片大小:',mnist_train[0][0].shape)#两个可视化数据集的函数
def get_fashion_mnist_labels(labels): #返回fashion_mnist数据集的文本标签text_labels=['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels ]
def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):figsize=(num_rows*scale,num_cols*scale)_,axes=d2l.plt.subplots(num_rows,num_cols,figsize=figsize)axes=axes.flatten()for i ,(ax,img) in enumerate(zip(axes,imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
#几个样本的图像及其相应的标签
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()#读取一小批量数据,大小为batchsize
batch_size=256
def get_dataloader_workers(): #使用4个进程来读取数据return 4
train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer=d2l.Timer()
for X,y in train_iter:continue
print(f'{timer.stop():.2f}sec')
# 便于重用函数
def load_data_fasion_mnist(batch_size,resize:None):trans = [transforms.ToTensor()]if resize:trans.insert(0,transforms.Resize(resize))trans=transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return(data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))
运行结果


2.Softmax 回归从零开始实现
import torch
from IPython import display
from d2l import torch as d2l
import matplotlib.pyplot as plt
import torchvision
from torch.utils import data
from torchvision import transforms
import numpy as npbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)
num_inputs=784 #展平图像为向量
num_outputs=10 # 有10个类所以模型输出为10
w=torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)#定义权重w
b=torch.zeros(num_outputs,requires_grad=True)# 定义softmax
def softmax(X):X_exp=torch.exp(X)#对每个元素做指数运算partition =X_exp.sum(1,keepdim=True)#按照行求和return X_exp/partition #矩阵中的各个元素/对应行元素之和
#验证一下是否是正确的
X=torch.normal(0,0.01,(2,5))# 创建均值为0方差为1的两行五列的X
X_prob=softmax(X)
print('1.验证softmax:',X_prob,X_prob.sum(1))
#实现softmax回归模型
def net(X):return softmax(torch.matmul(X.reshape((-1,w.shape[0])),w)+b) # -1,每次喂数据的量,就是batchsizey=torch.tensor([0,2])
y_hat=torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])
print('2.根据标号拿出预测值:',y_hat[[0,1],y])
# 实现交叉熵损失
def cross_entropy(y_hat,y): #给定预测和真实标号Yreturn -torch.log(y_hat[range(len(y_hat)),y])# 锁定y轴在x轴上根据labels收取预测值,交叉熵损失中除了真值=1,其他都是0,这里直接算针织对应的预测概率
print('3.交叉熵损失:',cross_entropy(y_hat,y))#将预测类别与真实元素y进行比较
def accuracy(y_hat,y):if len(y_hat.shape)>1 and y_hat.shape[1]>1: #shape和列数大于1的时候y_hat=y_hat.argmax(axis=1)#把每一行元素最大的下标存到y_hatcmp=y_hat.type(y.dtype)==y #y_hat和y的数据类型转换,作比较变成布尔return float(cmp.type(y.dtype).sum())#转换成和y一样的形状求和
print('4.预测正确的概率:',accuracy(y_hat,y)/len(y))# 预测正确的样本数除以y的长度就是预测正确的概率#计算模型在数据迭代器上的精度
def evaluate_accuracy(net,data_iter):if isinstance(net,torch.nn.Module):net.eval()#将模型设置为评估模式,输入后得出的结果用来评估模型的准确率,不做反向传播metric =Accumulator(2) # 累加器for X,y in data_iter:metric.add(accuracy(net(X),y),y.numel())return metric[0]/metric[1] #返回分类正确的样本数和总样本数# accumulator的实现
class Accumulator: #作用是累加def __init__(self,n):self.data=[0.0]*ndef add(self,*args):self.data=[a+float(b) for a,b in zip(self.data,args)]def reset(self):self.data=[0.0]*len(self.data)def __getitem__(self, idx):return self.data[idx]
if __name__=='__main__':print(evaluate_accuracy(net,test_iter))# softmax回归的训练
def train_epoch_ch3(net,train_iter,loss,updater):if isinstance(net,torch.nn.Module):net.train()metric=Accumulator(3)for X,y in train_iter:y_hat=net(X)l=loss(y_hat,y)if isinstance(updater,torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()metric.add(float(l)*len(y),accuracy(y_hat,y),y.size().numel())else:l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y),y.numel())return metric[0]/metric[2],metric[1]/metric[2]class Animator:def __init__(self,xlabel=None,ylabel=None,legend=None,xlim=None,ylim=None,xscale='linear',yscale='linear',fmts=('-','m--','g-.','r:'),nrows=1,ncols=1,figsize=(3.5,2.5)):if legend is None:legend=[]d2l.use_svg_display()self.fig,self.axes=d2l.plt.subplots(nrows,ncols,figsize=figsize)if nrows*ncols==1:self.axes=[self.axes, ]self.config_axes=lambda :d2l.set_axes(self.axes[0],xlabel,ylabel,xlim,ylim,xscale,yscale,legend)self.X,self.Y,self.fmts=None,None,fmtsdef add(self,x,y):if not hasattr(y,"__len__"):y=[y]n=len(y)if not hasattr(x, "__len__"):x=[x]*nif not self.X:self.X=[[]for _ in range(n)]if not self.Y:self.Y=[[]for _ in range(n)]for i ,(a,b) in enumerate(zip(x,y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x,y,fmt in zip(self.X,self.Y,self.fmts):self.axes[0].plot(x,y,fmt)self.config_axes()plt.draw()plt.pause(0.001)display.display(self.fig)display.clear_output(wait=True)def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):animator=Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0.3,0.9],legend=['train loss','train acc','test acc'])for epoch in range(num_epochs):train_metrics=train_epoch_ch3(net,train_iter,loss,updater)test_acc=evaluate_accuracy(net,test_iter)animator.add(epoch+1, train_metrics+(test_acc,))train_loss,train_acc=train_metricslr = 0.1
def updater(batch_size):return d2l.sgd([w,b],lr,batch_size)if __name__ == '__main__':num_epochs=10train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,updater)# 对图像进行分类的预测def predict_ch3(net,test_iter,n=6):for X,y in test_iter:breaktrues=d2l.get_fashion_mnist_labels(y)preds=d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles=[true+'\n'+pred for true,pred in zip(trues,preds)]d2l.show_images(X[0:n].reshape((n,28,28)),1,n,titles=titles[0:n])d2l.plt.show()
if __name__ == '__main__':predict_ch3(net,test_iter)
运行结果



3.Softmax 回归简洁实现
import torch
from torch import nn
from d2l import torch as d2lbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)# 初始化模型参数
net =nn.Sequential(nn.Flatten(),nn.Linear(784,10))def init_weights(m):if type(m)==nn.Linear:nn.init.normal_(m.weight,std=0.01)
net.apply(init_weights);loss=nn.CrossEntropyLoss(reduction='none')
trainer=torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs=10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)d2l.plt.show()
运行结果

相关文章:
【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集
课程链接 1.读取图像分类数据集 import matplotlib.pyplot as plt import torch import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2l d2l.use_svg_display() #读取数据集 transtransforms.ToTensor() mnist_…...
设计模式:24、访问者模式
目录 0、定义 1、访问者模式的五种角色 2、访问者模式的UML类图 3、示例代码 0、定义 表示一个作用于某对象结构中的各个元素的操作。它可以在不改变各个元素的类的前提下,定义作用于这些元素的新操作。 1、访问者模式的五种角色 抽象元素(Element…...
基于JAVA的旅游网站系统设计
摘要 随着信息技术和网络技术的迅速发展,人们的生活质量和观念也在发生着改变,各地争相发展旅游业,传统的 旅游社已经无法满足人们的需求,旅游网站将突破传统在时间和地域的限制,成为方便、快捷、安全、可靠的旅游 方…...
网络安全产品之认识防火墙
防火墙是一种网络安全产品,它设置在不同网络(如可信任的企业内部网和不可信的公共网)或网络安全域之间,通过监测、限制、更改跨越防火墙的数据流,尽可能地对外部屏蔽网络内部的信息、结构和运行状况,以此来…...
nginx反向代理(负载均衡)和tomcat介绍
nginx的代理 负载均衡 负载均衡的算法 负载均衡的架构 基于ip的七层代理 upstream模块要写在http模块中 七层代理的调用要写在location模块中 轮询 加权轮询 最小连接数 ip_Hash URL_HASH 基于域名的七层代理 配置主机 给其余客户机配置域名 给所有机器做域名映射 四层代理…...
Microsoft Azure 在线技术公开课:生成式 AI 基础知识
课程介绍 参加我们的生成式 AI 基础知识公开课,了解如何将最新 AI 进展应用到你的工作中。你将了解有关语言模型和生成式 AI 应用程序的基础知识。此外,你还将了解 Azure OpenAI 服务如何通过文本、代码、图像生成、自然语言摘要和语义搜索助你实现成果…...
lnmp+discuz论坛 附实验:搭建discuz论坛
Inmpdiscuz论坛 Inmp: t: linux操作系统 nr: nginx前端页面 me: mysql数据库 账号密码,等等都是保存在这个数据库里面 p: php——nginx擅长处理的是静态页面,页面登录账户,需要请求到数据库,通过php把动态请求转发到数据库 n…...
谷粒商城—分布式高级①.md
1. ELASTICSEARCH 1、安装elastic search dokcer中安装elastic search (1)下载ealastic search和kibana docker pull elasticsearch:7.6.2 docker pull kibana:7.6.2(2)配置 mkdir -p /mydata/elasticsearch/config mkdir -p /mydata/elasticsearch/data echo "h…...
Unity开发配置不足,卡顿崩溃怎么办?
在游戏开发和虚拟现实等领域,Unity 软件以其强大的功能和广泛的适用性成为了众多开发者的首选。然而,要充分发挥 Unity 的性能,一台高性能的电脑设备是必不可少的。今天,我要向大家介绍川翔云电脑,它为 Unity 开发者提…...
在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1
KubeSphere4.1安装文档 在 Kubernetes 上快速安装 KubeSphere 在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1 官方文档:在 Linux 上以 All-in-One 模式安装 KubeSphere 下载文件 KubeKey git地址Releases kubesphere/kubekey 或 …...
网络安全自学是一项需要耐心和恒心的任务
网络安全自学是一项需要耐心和恒心的任务,但只要你按照正确的学习路线图去努力,就能够逐步掌握这一领域的知识和技能。下面是一份详细的学习路线图,它将帮助你从零基础开始,逐步成为网络安全领域的专家。 第一阶段:基…...
Python+OpenCV系列:图像的几何变换
Python OpenCV 系列:图像的几何变换 引言 在图像处理领域,几何变换是一个非常重要的操作,它可以改变图像的位置、大小、方向或形状。在计算机视觉中,这些操作对于图像预处理、特征提取和图像增强至关重要。本文将介绍如何利用 …...
第P1周:Pytorch实现mnist手写数字识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 1. 实现pytorch环境配置 2. 实现mnist手写数字识别 3. 自己写几个数字识别试试具体实现 (一)环境 语言环境:Python…...
使用EventLog Analyzer进行Apache日志监控和日志分析
一、什么是Apache日志分析 Apache日志分析是网站管理和维护的重要部分,通过分析Apache服务器生成的日志文件,可以了解网站的访问情况、识别潜在的安全问题、优化网站性能等。 二、Apache日志类型 Apache日志主要有两种类型:访问日志&a…...
PaddleOCR模型ch_PP-OCRv3文本检测模型研究(二)颈部网络
上节研究了PaddleOCR文本检测v3模型的骨干网,本文接着研究其颈部网络。 文章目录 研究起点残注层颈部网络代码实验小结 研究起点 摘取开源yml配置文件,摘取网络架构Architecture中颈部网络的配置如下 Neck:name: RSEFPNout_channels: 96shortcut: True可…...
360极速浏览器不支持看PDF
360安全浏览器采用的是基于IE内核和Chrome内核的双核浏览器。360极速浏览器是源自Chromium开源项目的浏览器,不但完美融合了IE内核引擎,而且实现了双核引擎的无缝切换。因此在速度上,360极速浏览器的极速体验感更佳。 展示自己的时候要在有优…...
【深度学习】深刻理解ViT
ViT(Vision Transformer)是谷歌研究团队于2020年提出的一种新型图像识别模型,首次将Transformer架构成功应用于计算机视觉任务中。Transformer最初应用于自然语言处理(如BERT和GPT),而ViT展示了其在视觉任务…...
解决vue2中更新列表数据,页面dom没有重新渲染的问题
在 Vue 2 中,直接修改数组的某个项可能不会触发视图的更新。这是因为 Vue 不能检测到数组的索引变化或对象属性的直接赋值。为了确保 Vue 能够正确地响应数据变化,你可以使用以下几种方法: 1. 使用 Vue.set() 使用 Vue.set() 方法可以确保 …...
vscode通过ssh连接远程服务器(实习心得)
一、连接ssh服务器 1.打开Visual Studio Code,进入拓展市场(CtrlShiftX),下载拓展Remote - SSH 2. 点击远程资源管理器选项卡,并选择远程(隧道/SSH)类别 3. 点击ssh配置:输入你的账号主机ip地址 4.在弹出的选择配置文件中…...
知识图谱9:知识图谱的展示
1、知识图谱的展示有很多工具 Neo4j Browser - - - - 浏览器版本 Neo4j Desktop - - - - 桌面版本 graphX - - - - 可以集成到Neo4j Desktop Neo4j 提供的 Neo4j Bloom 是用户友好的可视化工具,适合非技术用户直观地浏览图数据。Cypher 是其核心查询语言&#x…...
jquery.inputmask插件介绍
目录 一、什么是 jQuery.inputmask? 主要应用场景 二、快速上手 1. 引入依赖文件 2. 基础用法 3. 掩码字符定义 三、高级功能 1. 自定义占位符 2. 完成回调 3. 扩展自定义字符 4. 重复掩码 5. 移除默认占位符 四、配合 Vue.js 使用 五、更多实用示例 …...
C++内联函数性能分析
C内联函数性能分析内联函数通过在调用点展开函数体来消除函数调用开销。理解内联机制和使用场景对于编写高性能代码至关重要。inline关键字建议编译器内联函数。#include #includeinline int add(int a, int b) { return a b; }inline int multiply(int a, int b) { return a …...
从Noise2Noise到Neighbor2Neighbor:图解自监督去噪的演进与核心‘采样’技巧
从Noise2Noise到Neighbor2Neighbor:自监督去噪技术的范式跃迁与工程实践 当你在昏暗环境下用手机拍摄一张照片时,那些恼人的彩色颗粒可能让你直接点击删除键。传统去噪方法需要大量"干净-噪声"图像对进行训练,而真实世界中获取完美…...
如何快速安装xfce-winxp-tc:10分钟打造XP风格的Linux桌面
如何快速安装xfce-winxp-tc:10分钟打造XP风格的Linux桌面 【免费下载链接】xfce-winxp-tc Windows XP stuff for XFCE 项目地址: https://gitcode.com/gh_mirrors/xf/xfce-winxp-tc 你是否怀念经典的Windows XP界面?xfce-winxp-tc项目让你在Linux…...
2026.5.12【芯片设计面试经验分享】上海车载芯片设计公司
一、主管面试 1、介绍下负责的cpu的九级流水线都有哪级? 指令预取、PC取指、指令译码、发射(双发射)、执行1(alu、运算)、执行2(乘法、移位)、访存、写回、提交/重排 2、负责的spyglass cdc 一般…...
STM32矩阵按键详解——4×4行列扫描与非阻塞消抖(硬件总结六)
前言 独立按键虽然简单,但当产品需要十几个按键时,每个按键独占一个GPIO的接法就变得很不经济。矩阵按键通过“行列”的交叉结构,仅用NM个GPIO即可驱动NM个按键。以最常见的44矩阵为例,16个按键仅需8个GPIO,引脚利用率…...
中国分县林地面积统计数据
一、数据简介 林地是指生长乔木、竹类、灌木及其他林业植物的土地,是陆地生态系统的重要组成部分,也是森林资源的核心载体。CnOpenData中国分县林地面积统计数据基于中国国土三调及国土年度变更调查汇总统计成果整合形成,包括全国、分省、分市…...
免费图片去水印工具有哪些?2026年在线网站、APP软件完整盘点与推荐
处理图片水印已经成为很多工作和生活场景的常见需求。无论是自媒体运营者整理素材、设计师进行后期处理,还是普通用户保存喜欢的图片,找到一个好用的去水印工具都能显著提高效率。在2026年,市场上涌现出许多免费的图片去水印工具,…...
ElevenLabs河南话合成效果翻车?5大本地化陷阱与97.3%可听度提升实测方案
更多请点击: https://codechina.net 第一章:ElevenLabs河南话语音合成效果翻车现象全景扫描 近期多位河南本地开发者及方言内容创作者反馈,ElevenLabs官方API在调用其“multilingual v2”模型尝试生成河南话(中原官话郑开片&…...
实测在ubuntu环境下调用taotoken api的延迟与稳定性表现
🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 实测在ubuntu环境下调用taotoken api的延迟与稳定性表现 本文旨在分享在Ubuntu 22.04 LTS系统环境下,使用Python脚本持…...
