当前位置: 首页 > news >正文

《动手深度学习》线性回归简洁实现实例


🎈 作者:Linux猿

🎈 简介:CSDN博客专家🏆,华为云享专家🏆,Linux、C/C++、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊!

🎈 欢迎小伙伴们点赞👍、收藏⭐、留言💬


本文是《动手深度学习》线性回归简洁实现实例的实现和分析,主要对代码进行详细讲解,有问题欢迎在评论区讨论交流。

一、代码实现

实现代码如下所示。

import torch
from torch.utils import data
# d2l包是李沐老师等人开发的动手深度学习配套的包,
# 里面封装了很多有关与数据集定义,数据预处理,优化损失函数的包
from d2l import torch as d2l
# nn 是神经网络 Neural Network 的缩写,提供了一系列的模块和类,实现创建、训练、保存、恢复神经网络
from torch import nn'''
1. 生成数据集,共 1000 条
true_w 和 true_b 是临时变量用于生成数据集
生成 X, y :满足关系 y = Xw + b + noise
'''
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)'''
2. 构造循环读取数据集的迭代器
'''
def load_array(data_arrays, batch_size, is_train=True):  #@save# 构造一个 PyTorch 数据迭代器,对 tensor 进行打包,包装成 dataset。dataset = data.TensorDataset(*data_arrays)# 根据数据集构造一个迭代器return data.DataLoader(dataset, batch_size, shuffle=is_train)# 小批量数据
batch_size = 10
# 设置了一个数据读取的迭代器,每次读取 batch_size(10) 条
data_iter = load_array((features, labels), batch_size)'''
3. 设置全连接层
'''
'''
# nn.Linear(in_features, out_features, bias=True)
# in_features : 输入向量的列数
# out_features : 输出向量的列数
# bias = True 是否包含偏置
执行线性变换:Yn*o = Xn*i Wi*o + b
其中:W 和 b 模型需要学习的参数
在本例中:n = 10,i = 2, o = 1
'''
net = nn.Sequential(nn.Linear(2, 1))
# 设置权重 w 和 偏置 b
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)'''
4. 定义损失函数
'''
# 均方误差,是预测值与真实值之差的平方和的平均值
loss = nn.MSELoss()
# lr 学习率 learning rate
trainer = torch.optim.SGD(net.parameters(), lr=0.03)'''
4. 训练数据
'''
# 超参数 设置批次
num_epochs = 3
for epoch in range(num_epochs): # 进行 num_epochs 个迭代周期for X, y in data_iter:l = loss(net(X) ,y) # 计算损失,net(X) 计算预测值 y1,loss(y1, y) 计算预测值和真实值之间的差距trainer.zero_grad() # 将所有模型参数的梯度置为 0l.backward() # 求梯度,不使用从零实现中 l.sum.backward 的原因是损失计算中使用了平均的 gardtrainer.step() # 优化参数 w 和 bl = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

二、实现解析

针对实例中重要的函数解析如下。

2.1 Linear 函数

nn.Linear(in_features, out_features, bias=True)

神经网络的线性层,也成为全连接层,进行 Y = XW + b 的线性变换。

参数:

in_features : 输入向量的列数

out_features : 输出向量的列数

bias = True 是否包含偏置

in_features 和 out_features 是 W 的行和列。

执行线性变换:Yn*o = Xn*i Wi*o + b

其中:W 和 b 模型需要学习的参数

在本例中:n = 10,i = 2, o = 1。

2.2 Sequential 函数

一个序列容器,用于搭建神经网络的模块,按照传入构造器的顺序添加到 nn.Sequential() 容器中。按照内部模块的顺序自动依次计算并输出结果。

2.3 MSELoss 函数

均方误差,是预测值与真实值之差的平方和的平均值,即:

2.4 TensorDataset 函数

用来对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等. 另外:TensorDataset 中的参数必须是 tensor。可以参考如下例子:

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader# len = 12
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
# len = 12
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
# 将 tensor a 和 b 压缩在一起
train_ids = TensorDataset(a, b)
# 输出
for x, y in train_ids:print(x, y)

输出如下:

tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)

2.5 DataLoader 函数

DataLoader 是用来包装所使用的数据,每次抛出一批数据,下面来看一个例子。

import torch
from torch.utils import data# len = 12
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
# len = 12
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
# 将 tensor a 和 b 压缩在一起
train_ids = data.TensorDataset(a, b)
# 输出
#for x, y in train_ids:
#    print(x, y)BATCH_SIZE = 4
loader = data.DataLoader(dataset=train_ids,batch_size=BATCH_SIZE, # 每次取 BATCH_SIZE=4 个数据shuffle=False, # 不打乱顺序,便于查看num_workers=0)for x, y in loader:print(x, y)break

输出如下:

tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9],[1, 2, 3]]) tensor([44, 55, 66, 44])

 如上所示,输出第一个 BATCH_SIZE=4。

2.6 zero_grad 函数

trainer.zero_grad() 是用来清空模型参数梯度的函数,它将模型参数的梯度缓存设置为 0。在进行反向传播时,梯度会累加,如果不清空梯度,会影响后续的梯度计算。

2.7 backward 函数

对计算图进行梯度计算,求解计算图中所有节点的梯度。

2.8 step 函数

根据 backward 函数计算出的梯度进行参数更新。

参考链接:

线性回归的实现学习_data.tensordataset_带刺的厚崽的博客-CSDN博客

nn.Sequential()_一颗磐石的博客-CSDN博客

【Pytorch基础】torch.nn.MSELoss损失函数_一穷二白到年薪百万的博客-CSDN博客

pytorch之trainer.zero_grad()_FibonacciCode的博客-CSDN博客

清空模型参数梯度的函数 - 知乎

pytorch中backward()函数详解_backward函数_Camlin_Z的博客-CSDN博客

理解Pytorch的loss.backward()和optimizer.step() - 知乎


🎈 感觉有帮助记得「一键三连支持下哦!有问题可在评论区留言💬,感谢大家的一路支持!🤞猿哥将持续输出「优质文章回馈大家!🤞🌹🌹🌹🌹🌹🌹🤞


相关文章:

《动手深度学习》线性回归简洁实现实例

🎈 作者:Linux猿 🎈 简介:CSDN博客专家🏆,华为云享专家🏆,Linux、C/C、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊! &…...

国家数据局正式揭牌,数据专业融合型人才迎来发展良机

📕作者简介:热爱跑步的恒川,致力于C/C、Java、Python等多编程语言,热爱跑步,喜爱音乐的一位博主。 📗本文收录于恒川的日常汇报系列,大家有兴趣的可以看一看 📘相关专栏C语言初阶、C…...

基于springboot实现休闲娱乐代理售票平台系统项目【项目源码+论文说明】

基于springboot实现休闲娱乐代理售票系统演示 摘要 网络的广泛应用给生活带来了十分的便利。所以把休闲娱乐代理售票管理与现在网络相结合,利用java技术建设休闲娱乐代理售票系统,实现休闲娱乐代理售票的信息化。则对于进一步提高休闲娱乐代理售票管理发…...

3.AUTOSAR OS分析(一)

1. AUTOSAR OS诞生背景 在最初接触汽车ECU开发时,提到最多的还是OSEK,比如OSEK NM、OSEK OS等等;而OSEK/VDK操作系统也是最先引入汽车行业;OSEK OS是基于事件触发的操作系统,有以下特性: 固定优先级调度中断处理函数StartOS和StartupHook作为启动阶段的通用接口函数Shutd…...

AB试验(七)利用Python模拟A/B试验

AB试验(七)利用Python模拟A/B试验 到现在,我相信大家理论已经掌握了,轮子也造好了。但有的人是不是总感觉还差点什么?没错,还缺了实战经验。对于AB实验平台完善的公司 ,这个经验不难获得&#…...

Go语言入门-流程控制语句

流程控制 Go语言中有以下几种常见的流程控制语句: 条件语句(Conditional Statements): if语句:用于根据条件执行代码块。else语句:在if条件不满足时执行的语句块。else if语句:用于在多个条件之…...

深入探究ASEMI肖特基二极管MBR60100PT的材质

编辑-Z 在电子零件领域中,肖特基二极管MBR60100PT因其出色的性能和广泛的应用而显得尤为关键。理解其材质不仅有助于我们深入理解其运作原理,也有助于我们做出更合适的电子设计。那么,肖特基二极管MBR60100PT是什么材质呢? 首先&#xff0c…...

python类模拟“对战游戏”

Game类含玩家昵称、生命值、攻击力(整数),暴击率、闪避率(小数),在魔术方法init定义;attack方法中实现两个Game实例对战模拟。 (本笔记适合初通Python类class的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.py…...

Maven第二章:Maven基本概念与生命周期

Maven第二章:Maven基本概念与生命周期 前言 本章主要内容,介绍Maven基本概念,包括maven坐标含义,命名规则,继承与聚合、了解与理解生命周期,如何通过Maven进行依赖和版本管理。 什么是Maven的坐标&#xf…...

<蓝桥杯软件赛>零基础备赛20周--第3周--填空题

报名明年4月蓝桥杯软件赛的同学们,如果你是大一零基础,目前懵懂中,不知该怎么办,可以看看本博客系列:备赛20周合集 20周的完整安排请点击:20周计划 每周发1个博客,共20周(读者可以按…...

【Linux】VM及WindowsServer安装

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是Java方文山,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的专栏《微信小程序开发实战》。🎯&#x1f3a…...

【实用教程】MySQL内置函数

1 背景 在MySQL查询等操作过程中,我们需要根据实际情况,使用其提供的内置函数。今天我们就来一起来学习下这些函数,在之后的使用过程中更加得心应手。 2 MySQL函数 2.1 字符串函数 常用的函数如下: concat(s1,s2,…sn)字符串…...

第十二节——ref

一、概念 ref 被用来给DOM元素或子组件注册引用信息。引用信息会根据父组件的 $refs 对象进行注册。如果在普通的DOM元素上使用,引用信息就是元素; 如果用在子组件上,引用信息就是组件实例。 注意:只要想要在Vue中直接操作DOM元素&#xff…...

少儿编程 2023年9月中国电子学会图形化编程等级考试Scratch编程四级真题解析(判断题)

2023年9月scratch编程等级考试四级真题 判断题(共10题,每题2分,共20分) 11、运行程序后,变量"result"的值是6 答案:对 考点分析:考查积木综合使用,重点考查自定义积木的使用 图中自定义积木实现的功能是获取两个数中最大的那个数并存放在result变量中,左…...

【设计模式三原则】

设计模式三原则 单一职责原则开放封闭原则依赖倒转原则里氏代换原则 我们在进行程序设计的时候,要尽可能地保证程序的可扩展性、可维护性和可读性,所以需要使用一些设计模式,这些设计模式都遵循了以下三个原则,下面来依次为大家介…...

600MW发电机组继电保护自动装置的整定计算及仿真

摘要 随着科技的发展,电力已成为最重要的资源之一,如何保证电力的供应对于国民经济发展和人民生活水平的提高都有非常重要的意义。在电能输送过程中,发电机组是整个过程中最重要的一个基本元素,在电力系统中的输送和分配中被广泛应…...

【蓝桥每日一题]-字符串(保姆级教程 篇1)#atcoder324C~E题

今天来讲字符串题型 目录 题目:atcoder324C题 思路: 题目:atcoder324D题 思路: 题目:atcoder324E题 思路: 题目:atcoder324C题 给一个T字符串,然后给出n个S串,对…...

4.2.1 SQL语句、索引、视图、存储过程

怎么执行一条select语句 1.连接器 接收连接-》管理连接-》校验用户信息 2.查询缓存 kv存储,命中直接返回,否则继续执行 8.0已经删除 3.分析器 词法句法分析生成语法树 4.优化器 指定执行计划,选择查询成本最小的计划 5.执行器 根据执行计划&a…...

1992-2021年全国各地级市经过矫正的夜间灯光数据(GNLD、VIIRS)

1992-2021年全国各地级市经过矫正的夜间灯光数据(GNLD、VIIRS) 1、时间:1992-2021年3月,其中1992-2013年为年度数据,2013-2021年3月为月度数据 2、来源:DMSP、VIIRS 3、范围:分区域汇总&…...

机器人的触发条件有什么区别,如何巧妙的使用

简介​ 维格机器人触发条件,分为3个,分别是: 有新表单提交时、有记录满足条件时、有新的记录创建时 。 看似3个,其实是能够满足我们非常多的使用场景。 本篇将先介绍3个条件的触发条件,然后再列举一些复杂的触发条件如何用现有的触发条件来满足 注意: 维格机器人所有的…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中,UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文全面剖析RNN核心原理,深入讲解梯度消失/爆炸问题,并通过LSTM/GRU结构实现解决方案,提供时间序列预测和文本生成…...

均衡后的SNRSINR

本文主要摘自参考文献中的前两篇,相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程,其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt​ 根发送天线, n r n_r nr​ 根接收天线的 MIMO 系…...

JAVA后端开发——多租户

数据隔离是多租户系统中的核心概念,确保一个租户(在这个系统中可能是一个公司或一个独立的客户)的数据对其他租户是不可见的。在 RuoYi 框架(您当前项目所使用的基础框架)中,这通常是通过在数据表中增加一个…...

C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...

LUA+Reids实现库存秒杀预扣减 记录流水 以及自己的思考

目录 lua脚本 记录流水 记录流水的作用 流水什么时候删除 我们在做库存扣减的时候,显示基于Lua脚本和Redis实现的预扣减 这样可以在秒杀扣减的时候保证操作的原子性和高效性 lua脚本 // ... 已有代码 ...Overridepublic InventoryResponse decrease(Inventor…...

Qt Quick Controls模块功能及架构

Qt Quick Controls是Qt Quick的一个附加模块,提供了一套用于构建完整用户界面的UI控件。在Qt 6.0中,这个模块经历了重大重构和改进。 一、主要功能和特点 1. 架构重构 完全重写了底层架构,与Qt Quick更紧密集成 移除了对Qt Widgets的依赖&…...

基于django+vue的健身房管理系统-vue

开发语言:Python框架:djangoPython版本:python3.8数据库:mysql 5.7数据库工具:Navicat12开发软件:PyCharm 系统展示 会员信息管理 员工信息管理 会员卡类型管理 健身项目管理 会员卡管理 摘要 健身房管理…...

信息系统分析与设计复习

2024试卷 单选题(20) 1、在一个聊天系统(类似ChatGPT)中,属于控制类的是()。 A. 话语者类 B.聊天文字输入界面类 C. 聊天主题辨别类 D. 聊天历史类 ​解析 B-C-E备选架构中分析类分为边界类、控制类和实体类。 边界…...

河北对口计算机高考MySQL笔记(完结版)(2026高考)持续更新~~~~

MySQL 基础概念 数据(Data):文本,数字,图片,视频,音频等多种表现形式,能够被计算机存储和处理。 **数据库(Data Base—简称DB):**存储数据的仓库…...