当前位置: 首页 > news >正文

【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集

课程链接

1.读取图像分类数据集

import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
#读取数据集
trans=transforms.ToTensor()
mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)
print('训练数据集:',len(mnist_train),'测试数据集:',len(mnist_test))
print('训练数据集图片大小:',mnist_train[0][0].shape)#两个可视化数据集的函数
def get_fashion_mnist_labels(labels): #返回fashion_mnist数据集的文本标签text_labels=['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels ]
def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):figsize=(num_rows*scale,num_cols*scale)_,axes=d2l.plt.subplots(num_rows,num_cols,figsize=figsize)axes=axes.flatten()for i ,(ax,img) in enumerate(zip(axes,imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
#几个样本的图像及其相应的标签
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28,  28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()#读取一小批量数据,大小为batchsize
batch_size=256
def get_dataloader_workers(): #使用4个进程来读取数据return 4
train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer=d2l.Timer()
for X,y in train_iter:continue
print(f'{timer.stop():.2f}sec')
# 便于重用函数
def load_data_fasion_mnist(batch_size,resize:None):trans = [transforms.ToTensor()]if resize:trans.insert(0,transforms.Resize(resize))trans=transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return(data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))

运行结果

在这里插入图片描述
在这里插入图片描述

2.Softmax 回归从零开始实现

import torch
from IPython import display
from d2l import torch as d2l
import matplotlib.pyplot as plt
import torchvision
from torch.utils import data
from torchvision import transforms
import numpy as npbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)
num_inputs=784 #展平图像为向量
num_outputs=10 # 有10个类所以模型输出为10
w=torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)#定义权重w
b=torch.zeros(num_outputs,requires_grad=True)# 定义softmax
def softmax(X):X_exp=torch.exp(X)#对每个元素做指数运算partition =X_exp.sum(1,keepdim=True)#按照行求和return X_exp/partition #矩阵中的各个元素/对应行元素之和
#验证一下是否是正确的
X=torch.normal(0,0.01,(2,5))# 创建均值为0方差为1的两行五列的X
X_prob=softmax(X)
print('1.验证softmax:',X_prob,X_prob.sum(1))
#实现softmax回归模型
def net(X):return softmax(torch.matmul(X.reshape((-1,w.shape[0])),w)+b) # -1,每次喂数据的量,就是batchsizey=torch.tensor([0,2])
y_hat=torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])
print('2.根据标号拿出预测值:',y_hat[[0,1],y])
# 实现交叉熵损失
def cross_entropy(y_hat,y): #给定预测和真实标号Yreturn -torch.log(y_hat[range(len(y_hat)),y])# 锁定y轴在x轴上根据labels收取预测值,交叉熵损失中除了真值=1,其他都是0,这里直接算针织对应的预测概率
print('3.交叉熵损失:',cross_entropy(y_hat,y))#将预测类别与真实元素y进行比较
def accuracy(y_hat,y):if len(y_hat.shape)>1 and y_hat.shape[1]>1: #shape和列数大于1的时候y_hat=y_hat.argmax(axis=1)#把每一行元素最大的下标存到y_hatcmp=y_hat.type(y.dtype)==y #y_hat和y的数据类型转换,作比较变成布尔return float(cmp.type(y.dtype).sum())#转换成和y一样的形状求和
print('4.预测正确的概率:',accuracy(y_hat,y)/len(y))# 预测正确的样本数除以y的长度就是预测正确的概率#计算模型在数据迭代器上的精度
def evaluate_accuracy(net,data_iter):if isinstance(net,torch.nn.Module):net.eval()#将模型设置为评估模式,输入后得出的结果用来评估模型的准确率,不做反向传播metric =Accumulator(2) # 累加器for X,y in data_iter:metric.add(accuracy(net(X),y),y.numel())return metric[0]/metric[1] #返回分类正确的样本数和总样本数# accumulator的实现
class Accumulator: #作用是累加def __init__(self,n):self.data=[0.0]*ndef add(self,*args):self.data=[a+float(b) for a,b in zip(self.data,args)]def reset(self):self.data=[0.0]*len(self.data)def __getitem__(self, idx):return self.data[idx]
if __name__=='__main__':print(evaluate_accuracy(net,test_iter))# softmax回归的训练
def train_epoch_ch3(net,train_iter,loss,updater):if isinstance(net,torch.nn.Module):net.train()metric=Accumulator(3)for X,y in train_iter:y_hat=net(X)l=loss(y_hat,y)if isinstance(updater,torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()metric.add(float(l)*len(y),accuracy(y_hat,y),y.size().numel())else:l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y),y.numel())return metric[0]/metric[2],metric[1]/metric[2]class Animator:def __init__(self,xlabel=None,ylabel=None,legend=None,xlim=None,ylim=None,xscale='linear',yscale='linear',fmts=('-','m--','g-.','r:'),nrows=1,ncols=1,figsize=(3.5,2.5)):if legend is None:legend=[]d2l.use_svg_display()self.fig,self.axes=d2l.plt.subplots(nrows,ncols,figsize=figsize)if nrows*ncols==1:self.axes=[self.axes, ]self.config_axes=lambda :d2l.set_axes(self.axes[0],xlabel,ylabel,xlim,ylim,xscale,yscale,legend)self.X,self.Y,self.fmts=None,None,fmtsdef add(self,x,y):if not hasattr(y,"__len__"):y=[y]n=len(y)if not hasattr(x, "__len__"):x=[x]*nif not self.X:self.X=[[]for _ in range(n)]if not self.Y:self.Y=[[]for _ in range(n)]for i ,(a,b) in enumerate(zip(x,y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x,y,fmt in zip(self.X,self.Y,self.fmts):self.axes[0].plot(x,y,fmt)self.config_axes()plt.draw()plt.pause(0.001)display.display(self.fig)display.clear_output(wait=True)def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):animator=Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0.3,0.9],legend=['train loss','train acc','test acc'])for epoch in range(num_epochs):train_metrics=train_epoch_ch3(net,train_iter,loss,updater)test_acc=evaluate_accuracy(net,test_iter)animator.add(epoch+1, train_metrics+(test_acc,))train_loss,train_acc=train_metricslr = 0.1
def updater(batch_size):return d2l.sgd([w,b],lr,batch_size)if __name__ == '__main__':num_epochs=10train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,updater)# 对图像进行分类的预测def predict_ch3(net,test_iter,n=6):for X,y in test_iter:breaktrues=d2l.get_fashion_mnist_labels(y)preds=d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles=[true+'\n'+pred for true,pred in zip(trues,preds)]d2l.show_images(X[0:n].reshape((n,28,28)),1,n,titles=titles[0:n])d2l.plt.show()
if __name__ == '__main__':predict_ch3(net,test_iter)

运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.Softmax 回归简洁实现

import torch
from torch import nn
from d2l import torch as d2lbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)# 初始化模型参数
net =nn.Sequential(nn.Flatten(),nn.Linear(784,10))def init_weights(m):if type(m)==nn.Linear:nn.init.normal_(m.weight,std=0.01)
net.apply(init_weights);loss=nn.CrossEntropyLoss(reduction='none')
trainer=torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs=10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)d2l.plt.show()

运行结果

在这里插入图片描述

相关文章:

【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集

课程链接 1.读取图像分类数据集 import matplotlib.pyplot as plt import torch import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2l d2l.use_svg_display() #读取数据集 transtransforms.ToTensor() mnist_…...

设计模式:24、访问者模式

目录 0、定义 1、访问者模式的五种角色 2、访问者模式的UML类图 3、示例代码 0、定义 表示一个作用于某对象结构中的各个元素的操作。它可以在不改变各个元素的类的前提下,定义作用于这些元素的新操作。 1、访问者模式的五种角色 抽象元素(Element…...

基于JAVA的旅游网站系统设计

摘要 随着信息技术和网络技术的迅速发展,人们的生活质量和观念也在发生着改变,各地争相发展旅游业,传统的 旅游社已经无法满足人们的需求,旅游网站将突破传统在时间和地域的限制,成为方便、快捷、安全、可靠的旅游 方…...

网络安全产品之认识防火墙

防火墙是一种网络安全产品,它设置在不同网络(如可信任的企业内部网和不可信的公共网)或网络安全域之间,通过监测、限制、更改跨越防火墙的数据流,尽可能地对外部屏蔽网络内部的信息、结构和运行状况,以此来…...

nginx反向代理(负载均衡)和tomcat介绍

nginx的代理 负载均衡 负载均衡的算法 负载均衡的架构 基于ip的七层代理 upstream模块要写在http模块中 七层代理的调用要写在location模块中 轮询 加权轮询 最小连接数 ip_Hash URL_HASH 基于域名的七层代理 配置主机 给其余客户机配置域名 给所有机器做域名映射 四层代理…...

Microsoft Azure 在线技术公开课:生成式 AI 基础知识

课程介绍 参加我们的生成式 AI 基础知识公开课,了解如何将最新 AI 进展应用到你的工作中。你将了解有关语言模型和生成式 AI 应用程序的基础知识。此外,你还将了解 Azure OpenAI 服务如何通过文本、代码、图像生成、自然语言摘要和语义搜索助你实现成果…...

lnmp+discuz论坛 附实验:搭建discuz论坛

Inmpdiscuz论坛 Inmp: t: linux操作系统 nr: nginx前端页面 me: mysql数据库 账号密码,等等都是保存在这个数据库里面 p: php——nginx擅长处理的是静态页面,页面登录账户,需要请求到数据库,通过php把动态请求转发到数据库 n…...

谷粒商城—分布式高级①.md

1. ELASTICSEARCH 1、安装elastic search dokcer中安装elastic search (1)下载ealastic search和kibana docker pull elasticsearch:7.6.2 docker pull kibana:7.6.2(2)配置 mkdir -p /mydata/elasticsearch/config mkdir -p /mydata/elasticsearch/data echo "h…...

Unity开发配置不足,卡顿崩溃怎么办?

在游戏开发和虚拟现实等领域,Unity 软件以其强大的功能和广泛的适用性成为了众多开发者的首选。然而,要充分发挥 Unity 的性能,一台高性能的电脑设备是必不可少的。今天,我要向大家介绍川翔云电脑,它为 Unity 开发者提…...

在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1

KubeSphere4.1安装文档 在 Kubernetes 上快速安装 KubeSphere 在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1 官方文档:在 Linux 上以 All-in-One 模式安装 KubeSphere 下载文件 KubeKey git地址Releases kubesphere/kubekey 或 …...

网络安全自学是一项需要耐心和恒心的任务

网络安全自学是一项需要耐心和恒心的任务,但只要你按照正确的学习路线图去努力,就能够逐步掌握这一领域的知识和技能。下面是一份详细的学习路线图,它将帮助你从零基础开始,逐步成为网络安全领域的专家。 第一阶段:基…...

Python+OpenCV系列:图像的几何变换

Python OpenCV 系列:图像的几何变换 引言 在图像处理领域,几何变换是一个非常重要的操作,它可以改变图像的位置、大小、方向或形状。在计算机视觉中,这些操作对于图像预处理、特征提取和图像增强至关重要。本文将介绍如何利用 …...

第P1周:Pytorch实现mnist手写数字识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 1. 实现pytorch环境配置 2. 实现mnist手写数字识别 3. 自己写几个数字识别试试具体实现 (一)环境 语言环境:Python…...

使用EventLog Analyzer进行Apache日志监控和日志分析

一、什么是Apache日志分析 Apache日志分析是网站管理和维护的重要部分,通过分析Apache服务器生成的日志文件,可以了解网站的访问情况、识别潜在的安全问题、优化网站性能等。 二、Apache日志类型 Apache日志主要有两种类型:访问日志&a…...

PaddleOCR模型ch_PP-OCRv3文本检测模型研究(二)颈部网络

上节研究了PaddleOCR文本检测v3模型的骨干网,本文接着研究其颈部网络。 文章目录 研究起点残注层颈部网络代码实验小结 研究起点 摘取开源yml配置文件,摘取网络架构Architecture中颈部网络的配置如下 Neck:name: RSEFPNout_channels: 96shortcut: True可…...

360极速浏览器不支持看PDF

360安全浏览器采用的是基于IE内核和Chrome内核的双核浏览器。360极速浏览器是源自Chromium开源项目的浏览器,不但完美融合了IE内核引擎,而且实现了双核引擎的无缝切换。因此在速度上,360极速浏览器的极速体验感更佳。 展示自己的时候要在有优…...

【深度学习】深刻理解ViT

ViT(Vision Transformer)是谷歌研究团队于2020年提出的一种新型图像识别模型,首次将Transformer架构成功应用于计算机视觉任务中。Transformer最初应用于自然语言处理(如BERT和GPT),而ViT展示了其在视觉任务…...

解决vue2中更新列表数据,页面dom没有重新渲染的问题

在 Vue 2 中,直接修改数组的某个项可能不会触发视图的更新。这是因为 Vue 不能检测到数组的索引变化或对象属性的直接赋值。为了确保 Vue 能够正确地响应数据变化,你可以使用以下几种方法: 1. 使用 Vue.set() 使用 Vue.set() 方法可以确保 …...

vscode通过ssh连接远程服务器(实习心得)

一、连接ssh服务器 1.打开Visual Studio Code,进入拓展市场(CtrlShiftX),下载拓展Remote - SSH 2. 点击远程资源管理器选项卡,并选择远程(隧道/SSH)类别 3. 点击ssh配置:输入你的账号主机ip地址 4.在弹出的选择配置文件中&#xf…...

知识图谱9:知识图谱的展示

1、知识图谱的展示有很多工具 Neo4j Browser - - - - 浏览器版本 Neo4j Desktop - - - - 桌面版本 graphX - - - - 可以集成到Neo4j Desktop Neo4j 提供的 Neo4j Bloom 是用户友好的可视化工具,适合非技术用户直观地浏览图数据。Cypher 是其核心查询语言&#x…...

leetcode 面试经典 150 题:验证回文串

链接验证回文串题序号125类型字符串解题方法双指针法难度简单 题目 如果在将所有大写字符转换为小写字符、并移除所有非字母数字字符之后,短语正着读和反着读都一样。则可以认为该短语是一个 回文串 。 字母和数字都属于字母数字字符。 给你一个字符串 s&#xf…...

【0363】Postgres内核 从 XLogReaderState readBuf 解析 XLOG Record( 8 )

上一篇: 【0362】Postgres内核 XLogReaderState readBuf 有完整 XLOG page header 信息 ? ( 7 ) 直接相关: 【0341】Postgres内核 读取单个 xlog page (2 - 2 ) 文章目录 1. readBuf 获取 page header 大小1.1 XLOG record 跨 page ?1.2 获取 XLOG Record 的 长度(xl…...

docker tdengine windows快速体验

#拉取镜像 docker pull tdengine/tdengine:2.6.0.34#容器运行 docker run -d --name td2.6 --restartalways -p 6030:6030 -p 6041:6041 -p 6043:6043 -p 6044-6049:6044-6049 -p 6044-6045:6044-6045/udp -p 6060:6060 tdengine/tdengine:2.6.0.34#容器数据持久化到本地 #/va…...

详解RabbitMQ在Ubuntu上的安装

​​​​​​​ 目录 Ubuntu 环境安装 安装Erlang 查看Erlang版本 退出命令 ​编辑安装RabbitMQ 确认安装结果 安装RabbitMQ管理界面 启动服务 查看服务状态 通过IP:port访问 添加管理员用户 给用户添加权限 再次访问 Ubuntu 环境安装 安装Erlang RabbitMq需要…...

Python的3D可视化库【vedo】2-2 (plotter模块) 访问绘制器信息、操作渲染器

文章目录 4 Plotter类的方法4.1 访问Plotter信息4.1.1 实例信息4.1.2 演员对象列表 4.2 渲染器操作4.2.1 选择渲染器4.2.2 更新渲染场景 4.3 控制渲染效果4.3.1 渲染窗格的背景色4.3.2 深度剥离效果4.3.3 隐藏线框的线条4.3.4 改为平行投影模式4.3.5 添加阴影4.3.6 环境光遮蔽4…...

【vue2】文本自动省略组件,支持单行和多行省略,超出显示tooltip

代码见文末 vue3实现 最开始就用的vue3实现,如下 Vue3实现方式 vue2开发和使用文档 组件功能 TooltipText 是一个文字展示组件,具有以下功能: 文本显示:支持单行和多行文本显示。自动判断溢出:判断文本是否溢出…...

网络安全产品之认识防病毒软件

随着计算机技术的不断发展,防病毒软件已成为企业和个人计算机系统中不可或缺的一部分。防病毒软件是网络安全产品中的一种,主要用于检测、清除计算机病毒,以及预防病毒的传播。本文我们一起来认识一下防病毒软件。 一、什么是计算机病毒 计算…...

游戏引擎学习第42天

仓库: https://gitee.com/mrxiao_com/2d_game 简介 目前我们正在研究的内容是如何构建一个基本的游戏引擎。我们将深入了解游戏开发的每一个环节,从最基础的技术实现到高级的游戏编程。 角色移动代码 我们主要讨论的是角色的移动代码。我一直希望能够使用一些基…...

区块链智能合约( solidity) 安全编程

引言:本文由天玄链开源开发者提供,欢迎报名公益天玄链训练营 https://blockchain.163.com/trainingCamp 一、重入和竞态 重入和竞态在solidity 编程安全中会多次提及,历史上也造成了重大的损失。 1.1 问题分析 竞态的描述不严格&#xf…...

GUNS搭建

一、准备工作 源码下载: 链接: https://pan.baidu.com/s/1bJZzAzGJRt-NxtIQ82KlBw 提取码: criq 官方文档 二、导入代码 1、导入后端IDE 导入完成需要,需要修改yml文件中的数据库配置,改成自己的。 2、导入前端IDE 我是用npm安装的yarn npm…...

减肥产品网站模板/百度的电话人工客服电话

穆僮电脑小课堂 (QQ群:141826908)摘编整理如果你不小心把ubuntu引导弄坏了,比如重装了windows,比如格式化错了盘等等,那么通过下述方法可以简单的修复ubuntu首先,插入ubuntu的安装盘,没有的话只好做一个了&…...

商丘柘城做网站/北京培训学校

在Eclipse上创建Web项目,默认会产生一个WebRootWEB-INFlib目录,jar包复制到该目录后会自动加载到Web App Libraries库中,效果如下:而如果创建普通的Java项目,一般需要自己创建一个lib目录,再将jar包复制到该…...

wordpress 生成 客户端/企业网站seo优化公司

问:给你一个含 n 个整数的数组 nums ,其中 nums[i] 在区间 [1, n] 内。请你找出所有在 [1, n] 范围内但没有出现在 nums 中的数字,并以数组的形式返回结果。 原题链接:https://leetcode.cn/problems/find-all-numbers-disappeare…...

磁县邯郸网站建设/网络推广电话

文章目录一、传统以太网和虚拟局域网(VLAN)。1.传统以太网的问题。2.虚拟局域网(VLAN)。二、VLAN数据帧。三、以太网二层接口及其配置。1.Access接口。2.Trunk接口。3.Hybrid接口。4.配置示例。一、传统以太网和虚拟局域网(VLAN)。 1.传统以太网的问题。 在典型交换网络中&…...

如何做聚合类网站/全面网络推广营销策划

最近,Google 开源了其 TCP BBR 拥塞控制算法,并提交到了 Linux 内核,从 4.9 开始,Linux 内核已经用上了该算法。根据以往的传统,Google 总是先在自家的生产环境上线运用后,才会将代码开源,此次也…...

网站建设验收程序/神马推广

EIGRP的负载平衡与RIP和OSPF负载平衡有很大区别, EIGRP支持非等价负载平衡,即在两条不等开销的路径上做负载平衡,下面的实例将对EIGRP的非等价负载平衡做演示。 演示目标:理解并配置EIGRP的非等价负载平衡。 演示环境:…...