当前位置: 首页 > news >正文

pytorch学习---实现线性回归初体验

假设我们的基础模型就是y = wx + b,其中w和b均为参数,我们使用y = 3x+0.8来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8。

步骤如下:

  1. 准备数据
  2. 计算预测值
  3. 计算损失,把参数的梯度置为0,进行反向传播
  4. 更新参数

方式一

该方式没有用pytorch的模型api,手动实现

import torch,numpy
import matplotlib.pyplot as plt# 1、准备数据
learning_rate = 0.01
#y=3x + 0.8
x = torch.rand([500,1])
y_true= x*3 + 0.8# 2、通过模型计算y_predict
w = torch.rand([1,1],requires_grad=True)
b = torch.tensor(0,requires_grad=True,dtype=torch.float32)# 3、通过循环,反向传播,更新参数
for i in range(500):# 4、计算lossy_predict = torch.matmul(x,w) + bloss = (y_true-y_predict).pow(2).mean()# 每次循环判断是否存在梯度,防止累加if w.grad is not None:w.grad.data.zero_()if b.grad is not None:b.grad.data.zero_()# 反向传播loss.backward()w.data = w.data - learning_rate*w.gradb.data = b.data - learning_rate*b.grad# 每50次输出一下结果if i%50==0:print("w,b,loss",w.item(),b.item(),loss.item())#可视化显示
plt.figure(figsize=(20,8))
plt.scatter(x.numpy().reshape(-1),y_true.numpy().reshape(-1))
y_predict = torch.matmul(x,w) + b
plt.plot(x.numpy().reshape(-1),y_predict.detach().numpy().reshape(-1),c="r")
plt.show()

循环500次的效果
在这里插入图片描述
循环2000次的结果
在这里插入图片描述

方式二

方式一的方式虽然已经购简便了,但是还是有些许繁琐,所以我们可以采用pytorchapi来实现。
nn.Moduletorch.nn提供的一个类,是pytorch中我们自定义网络的一个基类,在这个类中定义了很多有用的方法,让我们在继承这个类定义网络的时候非常简单。
当我们自定义网络的时候,有两个方法需要特别注意:
1.__init__需要调用super方法,继承父类的属性和方法
2. forward方法必须实现,用来定义我们的网络的向前计算的过程用前面的y = wx+b的模型举例如下:

#定义模型
from torch import nn
class Lr(nn.Module): #继承nn.Moduledef __init__(self):super(Lr, self).__init__()self.linear = nn.Linear(1,1)def forward(self,x):out = self.linear(x)return out

全部代码如下:

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt#1、定义数据
x = torch.rand([50,1])
y = x*3 + 0.8#定义模型
class Lr(nn.Module): #继承nn.Moduledef __init__(self):super(Lr, self).__init__()self.linear = nn.Linear(1,1)def forward(self,x):out = self.linear(x)return out
#2、实例化模型、loss函数以及优化器
model = Lr()
criterion = nn.MSELoss()   #损失函数
optimizer = optim.SGD(model.parameters(),lr=1e-3) #优化器#3、训练模型
for i in range(3000):out = model(x)# 获取预测值loss = criterion(y,out) #计算损失optimizer.zero_grad() #梯度归零loss.backward() #计算梯度optimizer.step() #更新梯度if(i+1) % 20 ==0:print('Epoch[{}/{}],loss:{:.6f}'.format(i,500,loss.data))#4、模型评估
model.eval() #设置模型为评估模式,即预测模式
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(),y.data.numpy(),c="r")
plt.plot(x.data.numpy(),predict)
plt.show()

注意:

model.eval()表示设置模型为评估模式,即预测模式

model.train(mode=True) 表示设置模型为训练模式

在当前的线性回归中,上述并无区别

但是在其他的一些模型中,训练的参数和预测的参数会不相同,到时候就需要具体告诉程序我们是在进行训练还是预测,比如模型中存在DropoutBatchNorm的时候

循环2000次的结果:
在这里插入图片描述
循环30000次的结果:
在这里插入图片描述

相关文章:

pytorch学习---实现线性回归初体验

假设我们的基础模型就是y wx b,其中w和b均为参数,我们使用y 3x0.8来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8。 步骤如下: 准备数据计算预测值计算损失,把参数的梯度置为0,进行反向传播…...

别再乱写git commit了

B站|公众号:啥都会一点的研究生 写在前面 在很长的一段时间中,使用git commit都是随心所欲,log肥肠简洁,随着代码的迭代,当时有多偷懒,返过头查看git日志就有多懊悔,就和写代码不写doc string…...

八大排序(一)冒泡排序,选择排序,插入排序,希尔排序

一、冒泡排序 冒泡排序的原理是:从左到右,相邻元素进行比较。每次比较一轮,就会找到序列中最大的一个或最小的一个。这个数就会从序列的最右边冒出来。 以从小到大排序为例,第一轮比较后,所有数中最大的那个数就会浮…...

泊松分布简要介绍

泊松分布是一种常见的离散概率分布,它用于描述某个时间段或区域内随机事件发生的次数。它得名于法国数学家西蒙丹尼泊松。 泊松分布的概率质量函数表示某个时间段或区域内事件发生次数的概率。如果随机变量 X 服从泊松分布,记作 X ~ Poisson(λ)&#x…...

C语言每日一题(10):无人生还

文章主题:无人生还🔥所属专栏:C语言每日一题📗作者简介:每天不定时更新C语言的小白一枚,记录分享自己每天的所思所想😄🎶个人主页:[₽]的个人主页🏄&#x1f…...

VSCode开发go手记

断点调试: 安装delve(windows): go get -u github.com/go-delve/delve/cmd/dlv 设置 launch.json 配置文件: ctrlshiftp 输入 Debug: Open launch.json 打开 launch.json 文件,如果第一次打开,会新建一…...

怎么选择AI伪原创工具-AI伪原创工具有哪些

在数字时代,创作和发布内容已经成为了一种不可或缺的活动。不论您是个人博主、企业家还是网站管理员,都会面临一个共同的挑战:如何在互联网上脱颖而出,吸引更多的读者和访客。而正是在这个背景下,AI伪原创工具逐渐崭露…...

【块状链表C++】文本编辑器(指针中 引用 的使用)

》》》算法竞赛 /*** file * author jUicE_g2R(qq:3406291309)————彬(bin-必应)* 一个某双流一大学通信与信息专业大二在读 * * brief 一直在竞赛算法学习的路上* * copyright 2023.9* COPYRIGHT 原创技术笔记:转载…...

echarts的Y轴设置为整数

场景:使用echarts,设置Y轴为整数。通过判断Y轴的数值为整数才显示即可 yAxis: [{name: ,type: value,min: 0, // 最小值// max: 200, // 最大值// splitNumber: 5, // 坐标轴的分割段数// interval: 100 / 5, // 强制设置坐标轴分割间隔度(取本Y轴的最大…...

恢复删除文件?不得不掌握的4个方法!

“删除了的文件还可以恢复吗?有个文件我本来以为不重要了,就把它删除了,没想到现在还需要用到!这可怎么办?有没有办法找回来呢?” 重要的文件一旦丢失或误删可能都会对我们的工作和学习造成比较大的影响。怎…...

GitLab CI/CD:.gitlab-ci.yml 文件常用参数小结

文章目录 一、.gitlab-ci.yml 文件作用二、一个简单的.gitlab-ci.yml 文件示例参考 一、.gitlab-ci.yml 文件作用 可以定义跑CI时想要运行的命令或脚本 可以定义job之间的依赖和缓存 可以执行程序部署并定义部署位置 可以定义想要包含的其他配置文件和模版 二、一个简单的.gi…...

MySQL学习笔记9

MySQL数据表中的数据类型: 在考虑数据类型、长度、标度和精度时,一定要仔细地进行短期和长远的规划,另外,公司制度和希望用户用什么方式访问数据也是要考虑的因素。开发人员应该了解数据的本质,以及数据在数据库里是如…...

从零学习开发一个RISC-V操作系统(三)丨嵌入式操作系统开发的常用概念和工具

本篇文章的内容 一、嵌入式操作习系统开发的常用概念和工具1.1 本地编译和交叉编译1.2 调试器GDB(The GNU Project Debugger)1.3 QEMU模拟器1.4 项目构造工具Make 本系列是博主参考B站课程学习开发一个RISC-V的操作系统的学习笔记,计划从RISC…...

小米机型解锁bl 跳“168小时”限制 操作步骤分析

写到前面的安全提示 了解解锁bl后的风险: 解锁设备后将允许修改系统重要组件,并有可能在一定程度上导致设备受损;解锁后设备安全性将失去保证,易受恶意软件攻击,从而导致个人隐私数据泄露;解锁后部分对系…...

基础练习 回文数

问题描述 1221是一个非常特殊的数&#xff0c;它从左边读和从右边读是一样的&#xff0c;编程求所有这样的四位十进制数。 输出格式 按从小到大的顺序输出满足条件的四位十进制数。 solution1 #include <stdio.h> int main(){int n 1000, n1, n2, n3, n4;while(n &…...

解决Spring Boot 2.7.16 在服务器显示启动成功无法访问问题:从本地到服务器的部署坑

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…...

洛谷P5661:公交换乘 ← CSP-J 2019 复赛第2题

【题目来源】https://www.luogu.com.cn/problem/P5661https://www.acwing.com/problem/content/1164/【题目描述】 著名旅游城市 B 市为了鼓励大家采用公共交通方式出行&#xff0c;推出了一种地铁换乘公交车的优惠方案&#xff1a; 1.在搭乘一次地铁后可以获得一张优惠票&…...

mysql优化之索引

索引官方定义&#xff1a;索引是帮助mysql高效获取数据的数据结构。 索引的目的在于提高查询效率&#xff0c;可以类比字典。 可以简单理解为&#xff1a;排好序的快速查找数据结构 在数据之外&#xff0c;数据库系统还维护着满足特定查找算法的数据结构&#xff0c;这种数据…...

文件系统详解

目录 文件系统&#xff08;1&#xff09; 第一节文件系统的基本概念 一、文件系统的任务 二、文件的存储介质及存储方式 三、文件的分类 第二节 文件的逻辑结构和物理结构 一、文件的逻辑结构 二、文件的物理结构 文件系统&#xff08;2&#xff09; 第三节 文件目…...

有名管道及其应用

创建FIFO文件 1.通过命令&#xff1a; mkfifo 文件名 2.通过函数: mkfifo #include <sys/types.h> #include <sys/stat.h> int mkfifo(const char *pathname, mode_t mode); 参数&#xff1a; -pathname&#xff1a;管道名称的路径 -mode&#xff1a;文件的权限&a…...

加州大学伯克利分校 计算机科学专业

加州大学伯克利分校 计算机科学专业 cs 61a cs 61b cs61c...

一个关于 i++ 和 ++i 的面试题打趴了所有人

前言 都说大城市现在不好找工作&#xff0c;可小城市却也不好招人。 我们公司招了挺久都没招到&#xff0c;主管感到有些心累。 我提了点建议&#xff0c;是不是面试问的太深了&#xff0c;在这种小城市&#xff0c;能干活就行。 他说自己问的面试题都很浅显&#xff0c;如果答…...

程序员的快乐如此简单

最近在GitHub上发起了一个关于Beego框架的小插件的开源仓库&#xff0c;这一举动虽然看似微小&#xff0c;但其中的快乐和意义却是无法用言语表达的。 Beego是一个开源的Go语言Web框架&#xff0c;它采用了MVC架构模式&#xff0c;并集成了很多常用的功能和中间件。小插件是指…...

浅谈云原生Cloud Native

目录 1.云原生是什么2.云原生与传统软件有什么区别3.云原生有哪些代表性的技术 1.云原生是什么 云原生&#xff08;Cloud Native&#xff09;是一种构建和运行应用程序的方法&#xff0c;可以充分利用云计算模型的优势。云原生是一种面向服务的架构&#xff08;SOA&#xff09…...

解决react报错“JSX 表达式必须具有一个父元素“

现象如下&#xff1a; 原因&#xff1a; 新插入的dom元素跟已有的dom元素平级了&#xff0c;必须创建一个共有的根元素 解决办法&#xff1a; 使用<> </>标签作为根元素&#xff0c;把所有子元素包裹起来 <> ....原代码 </> 问题解决&#xff01;…...

Spring学习笔记7 Bean的生命周期

Spring学习笔记6 Bean的实例化方式_biubiubiu0706的博客-CSDN博客 Spring其实就是一个管理Bean对象的工厂.它负责对象的创建,对象的销毁. 这样我们才可以知道在哪个时间节点上调用了哪个类的哪个方法,知道代码该写在哪里 Bean的生命周期之粗略5步 Bean生命周期的管理可以参考S…...

React 如何导出excel

在现代的Web开发中&#xff0c;数据导出是一个非常常见的需求。而在React应用中&#xff0c;我们经常需要将数据导出为Excel文件&#xff0c;以便用户可以轻松地在本地计算机上查看和编辑。本文将介绍如何在React应用中实现导出Excel文件的功能。 章节一&#xff1a;安装依赖 …...

Texlive2020 for win10 宏包更新

用命令提示符更新texlive的宏包,这个方法非常简单实用 1.以管理员身份打开命令提示符 2.系统自动选择镜像网站 tlmgr option repository ctan 3.更新宏包 tlmgr update --self --all 其中–self参数表示升级tlmgr本身,–all表示升级所有宏包,这样就可以将所有宏包更新了 4.列…...

Ps 在用鼠标滚轮缩放图片时,速度太快?

1.原因 在于安装了第三方鼠标优化软件Mos&#xff0c;它起着对第三方鼠标全局浏览效果的优化&#xff0c;使浏览更加顺滑&#xff0c;而不精确&#xff0c;消除了mac使用第三方鼠标浏览页面时的卡顿问题。这也使得像ps、ai这类软件&#xff0c;在进行页面缩放时&#xff0c;变得…...

基于docker进行Grafana + prometheus实现服务监听

基于docker进行Grafana Prometheus实现服务监听 Grafana安装Prometheus安装Jvm监控配置服务器主机监控(基础cpu&#xff0c;内存&#xff0c;磁盘&#xff0c;网络) Grafana安装 docker pull grafana/grafanamkdir /server/grafanachmod 777 /server/grafanadocker run -d -p…...

网站建设金手指稳定/百度首页排名优化服务

目录 第一部 Java入门 第二部 开发框架【知识深度升华】 第三部 【扩展知识广度】 第四部 【源码学习】 第一部 Java入门 &#xff08;根据你聪明智慧&#xff1a;建议快速游览大致了解都是有什么东东就行了&#xff09; https://www.nowcoder.com/stack/23 jav…...

南山做棋牌网站建设/北京网站提升排名

Postgresql统计慢查询 在postgresl开启pg_stat_statements模块后&#xff0c;可通过下方SELECT查找出执行平均时间最长的sql SELECT query AS SQL语句 ,calls 被执行的次数 ,min_time ,max_time ,mean_time ,stddev_time ,total_time 总执行时间 ,rows AS 检索行数 ,100.0 *…...

计算机专业有哪些/seo精灵

head元素元素包含了所有的头部标签元素可以添加在头部区域的元素标签为:title,style,meta,link,script,noscript,base定义不同文档的标题。定义了浏览器工具栏的标题。当网页添加到收藏夹时&#xff0c;显示在收藏夹中的标题。显示在搜索引擎结果页面的标题。eg: 我是标题 标签…...

wordpress 个人博客 主题/成都seo学徒

一维数组定义 定义形式&#xff1a;类型名 数组名 [数组长度]; ——数组长度是一个整型常量表达式。 例如&#xff1a; int a [10];char c [200]; 2.引用形式&#xff1a; 数组名 [ 下标 ] ——C语言规定只能引用单个的数组元素。 a[0…...

网站百度不到验证码怎么办/刷外链工具

Java NIO &#xff1a; 同步非阻塞&#xff0c;服务器实现模式为一个请求一个线程&#xff0c;即客户端发送的连接请求都会注册到多路复用器上&#xff0c;多路复用器轮询到连接有I/O请求时才启动一个线程进行处理。Java AIO(NIO.2) &#xff1a; 异步非阻塞&#xff0c;服务器…...

网站建设上的新闻/中山seo推广优化

1.修改网卡&#xff0c;修改配置文件/etc/network/interfaces之后&#xff0c;重启network出现错误&#xff1a;Unit network.service failed to load。 不用service network restart 用service network-manager restart重启&#xff0c;不会报错。如果没有用&#xff0c;只…...