用PyTorch轻松实现二分类:逻辑回归入门
💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
文章目录
- 🥦引言
- 🥦什么是逻辑回归?
- 🥦分类问题
- 🥦交叉熵
- 🥦代码实现
- 🥦总结
🥦引言
当谈到机器学习和深度学习时,逻辑回归是一个非常重要的算法,它通常用于二分类问题。在这篇博客中,我们将使用PyTorch来实现逻辑回归。PyTorch是一个流行的深度学习框架,它提供了强大的工具来构建和训练神经网络,适用于各种机器学习任务。
在机器学习中已经使用了sklearn库介绍过逻辑回归,这里重点使用pytorch这个深度学习框架
🥦什么是逻辑回归?
我们首先来回顾一下什么是逻辑回归?
逻辑回归是一种用于二分类问题的监督学习算法。它的主要思想是通过一个S形曲线(通常是Sigmoid函数)将输入特征映射到0和1之间的概率值,然后根据这些概率值进行分类决策。在逻辑回归中,我们使用一个线性模型和一个激活函数来实现这个映射。
🥦分类问题
这里以MINIST Dataset手写数字集为例
这个数据集中包含了6w个训练集1w个测试集,类别10个
这里我们不再向之前线性回归那样,根据属于判断具体的数值大小;而是根据输入的值判断从0-9每个数字的概率大小记为p(0)、p(1)…而且十个概率值和为1,我们的目标就是根据输入得到这十个分类对于输入的每一个的概率值,哪个大就是我们需要的。
这里介绍一下与torch相关联的库—torchvision
torchvision:
- “torchvision” 是一个PyTorch的附加库,专门用于处理图像和视觉任务。
它包含了一系列用于数据加载、数据增强、计算机视觉任务(如图像分类、目标检测等)的工具和数据集。 - “torchvision” 提供了许多预训练的视觉模型(例如,ResNet、VGG、AlexNet等),可以用于迁移学习或作为基准模型。
此外,它还包括了用于图像预处理、转换和可视化的函数。
上图已经清楚的显示了,这个库包含了一些自带的数据集,但是并不是我们安装完这个库就有了,而且需要进行调用的,类似在线下载,root指定下载的路径,train表示你需要训练集还是测试集,通常情况下就是两个一个训练,一个测试,download就是判断你下没下载,下载了就是摆设,没下载就给你下载了
我们再来看一个数据集(CIFAR-10)
包含了5w训练样本,1w测试样本,10类。调用方式与上一个类似。
接下来我们从一张图更加直观的查看分类和回归
左边的是回归,右边的是分类
过去我们使用回归例如 y ^ \hat{y} y^=wx+b∈R,这是属于一个实数的;但是在分类问题, y ^ \hat{y} y^∈[0,1]
这说明我们需要寻找一个函数,将原本实数的值经过函数的映射转化为[0,1]之间。这里我们引入Logistic函数,使用极限很清楚的得出x趋向于正无穷的时候函数为1,x趋向于负无穷的时候,函数为0,x=0的时候,函数为0.5,当我们计算的时候将 y ^ \hat{y} y^带入这样就会出现一个0到1的概率了。
下图展示一些其他的Sigmoid函数
🥦交叉熵
过去我们所使用的损失函数普遍都是MSE,这里引入一个新的损失函数—交叉熵
==交叉熵(Cross-Entropy)==是一种用于衡量两个概率分布之间差异的数学方法,常用于机器学习和深度学习中,特别是在分类问题中。它是一个非常重要的损失函数,用于衡量模型的预测与真实标签之间的差异,从而帮助优化模型参数。
在交叉熵的上下文中,通常有两个概率分布:
-
真实分布(True Distribution): 这是指问题的实际概率分布,表示样本的真实标签分布。通常用 p ( x ) p(x) p(x)表示,其中 x x x表示样本或类别。
-
预测分布(Predicted Distribution): 这是指模型的预测概率分布,表示模型对每个类别的预测概率。通常用 q ( x ) q(x) q(x)表示,其中 x x x表示样本或类别。
交叉熵的一般定义如下:
其中, H ( p , q ) H(p, q) H(p,q) 表示真实分布 p p p 和预测分布 q q q 之间的交叉熵。
交叉熵的主要特点和用途包括:
-
度量差异性: 交叉熵度量了真实分布和预测分布之间的差异。当两个分布相似时,交叉熵较小;当它们之间的差异增大时,交叉熵增大。
-
损失函数: 在机器学习中,交叉熵通常用作损失函数,用于衡量模型的预测与真实标签之间的差异。在分类任务中,通常使用交叉熵作为模型的损失函数,帮助模型优化参数以提高分类性能。
-
反向传播: 交叉熵在训练神经网络时非常有用。通过计算交叉熵的梯度,可以使用反向传播算法来调整神经网络的权重,从而使模型的预测更接近真实标签。
在分类问题中,常见的交叉熵损失函数包括二元交叉熵(Binary Cross-Entropy)和多元交叉熵(Categorical Cross-Entropy)。二元交叉熵用于二分类问题,多元交叉熵用于多类别分类问题。
刘二大人的PPT中也介绍了
右边的表格中每组y与 y ^ \hat{y} y^对应的BCE,BCE越高说明越可能,最后将其求均值
🥦代码实现
根据上图可知,线性回归和逻辑回归的流程与函数只区别于Sigmoid函数
这里就是BCEloss的调用,里面的参数代表求不求均值
完整代码如下
import torch.nn.functional as F
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()
最后绘制一下
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1)) # 相当于reshape
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
运行结果如下
🥦总结
这就是使用PyTorch实现逻辑回归的基本步骤。逻辑回归是一个简单但非常有用的算法,可用于各种分类问题。希望这篇博客能帮助你开始使用PyTorch构建自己的逻辑回归模型。如果你想进一步扩展你的知识,可以尝试在更大的数据集上训练模型或探索其他深度学习算法。祝你好运!
挑战与创造都是很痛苦的,但是很充实。
相关文章:

用PyTorch轻松实现二分类:逻辑回归入门
💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…...

[nltk_data] Error loading stopwords: <urlopen error [WinError 10054]
报错提示: >>> import nltk >>> nltk.download(stopwords) 按照提示执行后 [nltk_data] Error loading stopwords: <urlopen error [WinError 10054] 找到路径C:\\Users\\EDY\\nltk_data,如果没有nltk_data文件夹,在…...

基于Spring Boot的网上租贸系统设计与实现(源码+lw+部署文档+讲解等)
文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序(小蔡coding)有保障的售后福利 代码参考源码获取 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作…...

通过IP地址管理提升企业网络安全防御
在今天的数字时代,企业面临着越来越多的网络安全威胁。这些威胁可能来自各种来源,包括恶意软件、网络攻击和数据泄露。为了提高网络安全防御,企业需要采取一系列措施,其中IP地址管理是一个重要的方面 1. IP地址的基础知识 首先&a…...

termius mac版无需登录注册直接永久使用
1. 下载地址:termius下载 2. 解压安装 3. 当出现 “termius”已损坏,无法打开 则输入以下命令即可:sudo xattr -r -d com.apple.quarantine /Applications/Termius.app 最后去 系统设置-> 隐私与安全性-> 仍要打开 4. 删除app-update.yml文件&…...

TPU编程竞赛|Stable Diffusion大模型巅峰对决,第五届全球校园人工智能算法精英赛正式启动!
目录 赛题介绍 赛题背景 赛题任务 赛程安排 评分机制 奖项设置 近日,2023第五届全球校园人工智能算法精英赛正式开启报名。作为赛题合作方,算丰承办了“算法专项赛”赛道,提供赛题「面向Stable Diffusion的图像提示语优化」,…...
微信小程序 rpx 转 px
前言 略 rpx 转 px let query wx.createSelectorQuery(); query.selectViewport().boundingClientRect(function(res){let rpx2Px 1 * (res.width/750);console.log("1rpx " rpx2Px "px"); }); query.exec();参考 https://blog.csdn.net/qq_39702…...

机器学习之旅-从Python 开始
导读你想知道如何开始机器学习吗?在这篇文章中,我将简要概括一下使用 Python 来开始机器学习的一些步骤。Python 是一门流行的开源程序设计语言,也是在人工智能及其它相关科学领域中最常用的语言之一。机器学习简称 ML,是人工智能…...
100天精通Python(可视化篇)——第103天:Pyecharts绘制多种炫酷水球图参数说明+代码实战
文章目录 专栏导读一、水球图介绍1. 水球图是什么?2. 水球图的应用场景二、水球图类配置选项1. 导包2. Liquid类3. add函数三、水球图实战1. 基础水球图2. 矩形水球图3. 圆棱角矩形水球图4. 三角形水球图5. 菱形水球图6. 箭头型水球图7. 修改数据精度8. 设置无边框9. 多个并排…...

好用的文件备份软件推荐!
为什么需要文件备份软件? 在我们使用计算机的日常工作生活中,可能会遇到各种不同类型的文件,例如文档、Word文档、Excel表格、PPT演示文稿、图片等,这些数据中可能有些对我们来说很重要,但是可能会因为一些意外状况…...

1130 - Host ‘192.168.10.10‘ is not allowed to connect to this MysOL server
mysql 远程登录报错误信息:1130 - Host 124.114.155.70 is not allowed to connect to this MysOL server //需要在mysql 数据库目录下修改 use mysql; //更改用户的登录主机为所有主机,%代表所有主机 update user set host% where userroot; //刷新权…...

如何实现 Es 全文检索、高亮文本略缩处理
如何实现 Es 全文检索、高亮文本略缩处理 前言技术选型JAVA 常用语法说明全文检索开发高亮开发Es Map 转对象使用核心代码 Trans 接口(支持父类属性的复杂映射)Trans 接口的不足真实项目落地效果 前言 最近手上在做 Es 全文检索的需求,类似于…...

Netty(四)NIO-优化与源码
Netty优化与源码 1. 优化 1.1 扩展序列化算法 序列化,反序列化主要用于消息正文的转换。 序列化:将java对象转为要传输对象(byte[]或json,最终都是byte[]) 反序列化:将正文还原成java对象。 //java自带的序列化 // 反序列化 b…...

我的创业之路:我为什么选择 Angular 作为前端的开发框架?
我是一名后端开发人员,在上班时我的主要精力集中在搜索和推荐系统的开发和设计工作上,我比较熟悉的语言包括java、golang和python。对于前端技术中typescript、dom、webpack等流行的框架和工具也懂一些。目前,已成为一名自由职业者࿰…...

阿里云服务器ECS是什么?云服务器详细介绍
阿里云服务器ECS英文全程Elastic Compute Service,云服务器ECS是一种安全可靠、弹性可伸缩的云计算服务,阿里云提供多种云服务器ECS实例规格,如经济型e实例、通用算力型u1、ECS计算型c7、通用型g7、GPU实例等,阿里云服务器网分享阿…...

深入了解快速排序:原理、性能分析与 Java 实现
快速排序(Quick Sort)是一种经典的、高效的排序算法,被广泛应用于计算机科学和软件开发领域。本文将深入探讨快速排序的工作原理、步骤以及其在不同情况下的性能表现。 什么是快速排序? 快速排序是一种基于分治策略的排序算法&am…...
[晕事]今天做了件晕事22;寻找99-sysctl.conf; systemd
这个文件,使用ls命令看不出来是一个链接。 然后满世界的找这个文件怎么来的,后来发现是systemd里的一个文件。 从systemd的源文件里也没找到相关的文件信息。 最后把这个rpm安装包下载下来,才找到这个文件原来是一个链接 #ll /etc/sysctl.d/9…...
2578. 最小和分割
给你一个正整数 num ,请你将它分割成两个非负整数 num1 和 num2 ,满足: num1 和 num2 直接连起来,得到 num 各数位的一个排列。 换句话说,num1 和 num2 中所有数字出现的次数之和等于 num 中所有数字出现的次数。num1…...

Mybatis mapper报错:Class not found: org.jboss.vfs.VFS
报错 Logging initialized using class org.apache.ibatis.logging.stdout.StdOutImpl adapter. Class not found: org.jboss.vfs.VFS JBoss 6 VFS API is not available in this environment. Class not found: org.jboss.vfs.VirtualFile VFS implementation org.apache.iba…...

ARM作业1
三盏灯流水 代码 .text .global _start _start: 1.设置GPIOE寄存器的时钟使能 RCC_MP_AHB4ENSETR[4]->1 0x50000a28 LDR R0,0X50000A28 LDR R1,[R0] 从r0为起始地址的4字节数据取出放在R1 ORR R1,R1,#(0x3<<4) 第4位设置为1 STR R1,[R0] 写回2.设置PE10管…...

接口测试中缓存处理策略
在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...

大数据学习栈记——Neo4j的安装与使用
本文介绍图数据库Neofj的安装与使用,操作系统:Ubuntu24.04,Neofj版本:2025.04.0。 Apt安装 Neofj可以进行官网安装:Neo4j Deployment Center - Graph Database & Analytics 我这里安装是添加软件源的方法 最新版…...

stm32G473的flash模式是单bank还是双bank?
今天突然有人stm32G473的flash模式是单bank还是双bank?由于时间太久,我真忘记了。搜搜发现,还真有人和我一样。见下面的链接:https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

React第五十七节 Router中RouterProvider使用详解及注意事项
前言 在 React Router v6.4 中,RouterProvider 是一个核心组件,用于提供基于数据路由(data routers)的新型路由方案。 它替代了传统的 <BrowserRouter>,支持更强大的数据加载和操作功能(如 loader 和…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案
随着新能源汽车的快速普及,充电桩作为核心配套设施,其安全性与可靠性备受关注。然而,在高温、高负荷运行环境下,充电桩的散热问题与消防安全隐患日益凸显,成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...
Java入门学习详细版(一)
大家好,Java 学习是一个系统学习的过程,核心原则就是“理论 实践 坚持”,并且需循序渐进,不可过于着急,本篇文章推出的这份详细入门学习资料将带大家从零基础开始,逐步掌握 Java 的核心概念和编程技能。 …...

自然语言处理——循环神经网络
自然语言处理——循环神经网络 循环神经网络应用到基于机器学习的自然语言处理任务序列到类别同步的序列到序列模式异步的序列到序列模式 参数学习和长程依赖问题基于门控的循环神经网络门控循环单元(GRU)长短期记忆神经网络(LSTM)…...
OpenLayers 分屏对比(地图联动)
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能,和卷帘图层不一样的是,分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...

vue3+vite项目中使用.env文件环境变量方法
vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量,这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...
力扣-35.搜索插入位置
题目描述 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...