【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集
课程链接
1.读取图像分类数据集
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
#读取数据集
trans=transforms.ToTensor()
mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)
print('训练数据集:',len(mnist_train),'测试数据集:',len(mnist_test))
print('训练数据集图片大小:',mnist_train[0][0].shape)#两个可视化数据集的函数
def get_fashion_mnist_labels(labels): #返回fashion_mnist数据集的文本标签text_labels=['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels ]
def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):figsize=(num_rows*scale,num_cols*scale)_,axes=d2l.plt.subplots(num_rows,num_cols,figsize=figsize)axes=axes.flatten()for i ,(ax,img) in enumerate(zip(axes,imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
#几个样本的图像及其相应的标签
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()#读取一小批量数据,大小为batchsize
batch_size=256
def get_dataloader_workers(): #使用4个进程来读取数据return 4
train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer=d2l.Timer()
for X,y in train_iter:continue
print(f'{timer.stop():.2f}sec')
# 便于重用函数
def load_data_fasion_mnist(batch_size,resize:None):trans = [transforms.ToTensor()]if resize:trans.insert(0,transforms.Resize(resize))trans=transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return(data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))
运行结果


2.Softmax 回归从零开始实现
import torch
from IPython import display
from d2l import torch as d2l
import matplotlib.pyplot as plt
import torchvision
from torch.utils import data
from torchvision import transforms
import numpy as npbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)
num_inputs=784 #展平图像为向量
num_outputs=10 # 有10个类所以模型输出为10
w=torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)#定义权重w
b=torch.zeros(num_outputs,requires_grad=True)# 定义softmax
def softmax(X):X_exp=torch.exp(X)#对每个元素做指数运算partition =X_exp.sum(1,keepdim=True)#按照行求和return X_exp/partition #矩阵中的各个元素/对应行元素之和
#验证一下是否是正确的
X=torch.normal(0,0.01,(2,5))# 创建均值为0方差为1的两行五列的X
X_prob=softmax(X)
print('1.验证softmax:',X_prob,X_prob.sum(1))
#实现softmax回归模型
def net(X):return softmax(torch.matmul(X.reshape((-1,w.shape[0])),w)+b) # -1,每次喂数据的量,就是batchsizey=torch.tensor([0,2])
y_hat=torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])
print('2.根据标号拿出预测值:',y_hat[[0,1],y])
# 实现交叉熵损失
def cross_entropy(y_hat,y): #给定预测和真实标号Yreturn -torch.log(y_hat[range(len(y_hat)),y])# 锁定y轴在x轴上根据labels收取预测值,交叉熵损失中除了真值=1,其他都是0,这里直接算针织对应的预测概率
print('3.交叉熵损失:',cross_entropy(y_hat,y))#将预测类别与真实元素y进行比较
def accuracy(y_hat,y):if len(y_hat.shape)>1 and y_hat.shape[1]>1: #shape和列数大于1的时候y_hat=y_hat.argmax(axis=1)#把每一行元素最大的下标存到y_hatcmp=y_hat.type(y.dtype)==y #y_hat和y的数据类型转换,作比较变成布尔return float(cmp.type(y.dtype).sum())#转换成和y一样的形状求和
print('4.预测正确的概率:',accuracy(y_hat,y)/len(y))# 预测正确的样本数除以y的长度就是预测正确的概率#计算模型在数据迭代器上的精度
def evaluate_accuracy(net,data_iter):if isinstance(net,torch.nn.Module):net.eval()#将模型设置为评估模式,输入后得出的结果用来评估模型的准确率,不做反向传播metric =Accumulator(2) # 累加器for X,y in data_iter:metric.add(accuracy(net(X),y),y.numel())return metric[0]/metric[1] #返回分类正确的样本数和总样本数# accumulator的实现
class Accumulator: #作用是累加def __init__(self,n):self.data=[0.0]*ndef add(self,*args):self.data=[a+float(b) for a,b in zip(self.data,args)]def reset(self):self.data=[0.0]*len(self.data)def __getitem__(self, idx):return self.data[idx]
if __name__=='__main__':print(evaluate_accuracy(net,test_iter))# softmax回归的训练
def train_epoch_ch3(net,train_iter,loss,updater):if isinstance(net,torch.nn.Module):net.train()metric=Accumulator(3)for X,y in train_iter:y_hat=net(X)l=loss(y_hat,y)if isinstance(updater,torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()metric.add(float(l)*len(y),accuracy(y_hat,y),y.size().numel())else:l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y),y.numel())return metric[0]/metric[2],metric[1]/metric[2]class Animator:def __init__(self,xlabel=None,ylabel=None,legend=None,xlim=None,ylim=None,xscale='linear',yscale='linear',fmts=('-','m--','g-.','r:'),nrows=1,ncols=1,figsize=(3.5,2.5)):if legend is None:legend=[]d2l.use_svg_display()self.fig,self.axes=d2l.plt.subplots(nrows,ncols,figsize=figsize)if nrows*ncols==1:self.axes=[self.axes, ]self.config_axes=lambda :d2l.set_axes(self.axes[0],xlabel,ylabel,xlim,ylim,xscale,yscale,legend)self.X,self.Y,self.fmts=None,None,fmtsdef add(self,x,y):if not hasattr(y,"__len__"):y=[y]n=len(y)if not hasattr(x, "__len__"):x=[x]*nif not self.X:self.X=[[]for _ in range(n)]if not self.Y:self.Y=[[]for _ in range(n)]for i ,(a,b) in enumerate(zip(x,y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x,y,fmt in zip(self.X,self.Y,self.fmts):self.axes[0].plot(x,y,fmt)self.config_axes()plt.draw()plt.pause(0.001)display.display(self.fig)display.clear_output(wait=True)def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):animator=Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0.3,0.9],legend=['train loss','train acc','test acc'])for epoch in range(num_epochs):train_metrics=train_epoch_ch3(net,train_iter,loss,updater)test_acc=evaluate_accuracy(net,test_iter)animator.add(epoch+1, train_metrics+(test_acc,))train_loss,train_acc=train_metricslr = 0.1
def updater(batch_size):return d2l.sgd([w,b],lr,batch_size)if __name__ == '__main__':num_epochs=10train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,updater)# 对图像进行分类的预测def predict_ch3(net,test_iter,n=6):for X,y in test_iter:breaktrues=d2l.get_fashion_mnist_labels(y)preds=d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles=[true+'\n'+pred for true,pred in zip(trues,preds)]d2l.show_images(X[0:n].reshape((n,28,28)),1,n,titles=titles[0:n])d2l.plt.show()
if __name__ == '__main__':predict_ch3(net,test_iter)
运行结果



3.Softmax 回归简洁实现
import torch
from torch import nn
from d2l import torch as d2lbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)# 初始化模型参数
net =nn.Sequential(nn.Flatten(),nn.Linear(784,10))def init_weights(m):if type(m)==nn.Linear:nn.init.normal_(m.weight,std=0.01)
net.apply(init_weights);loss=nn.CrossEntropyLoss(reduction='none')
trainer=torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs=10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)d2l.plt.show()
运行结果

相关文章:
【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集
课程链接 1.读取图像分类数据集 import matplotlib.pyplot as plt import torch import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2l d2l.use_svg_display() #读取数据集 transtransforms.ToTensor() mnist_…...
设计模式:24、访问者模式
目录 0、定义 1、访问者模式的五种角色 2、访问者模式的UML类图 3、示例代码 0、定义 表示一个作用于某对象结构中的各个元素的操作。它可以在不改变各个元素的类的前提下,定义作用于这些元素的新操作。 1、访问者模式的五种角色 抽象元素(Element…...
基于JAVA的旅游网站系统设计
摘要 随着信息技术和网络技术的迅速发展,人们的生活质量和观念也在发生着改变,各地争相发展旅游业,传统的 旅游社已经无法满足人们的需求,旅游网站将突破传统在时间和地域的限制,成为方便、快捷、安全、可靠的旅游 方…...
网络安全产品之认识防火墙
防火墙是一种网络安全产品,它设置在不同网络(如可信任的企业内部网和不可信的公共网)或网络安全域之间,通过监测、限制、更改跨越防火墙的数据流,尽可能地对外部屏蔽网络内部的信息、结构和运行状况,以此来…...
nginx反向代理(负载均衡)和tomcat介绍
nginx的代理 负载均衡 负载均衡的算法 负载均衡的架构 基于ip的七层代理 upstream模块要写在http模块中 七层代理的调用要写在location模块中 轮询 加权轮询 最小连接数 ip_Hash URL_HASH 基于域名的七层代理 配置主机 给其余客户机配置域名 给所有机器做域名映射 四层代理…...
Microsoft Azure 在线技术公开课:生成式 AI 基础知识
课程介绍 参加我们的生成式 AI 基础知识公开课,了解如何将最新 AI 进展应用到你的工作中。你将了解有关语言模型和生成式 AI 应用程序的基础知识。此外,你还将了解 Azure OpenAI 服务如何通过文本、代码、图像生成、自然语言摘要和语义搜索助你实现成果…...
lnmp+discuz论坛 附实验:搭建discuz论坛
Inmpdiscuz论坛 Inmp: t: linux操作系统 nr: nginx前端页面 me: mysql数据库 账号密码,等等都是保存在这个数据库里面 p: php——nginx擅长处理的是静态页面,页面登录账户,需要请求到数据库,通过php把动态请求转发到数据库 n…...
谷粒商城—分布式高级①.md
1. ELASTICSEARCH 1、安装elastic search dokcer中安装elastic search (1)下载ealastic search和kibana docker pull elasticsearch:7.6.2 docker pull kibana:7.6.2(2)配置 mkdir -p /mydata/elasticsearch/config mkdir -p /mydata/elasticsearch/data echo "h…...
Unity开发配置不足,卡顿崩溃怎么办?
在游戏开发和虚拟现实等领域,Unity 软件以其强大的功能和广泛的适用性成为了众多开发者的首选。然而,要充分发挥 Unity 的性能,一台高性能的电脑设备是必不可少的。今天,我要向大家介绍川翔云电脑,它为 Unity 开发者提…...
在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1
KubeSphere4.1安装文档 在 Kubernetes 上快速安装 KubeSphere 在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1 官方文档:在 Linux 上以 All-in-One 模式安装 KubeSphere 下载文件 KubeKey git地址Releases kubesphere/kubekey 或 …...
网络安全自学是一项需要耐心和恒心的任务
网络安全自学是一项需要耐心和恒心的任务,但只要你按照正确的学习路线图去努力,就能够逐步掌握这一领域的知识和技能。下面是一份详细的学习路线图,它将帮助你从零基础开始,逐步成为网络安全领域的专家。 第一阶段:基…...
Python+OpenCV系列:图像的几何变换
Python OpenCV 系列:图像的几何变换 引言 在图像处理领域,几何变换是一个非常重要的操作,它可以改变图像的位置、大小、方向或形状。在计算机视觉中,这些操作对于图像预处理、特征提取和图像增强至关重要。本文将介绍如何利用 …...
第P1周:Pytorch实现mnist手写数字识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 1. 实现pytorch环境配置 2. 实现mnist手写数字识别 3. 自己写几个数字识别试试具体实现 (一)环境 语言环境:Python…...
使用EventLog Analyzer进行Apache日志监控和日志分析
一、什么是Apache日志分析 Apache日志分析是网站管理和维护的重要部分,通过分析Apache服务器生成的日志文件,可以了解网站的访问情况、识别潜在的安全问题、优化网站性能等。 二、Apache日志类型 Apache日志主要有两种类型:访问日志&a…...
PaddleOCR模型ch_PP-OCRv3文本检测模型研究(二)颈部网络
上节研究了PaddleOCR文本检测v3模型的骨干网,本文接着研究其颈部网络。 文章目录 研究起点残注层颈部网络代码实验小结 研究起点 摘取开源yml配置文件,摘取网络架构Architecture中颈部网络的配置如下 Neck:name: RSEFPNout_channels: 96shortcut: True可…...
360极速浏览器不支持看PDF
360安全浏览器采用的是基于IE内核和Chrome内核的双核浏览器。360极速浏览器是源自Chromium开源项目的浏览器,不但完美融合了IE内核引擎,而且实现了双核引擎的无缝切换。因此在速度上,360极速浏览器的极速体验感更佳。 展示自己的时候要在有优…...
【深度学习】深刻理解ViT
ViT(Vision Transformer)是谷歌研究团队于2020年提出的一种新型图像识别模型,首次将Transformer架构成功应用于计算机视觉任务中。Transformer最初应用于自然语言处理(如BERT和GPT),而ViT展示了其在视觉任务…...
解决vue2中更新列表数据,页面dom没有重新渲染的问题
在 Vue 2 中,直接修改数组的某个项可能不会触发视图的更新。这是因为 Vue 不能检测到数组的索引变化或对象属性的直接赋值。为了确保 Vue 能够正确地响应数据变化,你可以使用以下几种方法: 1. 使用 Vue.set() 使用 Vue.set() 方法可以确保 …...
vscode通过ssh连接远程服务器(实习心得)
一、连接ssh服务器 1.打开Visual Studio Code,进入拓展市场(CtrlShiftX),下载拓展Remote - SSH 2. 点击远程资源管理器选项卡,并选择远程(隧道/SSH)类别 3. 点击ssh配置:输入你的账号主机ip地址 4.在弹出的选择配置文件中…...
知识图谱9:知识图谱的展示
1、知识图谱的展示有很多工具 Neo4j Browser - - - - 浏览器版本 Neo4j Desktop - - - - 桌面版本 graphX - - - - 可以集成到Neo4j Desktop Neo4j 提供的 Neo4j Bloom 是用户友好的可视化工具,适合非技术用户直观地浏览图数据。Cypher 是其核心查询语言&#x…...
树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...
ubuntu搭建nfs服务centos挂载访问
在Ubuntu上设置NFS服务器 在Ubuntu上,你可以使用apt包管理器来安装NFS服务器。打开终端并运行: sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享,例如/shared: sudo mkdir /shared sud…...
中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试
作者:Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位:中南大学地球科学与信息物理学院论文标题:BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接:https://arxiv.…...
最新SpringBoot+SpringCloud+Nacos微服务框架分享
文章目录 前言一、服务规划二、架构核心1.cloud的pom2.gateway的异常handler3.gateway的filter4、admin的pom5、admin的登录核心 三、code-helper分享总结 前言 最近有个活蛮赶的,根据Excel列的需求预估的工时直接打骨折,不要问我为什么,主要…...
什么是库存周转?如何用进销存系统提高库存周转率?
你可能听说过这样一句话: “利润不是赚出来的,是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业,很多企业看着销售不错,账上却没钱、利润也不见了,一翻库存才发现: 一堆卖不动的旧货…...
论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)
笔记整理:刘治强,浙江大学硕士生,研究方向为知识图谱表示学习,大语言模型 论文链接:http://arxiv.org/abs/2407.16127 发表会议:ISWC 2024 1. 动机 传统的知识图谱补全(KGC)模型通过…...
Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...
什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南
文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...
C++使用 new 来创建动态数组
问题: 不能使用变量定义数组大小 原因: 这是因为数组在内存中是连续存储的,编译器需要在编译阶段就确定数组的大小,以便正确地分配内存空间。如果允许使用变量来定义数组的大小,那么编译器就无法在编译时确定数组的大…...
