当前位置: 首页 > news >正文

【深度学习实验】线性模型(五):使用Pytorch实现线性模型:基于鸢尾花数据集,对模型进行评估(使用随机梯度下降优化器)

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 线性模型linear_model

2. 损失函数loss_function

3. 鸢尾花数据预处理

4. 初始化权重和偏置

5. 优化器

6. 迭代

7. 测试集预测

8. 实验结果评估

9. 完整代码


一、实验介绍

        线性模型是机器学习中最基本的模型之一,通过对输入特征进行线性组合来预测输出。本实验旨在展示使用随机梯度下降优化器训练线性模型的过程,并评估模型在鸢尾花数据集上的性能。

 二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入库

import torch
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import metrics
  • PyTorch
    • 优化器模块(optim
  • scikit-learn
    • 数据模块(load_iris)
    • 数据划分(train_test_split)
    • 评估指标模块(metrics

1. 线性模型linear_model

        该函数接受输入数据x,使用随机生成的权重w和偏置b,计算输出值output。这里的线性模型的形式为 output = x * w + b

def linear_model(x):return torch.matmul(x, w) + b

2. 损失函数loss_function

      这里使用的是均方误差(MSE)作为损失函数,计算预测值与真实值之间的差的平方。

def loss_function(y_true, y_pred):loss = (y_pred - y_true) ** 2return loss

3. 鸢尾花数据预处理

  • 加载鸢尾花数据集并进行预处理

    • 将数据集分为训练集和测试集

    • 将数据转换为PyTorch张量

iris = load_iris()
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
x_train = torch.tensor(x_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
x_test = torch.tensor(x_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)

4. 初始化权重和偏置

w = torch.rand(1, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)

5. 优化器

        使用随机梯度下降(SGD)优化器进行模型训练,指定学习率和待优化的参数w, b。

optimizer = optim.SGD([w, b], lr=0.01) # 使用SGD优化器

        

6. 迭代

num_epochs = 100
for epoch in range(num_epochs):optimizer.zero_grad()       # 梯度清零prediction = linear_model(x_train, w, b)loss = loss_function(y_train, prediction)loss.mean().backward()      # 计算梯度optimizer.step()            # 更新参数if (epoch + 1) % 10 == 0:print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.mean().item()}")
  • 在每个迭代中:

    • 将优化器的梯度缓存清零,然后使用当前的权重和偏置对输入 x 进行预测,得到预测结果 prediction

    • 使用 loss_function 计算预测结果与真实标签之间的损失,得到损失张量 loss

    • 调用 loss.mean().backward() 计算损失的平均值,并根据计算得到的梯度进行反向传播。

    • 调用 optimizer.step() 更新权重和偏置,使用优化器进行梯度下降更新。

    • 每隔 10 个迭代输出当前迭代的序号、总迭代次数和损失的平均值。

7. 测试集预测

        在测试集上进行预测,使用训练好的模型对测试集进行预测

with torch.no_grad():test_prediction = linear_model(x_test, w, b)test_prediction = torch.round(test_prediction) # 四舍五入为整数test_prediction = test_prediction.detach().numpy()

8. 实验结果评估

  • 使用 metrics 模块计算分类准确度(accuracy)、精确度(precision)、召回率(recall)和F1得分(F1 score)。
  • 输出经过优化后的参数 w 和 b,以及在测试集上的评估指标。
accuracy = metrics.accuracy_score(y_test, test_prediction)
precision = metrics.precision_score(y_test, test_prediction, average='macro')
recall = metrics.recall_score(y_test, test_prediction, average='macro')
f1 = metrics.f1_score(y_test, test_prediction, average='macro')
print("The optimized parameters are:")
print("w:", w.flatten().tolist())
print("b:", b.item())print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

        本实验使用随机梯度下降优化器训练线性模型,并在鸢尾花数据集上取得了较好的分类性能。实验结果表明,经过优化后的模型能够对鸢尾花进行准确的分类,并具有较高的精确度、召回率和F1得分。

9. 完整代码

import torch
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import metricsdef linear_model(x, w, b):return torch.matmul(x, w) + bdef loss_function(y_true, y_pred):loss = (y_pred - y_true) ** 2return loss# 加载鸢尾花数据集并进行预处理
iris = load_iris()
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
x_train = torch.tensor(x_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
x_test = torch.tensor(x_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)w = torch.rand(x_train.shape[1], 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
optimizer = optim.SGD([w, b], lr=0.01) # 使用SGD优化器num_epochs = 100
for epoch in range(num_epochs):optimizer.zero_grad()       # 梯度清零prediction = linear_model(x_train, w, b)loss = loss_function(y_train, prediction)loss.mean().backward()      # 计算梯度optimizer.step()            # 更新参数if (epoch + 1) % 10 == 0:print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.mean().item()}")# 在测试集上进行预测
with torch.no_grad():test_prediction = linear_model(x_test, w, b)test_prediction = torch.round(test_prediction) # 四舍五入为整数test_prediction = test_prediction.detach().numpy()accuracy = metrics.accuracy_score(y_test, test_prediction)
precision = metrics.precision_score(y_test, test_prediction, average='macro')
recall = metrics.recall_score(y_test, test_prediction, average='macro')
f1 = metrics.f1_score(y_test, test_prediction, average='macro')
print("The optimized parameters are:")
print("w:", w.flatten().tolist())
print("b:", b.item())print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

相关文章:

【深度学习实验】线性模型(五):使用Pytorch实现线性模型:基于鸢尾花数据集,对模型进行评估(使用随机梯度下降优化器)

目录 一、实验介绍 二、实验环境 1. 配置虚拟环境 2. 库版本介绍 三、实验内容 0. 导入库 1. 线性模型linear_model 2. 损失函数loss_function 3. 鸢尾花数据预处理 4. 初始化权重和偏置 5. 优化器 6. 迭代 7. 测试集预测 8. 实验结果评估 9. 完整代码 一、实验介…...

ADB底层原理

介绍 adb的全称为Android Debug Bridge,就是起到调试桥的作用。通过adb我们可以在Eclipse/Android Studio中方便通过DDMS来调试Android程序,说白了就是debug工具。adb是android sdk里的一个工具, 用这个工具可以直接操作管理android模拟器或者真实的and…...

etcd之读性能主要影响因素

1、Raft模块-线性读ReadIndex-节点之间的RTT延时、磁盘IO 线性读时Follower节点首先会向Raft 模块发送ReadIndex请求,此时Raft模块会先向各节点发送心跳确认,一半以上节点确认 Leader 身份后由leader节点将已提交日志索引 (committed index) 封装成 Rea…...

【Stable Diffusion】安装 Comfyui 之 window版

序言 由于stable diffusion web ui无法做到对流程进行控制,只是点击个生成按钮后,一切都交给AI来处理。但是用于生产生活是需要精细化对各个流程都要进行控制的。 故也就有个今天的猪脚:Comfyui 步骤 下载comfyui项目配置大模型和vae下载…...

Ansys Zemax | 如何建立二向分色分光镜

分光镜(Beam splitter)可被运用在许多不同的场合。一般而言,入射光抵达二向分色分光镜(dichroic beam splitter)时,会根据波长的差异产生穿透或反射的现象。这篇文章将说明如何在OpticStudio的非序列模式(non-sequential mode)中建立二向分色分光镜&…...

Mybatis学习笔记8 查询返回专题

1.返回实体类 2.返回List<实体类> 3.返回Map 4.返回List<Map> 5.返回Map<String,Map> 6.resultMap结果集映射 7.返回总记录条数 新建模块 依赖 目录结构 1.返回实体类 如果返回多条,用单个实体接收会出异常 2.返回List<实体类> 即使返回一条记…...

【测试开发】基础篇 · 专业术语 · 软件测试生命周期 · bug的描述 · bug的级别 · bug的生命周期 · 处理争执

【测试开发】基础篇 文章目录 【测试开发】基础篇1. 软件测试生命周期1.1 软件生命周期1.2 软件测试生命周期 2. 描述bug3. 如何定义bug的级别3.1 为什么要对bug进行级别划分3.2 bug的一些常见级别 4. bug的生命周期5. 产生争执这么怎么办&#xff08;处理人际关系&#xff09;…...

​bing许少辉乡村振兴战略下传统村落文化旅游设计images

​bing许少辉乡村振兴战略下传统村落文化旅游设计images...

第三十一章 Classes - 继承规则

第三十一章 Classes - 继承规则 继承规则 与其他基于类的语言一样&#xff0c;可以通过继承组合多个类定义。 类定义可以扩展&#xff08;或继承&#xff09;多个其他类。这些类又可以扩展其他类。 请注意&#xff0c;类不能继承 Python 中定义的类&#xff08;即 .py 文件中…...

华为云HECS安装docker并安装mysql

1、运行安装指令 yum install docker都选择y&#xff0c;直到安装成功 2、查看是否安装成功 运行版本查看指令&#xff0c;显示docker版本&#xff0c;证明安装成功 docker --version 3、启用并运行docker 3.1启用docker指令 systemctl enable docker 3.2 运行docker指令…...

MQ - 04 基础篇_存储_消息数据和元数据的存储设计

文章目录 导图概述元数据信息的存储消息数据的存储数据存储结构设计思路一 (Kafka的方案)思路二 (RocketMQ、RabbitMQ 和 Pulsar 的底层存储 BookKeeper 采用的方案)消息数据的分段实现根据偏移量定位根据索引定位 (RabbitMQ 和 RocketMQ的思路)使用场景消息数据存储格式…...

JavaScript:隐式转换、显示转换、隐式操作、显示操作

一、理解js隐式转换 JavaScript 中的隐式转换是指不需要显式地调用转换函数&#xff0c;而是在执行期间自动发生的数据类型的转换。即在使用不同类型的值进行操作时&#xff0c;JavaScript会自动进行类型转换。这种转换通常发生在不同数据类型之间进行运算或比较时。 序号分类…...

2023全新TwoNav开源网址导航系统源码 | 去授权版

2023全新TwoNav开源网址导航系统源码 已过授权 所有功能可用 测试环境&#xff1a;NginxPHP7.4MySQL5.6 一款开源的书签导航管理程序&#xff0c;界面简洁&#xff0c;安装简单&#xff0c;使用方便&#xff0c;基础功能免费。 TwoNav可帮助你将浏览器书签集中式管理&#…...

Android 12 源码分析 —— 应用层 六(StatusBar的UI创建和初始化)

Android 12 源码分析 —— 应用层 六&#xff08;StatusBar的UI创建和初始化) 在前面的文章中,我们分别介绍了Layout整体布局,以及StatusBar类的初始化.前者介绍了整体上面的布局,后者介绍了三大窗口的创建的入口处,以及需要做的准备工作.现在我们分别来细化三大窗口的UI创建和…...

华为云ROMA Connect亮相Gartner®全球应用创新及商业解决方案峰会,助力企业应用集成和数字化转型

9月13日-9月14日 Gartner全球应用创新及商业解决方案峰会在伦敦举行 本届峰会以“重塑软件交付&#xff0c;驱动业务价值”为主题&#xff0c;全球1000多位业内专家交流最新的企业应用、软件工程、解决方案架构、集成与自动化、API等企业IT战略和新兴技术热门话题。 9月13日…...

虚拟线上发布会带来颠覆性新体验,3D虚拟场景直播迸发品牌新动能

虚拟线上发布会是近年来在数字化营销领域备受关注的形式&#xff0c;而随着虚拟现实技术的不断进步&#xff0c;3D虚拟场景直播更成为了品牌宣传、推广的新选择。可以说&#xff0c;虚拟线上发布会正在以其颠覆性的新体验&#xff0c;为品牌带来全新的活力。 1.突破时空限制&am…...

Linux arm64 pte相关宏

文章目录 一、pte 和 pfn1.1 pte_pfn1.2 pfn_pte 二、其他宏参考资料 一、pte 和 pfn // linux-5.4.18/arch/arm64/include/asm/pgtable.h#define pte_pfn(pte) (__pte_to_phys(pte) >> PAGE_SHIFT) #define pfn_pte(pfn,prot) \__pte(__phys_to_pte_val((phys_addr_t)…...

MVCC:多版本并发控制案例分析(一)

&#xff08;笔记总结自b站马士兵教育课程&#xff09; 一、简介 MVCC&#xff1a;全称multi-version Concurency control&#xff0c;多版本并发控制&#xff0c;是为了解决并发读写问题存在的。MVCC的实现原理由三部分组成&#xff1a;隐藏字段、undolog、readview。 二、概…...

以数据为中心的安全市场快速增长

根据Adroit Market Research的数据&#xff0c;2021年全球以数据为中心的安全市场规模估计为27.6亿美元&#xff0c;预计到2030年将增长至393.48亿美元&#xff0c;2021年至2030年的复合年增长率为30.9%。 研究人员表示&#xff0c;以数据为中心的安全强调保护数据本身&#x…...

AUTOSAR汽车电子嵌入式编程精讲300篇-经典 AUTOSAR 安全防御能力的分析及改善(下)

目录 4.4.2 Security 攻击 4.4.3 Security 要求 4.4.4 SDSA 有效性验证 经典 AUTOSAR 安全防御能力分析...

【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15

缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下&#xff1a; struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...

深入剖析AI大模型:大模型时代的 Prompt 工程全解析

今天聊的内容&#xff0c;我认为是AI开发里面非常重要的内容。它在AI开发里无处不在&#xff0c;当你对 AI 助手说 "用李白的风格写一首关于人工智能的诗"&#xff0c;或者让翻译模型 "将这段合同翻译成商务日语" 时&#xff0c;输入的这句话就是 Prompt。…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

【第二十一章 SDIO接口(SDIO)】

第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

MMaDA: Multimodal Large Diffusion Language Models

CODE &#xff1a; https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA&#xff0c;它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构&#xf…...

ardupilot 开发环境eclipse 中import 缺少C++

目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同&#xff0c;结合所安装的tensorflow的目录结构修改from语句即可。 原语句&#xff1a; from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后&#xff1a; from tensorflow.python.keras.lay…...

Go 语言并发编程基础:无缓冲与有缓冲通道

在上一章节中&#xff0c;我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道&#xff0c;它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好&#xff0…...

CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝

目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为&#xff1a;一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...

Golang——6、指针和结构体

指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...