当前位置: 首页 > news >正文

【动手学深度学习Pytorch】2. Softmax回归代码

零实现

        导入所需要的包:

import torch
from IPython import display
from d2l import torch as d2l

        定义数据集参数、模型参数:

batch_size = 256 # 每次随机读取256张图片
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
# 将展平每个图片将其视为长度为784的向量,数据集存在10个类别
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

        实现Softmax操作:

# 实现Softmax
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True) #列数为特征数,行数为样本数return X_exp / partition #广播机制# 尝试进行Softmax操作
X = torch.normal(0, 1, (2,5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)# 实现Softmax回归模型
def net(X):return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)

        定义交叉熵函数:

# 创建一个数据y_hat,其中包含2个样本在3个类别的预测概率,使用y作为y_hat中概率的索引
y = torch.tensor([0,2])
y_hat = torch.tensor([[0.1, 0.3, 0.6],[0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
# 交叉熵函数
def cross_entropy(y_hat, y):return -torch.log(y_hat[range(len(y_hat)),y])
cross_entropy(y_hat, y)

        将预测类别于真实元素进行比较:

torch.argmax(input, dim=None, keepdim=False):用于返回指定维度中最大值的索引。通常用于分类任务中从预测输出中找到概率最大的类别

.dtype:.dtype 是张量的属性,用于返回该张量的 数据类型 (data type)。每个张量都有一个数据类型,用于定义其中存储元素的类型,例如浮点数、整数或布尔值。

tensor.type(dtype=None):不传入参数时,返回一个字符串,表示张量的类型;传入参数时,返回一个新的张量,该张量的类型与指定类型匹配。

x = torch.tensor([1.0, 2.0, 3.0])  # 默认 float32 类型
print(x.type())  # 输出: torch.FloatTensorx_int = x.type(torch.int64)
print(x_int)         # 输出: tensor([1, 2, 3])
print(x_int.type())  # 输出: torch.LongTensor (int64 的别名)

net.eval():设置为评估模式。

def accuracy(y_hat, y):#计算预测争取的数量# 判断 y_hat 是否为多维张量(例如二维)if len(y_hat.shape)>1 and y_hat.shape[1] > 1:# 如果是多类别分类(第二维大于 1),通过argmax获取每行中概率或分数最大的类别索引y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype)==y  # 比较预测结果和真实标签是否相等return float(cmp.type(y.dtype).sum()) # 返回预测正确的总数量accuracy(y_hat, y) / len(y)def evaluate_accuracy(net, data_iter):#计算在指定数据集上的模型精度# 如果是 PyTorch 模型,设置为评估模式if isinstance(net, torch.nn.Module):net.eval() metric = Accumulator(2)  # 初始化累加器,存储 [正确预测数, 总样本数]for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel()) # 累加每批数据的预测结果return metric[0] / metric[1]  # 返回精度:正确预测数 / 总样本数

        Accumulator实例:

class Accumulator: #在n个变量上累加def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]evaluate_accuracy(net, test_iter)

        定义训练过程: 

net.train():设置为训练模式。

torch.optim.Optimizer.step():用于执行模型参数更新基于之前计算好的梯度(通过反向传播获得),按照优化算法的规则调整模型参数的值,以最小化损失函数。

def train_epoch_ch3(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train()metric = Accumulator(3)for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y) #计算损失if isinstance(updater, torch.optim.Optimizer):updater.zero_grad() # 清除梯度l.backward() # 反向传播计算梯度updater.step() # 根据梯度更新模型参数metric.add(float(l) * len(y),  # 累加当前批次的损失accuracy(y_hat, y),  # 累加当前批次的正确预测数y.size().numel())  # 累加当前批次的样本数else: # 如果是自定义优化器l.sum().backward()updater(X.shape[0]) # 自定义的更新函数,可能需要批次大小作为参数metric.add(float(l.sum()), accuracy(y_hat),y.numel())return metric[0] / metric[2], metric[1] / metric[2]

        定义一个在动画中绘制数据的实用程序类:

class Animator: #实时观看在训练过程中的变化# 初始化绘图环境,包括图表的设置、标签、坐标轴范围、曲线样式等。def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-','m--','g-','r:'),nrows=1,ncols=1,figsize=(3.5, 2,5)):if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols ==1:self.axes = [self.axes,]self.config_axes = lambda:d2l.set_axes(self.axes[0],xlabel, ylabel,xlim, ylim,xscale, yscale,legend)self.X, self.Y, self.fmt = None, None, fmtsdef add(self, x, y):if not hasattr(y, "__len__"):y = [y]n = len(y)

        训练函数: 

# 训练函数
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):# 进行可视化animator = Aminator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3,],legend=['train loss','train acc','test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch2(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch+1, train_metrics+(test_acc,))train_loss, train_acc = train_metrics# 小批量随机梯度下降来优化训练算法
lr = 0.1
def updater(batch_size):return d2l.sgd([W,b],lr,batch_size)num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater(10))

 简洁实现

        导入所需要的包:

import torch
from IPython import display
from d2l import torch as d2l

        初始化数据集、模型参数、损失函数以及训练优化算法:网络加入高斯噪声,增强泛化性。

torch.nn.init.normal_(tensor, mean=0.0, std=1.0):正态分布(高斯分布)随机初始化张量的值

nn.Sequential(*modules):用于将多个模块(如线性层、激活函数等)按顺序组合成一个模型。适合简单的前向计算场景。

nn.Flatten(start_dim=1, end_dim=-1):将输入张量展平成二维张量,适用于线性层输入。

nn.Linear(in_features, out_features, bias=True):实现一个线性层(全连接层)

nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean'):计算分类任务中的交叉熵损失(适用于多分类问题)。
torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False):实现随机梯度下降(SGD)优化算法,用于更新模型参数。

net.parameters():返回模型的可训练参数的迭代器。

batch_size = 256 # 每次随机读取256张图片
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)net = nn.Sequential(nn.Flatten(),nn.Linear(784, 100))
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);loss = nn.CrossEntropyLoss()trainer = torch.optim.SGD(net.parameters(),lr=0.1)

        用之前定义的训练函数训练模型:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater(10))

相关文章:

【动手学深度学习Pytorch】2. Softmax回归代码

零实现 导入所需要的包&#xff1a; import torch from IPython import display from d2l import torch as d2l定义数据集参数、模型参数&#xff1a; batch_size 256 # 每次随机读取256张图片 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size) # 将展平每个…...

技术周总结 11.11~11.17 周日(Js JVM XML)

文章目录 一、11.11 周一1.1&#xff09;问题01&#xff1a;js中的prompt弹窗区分出来用户点击的是 确认还是取消进一步示例 1.2&#xff09;问题02&#xff1a;在 prompt弹窗弹出时默认给弹窗中写入一些内容 二、11.12 周二2.1) 问题02: 详解JVM中的本地方法栈本地方法栈的主要…...

MATLAB 使用教程 —— 矩阵和数组

矩阵和数组MATLAB 中矩阵和数组长什么样&#xff1f;MATLAB 怎么用矩阵计算&#xff1f;创建和操作矩阵矩阵运算示例串联 访问矩阵的元素 矩阵和数组 MATLAB 是“matrix laboratory”的缩写形式。MATLAB 主要用于处理 整个的矩阵和数组&#xff0c;而其他编程语言大多逐个处理…...

React教程第二节之虚拟DOM与Diffing算法理解

1、什么是虚拟DOM 虚拟DOM 是javascript的一个对象&#xff0c;是内存中的一种数据结构&#xff0c;以树的形式存储UI的状态&#xff0c;树中的每个节点都代表着真实的DOM&#xff0c;用来描述我们希望在页面看到的 HTML结构&#xff1b; 现在的MVVM 框架&#xff0c;大多使用…...

C++——类和对象(part2)

前言 本篇博客继续为大家介绍类与对象的知识&#xff0c;承接part1的内容&#xff0c;本篇内容是类与对象的核心内容&#xff0c;稍微有些复杂&#xff0c;如果你对其感兴趣&#xff0c;请继续阅读&#xff0c;下面进入正文部分。 1. 类的默认成员函数 默认成员函数就是用户…...

【FFmpeg系列】:音频处理

前言 在多媒体处理领域&#xff0c;FFmpeg无疑是一个不可或缺的利器。它功能强大且高度灵活&#xff0c;能够轻松应对各种音频和视频处理任务&#xff0c;无论是简单的格式转换&#xff0c;还是复杂的音频编辑&#xff0c;都不在话下。然而&#xff0c;要想真正发挥FFmpeg的潜…...

Python绘制雪花

文章目录 系列目录写在前面技术需求完整代码代码分析1. 代码初始化部分分析2. 雪花绘制核心逻辑分析3. 窗口保持部分分析4. 美学与几何特点总结 写在后面 系列目录 序号直达链接爱心系列1Python制作一个无法拒绝的表白界面2Python满屏飘字表白代码3Python无限弹窗满屏表白代码4…...

vue3 如何调用第三方npm包内部的 pinia 状态管理库方法

抛砖引玉: 如果在开发vue3项目是, 引用了npm第三方包 ,而且这个包内使用了Pinia 状态管理库,那我们如何去调用 npm内部的 Pinia 状态管理库呢? 实际遇到的问题: 今天在制作npm包时遇到的问题,之前Vue2版本的时候状态管理库用的Vuex ,当时调用npm包内的状态管理库很简单,直接引…...

uni-app快速入门(七)--组件路由跳转和API路由跳转及参数传递

uni-app有两种页面路由跳转模式&#xff0c;即使用navigator组件跳转和调用API跳转&#xff0c;API调转不要理解为调用后台接口的API&#xff0c;而是指脚本函数中使用跳转函数。 一、组件路由跳转 1.1 打开新页面 打开新页面使用组件的open-type"navigate",见下面…...

Flink升级程序和版本

Flink DataStream程序通常设计为长时间运行,如几周、几个月甚至几年。与所有长时间运行的服务一样,Flink streaming应用程序也需要维护,包括修复错误、实现改进或将应用程序迁移到更高版本的Flink集群。 这里就来描述下如何更新Flink streaming应用程序,以及如何将正在运行…...

从0安装mysql server

安装 MySQL Server 首先,你需要在 Ubuntu 上安装 MySQL 服务器。运行以下命令来安装:sudo apt update sudo apt install mysql-server安装完成后,MySQL 服务会自动启动。你可以通过以下命令检查 MySQL 服务是否正在运行: sudo systemctl status mysql如果 MySQL 正在运行,…...

web安全测试渗透案例知识点总结(上)——小白入狱

目录 一、Web安全渗透测试概念详解1. Web安全与渗透测试2. Web安全的主要攻击面与漏洞类型3. 渗透测试的基本流程 二、知识点详细总结1. 常见Web漏洞分析2. 渗透测试常用工具及其功能 三、具体案例教程案例1&#xff1a;SQL注入漏洞利用教程案例2&#xff1a;跨站脚本&#xff…...

PHP访问NetSuite REST Web Services

“同等看待欢乐和痛苦、得到和失去、胜利和失败、投入战斗。以此方式履行职责&#xff0c;你就不会招致任何罪恶。” -Bhagavad Gita 为了帮助PHP开发者快速起步&#xff0c;以REST Web Services方式打通与NetSuite的接口&#xff0c;我们答应给一个样例。但是我是不懂PHP的&a…...

【编译】多图解释 什么是短语、直接短语、句柄、素短语、可归约串

一、什么是短语二、什么是“直接”短语&#xff1f;三、什么是句柄&#xff1f;四、什么是素短语&#xff1f;五、什么是最左素短语可归约串就是“最左素短语” 首先&#xff0c;这些概念 都是相对于【句型】的&#xff0c;都是相对于【句型】的&#xff0c;都是相对于【句型】…...

React中事件绑定和Vue有什么区别?

1. 绑定方式 React&#xff1a;使用jsx语法&#xff0c;通过属性绑定事件。Vue&#xff1a;使用指令&#xff08;如v-on&#xff09;在模板中直接绑定事件。 2. 事件处理 React&#xff1a;通过合成事件系统封装原生事件&#xff0c;提供统一的API。Vue&#xff1a;直接使用…...

【DBA攻坚指南:左右Oracle,右手MySQL-学习总结】

处理log file sync等待事件 首先明确什么是log file sync等待事件 从用户提交会话开始&#xff0c;LGWR进程将redo缓存中的信息写入redo日志文件后&#xff0c;LGWR进程通知用户写操作完成&#xff0c;到用户会话接受到LGWR进程通知为止&#xff0c;这整个过程就是可能出现lo…...

C++中的内联函数

在C中&#xff0c;内联函数是一种特殊的函数。 定义 内联函数是在函数定义前加上关键字“inline”的函数。编译器在处理对内联函数的调用时&#xff0c;会尝试将函数体的代码直接插入到函数调用处&#xff0c;而不是像普通函数调用那样&#xff0c;进行跳转指令执行函数体代码…...

ssh.service could not be found“

如果你收到 “ssh.service could not be found” 错误&#xff0c;说明目标主机上没有安装 SSH 服务&#xff0c;或者安装的 SSH 服务的名称不为 ssh。这里有一些解决步骤&#xff1a; 1. 检查 SSH 服务是否已安装 在目标主机上执行以下命令来检查是否安装了 SSH 服务&#x…...

tensorflow有哪些具体影响,和chatgpt有什么关系

### TensorFlow的影响 **1. 深度学习框架的领军者** - **广泛使用**: TensorFlow是由Google开发的开源深度学习框架&#xff0c;广泛应用于各种机器学习任务&#xff0c;包括图像识别、自然语言处理、语音识别等。它是深度学习领域中最受欢迎的框架之一。 - **大规模生产环境*…...

Android OpenGL ES详解——几何着色器

目录 一、概念 1、图元 2、几何着色器 1、输入类型 2、输出类型 3、输出顶点数量最大值限制 二、使用几何着色器 三、应用举例——造几个房子 四、应用举例——爆破物体 1、获取法向量 2、显示法线 五、应用举例——细分三角形 六、应用举例——广告牌技术 一、概…...

Java学生管理系统(GUI和数据库)

Java学生管理系统&#xff08;GUI和数据库&#xff09; 本文简介 本资源演示了一个用Java实现的学生管理系统&#xff0c;结合了图形用户界面&#xff08;GUI&#xff09;和数据库操作。系统实现了学生、课程和账号三张表的管理功能&#xff0c;包括增删改查等操作。通过本资…...

035_Progress_Dialog_in_Matlab中的进度条对话框

进度条 概念 在使用Matlab开发界面时&#xff0c;有一个很好用的工具就是进度条。在计算过程中&#xff0c;为用户提供计算进度的反馈是改善用户体验的重要手段。 一项进行的计算任务&#xff0c;如果其总体进度是比较容易量化&#xff0c;则可以按照0%~100%的方式&#xff0…...

【GPTs】Ai-Ming:AI命理助手,个人运势与未来发展剖析

博客主页&#xff1a; [小ᶻZ࿆] 本文专栏: AIGC | GPTs应用实例 文章目录 &#x1f4af;GPTs指令&#x1f4af;前言&#x1f4af;Ai-Ming主要功能适用场景优点缺点 &#x1f4af;小结 &#x1f4af;GPTs指令 中文翻译&#xff1a; defcomplete_sexagenary&#xff08;年&a…...

如何利用SAP低代码平台快速构建企业级应用?

SAP作为全球领先的企业管理软件解决方案提供商&#xff0c;一直致力于为企业提供全面且高效的业务管理工具。随着技术的快速发展&#xff0c;传统的开发方式已经无法满足企业在快速变化的市场环境下的需求。低代码开发平台应运而生&#xff0c;它通过简化应用程序的创建过程&am…...

Redis设计与实现 学习笔记 第十七章 集群

Redis集群是Redis提供的分布式数据库方案&#xff0c;集群通过分片&#xff08;sharding&#xff0c;水平切分&#xff09;来进行数据共享&#xff0c;并提供复制和故障转移功能。 17.1 节点 一个Redis集群通常由多个节点&#xff08;node&#xff09;组成&#xff0c;在刚开…...

多端校园圈子论坛小程序,多个学校同时代理,校园小程序分展示后台管理源码

社团活动与组织 信息发布&#xff1a;系统支持社团发布活动信息、招募新成员等&#xff0c;方便社团进行线上线下活动的组织和管理。 增强凝聚力&#xff1a;通过系统&#xff0c;社团成员可以更好地交流和互动&#xff0c;增强社团的凝聚力和影响力。 生活服务功能 二手市场…...

鸿蒙核心技术理念

文章目录 1)一次开发,多端部署2)可分可合,自由流转3)统一生态,原生智能1)一次开发,多端部署 “一次开发,多端部署”指的是一个工程,一次开发上架,多端按需部署。目的是支撑开发者高效地开发多种终端设备上的应用 2)可分可合,自由流转 元服务是鸿蒙系统提供的一…...

8. 基于 Redis 实现限流

在高并发的分布式系统中&#xff0c;限流是保证服务稳定性的重要手段之一。通过限流机制&#xff0c;可以控制系统处理请求的频率&#xff0c;避免因瞬时流量过大导致系统崩溃。Redis 是一种高效的缓存数据库&#xff0c;具备丰富的数据结构和原子操作&#xff0c;适合用来实现…...

241117学习日志——[CSDIY] [ByteDance] 后端训练营 [05]

CSDIY&#xff1a;这是一个非科班学生的努力之路&#xff0c;从今天开始这个系列会长期更新&#xff0c;&#xff08;最好做到日更&#xff09;&#xff0c;我会慢慢把自己目前对CS的努力逐一上传&#xff0c;帮助那些和我一样有着梦想的玩家取得胜利&#xff01;&#xff01;&…...

蓝桥杯备赛(持续更新)

16届蓝桥杯算法类知识图谱.pdf 1. 格式打印 %03d&#xff1a;如果是两位数&#xff0c;将会在前面添上一位0 %.2f&#xff1a;会保留两位小数 如果是long&#xff0c;必须在数字后面加上L。 2. 进制转化 2.1. 十进制转任意进制&#xff1a; 十进制转任意进制时&#xff…...