【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集
课程链接
1.读取图像分类数据集
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
#读取数据集
trans=transforms.ToTensor()
mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)
print('训练数据集:',len(mnist_train),'测试数据集:',len(mnist_test))
print('训练数据集图片大小:',mnist_train[0][0].shape)#两个可视化数据集的函数
def get_fashion_mnist_labels(labels): #返回fashion_mnist数据集的文本标签text_labels=['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels ]
def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):figsize=(num_rows*scale,num_cols*scale)_,axes=d2l.plt.subplots(num_rows,num_cols,figsize=figsize)axes=axes.flatten()for i ,(ax,img) in enumerate(zip(axes,imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
#几个样本的图像及其相应的标签
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()#读取一小批量数据,大小为batchsize
batch_size=256
def get_dataloader_workers(): #使用4个进程来读取数据return 4
train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer=d2l.Timer()
for X,y in train_iter:continue
print(f'{timer.stop():.2f}sec')
# 便于重用函数
def load_data_fasion_mnist(batch_size,resize:None):trans = [transforms.ToTensor()]if resize:trans.insert(0,transforms.Resize(resize))trans=transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return(data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))
运行结果
2.Softmax 回归从零开始实现
import torch
from IPython import display
from d2l import torch as d2l
import matplotlib.pyplot as plt
import torchvision
from torch.utils import data
from torchvision import transforms
import numpy as npbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)
num_inputs=784 #展平图像为向量
num_outputs=10 # 有10个类所以模型输出为10
w=torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)#定义权重w
b=torch.zeros(num_outputs,requires_grad=True)# 定义softmax
def softmax(X):X_exp=torch.exp(X)#对每个元素做指数运算partition =X_exp.sum(1,keepdim=True)#按照行求和return X_exp/partition #矩阵中的各个元素/对应行元素之和
#验证一下是否是正确的
X=torch.normal(0,0.01,(2,5))# 创建均值为0方差为1的两行五列的X
X_prob=softmax(X)
print('1.验证softmax:',X_prob,X_prob.sum(1))
#实现softmax回归模型
def net(X):return softmax(torch.matmul(X.reshape((-1,w.shape[0])),w)+b) # -1,每次喂数据的量,就是batchsizey=torch.tensor([0,2])
y_hat=torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])
print('2.根据标号拿出预测值:',y_hat[[0,1],y])
# 实现交叉熵损失
def cross_entropy(y_hat,y): #给定预测和真实标号Yreturn -torch.log(y_hat[range(len(y_hat)),y])# 锁定y轴在x轴上根据labels收取预测值,交叉熵损失中除了真值=1,其他都是0,这里直接算针织对应的预测概率
print('3.交叉熵损失:',cross_entropy(y_hat,y))#将预测类别与真实元素y进行比较
def accuracy(y_hat,y):if len(y_hat.shape)>1 and y_hat.shape[1]>1: #shape和列数大于1的时候y_hat=y_hat.argmax(axis=1)#把每一行元素最大的下标存到y_hatcmp=y_hat.type(y.dtype)==y #y_hat和y的数据类型转换,作比较变成布尔return float(cmp.type(y.dtype).sum())#转换成和y一样的形状求和
print('4.预测正确的概率:',accuracy(y_hat,y)/len(y))# 预测正确的样本数除以y的长度就是预测正确的概率#计算模型在数据迭代器上的精度
def evaluate_accuracy(net,data_iter):if isinstance(net,torch.nn.Module):net.eval()#将模型设置为评估模式,输入后得出的结果用来评估模型的准确率,不做反向传播metric =Accumulator(2) # 累加器for X,y in data_iter:metric.add(accuracy(net(X),y),y.numel())return metric[0]/metric[1] #返回分类正确的样本数和总样本数# accumulator的实现
class Accumulator: #作用是累加def __init__(self,n):self.data=[0.0]*ndef add(self,*args):self.data=[a+float(b) for a,b in zip(self.data,args)]def reset(self):self.data=[0.0]*len(self.data)def __getitem__(self, idx):return self.data[idx]
if __name__=='__main__':print(evaluate_accuracy(net,test_iter))# softmax回归的训练
def train_epoch_ch3(net,train_iter,loss,updater):if isinstance(net,torch.nn.Module):net.train()metric=Accumulator(3)for X,y in train_iter:y_hat=net(X)l=loss(y_hat,y)if isinstance(updater,torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()metric.add(float(l)*len(y),accuracy(y_hat,y),y.size().numel())else:l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y),y.numel())return metric[0]/metric[2],metric[1]/metric[2]class Animator:def __init__(self,xlabel=None,ylabel=None,legend=None,xlim=None,ylim=None,xscale='linear',yscale='linear',fmts=('-','m--','g-.','r:'),nrows=1,ncols=1,figsize=(3.5,2.5)):if legend is None:legend=[]d2l.use_svg_display()self.fig,self.axes=d2l.plt.subplots(nrows,ncols,figsize=figsize)if nrows*ncols==1:self.axes=[self.axes, ]self.config_axes=lambda :d2l.set_axes(self.axes[0],xlabel,ylabel,xlim,ylim,xscale,yscale,legend)self.X,self.Y,self.fmts=None,None,fmtsdef add(self,x,y):if not hasattr(y,"__len__"):y=[y]n=len(y)if not hasattr(x, "__len__"):x=[x]*nif not self.X:self.X=[[]for _ in range(n)]if not self.Y:self.Y=[[]for _ in range(n)]for i ,(a,b) in enumerate(zip(x,y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x,y,fmt in zip(self.X,self.Y,self.fmts):self.axes[0].plot(x,y,fmt)self.config_axes()plt.draw()plt.pause(0.001)display.display(self.fig)display.clear_output(wait=True)def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):animator=Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0.3,0.9],legend=['train loss','train acc','test acc'])for epoch in range(num_epochs):train_metrics=train_epoch_ch3(net,train_iter,loss,updater)test_acc=evaluate_accuracy(net,test_iter)animator.add(epoch+1, train_metrics+(test_acc,))train_loss,train_acc=train_metricslr = 0.1
def updater(batch_size):return d2l.sgd([w,b],lr,batch_size)if __name__ == '__main__':num_epochs=10train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,updater)# 对图像进行分类的预测def predict_ch3(net,test_iter,n=6):for X,y in test_iter:breaktrues=d2l.get_fashion_mnist_labels(y)preds=d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles=[true+'\n'+pred for true,pred in zip(trues,preds)]d2l.show_images(X[0:n].reshape((n,28,28)),1,n,titles=titles[0:n])d2l.plt.show()
if __name__ == '__main__':predict_ch3(net,test_iter)
运行结果
3.Softmax 回归简洁实现
import torch
from torch import nn
from d2l import torch as d2lbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)# 初始化模型参数
net =nn.Sequential(nn.Flatten(),nn.Linear(784,10))def init_weights(m):if type(m)==nn.Linear:nn.init.normal_(m.weight,std=0.01)
net.apply(init_weights);loss=nn.CrossEntropyLoss(reduction='none')
trainer=torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs=10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)d2l.plt.show()
运行结果
相关文章:

【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集
课程链接 1.读取图像分类数据集 import matplotlib.pyplot as plt import torch import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2l d2l.use_svg_display() #读取数据集 transtransforms.ToTensor() mnist_…...

设计模式:24、访问者模式
目录 0、定义 1、访问者模式的五种角色 2、访问者模式的UML类图 3、示例代码 0、定义 表示一个作用于某对象结构中的各个元素的操作。它可以在不改变各个元素的类的前提下,定义作用于这些元素的新操作。 1、访问者模式的五种角色 抽象元素(Element…...

基于JAVA的旅游网站系统设计
摘要 随着信息技术和网络技术的迅速发展,人们的生活质量和观念也在发生着改变,各地争相发展旅游业,传统的 旅游社已经无法满足人们的需求,旅游网站将突破传统在时间和地域的限制,成为方便、快捷、安全、可靠的旅游 方…...
网络安全产品之认识防火墙
防火墙是一种网络安全产品,它设置在不同网络(如可信任的企业内部网和不可信的公共网)或网络安全域之间,通过监测、限制、更改跨越防火墙的数据流,尽可能地对外部屏蔽网络内部的信息、结构和运行状况,以此来…...

nginx反向代理(负载均衡)和tomcat介绍
nginx的代理 负载均衡 负载均衡的算法 负载均衡的架构 基于ip的七层代理 upstream模块要写在http模块中 七层代理的调用要写在location模块中 轮询 加权轮询 最小连接数 ip_Hash URL_HASH 基于域名的七层代理 配置主机 给其余客户机配置域名 给所有机器做域名映射 四层代理…...
Microsoft Azure 在线技术公开课:生成式 AI 基础知识
课程介绍 参加我们的生成式 AI 基础知识公开课,了解如何将最新 AI 进展应用到你的工作中。你将了解有关语言模型和生成式 AI 应用程序的基础知识。此外,你还将了解 Azure OpenAI 服务如何通过文本、代码、图像生成、自然语言摘要和语义搜索助你实现成果…...

lnmp+discuz论坛 附实验:搭建discuz论坛
Inmpdiscuz论坛 Inmp: t: linux操作系统 nr: nginx前端页面 me: mysql数据库 账号密码,等等都是保存在这个数据库里面 p: php——nginx擅长处理的是静态页面,页面登录账户,需要请求到数据库,通过php把动态请求转发到数据库 n…...

谷粒商城—分布式高级①.md
1. ELASTICSEARCH 1、安装elastic search dokcer中安装elastic search (1)下载ealastic search和kibana docker pull elasticsearch:7.6.2 docker pull kibana:7.6.2(2)配置 mkdir -p /mydata/elasticsearch/config mkdir -p /mydata/elasticsearch/data echo "h…...

Unity开发配置不足,卡顿崩溃怎么办?
在游戏开发和虚拟现实等领域,Unity 软件以其强大的功能和广泛的适用性成为了众多开发者的首选。然而,要充分发挥 Unity 的性能,一台高性能的电脑设备是必不可少的。今天,我要向大家介绍川翔云电脑,它为 Unity 开发者提…...

在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1
KubeSphere4.1安装文档 在 Kubernetes 上快速安装 KubeSphere 在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1 官方文档:在 Linux 上以 All-in-One 模式安装 KubeSphere 下载文件 KubeKey git地址Releases kubesphere/kubekey 或 …...
网络安全自学是一项需要耐心和恒心的任务
网络安全自学是一项需要耐心和恒心的任务,但只要你按照正确的学习路线图去努力,就能够逐步掌握这一领域的知识和技能。下面是一份详细的学习路线图,它将帮助你从零基础开始,逐步成为网络安全领域的专家。 第一阶段:基…...
Python+OpenCV系列:图像的几何变换
Python OpenCV 系列:图像的几何变换 引言 在图像处理领域,几何变换是一个非常重要的操作,它可以改变图像的位置、大小、方向或形状。在计算机视觉中,这些操作对于图像预处理、特征提取和图像增强至关重要。本文将介绍如何利用 …...

第P1周:Pytorch实现mnist手写数字识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 1. 实现pytorch环境配置 2. 实现mnist手写数字识别 3. 自己写几个数字识别试试具体实现 (一)环境 语言环境:Python…...

使用EventLog Analyzer进行Apache日志监控和日志分析
一、什么是Apache日志分析 Apache日志分析是网站管理和维护的重要部分,通过分析Apache服务器生成的日志文件,可以了解网站的访问情况、识别潜在的安全问题、优化网站性能等。 二、Apache日志类型 Apache日志主要有两种类型:访问日志&a…...

PaddleOCR模型ch_PP-OCRv3文本检测模型研究(二)颈部网络
上节研究了PaddleOCR文本检测v3模型的骨干网,本文接着研究其颈部网络。 文章目录 研究起点残注层颈部网络代码实验小结 研究起点 摘取开源yml配置文件,摘取网络架构Architecture中颈部网络的配置如下 Neck:name: RSEFPNout_channels: 96shortcut: True可…...

360极速浏览器不支持看PDF
360安全浏览器采用的是基于IE内核和Chrome内核的双核浏览器。360极速浏览器是源自Chromium开源项目的浏览器,不但完美融合了IE内核引擎,而且实现了双核引擎的无缝切换。因此在速度上,360极速浏览器的极速体验感更佳。 展示自己的时候要在有优…...

【深度学习】深刻理解ViT
ViT(Vision Transformer)是谷歌研究团队于2020年提出的一种新型图像识别模型,首次将Transformer架构成功应用于计算机视觉任务中。Transformer最初应用于自然语言处理(如BERT和GPT),而ViT展示了其在视觉任务…...
解决vue2中更新列表数据,页面dom没有重新渲染的问题
在 Vue 2 中,直接修改数组的某个项可能不会触发视图的更新。这是因为 Vue 不能检测到数组的索引变化或对象属性的直接赋值。为了确保 Vue 能够正确地响应数据变化,你可以使用以下几种方法: 1. 使用 Vue.set() 使用 Vue.set() 方法可以确保 …...

vscode通过ssh连接远程服务器(实习心得)
一、连接ssh服务器 1.打开Visual Studio Code,进入拓展市场(CtrlShiftX),下载拓展Remote - SSH 2. 点击远程资源管理器选项卡,并选择远程(隧道/SSH)类别 3. 点击ssh配置:输入你的账号主机ip地址 4.在弹出的选择配置文件中…...

知识图谱9:知识图谱的展示
1、知识图谱的展示有很多工具 Neo4j Browser - - - - 浏览器版本 Neo4j Desktop - - - - 桌面版本 graphX - - - - 可以集成到Neo4j Desktop Neo4j 提供的 Neo4j Bloom 是用户友好的可视化工具,适合非技术用户直观地浏览图数据。Cypher 是其核心查询语言&#x…...
浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)
✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义(Task Definition&…...

网络编程(Modbus进阶)
思维导图 Modbus RTU(先学一点理论) 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议,由 Modicon 公司(现施耐德电气)于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?
编辑:陈萍萍的公主一点人工一点智能 未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战,在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

Chapter03-Authentication vulnerabilities
文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...
【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15
缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下: struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...

YSYX学习记录(八)
C语言,练习0: 先创建一个文件夹,我用的是物理机: 安装build-essential 练习1: 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件,随机修改或删除一部分,之后…...
解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错
出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上,所以报错,到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本,cu、torch、cp 的版本一定要对…...
#Uniapp篇:chrome调试unapp适配
chrome调试设备----使用Android模拟机开发调试移动端页面 Chrome://inspect/#devices MuMu模拟器Edge浏览器:Android原生APP嵌入的H5页面元素定位 chrome://inspect/#devices uniapp单位适配 根路径下 postcss.config.js 需要装这些插件 “postcss”: “^8.5.…...

MySQL 知识小结(一)
一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库,分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷,但是文件存放起来数据比较冗余,用二进制能够更好管理咱们M…...

淘宝扭蛋机小程序系统开发:打造互动性强的购物平台
淘宝扭蛋机小程序系统的开发,旨在打造一个互动性强的购物平台,让用户在购物的同时,能够享受到更多的乐趣和惊喜。 淘宝扭蛋机小程序系统拥有丰富的互动功能。用户可以通过虚拟摇杆操作扭蛋机,实现旋转、抽拉等动作,增…...