Pytorch实现线性回归Linear Regression
借助 PyTorch 实现深度神经网络 - 线性回归 - 第 2 周 | Coursera
线性回归预测
用PyTorch实现线性回归模块

创建自定义模块(内含一个线性回归)

训练线性回归模型
对于线性回归,特定类型的噪声是高斯噪声

平均损失均方误差函数:

loss求解(导数=0):

梯度下降

表示学习率
学习率过高,可能错过参数的最佳值
学习率过低,需要大量的迭代才能获得最小值
Batch Gradient Descent:使用整个训练集来更新模型的参数
用Pytorch实现线性回归--梯度




每个epoch就是一个iteration:

画图版:
import torch
w=torch.tensor(-10.0,requires_grad=True)
X=torch.arange(-3,3,0.1).view(-1,1)
f=-3*X
# The class for plottingclass plot_diagram():# Constructordef __init__(self, X, Y, w, stop, go = False):start = w.dataself.error = []self.parameter = []print(type(X.numpy()))self.X = X.numpy()self.Y = Y.numpy()self.parameter_values = torch.arange(start, stop)self.Loss_function = [criterion(forward(X), Y) for w.data in self.parameter_values] w.data = start# Executordef __call__(self, Yhat, w, error, n):self.error.append(error)self.parameter.append(w.data)plt.subplot(212)plt.plot(self.X, Yhat.detach().numpy())plt.plot(self.X, self.Y,'ro')plt.xlabel("A")plt.ylim(-20, 20)plt.subplot(211)plt.title("Data Space (top) Estimated Line (bottom) Iteration " + str(n))# Convert lists to PyTorch tensorsparameter_values_tensor = torch.tensor(self.parameter_values)loss_function_tensor = torch.tensor(self.Loss_function)# Plot using the tensorsplt.plot(parameter_values_tensor.numpy(), loss_function_tensor.numpy())plt.plot(self.parameter, self.error, 'ro')plt.xlabel("B")plt.figure()# Destructordef __del__(self):plt.close('all')
gradient_plot = plot_diagram(X, Y, w, stop = 5)
# Define a function for train the modeldef train_model(iter):LOSS=[]for epoch in range (iter):# make the prediction as we learned in the last labYhat = forward(X)# calculate the iterationloss = criterion(Yhat,Y)# plot the diagram for us to have a better ideagradient_plot(Yhat, w, loss.item(), epoch)# store the loss into listLOSS.append(loss.item())# backward pass: compute gradient of the loss with respect to all the learnable parametersloss.backward()# updata parametersw.data = w.data - lr * w.grad.data# zero the gradients before running the backward passw.grad.data.zero_()
train_model(4)




用Pytorch实现线性回归--训练
与上文类似,只是多加了个b
梯度



画函数图:
# The class for plot the diagramclass plot_error_surfaces(object):# Constructordef __init__(self, w_range, b_range, X, Y, n_samples = 30, go = True):W = np.linspace(-w_range, w_range, n_samples)B = np.linspace(-b_range, b_range, n_samples)w, b = np.meshgrid(W, B) Z = np.zeros((30,30))count1 = 0self.y = Y.numpy()self.x = X.numpy()for w1, b1 in zip(w, b):count2 = 0for w2, b2 in zip(w1, b1):Z[count1, count2] = np.mean((self.y - w2 * self.x + b2) ** 2)count2 += 1count1 += 1self.Z = Zself.w = wself.b = bself.W = []self.B = []self.LOSS = []self.n = 0if go == True:plt.figure()plt.figure(figsize = (7.5, 5))plt.axes(projection='3d').plot_surface(self.w, self.b, self.Z, rstride = 1, cstride = 1,cmap = 'viridis', edgecolor = 'none')plt.title('Cost/Total Loss Surface')plt.xlabel('w')plt.ylabel('b')plt.show()plt.figure()plt.title('Cost/Total Loss Surface Contour')plt.xlabel('w')plt.ylabel('b')plt.contour(self.w, self.b, self.Z)plt.show()# Setterdef set_para_loss(self, W, B, loss):self.n = self.n + 1self.W.append(W)self.B.append(B)self.LOSS.append(loss)# Plot diagramdef final_plot(self): ax = plt.axes(projection = '3d')ax.plot_wireframe(self.w, self.b, self.Z)ax.scatter(self.W,self.B, self.LOSS, c = 'r', marker = 'x', s = 200, alpha = 1)plt.figure()plt.contour(self.w,self.b, self.Z)plt.scatter(self.W, self.B, c = 'r', marker = 'x')plt.xlabel('w')plt.ylabel('b')plt.show()# Plot diagramdef plot_ps(self):plt.subplot(121)plt.ylimplt.plot(self.x, self.y, 'ro', label="training points")plt.plot(self.x, self.W[-1] * self.x + self.B[-1], label = "estimated line")plt.xlabel('x')plt.ylabel('y')plt.ylim((-10, 15))plt.title('Data Space Iteration: ' + str(self.n))plt.subplot(122)plt.contour(self.w, self.b, self.Z)plt.scatter(self.W, self.B, c = 'r', marker = 'x')plt.title('Total Loss Surface Contour Iteration' + str(self.n))plt.xlabel('w')plt.ylabel('b')plt.show()
相关文章:
Pytorch实现线性回归Linear Regression
借助 PyTorch 实现深度神经网络 - 线性回归 - 第 2 周 | Coursera 线性回归预测 用PyTorch实现线性回归模块 创建自定义模块(内含一个线性回归) 训练线性回归模型 对于线性回归,特定类型的噪声是高斯噪声 平均损失均方误差函数:…...
十八次(虚拟主机与vue项目、samba磁盘映射、nfs共享)
1、虚拟主机搭建环境准备 将原有的nginx.conf文件备份 [rootserver ~]# cp /usr/local/nginx/conf/nginx.conf /usr/local/nginx/conf/nginx.conf.bak[rootserver ~]# grep -Ev "#|^$" /usr/local/nginx/conf/nginx.conf[rootserver ~]# grep -Ev "#|^$"…...
P1340 兽径管理 题解|最小生成树
题目大意 洛谷中链接 推荐文章:并查集入门 原文 约翰农场的牛群希望能够在 N N N 个草地之间任意移动。草地的编号由 1 1 1 到 N N N。草地之间有树林隔开。牛群希望能够选择草地间的路径,使牛群能够从任一 片草地移动到任一片其它草地。 牛群可在…...
Python,Maskrcnn训练,cannot import name ‘saving‘ from ‘keras.engine‘ ,等问题集合
Python版本3.9,tensorflow2.11.0,keras2.11.0 问题一、module keras.engine has no attribute Layer Traceback (most recent call last):File "C:\Users\Administrator\Desktop\20240801\代码\test.py", line 16, in <module>from mrc…...
Linux常用工具
文章目录 tar打包命令详解unzip命令:解压zip文件vim操作详解netstat详解df命令详解ps命令详解find命令详解 tar打包命令详解 tar命令做打包操作 当 tar 命令用于打包操作时,该命令的基本格式为: tar [选项] 源文件或目录此命令常用的选项及…...
AI未来的发展如何
AI(人工智能)的发展前景非常广阔,随着技术的不断进步和应用场景的不断拓展,AI将在多个领域发挥重要作用。以下是对AI发展前景的详细分析: 一、技术突破与创新 生成式AI的兴起:以ChatGPT为代表的生成式AI技…...
若依替换首页上的logo
...
sed的使用示例
场景:使用sed将多个空格变成单空格,再使用cut来切分得到需要的结果 得到后面这个文件名: ls ./ drwxr-x— 2 root root 6 Jul 18 9:00 7b40f1412d83c1524af7977593607f15 drwxr-x— 2 root root 6 Jul 18 14:00 50af29cef2c65a9d28905a3ce831bcb7 drwxr-x— 2 root root 6 Jul…...
学历不是障碍:大专生如何成功进入软件测试行业
摘要: 在当今技术驱动的职场环境中,软件测试已成为一个关键的职业领域。尽管许多人认为高学历是进入这一行业的先决条件,但实际上,大专学历的学生同样有机会在软件测试领域取得成功。本文将探讨大专生如何通过技能提升、实践经验和…...
文件解析漏洞—IIS解析漏洞—IIS6.X
目录 方式 1:目录解析 方式 2:畸形文件解析 方式 3:PUT 上传漏洞(123.asp;.jpg 解析成 asp) 环境:Windows server 2003 添加 IIS 管理工具——打开 IIS——添加网站 创建完成之后,右击创建的…...
Sqlmap中文使用手册 - Brute force模块参数使用
目录 1. Brute force模块的帮助文档2. 各个参数的介绍2.1 --common-tables2.2 --common-columns2.3 --common-files 1. Brute force模块的帮助文档 Brute force:These options can be used to run brute force checks--common-tables Check existence of common tables--c…...
ubuntu20.04 开源鸿蒙源码编译配置
替换华为源 sudo sed -i "shttp://.*archive.ubuntu.comhttp://repo.huaweicloud.comg" /etc/apt/sources.list && sudo sed -i "shttp://.*security.ubuntu.comhttp://repo.huaweicloud.comg" /etc/apt/sources.list 安装依赖工具 如果是ubun…...
程序员面试 “八股文”在实际工作中是助力、阻力还是空谈?
“八股文”在实际工作中是助力、阻力还是空谈? 作为现在各类大中小企业面试程序员时的必问内容,“八股文”似乎是很重要的存在。但“八股文”是否能在实际工作中发挥它“敲门砖”应有的作用呢?有IT人士不禁发出疑问:程序员面试考…...
广告从用户点击开始到最终扣费的过程
用户点击广告 用户在网页或移动应用上看到广告,并点击广告。这一事件触发了整个广告处理流程。 广告请求触发 用户点击广告后,客户端(如浏览器、APP)向广告系统发送广告点击请求。请求通常包含以下信息: 用户ID 设备信…...
Linux系统编程-信号进程间通信
目录 异步(Asynchronous) 信号 数据结构 1.kill 2.alarm 3.pause 4.setitimer 5.abort 信号集(sigset_t类型) 1.sigemptyset 2.sigfillset 3.sigaddset 4.sigdelset 5.sigismember 信号屏蔽 1.sigprocmask 2.sigpending 3.sigsus…...
Attention Module (SAM)是什么?
SAM(Spatial Attention Module,空间注意力模块)是一种在神经网络中应用的注意力机制,特别是在处理图像数据时,它能够帮助模型更好地关注输入数据中不同空间位置的重要性。以下是关于SAM的详细解释: 1. 基本…...
【C语言】堆排序
堆排序即利用堆的思想来进行排序,总共分为两个步骤: 1. 建堆 升序:建大堆 降序:建小堆 原因分析: 若升序建小堆时间复杂度是O(N^2) 升序建大堆,时间复杂度O(N*logN) 所以升序建大堆…...
ntp服务重启报错Failed to restart ntpd.service: Unit is masked.
问题概述: 重启ntp服务报错Failed to restart ntpd.service: Unit is masked,使用systemctl unmask ntpd.service命令关闭屏蔽还是报错Failed to restart ntpd.service: Unit is masked 解决方法: 重装ntp服务 yum remove ntpyum install…...
面试题-每日5到
16.Files的常用方法都有哪些? Files.exists():检测文件路径是否存在 Files.createFile():创建文件 Files.createDirectory():创建文件夹 Files.delete():删除一个文件或目录 Files.copy():复制文件 Files.move():移动文件 Files.size():查看文件个数 Files.read():读…...
代码美学大师:打造Perl中的个性化代码格式化工具
代码美学大师:打造Perl中的个性化代码格式化工具 在软件开发过程中,代码的可读性至关重要。Perl,作为一种灵活的脚本语言,允许开发者以多种方式实现代码格式化。自定义代码格式化工具不仅能提升代码质量,还能加强团队…...
React 第五十五节 Router 中 useAsyncError的使用详解
前言 useAsyncError 是 React Router v6.4 引入的一个钩子,用于处理异步操作(如数据加载)中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误:捕获在 loader 或 action 中发生的异步错误替…...
日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻
在如今就业市场竞争日益激烈的背景下,越来越多的求职者将目光投向了日本及中日双语岗位。但是,一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧?面对生疏的日语交流环境,即便提前恶补了…...
23-Oracle 23 ai 区块链表(Blockchain Table)
小伙伴有没有在金融强合规的领域中遇见,必须要保持数据不可变,管理员都无法修改和留痕的要求。比如医疗的电子病历中,影像检查检验结果不可篡改行的,药品追溯过程中数据只可插入无法删除的特性需求;登录日志、修改日志…...
基于Flask实现的医疗保险欺诈识别监测模型
基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...
对WWDC 2025 Keynote 内容的预测
借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...
Qt Http Server模块功能及架构
Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...
今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存
文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...
【生成模型】视频生成论文调研
工作清单 上游应用方向:控制、速度、时长、高动态、多主体驱动 类型工作基础模型WAN / WAN-VACE / HunyuanVideo控制条件轨迹控制ATI~镜头控制ReCamMaster~多主体驱动Phantom~音频驱动Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation速…...
招商蛇口 | 执笔CID,启幕低密生活新境
作为中国城市生长的力量,招商蛇口以“美好生活承载者”为使命,深耕全球111座城市,以央企担当匠造时代理想人居。从深圳湾的开拓基因到西安高新CID的战略落子,招商蛇口始终与城市发展同频共振,以建筑诠释对土地与生活的…...
JavaScript基础-API 和 Web API
在学习JavaScript的过程中,理解API(应用程序接口)和Web API的概念及其应用是非常重要的。这些工具极大地扩展了JavaScript的功能,使得开发者能够创建出功能丰富、交互性强的Web应用程序。本文将深入探讨JavaScript中的API与Web AP…...


