当前位置: 首页 > news >正文

Pytorch学习笔记#1:拟合函数/梯度下降

学习自https://pytorch.org/tutorials/beginner/pytorch_with_examples.html

概念

Pytorch Tensor在概念上和Numpy的array一样是一个nnn维向量的。不过Tensor可以在GPU中进行计算,且可以跟踪计算图(computational graph)和梯度(gradients)。

手动梯度下降拟合函数

我们用三次函数去拟合任意函数。
y^=a+bx+cx2+dx3\hat{y}=a+bx+cx^2+dx^3y^=a+bx+cx2+dx3
定义损失函数L=∑(y−y^)2L=\sum(y-\hat{y})^2L=(yy^)2
那么梯度为:
L:2∗∑(y−y^)L:2*\sum(y-\hat{y})L:2(yy^)
a:2∗∑(y−y^)a:2*\sum(y-\hat{y})a:2(yy^)
b:2∗x∗∑(y−y^)b:2*x*\sum(y-\hat{y})b:2x(yy^)
c:2∗x2∗∑(y−y^)c:2*x^2*\sum(y-\hat{y})c:2x2(yy^)
d:2∗x3∗∑(y−y^)d:2*x^3*\sum(y-\hat{y})d:2x3(yy^)

代码

import torch
import mathdtype = torch.float
device = torch.device("cuda:0") # Run on GPU# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device,dtype=dtype)
y = torch.sin(x)# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)learning_rate = 1e-6
for t in range(2000):# Forward pass: compute predicted yy_pred = a + b * x + c * x **2 + d *x ** 3# Compute and print lossloss = (y_pred - y).pow(2).sum().item()if t % 100 == 99:print(t, loss)# Backprop to compute gradients of a, b, c, d with respect to lossgrad_y_pred = 2.0 * (y_pred - y)grad_a = grad_y_pred.sum()grad_b = (grad_y_pred * x).sum()grad_c = (grad_y_pred * x ** 2).sum()grad_d = (grad_y_pred * x ** 3).sum()# Update weights using gradient descenta -= learning_rate * grad_ab -= learning_rate * grad_bc -= learning_rate * grad_cd -= learning_rate * grad_dprint(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

自动梯度下降拟合函数

通过PyTorch: nn构建神经网络,如果我们需要一个三次函数来拟合,那么我们就需要一个隐藏层为1层,节点为3个的神经网络。
y^=∑i=13(wixi+bi)\hat{y}=\sum_{i=1}^3(w_ix^i+b_i)y^=i=13(wixi+bi)

model = torch.nn.Sequential(torch.nn.Linear(3, 1), #三个节点torch.nn.Flatten(0, 1) # 把三个节点的结果加起来
)

由于我们的神经网络第一层有三个输入(x,x2,x3x,x^2,x^3x,x2,x3),所以我们需要把数据预处理一下

x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)

然后我们预测输出就可以直接调用model了

y_pred = model(xx) # y_pred也是一个tensor

损失函数

loss_fn = torch.nn.MSELoss(reduction='sum') # 定义,使用均方误差
loss = loss_fn(y_pred, y) # 计算均方误差
model.zero_grad() # 先把原先模型的梯度信息清零
loss.backward() # 计算反向传播的梯度

完整代码

import torch
import mathx = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)model = torch.nn.Sequential(torch.nn.Linear(3, 1),torch.nn.Flatten(0, 1)
)loss_fn = torch.nn.MSELoss(reduction='sum')learning_rate = 1e-6
for t in range(2000):y_pred = model(xx)loss = loss_fn(y_pred, y)if t % 100 == 99:print(t, loss.item())model.zero_grad()loss.backward()with torch.no_grad(): # 进行梯度下降for param in model.parameters():param -= learning_rate * param.gradlinear_layer = model[0]
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')

相关文章:

Pytorch学习笔记#1:拟合函数/梯度下降

学习自https://pytorch.org/tutorials/beginner/pytorch_with_examples.html 概念 Pytorch Tensor在概念上和Numpy的array一样是一个nnn维向量的。不过Tensor可以在GPU中进行计算,且可以跟踪计算图(computational graph)和梯度(…...

挑战图像处理100问(24)——伽玛校正

伽马校正(Gamma Correction)是一种图像处理技术,用于校正显示设备的非线性响应。通过对图像进行伽马变换,可以将图像的亮度范围映射到显示设备的亮度范围内,从而提高图像的对比度和细节,改善图像的视觉效果…...

高级信息系统项目管理师(高项)软考论文评分标准(附历年高项论文题目汇总)

1、如果您想了解如何高分通过高级信息系统项目管理师(高项)你可以点击一下链接: 高级信息系统项目管理师(高项)高分通过经验分享_高项经验 2、如果您想了解更多的高级信息系统项目管理(高项 软考)原创论文&#xff0…...

MySQL实战记录篇2

事务? 1、事务的特性:原子性、一致性、隔离性、持久性 (ACID) 2、多事务同时执行的时候,可能会出现的问题:脏读、不可重复读、幻读 3、事务隔离级别:读未提交、读提交、可重复读、串行化 4、不…...

C++实现AVL树

目录 一、搜索二叉树 1.1 搜索二叉树概念 二、模拟实现二叉搜索树 2.1 框架 2.2 构造函数 2.2.1 构造函数 2.2.2 拷贝构造 2.2.3 赋值拷贝 2.3 插入函数 2.3.1 insert() 2.3.2 RcInsert() 递归实现 2.4 删除结点函数 2.4.1 Erase() 2.4.2 RcErase() 2.5 中序遍历…...

高并发语言erlang编程初步

初步 下载安装与初步使用 下载并安装,然后开始菜单中有对应的图标,打开就能进入erlang的命令行。当然也可以将其安装路径的bin文件夹加入环境变量,然后就可以在命令行中输入erl进入erlang了。 在erlang语言中,语句结束需要用.标…...

springboot 问题记录

部署到Tomcat中的时候,找不到需要部署的项目; project facets severt-name severt-class安装lombok.jar eclipse添加lombok插件后闪退打不开Clean 项目,project clean clean的作用检查插件部署项目Springboot修改端口号:applica…...

【PAT甲级题解记录】1034 Head of a Gang (30 分)

【PAT甲级题解记录】1034 Head of a Gang (30 分) 前言 Problem:1034 Head of a Gang (30 分) Tags:图的遍历 连通分量统计 DFS Difficulty:剧情模式 想流点汗 想流点血 死而无憾 Address:1034 Head of a Gang (30 分) 问题描述 …...

Python搭建一个steam钓鱼网站,只要免费领游戏,一钓一个准

前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! 我们日常上网的时候,总是会碰到一些盗号的网站,或者是别人发一些链接给你, 里面的内容是一些可以免费购物网站的优惠券、游戏官网上可以免费领取皮肤、打折的游戏。 这些盗号网站统一的目…...

maven 私服nexus安装与使用

一、下载nexus Sonatype公司的一款maven私服产品 1、官网下载地址:https://help.sonatype.com/repomanager3/product-information/download 2、csdn下载地址:https://download.csdn.net/download/u010197591/87522994 二、安装与配置 1、下载后解压如…...

详解数据结构中的顺序表的手动实现,顺序表功能接口【数据结构】

文章目录线性表顺序表接口实现尾插尾删头插头删指定位置插入指定位置删除练习线性表 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构,常见的线性表:顺序表、链表、栈、队列…...

【二】kubernetes操作

k8s卸载重置 名词解释 1、Namespace:名称用来隔离资源,不隔离网络 创建名称空间 一、命名空间namesapce 方式一:命令行创建 kubectl create ns hello删除名称空间 kubectl delete ns hello查询指定的名称空间 kubectl get pod -n kube-s…...

如何在 C++ 中调用 python 解析器来执行 python 代码(五)?

本节研究如何对 import 做白名单 目录 如何在 C 中调用 python 解析器来执行 python 代码(一)?如何在 C 中调用 python 解析器来执行 python 代码(二)?如何在 C 中调用 python 解析器来执行 python 代码&…...

八 SpringMVC【拦截器】登录验证

目录🚩一 SpringMVC拦截器✅ 1.配置文件✅2.登录验证代码(HandlerInterceptor)✅3.继承HandlerInterceptorAdapter(不建议使用)✅4.登录页面jsp✅5.主页面(操作页面)✅6.crud用户在访问页面时 只…...

PhotoShop基础使用

49:图片分类 1:像素图 特点:放大后可见,右一个个色块(像素)组合而成。 优点:容量小,纯天然 JPG、JPEG、png、GIF 2:矢量图 面向对象图像 绘图图像 特点:不…...

借助阿里云 AHPA,苏打智能轻松实现降本增效

作者:元毅 “高猛科技已在几个主要服务 ACK 集群上启用了 AHPA。相比于 HPA 的方案,AHPA 的主动预测模式额外降低了 12% 的资源成本。同时 AHPA 能够提前资源预热、自动容量规划,能够很好的应对突发流量。” ——赵劲松 (高猛科技高级后台工…...

美团2面:如何保障 MySQL 和 Redis 数据一致性?这样答,让面试官爱到 死去活来

美团2面:如何保障 MySQL 和 Redis 的数据一致性? 说在前面 在尼恩的(50)读者社群中,经常遇到一个 非常、非常高频的一个面试题,但是很不好回答,类似如下: 如何保障 MySQL 和 Redis…...

react hooks学习记录

react hook学习记录1.什么是hooks2.State Hook3.Effect Hook4.Ref Hook1.什么是hooks (1). Hook是React 16.8.0版本增加的新特性/新语法 (2). 可以让你在函数组件中使用 state 以及其他的 React 特性 貌似现在更多的也是使用函数式组件的了,重要 2.State Hook imp…...

创新型中小企业认定评定标准

一、公告条件评价得分达到 60 分以上(其中创新能力指标得分不低于 20 分、成长性指标及专业化指标得分均不低于 15 分),或满足下列条件之一:(一)近三年内获得过国家级、省级科技奖励。(二&#…...

记录一次nginx转发代理skywalking白屏 以及nginx鉴权配置

上nginx代码 #user nobody; worker_processes 1; #error_log logs/error.log; #error_log logs/error.log notice; #error_log logs/error.log info; #pid logs/nginx.pid; events { worker_connections 1024; } http { include mime.types; …...

如何使用FarsightAD在活动目录域中检测攻击者部署的持久化机制

关于FarsightAD FarsightAD是一款功能强大的PowerShell脚本,该工具可以帮助广大研究人员在活动目录域遭受到渗透攻击之后,检测到由攻击者部署的持久化机制。 该脚本能够生成并导出各种对象及其属性的CSV/JSON文件,并附带从元数据副本中获取…...

Python - 操作txt文件

文章目录打开txt文件读取txt文件写入txt文件删除txt文件打开txt文件 open(file, moder, bufferingNone, encodingNone, errorsNone, newlineNone, closefdTrue)函数用来打开txt文件。 #方法1,这种方式使用后需要关闭文件 f open("data.txt","r&qu…...

老杜MySQL入门基础 1

1 数据库:DataBase(存储数据的仓库) 2 数据库管理系统:DataBaseManagementSystem(DBMS)(管理数据库中的数据的) DBMS可以对数据库中的数据进行增删改查常见的数据库管理系统:MySQL、Oracle、SQLserver 3 SQL:结构化查询语言 编…...

Vue中splice的使用

splice(index,len,[item])它也可以用来替换/删除/添加数组内某一个或者几个值(该方法会改变原始数组) index:数组开始下标 len: 替换/删除的长度 item:替换的值,删除操作的话 item为空 删除: //删除起始下标为1&…...

Ubuntu通过rsync和inotify实现双机热备

Rsync Inotify双机热备 一、备份机操作 备份机:主服务器或主机文件将需要备份的文件同步到此服务器上,即从主服务器上同步过来进行备份。 1.1安装rsync sudo apt-get install rsync1.2修改/etc/dault/rsync文件 sudo vim /etc/default/rsync修改如…...

【python】异常详解

注:最后有面试挑战,看看自己掌握了吗 文章目录错误分类捕捉异常实例finally的使用捕捉特定异常抛出异常用户自定义异常🌸I could be bounded in a nutshell and count myself a king of infinite space. 特别鸣谢:木芯工作室 、I…...

pc、移动端自适应css

第一步按照vant官网给的rem适配,安装 postcss-pxtorem:npm install postcss-pxtorem; 第二步安装lib-flexible:npm i -S amfe-flexible,记得在main.js文件引入 import ‘amfe-flexible’; 第三步进行 postc…...

Threejs 教程1

threejs核心概念场景、照相机、对象、光、渲染器等1.1.场景Scene 场景是所有物体的容器,对应着显示生活中的三维世界,所有的可视化对象级相关的动作均发生在场景中。1.2.照相机Camera照相机是三维世界中的观察者,类似与眼睛。为了观察这个世界…...

WuThreat身份安全云-TVD每日漏洞情报-2023-02-23

漏洞名称:VMware vRealize Orchestrator 安全漏洞 漏洞级别:中危 漏洞编号:CVE-2023-20855,CNNVD-202302-1754 相关涉及:VMware Cloud Foundation 4.0 漏洞状态:未定义 参考链接:https://tvd.wuthreat.com/#/listDetail?TVD_IDTVD-2023-04396 漏洞名称:TP-LINK Archer C50 WE…...

C语言--模拟实现库函数qsort

什么是qsort qsort是一个库函数,是用来排序的库函数,使用的是快速排序的方法(quicksort)。 qsort的好处在于: 1,现成的 2,可以排序任意类型的数据。 在之前我们已经学过一种排序方法:冒泡排序。排序的原理…...

用前端框架做自适应网站/网站搜索排名优化

1 TreeMap 实战在创建 TreeMap 对象时,如果使用参数为空的构造方法,则根据 Map 对象的 key 进行排序(key必须实现Comparable接口);如果使用参数为 Comparator 的构造方法,则根据 Comparator对key 进行排序。方式一:自然…...

怎样给公司做网站/seo竞价

combo-tree是一款jQuery带多选和过滤功能的树状结构下拉框插件。通过该插件,可以在下拉框中生成指定数据结构的目录树,提供单选和多选,以及过滤功能。它的特点有:在下拉框中显示树状结构。支持单选和多选。返回选择数据的 title 或…...

有没有做废品的网站/个人代运营一般怎么收费

这里主要是用户名与密码的判断:先用sharedpreferences方式存储数据,包含用户名和密码:username,password然后在登录的时候进行判断:代码如下:String name et_username.getText().toString(); String passw…...

泰州哪家做网站建设比较好/网络营销策划书

问题来源: http://www.cnblogs.com/del/archive/2009/06/05/1496857.html#1549294TTrayIcon 的主要属性:TrayIcon.Icon指定托盘图标, 有几种用法:1、设计时选择;2、把一个 TIcon 对象给它;3、使用当前程序图标: TrayIcon1.Icon : Application.Icon;4、TrayIcon1.SetDefaultIcon…...

文网站建设/中国舆情观察网

http://www.bubuko.com/infodetail-648898.html转载于:https://www.cnblogs.com/beijingstruggle/p/5518269.html...

wordpress附件管理/百度优化插件

2019独角兽企业重金招聘Python工程师标准>>> MAVEN的生命周期和插件 maven是通过插件来实现功能的。所谓的生命周期就是我们在构建项目时,maven默认需要是想的一些功能,而每一个功能就通过插件的某一功能来实现。 每个插件会有一个或多个功能…...