当前位置: 首页 > news >正文

PyTorch实现逻辑回归

最终效果

先看下最终效果:
1
这里用一条直线把二维平面上不同的点分开。

生成随机数据

#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)

数据可视化

def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()

使用x和o来表示两种不同类别的数据。
1

定义模型和损失函数

#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)  # 随机初始化w
b = torch.zeros((1),requires_grad=True)  # 使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()

这里使用了平方损失函数来估算模型准确度。

训练模型

最多训练100次,每次都会更新模型参数,当损失值小于0.03时停止训练。

xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:  # 停止条件break

全部代码

import torch
import matplotlib.pyplot as plt
import matplotlib.markers as mmarkers#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + bn_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)#随机初始化w
b = torch.zeros((1),requires_grad=True)#使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:#停止条件break

相关文章:

PyTorch实现逻辑回归

最终效果 先看下最终效果&#xff1a; 这里用一条直线把二维平面上不同的点分开。 生成随机数据 #创建训练数据 x torch.rand(10,1)*10 #shape(10,1) y 2*x (5 torch.randn(10,1))#构建线性回归参数 w torch.randn((1))#随机初始化w&#xff0c;要用到自动梯度求导 b …...

什么是FPGA原型验证?

EDA工具的使用主要分为设计、验证和制造三大类。验证工作贯穿整个芯片设计流程&#xff0c;可以说芯片的验证阶段占据了整个芯片开发的大部分时间。从芯片需求定义、功能设计开发到物理实现制造&#xff0c;每个环节都需要进行大量的验证。 现如今验证方法也越来越多&#xff…...

基于VUE3+Layui从头搭建通用后台管理系统(前端篇)十四:系统设置模块相关功能实现

一、本章内容 本章使用已实现的公共组件实现系统管理中的系统设置模块相关功能,包括菜单管理、角色管理、日志管理、用户管理、系统配置、数据字典等。 1. 详细课程地址: 待发布 2. 源码下载地址: 待发布 二、界面预览 三、开发视频 3.1 B站视频地址:...

使用Visual Studio(VS)创建空项目的Win32桌面应用程序【main函数入口变WinMain】

前言 在Visual Studio中直接新建Windows桌面应用程序会有很多多余的代码生成&#xff0c;本文将提供从空项目创建Win32项目的方法&#xff0c;解决新建空项目直接使用WinMain代码编译报错的问题 例如&#xff1a;LNK2019 &#xff1a;无法解析的外部符号 参考博客&#xff1…...

基于自动化脚本批量上传依赖到nexus内网私服

前言 因为某些原因某些企业希望私服是不能连接外网的&#xff0c;所以需要某些开源依赖需要我们手动导入到nexus中&#xff0c;尽管nexus为我们提供了web页面。但是一个个手动导入显然是一个庞大的工程。 对此我们就不妨基于脚本的方式实现这一过程。 预期效果 笔者本地仓库…...

Linux中ps命令使用指南

目录 1 前言2 ps命令的含义和作用3 ps命令的基本使用4 常用选项参数5 一些常用情景5.1 查看系统中的所有进程&#xff08;标准语法&#xff09;5.2 使用 BSD 语法查看系统中的所有进程5.3 打印进程树5.4 获取线程信息5.5 获取安全信息5.6 查看以 root 用户身份&#xff08;实际…...

PHP开发语言中,网页端常用的标签

在PHP开发语言中&#xff0c;网页端常用的标签包括以下几种&#xff1a; <html>&#xff1a;用于定义整个HTML文档。<head>&#xff1a;用于定义文档的头部&#xff0c;包含元数据、样式表和脚本等。<title>&#xff1a;用于定义文档的标题&#xff0c;显示…...

Java 入门第四篇 集合

Java 入门第四篇 集合 一&#xff0c;什么是集合 在Java中&#xff0c;集合&#xff08;Collection&#xff09;是一种用于存储和操作一组对象的容器类。它提供了一系列的方法和功能&#xff0c;用于方便地管理和操作对象的集合。集合框架是Java中非常重要和常用的一部分&…...

VBA技术资料MF93:将多个Excel表插入PowerPoint不同位置

我给VBA的定义&#xff1a;VBA是个人小型自动化处理的有效工具。利用好了&#xff0c;可以大大提高自己的工作效率&#xff0c;而且可以提高数据的准确度。我的教程一共九套&#xff0c;分为初级、中级、高级三大部分。是对VBA的系统讲解&#xff0c;从简单的入门&#xff0c;到…...

STM32 MCU的易坑点收集

IIC配置中的Clock No Stretch Mode Clock Stretch Mode时钟延长模式&#xff1a; 时钟延长是一个术语&#xff0c;某些从设备可以把时钟线拉低&#xff0c;主设备发现自己释放时钟线之后时钟线还没有变成高电平&#xff0c;就会停止发送数据&#xff0c;然后等待从设备释放时钟…...

Vue3项目filter.js组件封装

1、element-plus(el-table)修改table的行样式 export function elTableRowClassName({ row, rowIndex }) {if (rowIndex % 2 ! 0) {return default-row} }2、时间戳转换格式 export function parseTimeFilter(dateTime, dateType) {if (dateTime || dateTime undefined ||…...

Linux: pwd命令查看当前工作目录

pwd 是 Linux 和其他类 Unix 操作系统中的一个命令&#xff0c;用于显示当前工作目录的绝对路径。 语法 pwd 描述 pwd 是 "print working directory" 的缩写&#xff0c;它用于打印当前工作目录的完整路径。这对于确定当前目录位置非常有用&#xff0c;特别是在嵌…...

【深度学习】PHP操作mysql数据库总结

一.PHP数据库的扩展分类 1.MySQL 扩展是针对 MySQL 4.1.3 或更早版本设计的&#xff0c;是 PHP 与 MySQL数据库交互的早期扩展。由于其不支持 MySQL 数据库服务器的新特性&#xff0c;且安全性差&#xff0c;在项目开发中不建议使用&#xff0c;可用 MySQLi 扩展代替。 2.MySQ…...

【送书活动】探究AIGC、AGI、GPT和人工智能大模型

文章目录 前言01 《ChatGPT 驱动软件开发》推荐语 02 《ChatGPT原理与实战》推荐语 03 《神经网络与深度学习》推荐语 04 《AIGC重塑教育》推荐语 05 《通用人工智能》推荐语 后记赠书活动 前言 人工智能技术在过去几年中发展迅猛&#xff0c;得益于大数据、云计算、深度学习等…...

Apple Find My「查找」认证芯片找哪家,认准伦茨科技ST17H6x芯片

深圳市伦茨科技有限公司&#xff08;以下简称“伦茨科技”&#xff09;发布ST17H6x Soc平台。成为继Nordic之后全球第二家取得Apple Find My「查找」认证的芯片厂家&#xff0c;该平台提供可通过Apple Find My认证的Apple查找&#xff08;Find My&#xff09;功能集成解决方案。…...

java.lang.IllegalArgumentException: Could not resolve placeholder XXX‘ in value

问题描述 使用Springcloudalibaba的nacos作为配置中心&#xff0c;服务启动时报错&#xff1a; java.lang.IllegalArgumentException: Could not resolve placeholder XXX‘ in value java.lang.IllegalArgumentException: Param ‘serviceName’ is illegal, serviceName is …...

自动机器学习是什么?概念及应用

自动机器学习 (Auto Machine Learning) 的应用和方法 随着众多企业在大量场景中开始采用机器学习&#xff0c;前后期处理和优化的数据量及规模指数级增长。企业很难雇用充足的人手来完成与高级机器学习模型相关的所有工作&#xff0c;因此机器学习自动化工具是未来人工智能 (A…...

el-date-picker限制选择7天内禁止内框选择

需求&#xff1a;elementPlus时间段选择框需要满足&#xff1a;①最多选7天时间。②不能手动输入。 <el-date-picker v-model"timeArrange" focus"timeEditable" :editable"false" type"datetimerange" range-separator"至&qu…...

Navicat 技术指引 | 适用于 GaussDB 分布式的调试器

Navicat Premium&#xff08;16.3.3 Windows 版或以上&#xff09;正式支持 GaussDB 分布式数据库。GaussDB 分布式模式更适合对系统可用性和数据处理能力要求较高的场景。Navicat 工具不仅提供可视化数据查看和编辑功能&#xff0c;还提供强大的高阶功能&#xff08;如模型、结…...

人工智能导论习题集(3)

第五章&#xff1a;不确定性推理 题1题2题3题4题5题6题7题8 题1 题2 题3 题4 题5 题6 题7 题8...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

解锁数据库简洁之道:FastAPI与SQLModel实战指南

在构建现代Web应用程序时&#xff0c;与数据库的交互无疑是核心环节。虽然传统的数据库操作方式&#xff08;如直接编写SQL语句与psycopg2交互&#xff09;赋予了我们精细的控制权&#xff0c;但在面对日益复杂的业务逻辑和快速迭代的需求时&#xff0c;这种方式的开发效率和可…...

Objective-C常用命名规范总结

【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名&#xff08;Class Name)2.协议名&#xff08;Protocol Name)3.方法名&#xff08;Method Name)4.属性名&#xff08;Property Name&#xff09;5.局部变量/实例变量&#xff08;Local / Instance Variables&…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量&#xff0c;这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

Angular微前端架构:Module Federation + ngx-build-plus (Webpack)

以下是一个完整的 Angular 微前端示例&#xff0c;其中使用的是 Module Federation 和 npx-build-plus 实现了主应用&#xff08;Shell&#xff09;与子应用&#xff08;Remote&#xff09;的集成。 &#x1f6e0;️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...

【Go语言基础【13】】函数、闭包、方法

文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数&#xff08;函数作为参数、返回值&#xff09; 三、匿名函数与闭包1. 匿名函数&#xff08;Lambda函…...

Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)

引言 工欲善其事&#xff0c;必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后&#xff0c;我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集&#xff0c;就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...