当前位置: 首页 > news >正文

用PyTorch轻松实现二分类:逻辑回归入门

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦什么是逻辑回归?
  • 🥦分类问题
  • 🥦交叉熵
  • 🥦代码实现
  • 🥦总结

🥦引言

当谈到机器学习和深度学习时,逻辑回归是一个非常重要的算法,它通常用于二分类问题。在这篇博客中,我们将使用PyTorch来实现逻辑回归。PyTorch是一个流行的深度学习框架,它提供了强大的工具来构建和训练神经网络,适用于各种机器学习任务。

在机器学习中已经使用了sklearn库介绍过逻辑回归,这里重点使用pytorch这个深度学习框架

🥦什么是逻辑回归?

我们首先来回顾一下什么是逻辑回归?

逻辑回归是一种用于二分类问题的监督学习算法。它的主要思想是通过一个S形曲线(通常是Sigmoid函数)将输入特征映射到0和1之间的概率值,然后根据这些概率值进行分类决策。在逻辑回归中,我们使用一个线性模型和一个激活函数来实现这个映射。

🥦分类问题

这里以MINIST Dataset手写数字集为例
在这里插入图片描述

这个数据集中包含了6w个训练集1w个测试集,类别10个
这里我们不再向之前线性回归那样,根据属于判断具体的数值大小;而是根据输入的值判断从0-9每个数字的概率大小记为p(0)、p(1)…而且十个概率值和为1,我们的目标就是根据输入得到这十个分类对于输入的每一个的概率值,哪个大就是我们需要的。

这里介绍一下与torch相关联的库—torchvision
torchvision:

  • “torchvision” 是一个PyTorch的附加库,专门用于处理图像和视觉任务。
    它包含了一系列用于数据加载、数据增强、计算机视觉任务(如图像分类、目标检测等)的工具和数据集。
  • “torchvision” 提供了许多预训练的视觉模型(例如,ResNet、VGG、AlexNet等),可以用于迁移学习或作为基准模型。
    此外,它还包括了用于图像预处理、转换和可视化的函数。

上图已经清楚的显示了,这个库包含了一些自带的数据集,但是并不是我们安装完这个库就有了,而且需要进行调用的,类似在线下载,root指定下载的路径,train表示你需要训练集还是测试集,通常情况下就是两个一个训练,一个测试,download就是判断你下没下载,下载了就是摆设,没下载就给你下载了

我们再来看一个数据集(CIFAR-10)
在这里插入图片描述
包含了5w训练样本,1w测试样本,10类。调用方式与上一个类似。

接下来我们从一张图更加直观的查看分类和回归
在这里插入图片描述

左边的是回归,右边的是分类


在这里插入图片描述

过去我们使用回归例如 y ^ \hat{y} y^=wx+b∈R,这是属于一个实数的;但是在分类问题, y ^ \hat{y} y^∈[0,1]
这说明我们需要寻找一个函数,将原本实数的值经过函数的映射转化为[0,1]之间。这里我们引入Logistic函数,使用极限很清楚的得出x趋向于正无穷的时候函数为1,x趋向于负无穷的时候,函数为0,x=0的时候,函数为0.5,当我们计算的时候将 y ^ \hat{y} y^带入这样就会出现一个0到1的概率了。

下图展示一些其他的Sigmoid函数
在这里插入图片描述

🥦交叉熵

过去我们所使用的损失函数普遍都是MSE,这里引入一个新的损失函数—交叉熵

==交叉熵(Cross-Entropy)==是一种用于衡量两个概率分布之间差异的数学方法,常用于机器学习和深度学习中,特别是在分类问题中。它是一个非常重要的损失函数,用于衡量模型的预测与真实标签之间的差异,从而帮助优化模型参数。

在交叉熵的上下文中,通常有两个概率分布:

  • 真实分布(True Distribution): 这是指问题的实际概率分布,表示样本的真实标签分布。通常用 p ( x ) p(x) p(x)表示,其中 x x x表示样本或类别。

  • 预测分布(Predicted Distribution): 这是指模型的预测概率分布,表示模型对每个类别的预测概率。通常用 q ( x ) q(x) q(x)表示,其中 x x x表示样本或类别。

交叉熵的一般定义如下:
在这里插入图片描述其中, H ( p , q ) H(p, q) H(p,q) 表示真实分布 p p p 和预测分布 q q q 之间的交叉熵。

交叉熵的主要特点和用途包括:

  • 度量差异性: 交叉熵度量了真实分布和预测分布之间的差异。当两个分布相似时,交叉熵较小;当它们之间的差异增大时,交叉熵增大。

  • 损失函数: 在机器学习中,交叉熵通常用作损失函数,用于衡量模型的预测与真实标签之间的差异。在分类任务中,通常使用交叉熵作为模型的损失函数,帮助模型优化参数以提高分类性能。

  • 反向传播: 交叉熵在训练神经网络时非常有用。通过计算交叉熵的梯度,可以使用反向传播算法来调整神经网络的权重,从而使模型的预测更接近真实标签。

在分类问题中,常见的交叉熵损失函数包括二元交叉熵(Binary Cross-Entropy)和多元交叉熵(Categorical Cross-Entropy)。二元交叉熵用于二分类问题,多元交叉熵用于多类别分类问题。

刘二大人的PPT中也介绍了
在这里插入图片描述
右边的表格中每组y与 y ^ \hat{y} y^对应的BCE,BCE越高说明越可能,最后将其求均值

🥦代码实现

在这里插入图片描述

根据上图可知,线性回归和逻辑回归的流程与函数只区别于Sigmoid函数
在这里插入图片描述
这里就是BCEloss的调用,里面的参数代表求不求均值

完整代码如下

import torch.nn.functional as F
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred
model = LogisticRegressionModel() 
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()

最后绘制一下

import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))  # 相当于reshape
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r') 
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

运行结果如下
在这里插入图片描述

🥦总结

这就是使用PyTorch实现逻辑回归的基本步骤。逻辑回归是一个简单但非常有用的算法,可用于各种分类问题。希望这篇博客能帮助你开始使用PyTorch构建自己的逻辑回归模型。如果你想进一步扩展你的知识,可以尝试在更大的数据集上训练模型或探索其他深度学习算法。祝你好运!

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

相关文章:

用PyTorch轻松实现二分类:逻辑回归入门

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…...

[nltk_data] Error loading stopwords: <urlopen error [WinError 10054]

报错提示&#xff1a; >>> import nltk >>> nltk.download(stopwords) 按照提示执行后 [nltk_data] Error loading stopwords: <urlopen error [WinError 10054] 找到路径C:\\Users\\EDY\\nltk_data&#xff0c;如果没有nltk_data文件夹&#xff0c;在…...

基于Spring Boot的网上租贸系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作…...

通过IP地址管理提升企业网络安全防御

在今天的数字时代&#xff0c;企业面临着越来越多的网络安全威胁。这些威胁可能来自各种来源&#xff0c;包括恶意软件、网络攻击和数据泄露。为了提高网络安全防御&#xff0c;企业需要采取一系列措施&#xff0c;其中IP地址管理是一个重要的方面 1. IP地址的基础知识 首先&a…...

termius mac版无需登录注册直接永久使用

1. 下载地址&#xff1a;termius下载 2. 解压安装 3. 当出现 “termius”已损坏,无法打开 则输入以下命令即可&#xff1a;sudo xattr -r -d com.apple.quarantine /Applications/Termius.app 最后去 系统设置-> 隐私与安全性-> 仍要打开 4. 删除app-update.yml文件&…...

TPU编程竞赛|Stable Diffusion大模型巅峰对决,第五届全球校园人工智能算法精英赛正式启动!

目录 赛题介绍 赛题背景 赛题任务 赛程安排 评分机制 奖项设置 近日&#xff0c;2023第五届全球校园人工智能算法精英赛正式开启报名。作为赛题合作方&#xff0c;算丰承办了“算法专项赛”赛道&#xff0c;提供赛题「面向Stable Diffusion的图像提示语优化」&#xff0c…...

微信小程序 rpx 转 px

前言 略 rpx 转 px let query wx.createSelectorQuery(); query.selectViewport().boundingClientRect(function(res){let rpx2Px 1 * (res.width/750);console.log("1rpx " rpx2Px "px"); }); query.exec();参考 https://blog.csdn.net/qq_39702…...

机器学习之旅-从Python 开始

导读你想知道如何开始机器学习吗&#xff1f;在这篇文章中&#xff0c;我将简要概括一下使用 Python 来开始机器学习的一些步骤。Python 是一门流行的开源程序设计语言&#xff0c;也是在人工智能及其它相关科学领域中最常用的语言之一。机器学习简称 ML&#xff0c;是人工智能…...

100天精通Python(可视化篇)——第103天:Pyecharts绘制多种炫酷水球图参数说明+代码实战

文章目录 专栏导读一、水球图介绍1. 水球图是什么?2. 水球图的应用场景二、水球图类配置选项1. 导包2. Liquid类3. add函数三、水球图实战1. 基础水球图2. 矩形水球图3. 圆棱角矩形水球图4. 三角形水球图5. 菱形水球图6. 箭头型水球图7. 修改数据精度8. 设置无边框9. 多个并排…...

好用的文件备份软件推荐!

为什么需要文件备份软件&#xff1f; 在我们使用计算机的日常工作生活中&#xff0c;可能会遇到各种不同类型的文件&#xff0c;例如文档、Word文档、Excel表格、PPT演示文稿、图片等&#xff0c;这些数据中可能有些对我们来说很重要&#xff0c;但是可能会因为一些意外状况…...

1130 - Host ‘192.168.10.10‘ is not allowed to connect to this MysOL server

mysql 远程登录报错误信息&#xff1a;1130 - Host 124.114.155.70 is not allowed to connect to this MysOL server //需要在mysql 数据库目录下修改 use mysql; //更改用户的登录主机为所有主机&#xff0c;%代表所有主机 update user set host% where userroot; //刷新权…...

如何实现 Es 全文检索、高亮文本略缩处理

如何实现 Es 全文检索、高亮文本略缩处理 前言技术选型JAVA 常用语法说明全文检索开发高亮开发Es Map 转对象使用核心代码 Trans 接口&#xff08;支持父类属性的复杂映射&#xff09;Trans 接口的不足真实项目落地效果 前言 最近手上在做 Es 全文检索的需求&#xff0c;类似于…...

Netty(四)NIO-优化与源码

Netty优化与源码 1. 优化 1.1 扩展序列化算法 序列化&#xff0c;反序列化主要用于消息正文的转换。 序列化&#xff1a;将java对象转为要传输对象(byte[]或json&#xff0c;最终都是byte[]) 反序列化&#xff1a;将正文还原成java对象。 //java自带的序列化 // 反序列化 b…...

我的创业之路:我为什么选择 Angular 作为前端的开发框架?

我是一名后端开发人员&#xff0c;在上班时我的主要精力集中在搜索和推荐系统的开发和设计工作上&#xff0c;我比较熟悉的语言包括java、golang和python。对于前端技术中typescript、dom、webpack等流行的框架和工具也懂一些。目前&#xff0c;已成为一名自由职业者&#xff0…...

阿里云服务器ECS是什么?云服务器详细介绍

阿里云服务器ECS英文全程Elastic Compute Service&#xff0c;云服务器ECS是一种安全可靠、弹性可伸缩的云计算服务&#xff0c;阿里云提供多种云服务器ECS实例规格&#xff0c;如经济型e实例、通用算力型u1、ECS计算型c7、通用型g7、GPU实例等&#xff0c;阿里云服务器网分享阿…...

深入了解快速排序:原理、性能分析与 Java 实现

快速排序&#xff08;Quick Sort&#xff09;是一种经典的、高效的排序算法&#xff0c;被广泛应用于计算机科学和软件开发领域。本文将深入探讨快速排序的工作原理、步骤以及其在不同情况下的性能表现。 什么是快速排序&#xff1f; 快速排序是一种基于分治策略的排序算法&am…...

[晕事]今天做了件晕事22;寻找99-sysctl.conf; systemd

这个文件&#xff0c;使用ls命令看不出来是一个链接。 然后满世界的找这个文件怎么来的&#xff0c;后来发现是systemd里的一个文件。 从systemd的源文件里也没找到相关的文件信息。 最后把这个rpm安装包下载下来&#xff0c;才找到这个文件原来是一个链接 #ll /etc/sysctl.d/9…...

2578. 最小和分割

给你一个正整数 num &#xff0c;请你将它分割成两个非负整数 num1 和 num2 &#xff0c;满足&#xff1a; num1 和 num2 直接连起来&#xff0c;得到 num 各数位的一个排列。 换句话说&#xff0c;num1 和 num2 中所有数字出现的次数之和等于 num 中所有数字出现的次数。num1…...

Mybatis mapper报错:Class not found: org.jboss.vfs.VFS

报错 Logging initialized using class org.apache.ibatis.logging.stdout.StdOutImpl adapter. Class not found: org.jboss.vfs.VFS JBoss 6 VFS API is not available in this environment. Class not found: org.jboss.vfs.VirtualFile VFS implementation org.apache.iba…...

ARM作业1

三盏灯流水 代码 .text .global _start _start: 1.设置GPIOE寄存器的时钟使能 RCC_MP_AHB4ENSETR[4]->1 0x50000a28 LDR R0,0X50000A28 LDR R1,[R0] 从r0为起始地址的4字节数据取出放在R1 ORR R1,R1,#(0x3<<4) 第4位设置为1 STR R1,[R0] 写回2.设置PE10管…...

leetcode 502. IPO

假设 力扣&#xff08;LeetCode&#xff09;即将开始 IPO 。为了以更高的价格将股票卖给风险投资公司&#xff0c;力扣 希望在 IPO 之前开展一些项目以增加其资本。 由于资源有限&#xff0c;它只能在 IPO 之前完成最多 k 个不同的项目。帮助 力扣 设计完成最多 k 个不同项目后…...

[软考中级]软件设计师-计算机网络

网络设备 物理层 物理层不能隔离广播域和冲突域 中继器&#xff0c;集线器 集线器可看成是特殊的多路中继器 数据链路层 可以隔离冲突域不能隔离广播域 网桥&#xff0c;交换机 交换机是多端口的网桥 网络层 可以隔离广播域和冲突域 路由器 应用层 网关 协议簇 …...

Linux搭建我的世界MC服务器 【Minecraft外网联机教程】

目录 前言 1. 安装JAVA 2. MCSManager安装 3.局域网访问MCSM 4.创建我的世界服务器 5.局域网联机测试 6.安装cpolar内网穿透 7. 配置公网访问地址 8.远程联机测试 9. 配置固定远程联机端口地址 9.1 保留一个固定tcp地址 9.2 配置固定公网TCP地址 9.3 使用固定公网…...

APISIX 中ETCD 的问题

1. 问题1 &#xff1a; Error: client: etcd cluster is unavailable or misconfigured; error #0: client: endpoint http://etcd:2379 exceeded header timeout error #0: client: endpoint http://etcd:2379 exceeded header timeout 修改APISIX config ETCD_ADVERTISE_CL…...

SSH版本信息可被获取

漏洞描述 Name SSH版本信息可被获取 Description SSH服务允许远程攻击者获得ssh的具体信息&#xff0c;如版本号等等。这可能为攻击者发动进一步攻击提供帮助。 CVE No. CVE-1999-0634 分析结果 该问题不属于漏洞&#xff0c;不存在安全风险。SSH协议是一种安全协议&am…...

android 修改输出apk的包名

一&#xff0c;打包方式使用IDE菜单选项 二、在app级别的build.gradle下配置&#xff1a; static def releaseTime() {return new Date().format("yyyyMMdd.kkmm", TimeZone.getTimeZone("GMT8")) }android.applicationVariants.all { variant ->print…...

uni-app:文本超出部分用省略号表示

效果 前 后 核心代码 white-space: nowrap; /* 强制不换行 */ text-overflow: ellipsis; /* 超过部分省略号代替 */ overflow: hidden; /* 必须同时设置overflow:hidden才能生效 */ 完整代码 <template><view><view class"all_style"><view c…...

轻松实现视频、音频、文案批量合并,享受批量剪辑的便捷

在日常生活中&#xff0c;我们经常会需要将多个视频、音频和文案进行合并剪辑&#xff0c;以制作出符合我们需求的短视频。然而&#xff0c;这个过程通常需要花费大量的时间和精力。幸运的是&#xff0c;现在有一款名为“固乔智剪软件”的工具可以帮助我们轻松完成这个任务。 首…...

Spring Boot、Nacos配置文件的优先级

在标准的 SpringBoot 应用中&#xff0c;本地配置加载顺序如下&#xff1a; 本地 bootstrap 配置&#xff0c;先于 application 配置加载。不带 profile 的配置&#xff0c;先于带 profile 的配置加载。xxx.yaml 先于 xxx.properties 加载。本地配置先于 nacos 配置中心加载。…...

GO脚本-模拟鼠标键盘

01GetCoordinate 获取坐标 package mainimport ("github.com/go-vgo/robotgo" )func main() {// 获取当前鼠标所在的位置x, y : robotgo.GetMousePos()println(x&#xff1a;, x, y&#xff1a;, y)}02GetColor 获取坐标颜色 package mainimport ("fmt&quo…...

深圳网站定制价格低/seo推广学院

C语言第四-五章第四章 数 组4.1数组的概念C 语言可以根据用户需要&#xff0c;用基本数据类型定义特殊性质的数据类型&#xff0c;称为构造类型。构造类型有&#xff1a;数组、结构、联合。数组&#xff1a;相同数据类型变量的有序集合。有序表现在数组元素在内存中连续存放。数…...

海东市公司网站建设/今天的国内新闻

选择文件之后自动上传文件&#xff1a; 这里uploadAsync的值为ture(默认)&#xff0c;则会走fileuploaded回调(能获取到previewId&#xff0c;所以我会用异步)&#xff1b;如果为false&#xff0c;则会走filebatchuploadsuccess回调(获取不到previewId) $(document).ready(fu…...

线上报名小程序怎么做/湖南seo优化价格

https://blog.csdn.net/bazhidao0031/article/details/81450815 转载于:https://www.cnblogs.com/guochen/p/10340837.html...

网页制作软件破解版下载/重庆seo排名优化

内存模型 内存模型定义为什么要有内存模型为什么要重排序&#xff0c;重排序在什么时候排如何约束重排序规则happens-before什么是顺序一致性CAS 实现的原理&#xff0c;是阻塞还是非阻塞方式&#xff1f;什么时候用&#xff0c;使用时需要考虑的问题处理器和 Java 分别怎么保…...

wordpress 源代码/搜索关键词排名优化服务

Kubernetes可视WEBUI Dashboard搭建 支持浏览器&#xff1a;火狐 一&#xff0e;Dashboard下载地址 git clone https://github.com/kubernetes/kubernetes/ 二&#xff0e;部署Dashboard需要文件 [rootk8s_master ui]# ll 总用量 28 -rwxr-xr-x 1 root root 833 3月 20 19:13…...

网站怎么做评估/核心关键词和长尾关键词

这个题真的很水&#xff0c;但我竟然连错&#xff0c;在此警醒自己&#xff01;&#xff01;&#xff01; 写代码改了东边&#xff0c;忘了西边&#xff0c;“认真”这两个字又被我吃了&#xff0c;打脸啪啪啪啪。 #include<iostream>using namespace std;int gcd(int a,…...