当前位置: 首页 > news >正文

【李沐】3.2线性回归从0开始实现

%matplotlib inline
import random
import torch
from d2l import torch as d2l

1、生成数据集:
看最后的效果,用正态分布弄了一些噪音
在这里插入图片描述
上面这个具体实现可以看书,又想了想还是上代码把:
在这里插入图片描述
按照上面生成噪声,其中最后那个代表服从正态分布的噪声

def synthetic_data(w, b, num_examples):  # 定义函数 synthetic_data,接受权重 w、偏差 b 和样本数量 num_examples 作为参数"""生成 y = Xw + b + 噪声 的合成数据集"""# 生成一个形状为 (num_examples, len(w)) 的特征矩阵 X,其中的元素是从均值为 0、标准差为 1 的正态分布中随机采样得到X = torch.normal(0, 1, (num_examples, len(w)))# 计算目标值 y,通过将特征矩阵 X 与权重 w 相乘,然后加上偏差 b,模拟线性回归的预测过程y = torch.matmul(X, w) + b# 给目标值 y 添加一个小的随机噪声,以模拟真实数据中的噪声。噪声从均值为 0、标准差为 0.01 的正态分布中随机采样得到y += torch.normal(0, 0.01, y.shape)# 返回特征矩阵 X 和目标值 y(将目标值 y 重塑为列向量的形式)return X, y.reshape((-1, 1)
# 定义真实的权重 true_w 为 [2, -3.4]
true_w = torch.tensor([2, -3.4])# 定义真实的偏差 true_b 为 4.2
true_b = 4.2# 调用 synthetic_data 函数生成合成数据集,传入真实的权重 true_w、偏差 true_b 和样本数量 1000
# 这将返回特征矩阵 features 和目标值 labels
features, labels = synthetic_data(true_w, true_b, 1000)

2、读取数据集
注意一般情况下要打乱。
下面函数的作用是该函数接收批量⼤⼩、特征矩阵和标签向量作为输⼊,⽣成⼤⼩为batch_size的⼩批量。每个⼩批量包含⼀组特征和标签。

def data_iter(batch_size, features, labels):num_examples = len(features)  # 获取样本数量indices = list(range(num_examples))  # 创建一个样本索引列表,表示样本的顺序# 将样本索引列表随机打乱,以便随机读取样本,没有特定的顺序random.shuffle(indices)# 通过循环每次取出一个批次大小的样本for i in range(0, num_examples, batch_size):# 计算当前批次的样本索引范围,确保不超出总样本数量batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])# 通过索引获取对应的特征和标签,然后通过 yield 返回这个批次的数据# yield 使得函数可以作为迭代器使用,在每次迭代时产生一个新的批次数据yield features[batch_indices], labels[batch_indices]

3、初始化模型参数
第一步:前面两行代码,,我
们通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重,并将偏置初始化为0。
计算梯度使用2.5节引入的自动微分

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

4、定义模型
这里注意b是一个标量和向量相加,咋办?
前面说过向量的广播机制,就相当于是加到每一个上面

def linreg(X, w, b): #@save
"""线性回归模型"""
return torch.matmul(X, w) + b

5、定义损失函数
y.reshape(y_hat.shape))啥意思?
y_hat是真实值,这里的意思是弄成和y_hat相同的大小

def squared_loss(y_hat, y): #@save
"""均⽅损失"""
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

6、优化算法
问:这里的参数是啥参数?params
更新完的参数不用返回吗?
为什么需要梯度清零?

def sgd(params, lr, batch_size):  # 定义函数 sgd,接受参数 params、学习率 lr 和批次大小 batch_size"""小批量随机梯度下降"""with torch.no_grad():  # 使用 torch.no_grad() 来关闭梯度跟踪,以减少内存消耗for param in params:  # 遍历模型参数列表param -= lr * param.grad / batch_size  # 更新参数:参数 = 参数 - 学习率 * 参数梯度 / 批次大小param.grad.zero_()  # 清零参数的梯度,以便下一轮梯度计算

7、训练
问:反向传播是为了干啥?
是为了计算梯度,那梯度是啥呢
梯度是参数更快收敛的方向(就是向量)
优化方法是干啥的?
优化方法就是根据上面传过来的梯度,计算参数更新
所以,这几章看完后需要梳理深度学习的整个过程,以及每块有哪些方法,这些方法的特点和用那种方法更好
问(1)每个epoch训练多少数据?
整个训练集
(2)损失函数是啥?
损失函数是用来计算真实值域预测值之间的距离,当然是距离越小越好,可以拿均方误差想一下
(3)l.sum().backward()是啥意思?
看注释,补充:.backward() 方法用于执行自动求导,计算总的损失值对于模型参数的梯度。这将会构建计算图并沿着图的反向传播路径计算梯度。
(4)但是上面所说的梯度保存在哪里呢?
w.grad 和 b.grad 中
(5)但是sgd中也没有用到w.grad 啊?
用到了,param 可以是 w 或者 b,而 param.grad 则是相应参数的梯度。
(6)新问题:train_l = loss(net(features, w, b), labels)不是在前面已经计算过损失函数了吗?为啥在这里还需要计算?
前面计算损失函数是间断性的,目的是更新模型参数。
后面仍然计算的目的是根据更新完的参数对模型在整个训练集上与真实标签的差距做一个评估。

lr = 0.03  # 设置学习率为 0.03,控制每次参数更新的步幅num_epochs = 3  # 设置训练的轮次(迭代次数)为 3,即遍历整个数据集的次数net = linreg  # 定义模型 net,通常表示线性回归模型loss = squared_loss  # 定义损失函数 loss,通常为均方损失函数,用于衡量预测值与真实值之间的差距
for epoch in range(num_epochs):  # 迭代 num_epochs 轮,进行训练for X, y in data_iter(batch_size, features, labels):  # 遍历数据集的每个批次l = loss(net(X, w, b), y)  # 计算当前批次的损失值 l,表示预测值与真实值之间的差距# 因为 l 的形状是 (batch_size, 1),而不是一个标量。将 l 中的所有元素加起来,# 并计算关于 [w, b] 的梯度l.sum().backward()sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数,执行随机梯度下降算法with torch.no_grad():train_l = loss(net(features, w, b), labels)  # 在整个训练集上计算损失值# 打印当前迭代轮次和训练损失值的均值print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

8、练习中的问题

  1. 如果我们将权重初始化为零,会发⽣什么。算法仍然有效吗?
    无效,为啥?因为,不同的X输入是相同的输出

相关文章:

【李沐】3.2线性回归从0开始实现

%matplotlib inline import random import torch from d2l import torch as d2l1、生成数据集: 看最后的效果,用正态分布弄了一些噪音 上面这个具体实现可以看书,又想了想还是上代码把: 按照上面生成噪声,其中最后那…...

一百五十六、Kettle——Linux上安装的Kettle9.3连接ClickHouse数据库(亲测,附流程截图)

一、目标 kettle9.3在Linux上安装好后,需要与ClickHouse数据库建立连接 二、前提准备 (一)在Linux已经安装好kettle并可以启动kettle (二)已知kettle和ClickHouse版本 1、kettle版本是9.3 2、ClickHouse版本是21…...

图数据库_Neo4j和SpringBoot整合使用_创建节点_删除节点_创建关系_使用CQL操作图谱---Neo4j图数据库工作笔记0009

首先需要引入依赖 springboot提供了一个spring data neo4j来操作 neo4j 可以看到它的架构 这个是下载下来的jar包来看看 有很多cypher对吧 可以看到就是通过封装的驱动来操作graph database 然后开始弄一下 首先添加依赖...

Uniapp连接蓝牙设备

一、效果图 二、流程图 三、实现 UI <uni-list><uni-list :border="true"><!-- 显示圆形头像 -->...

linux切换到root用户:su root和sudo su命令的区别

前言 工作过程中遇到需要切换到root用户下去执行命令 方法1&#xff1a;工作中常会选择这个方法 利用su root命令 临时获取root用户权限&#xff0c;工作目录不变 好处&#xff1a;不需要知道root用户的密码&#xff0c;直接输入普通用户的密码即可 方法2 利用sudo su命…...

kafka-- kafka集群 架构模型职责分派讲解

一、 kafka集群 架构模型职责分派讲解 生产者将消息发送到相应的Topic&#xff0c;而消费者通过从Topic拉取消息来消费 Kafka奇数个节点消费者consumer会将消息拉去过来生产者producer会将消息发送出去数据管理 放在zookeeper...

Effective C++条款07——为多态基类声明virtual析构函数(构造/析构/赋值运算)

有许多种做法可以记录时间&#xff0c;因此&#xff0c;设计一个TimeKeeper base class和一些derived classes 作为不同的计时方法&#xff0c;相当合情合理&#xff1a; class TimeKeeper { public:TimeKeeper();~TimeKeeper();// ... };class AtomicClock: public TimeKeepe…...

用友Java后端笔试2023-8-5

计算被直线划分区域 在笛卡尔坐标系&#xff0c;存在区域[A,B],被不同线划分成多块小的区域&#xff0c;简单起见&#xff0c;假设这些不同线都直线并且不存在三条直线相交于一点的情况。 img 那么&#xff0c;如何快速计算某个时刻&#xff0c;在 X 坐标轴上[ A&#xff0c;…...

idea2023 springboot2.7.5+mybatis+jsp 初学单表增删改查

创建项目 因为2.7.14使用量较少&#xff0c;特更改spring-boot为2.7.5版本 配置端口号 打开Sm01Application类&#xff0c;右键运行启动项目&#xff0c;或者按照如下箭头启动 启动后&#xff0c;控制台提示如下信息表示成功 此刻在浏览器中输入&#xff1a;http://lo…...

大语言模型之四-LlaMA-2从模型到应用

最近开源大语言模型LlaMA-2火出圈&#xff0c;从huggingface的Open LLM Leaderboard开源大语言模型排行榜可以看到LlaMA-2还是非常有潜力的开源商用大语言模型之一&#xff0c;相比InstructGPT&#xff0c;LlaMA-2在数据质量、培训技术、能力评估、安全评估和责任发布方面进行了…...

Android 远程真机调研

背景 现有的安卓测试机器较少&#xff0c;很难满足 SDK 的兼容性测试及线上问题&#xff08;特殊机型&#xff09;验证&#xff0c;基于真机成本较高且数量较多的前提下&#xff0c;可以考虑使用云测平台上的机器进行验证&#xff0c;因此需要针对各云测平台进行调研、比较。 …...

B. 攻防演练 (2021CCPC女生赛)

题意&#xff1a; 给出一个长度为n的字符&#xff0c;字符是前m个小写字母&#xff0c;有q个询问&#xff0c;每次询问一个最短子序列的长度满足不是[l,r]内任意一个子序列 思路&#xff1a; [l,r]中子序列可以看成是从[l,r]中的某个位置开始&#xff0c;跳到下一个字符的位…...

MAC环境,在IDEA执行报错java: -source 1.5 中不支持 diamond 运算符

Error:(41, 51) java: -source 1.5 中不支持 diamond 运算符 (请使用 -source 7 或更高版本以启用 diamond 运算符) 进入设置 修改java版本 pom文件中加入 <plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-compiler-plugin&l…...

Tomcat日志中文乱码

修改安装目录下的日志配置 D:\ProgramFiles\apache-tomcat-9.0.78\conf\logging.properties java.util.logging.ConsoleHandler.encoding GBK...

最小生成树 — Prim算法

同Kruskal算法一样&#xff0c;Prim算法也是最小生成树的算法&#xff0c;但与Kruskal算法有较大的差别。 Prim算法整体是通过“解锁” “选中”的方式&#xff0c;点 -> 边 -> 点 -> 边。 因为是最小生成树&#xff0c;所以针对的也是无向图&#xff0c;所以可以随意…...

如何使用PHP Smarty模板进行AJAX交互?

首先&#xff0c;我们要明白&#xff0c;AJAX是一种在无需刷新整个页面的情况下&#xff0c;与服务器进行通信的技术。这对于改善用户体验来说&#xff0c;是个大宝贝。而PHP Smarty模板则是PHP的一种模板引擎&#xff0c;它使得设计和开发人员能够更好地分离逻辑和显示。 现在…...

nginx反向代理、负载均衡

修改nginx.conf的配置 upstream nginx_boot{# 30s内检查心跳发送两次包&#xff0c;未回复就代表该机器宕机&#xff0c;请求分发权重比为1:2server 192.168.87.143 weight100 max_fails2 fail_timeout30s; server 192.168.87.1 weight200 max_fails2 fail_timeout30s;# 这里的…...

React Native文本添加下划线

import { StyleSheet } from react-nativeconst styles StyleSheet.create({mExchangeCopyText: {fontWeight: bold, color: #1677ff, textDecorationLine: underline} })export default styles...

微服务-Nacos(配置管理)

配置更改热更新 在Nacos中添加配置信息&#xff1a; 在弹出表单中填写配置信息&#xff1a; 配置获取的步骤如下&#xff1a; 1.引入Nacos的配置管理客户端依赖&#xff08;A、B服务&#xff09;&#xff1a; <!--nacos的配置管理依赖--><dependency><groupId&…...

UML图绘制 -- 类图

1.类图的画法 类 整体是个矩形&#xff0c;第一层类名&#xff0c;第二层属性&#xff0c;第三层方法。 &#xff1a;public- : private# : protected空格: 默认的default 对应的类写法。 public class Student {public String name;public Integer age;protected I…...

【网络】每天掌握一个Linux命令 - iftop

在Linux系统中&#xff0c;iftop是网络管理的得力助手&#xff0c;能实时监控网络流量、连接情况等&#xff0c;帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...

基于Flask实现的医疗保险欺诈识别监测模型

基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施&#xff0c;由雇主和个人按一定比例缴纳保险费&#xff0c;建立社会医疗保险基金&#xff0c;支付雇员医疗费用的一种医疗保险制度&#xff0c; 它是促进社会文明和进步的…...

抖音增长新引擎:品融电商,一站式全案代运营领跑者

抖音增长新引擎&#xff1a;品融电商&#xff0c;一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中&#xff0c;品牌如何破浪前行&#xff1f;自建团队成本高、效果难控&#xff1b;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...

【SQL学习笔记1】增删改查+多表连接全解析(内附SQL免费在线练习工具)

可以使用Sqliteviz这个网站免费编写sql语句&#xff0c;它能够让用户直接在浏览器内练习SQL的语法&#xff0c;不需要安装任何软件。 链接如下&#xff1a; sqliteviz 注意&#xff1a; 在转写SQL语法时&#xff0c;关键字之间有一个特定的顺序&#xff0c;这个顺序会影响到…...

《通信之道——从微积分到 5G》读书总结

第1章 绪 论 1.1 这是一本什么样的书 通信技术&#xff0c;说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号&#xff08;调制&#xff09; 把信息从信号中抽取出来&am…...

【开发技术】.Net使用FFmpeg视频特定帧上绘制内容

目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法&#xff0c;当前调用一个医疗行业的AI识别算法后返回…...

基于Java+MySQL实现(GUI)客户管理系统

客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息&#xff0c;对客户进行统一管理&#xff0c;可以把所有客户信息录入系统&#xff0c;进行维护和统计功能。可通过文件的方式保存相关录入数据&#xff0c;对…...

scikit-learn机器学习

# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: # Also add the following code, # so that every time the environment (kernel) starts, # just run the following code: import sys sys.path.append(/home/aistudio/external-libraries)机…...

基于PHP的连锁酒店管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的连锁酒店管理系统 一 介绍 连锁酒店管理系统基于原生PHP开发&#xff0c;数据库mysql&#xff0c;前端bootstrap。系统角色分为用户和管理员。 技术栈 phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注册/登录/注销 2 个人中…...

HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散

前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说&#xff0c;在叠衣服的过程中&#xff0c;我会带着团队对比各种模型、方法、策略&#xff0c;毕竟针对各个场景始终寻找更优的解决方案&#xff0c;是我个人和我司「七月在线」的职责之一 且个人认为&#xff0c…...