当前位置: 首页 > news >正文

PyTorch深度学习实战 | 基于线性回归、决策树和SVM进行鸢尾花分类

鸢尾花数据集是机器学习领域非常经典的一个分类任务数据集。它的英文名称为Iris Data Set,使用sklearn库可以直接下载并导入该数据集。数据集总共包含150行数据,每一行数据由4个特征值及一个标签组成。标签为三种不同类别的鸢尾花,分别为:Iris Setosa,Iris Versicolour,Iris Virginica。

对于多分类任务,有较多机器学习的算法可以支持。本文将使用决策树、线性回归、SVM等多种算法来完成这一任务,并对不同方法进行比较。

01、使用Logistic实现鸢尾花分类

在前面介绍过Logistic用于二分类任务,对其进行扩展也用于多分类任务。下面将使用sklearn库完成一个基于Logistic的鸢尾花分类任务。如代码清单1所示,首先是导入sklearn.datasets包从而加载数据集,并将数据集按照测试集占比0.2随机分为训练集和测试集。

代码清单1 导入包以及加载数据集

rom sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import label_binarize
from sklearn.metrics import confusion_matrix, precision_score, accuracy_score,recall_score, f1_score, roc_auc_score, \roc_curve
import matplotlib.pyplot as plt# 加载数据集
def loadDataSet():iris_dataset = load_iris()X = iris_dataset.datay = iris_dataset.target# 将数据划分为训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)return X_train, X_test, y_train, y_test

如代码清单2所示,编写函数训练Logistic模型。

代码清单2 训练Logistic模型

# 训练Logistic模性
def trainLS(x_train, y_train):# Logistic生成和训练clf = LogisticRegression()clf.fit(x_train, y_train)return clf

Logistic模型较为简单,不需要额外设置超参数即可开始训练。如代码清单3所示,初始化Logistic模型并将模型在训练集上训练,返回训练好的模型。

代码清单3 测试模型及打印各种评价指标

# 测试模型
def test(model, x_test, y_test):# 将标签转换为one-hot形式y_one_hot = label_binarize(y_test, np.arange(3))# 预测结果y_pre = model.predict(x_test)# 预测结果的概率y_pre_pro = model.predict_proba(x_test)# 混淆矩阵con_matrix = confusion_matrix(y_test, y_pre)print('confusion_matrix:\n', con_matrix)print('accuracy:{}'.format(accuracy_score(y_test, y_pre)))print('precision:{}'.format(precision_score(y_test, y_pre, average='micro')))print('recall:{}'.format(recall_score(y_test, y_pre, average='micro')))print('f1-score:{}'.format(f1_score(y_test, y_pre, average='micro')))# 绘制ROC曲线drawROC(y_one_hot, y_pre_pro)

在预测结果时,为了方便后面绘制ROC曲线,需要首先将测试集的标签转化为one-hot的形式,并得到模型在测试集上预测结果的概率值即y_pre_pro,从而传入drawROC函数完成ROC曲线的绘制。除此外,该函数实现了输出混淆矩阵以及计算准确率、精确率、查全率以及f1-score的功能。

代码清单4 绘制ROC曲线

def drawROC(y_one_hot, y_pre_pro):# AUC值auc = roc_auc_score(y_one_hot, y_pre_pro, average='micro')# 绘制ROC曲线fpr, tpr, thresholds = roc_curve(y_one_hot.ravel(), y_pre_pro.ravel())plt.plot(fpr, tpr, linewidth=2, label='AUC=%.3f' % auc)plt.plot([0, 1], [0, 1], 'k--')plt.axis([0, 1.1, 0, 1.1])plt.xlabel('False Postivie Rate')plt.ylabel('True Positive Rate')plt.legend()plt.show()

如代码清单4所示为绘制ROC曲线的代码实现。最后将加载数据集,训练模型,以及模型验证的整个流程连接起来从而实现main函数,如代码清单5所示。

代码清单5 main函数设置

if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainLS(X_train, y_train)test(model, X_test, y_test)

将上述所有代码放在同一py脚本文件中,如图1所示可得最终的输出结果为

图1 命令行打印的测试结果

绘制得到的ROC曲线如图2所示。

图2 ROC曲线

Logistic是一个较为简单的模型,参数量较少,一般也用于较为简单的分类任务中,当任务更为复杂时,可以选取更为复杂的模型获得更好的效果,下面将使用不同的模型从而验证同一任务在不同模型下的表现。

02、使用决策树实现鸢尾花分类

由于只改动了模型,加载数据集、模型评价等其他部分的代码不需要改动,如代码清单6所示,增加新的函数用于训练决策树模型。

代码清单6 使用决策树模型进行训练

from sklearn import tree
# 训练决策树模性
def trainDT(x_train, y_train):# DT生成和训练clf = tree.DecisionTreeClassifier(criterion="entropy")clf.fit(x_train, y_train)return clf

同时修改main函数中调用的训练函数如代码清单7所示。

代码清单7 修改main函数内容

if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainDT(X_train, y_train)test(model, X_test, y_test)

最后运行可得命令行输出如图3所示。

图3 决策树模型预测结果

以及ROC曲线如图4所示。

图4 决策树模型绘制ROC曲线

相比Logistic模型,决策树模型无论在哪一项指标上都得到了更高的评分,且决策树模型不会像Logistic模型一样受初始化的影响,多次运行程序均可获得相同的输出模型,而Logistic模型运行多次会发现评价指标会在某个范围内上下抖动。

03、使用SVM实现鸢尾花分类

到现在相信大家都已经非常熟悉如何继续修改代码从而实现SVM模型的预测,实现SVM模型的训练代码如代码清单8所示

代码清单8 使用SVM模型进行训练

# 训练SVM模性
from sklearn import svm
def trainSVM(x_train, y_train):# SVM生成和训练clf = svm.SVC(kernel='rbf', probability=True)clf.fit(x_train, y_train)return clf

同时修改main函数,如代码清单9所示。

代码清单9 修改main函数内容

if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainSVM(X_train, y_train)test(model, X_test, y_test)

程序运行输出如图5所示。

图5 使用SVM模型预测结果

绘制得到的ROC曲线如图6所示。

图6 使用SVM模型绘制的ROC曲线

可以发现,随着模型进一步变得复杂,最终预测的各项指标进一步上升,在三个模型中SVM模型的高斯核最终结果在测试集中表现得最好且没有发生过拟合的现象,因此可以选用SVM模型来完成鸢尾花分类这一任务。

相关文章:

PyTorch深度学习实战 | 基于线性回归、决策树和SVM进行鸢尾花分类

鸢尾花数据集是机器学习领域非常经典的一个分类任务数据集。它的英文名称为Iris Data Set,使用sklearn库可以直接下载并导入该数据集。数据集总共包含150行数据,每一行数据由4个特征值及一个标签组成。标签为三种不同类别的鸢尾花,分别为&…...

服务端接口优化方案

一、背景 针对老项目,去年做了许多降本增效的事情,其中发现最多的就是接口耗时过长的问题,就集中搞了一次接口性能优化。本文将给小伙伴们分享一下接口优化的通用方案。 二、接口优化方案总结 1. 批处理 批量思想:批量操作数据…...

【并发基础】Happens-Before模型详解

目录 一、Happens-Before模型简介 二、组成Happens-Before模型的八种规则 2.1 程序顺序规则(as-if-serial语义) 2.2 传递性规则 2.3 volatile变量规则 2.4 监视器锁规则 2.5 start规则 2.6 Join规则 一、Happens-Before模型简介 除了显示引用vo…...

Kubernetes系列---Kubernetes 理论知识 | 初识

Kubernetes系列---Kubernetes 理论知识 | 初识 1.K8s 是什么?2.K8s 特性3.小拓展(业务升级)4.K8s 集群架构与组件①架构拓扑图:②Master 组件③Node 组件 五 K8s 核心概念六 官方提供的三种部署方式总结 1.K8s 是什么&#xff1f…...

KingbaseES 原生XML系列三--XML数据查询函数

KingbaseES 原生XML系列三--XML数据查询函数(EXTRACT,EXTRACTVALUE,EXISTSNODE,XPATH,XPATH_EXISTS,XMLEXISTS) XML的简单使其易于在任何应用程序中读写数据,这使XML很快成为数据交换的一种公共语言。在不同平台下产生的信息,可以很容易加载XML数据到程序…...

【51单片机】点亮一个LED灯(看开发板原理图十分重要)

🎊专栏【51单片机】 🍔喜欢的诗句:更喜岷山千里雪 三军过后尽开颜。 🎆音乐分享【The Right Path】 🥰大一同学小吉,欢迎并且感谢大家指出我的问题🥰 目录 🍔基础内容 &#x1f3f3…...

数据可视化工具 - ECharts以及柱状图的编写

1 快速上手 引入echarts 插件文件到html页面中 <head><meta charset"utf-8"/><title>ECharts</title><!-- step1 引入刚刚下载的 ECharts 文件 --><script src"./echarts.js"></script> </head>准备一个…...

【AI绘画】——Midjourney关键词格式解析(常用参数分享)

目前在AI绘画模型中&#xff0c;Midjourney的效果是公认的top级别&#xff0c;但同时也是相对较难使用的&#xff0c;对小白来说比较难上手&#xff0c;主要就在于Mj没有webui&#xff0c;不能选择参数&#xff0c;怎么找到这些隐藏参数并且触发它是用好Mj的第一步。 今天就来…...

操作符知识点大全(简洁,全面,含使用场景,演示,代码)

目录 一.算术操作符 1.要点&#xff1a; 二.负数原码&#xff0c;反码&#xff0c;补码的互推 1.按位取反操作符&#xff1a;~&#xff08;二进制位&#xff09; 2.原反补互推演示 三.进制位的表示 1.不同进制位的特征&#xff1a; 2.二进制位表示 3.整型的二进制表…...

华工研究生语音课

这门课讲啥 语音蕴含的信息、语音识别的目的 语音的准平稳性、分帧、预加重、时域特征分析&#xff08;能量和过零率&#xff09;、端点检测&#xff08;双门限法&#xff09; 语音的基频及检测&#xff08;主要是自相关法、野点的处理&#xff09; 声音的产生过程&#xf…...

KingbaseES 原生XML系列二 -- XML数据操作函数

KingbaseES 原生XML系列二--XML数据操作函数(DELETEXML,APPENDCHILDXML,INSERTCHILDXML,INSERTCHILDXMLAFTER,INSERTCHILDXMLBEFORE,INSERTXMLAFTER,INSERTXMLBEFORE,UPDATEXML) XML的简单使其易于在任何应用程序中读写数据&#xff0c;这使XML很快成为数据交换的一种公共语言。…...

【Flink】DataStream API使用之源算子(Source)

源算子 创建环境之后&#xff0c;就可以构建数据的业务处理逻辑了&#xff0c;Flink可以从各种来源获取数据&#xff0c;然后构建DataStream进项转换。一般将数据的输入来源称为数据源&#xff08;data source&#xff09;&#xff0c;而读取数据的算子就叫做源算子&#xff08…...

树莓派硬件介绍及配件选择

目录 树莓派Datasheet下载地址&#xff1a; Raspberry 4B 外观图&#xff1a; 技术规格书&#xff1a; 性能介绍&#xff1a; 树莓派配件选用 电源的选用&#xff1a; 树莓派外壳选用&#xff1a; 内存卡/U盘选用 树莓派Datasheet下载地址&#xff1a; Raspberry Pi …...

O2OA (翱途) 平台 V8.0 发布新增数据台账能力

亲爱的小伙伴们&#xff0c;O2OA (翱途) 平台开发团队经过几个月的持续努力&#xff0c;实现功能的新增、优化以及问题的修复。2023 年度 V8.0 版本已正式发布。欢迎大家到 O2OA 的官网上下载进行体验&#xff0c;也希望大家在藕粉社区里多提宝贵建议。本篇我们先为大家介绍应用…...

数控解锁怎么解 数控系统解锁解密

Amazon Fargate 在中国区正式落地&#xff0c;因 数控解锁使用 Serverless 架构&#xff0c;更加适合对性能要求不敏感的服务使用&#xff0c;Pyroscope 是一款基于 Golang 开发的应用程序性能分析工具&#xff0c;Pyroscope 的服务端为无状态服务且性能要求不敏感&#xff0c;…...

3.0 响应式系统的设计与实现

1、Proxy代理对象 Proxy用于对一个普通对象代理&#xff0c;实现对象的拦截和自定义&#xff0c;如拦截其赋值、枚举、函数调用等。里面包含了很多组捕获器&#xff08;trap&#xff09;&#xff0c;在代理对象执行相应的操作时捕获&#xff0c;然后在内部实现自定义。 const…...

Rust 快速入门60分① 看完这篇就能写代码了

Rust 一门赋予每个人构建可靠且高效软件能力的语言https://hannyang.blog.csdn.net/article/details/130467813?spm1001.2014.3001.5502关于Rust安装等内容请参考上文链接&#xff0c;写完上文就在考虑写点关于Rust的入门文章&#xff0c;本专辑将直接从Rust基础入门内容开始讲…...

【5.JS基础-JavaScript的DOM操作】

1 认识DOM和BOM 所以我们学习DOM&#xff0c;就是在学习如何通过JavaScript对文档进行操作的&#xff1b; DOM Tree的理解 DOM的学习顺序 DOM的继承关系图 2 document对象 3 节点&#xff08;Node&#xff09;之间的导航&#xff08;navigator&#xff09; 4 元素&#xff0…...

【大数据之Hadoop】二十九、HDFS存储优化

纠删码和异构存储测试需要5台虚拟机。准备另外一套5台服务器集群。 环境准备&#xff1a; &#xff08;1&#xff09;克隆hadoop105为hadoop106&#xff0c;修改ip地址和hostname&#xff0c;然后重启。 vim /etc/sysconfig/network-scripts/ifcfg-ens33 vim /etc/hostname r…...

SuperMap GIS基础产品组件GIS FAQ集锦(2)

SuperMap GIS基础产品组件GIS FAQ集锦&#xff08;2&#xff09; 【iObjects for Spark】读取GDB参数该如何填写&#xff1f; 【解决办法】可参考以下示例&#xff1a; val GDB_params new util.HashMapString, java.io.Serializable GDB_params.put(FeatureRDDProviderParam…...

C语言printf()函数中整型格式说明符详解

每个整型在printf()函数中对应不同的格式说明符,以实现该整型的打印输出。格式说明符必须使用小写。现在让我们看看各个整型及其格式说明符: 短整型(short) 10进制:%hd16进制:无负数格式,正数使用%hx8进制:无负数格式,正数使用%ho c short s 34; printf("%hd", s…...

阿里云服务器地域和可用区怎么选择合适?

阿里云服务器地域和可用区怎么选择&#xff1f;地域是指云服务器所在物理数据中心的位置&#xff0c;地域选择就近选择&#xff0c;访客距离地域所在城市越近网络延迟越低&#xff0c;速度就越快&#xff1b;可用区是指同一个地域下&#xff0c;网络和电力相互独立的区域&#…...

Java序列化引发的血案

1、引言 阿里巴巴Java开发手册在第一章节&#xff0c;编程规约中OOP规约的第15条提到&#xff1a; **【强制】**序列化类新增属性时&#xff0c;请不要修改serialVersionUID字段&#xff0c;避免反序列失败&#xff1b;如果完全不兼容升级&#xff0c;避免反序列化混乱&#x…...

为Linux系统添加一块新硬盘,并扩展根目录容量

我的原来ubuntu20.04系统装的时候不是LVM格式的分区&#xff0c; 所以先将新硬盘转成LVM&#xff0c;再将原来的系统dd到新硬盘&#xff0c;从新硬盘的分区启动&#xff0c;之后再将原来的分区转成LVM&#xff0c;在融入进来 1&#xff1a;将新硬盘制作成 LVM分区 我的新硬盘…...

树莓派Opencv调用摄像头(Raspberry Pi 11)

前言&#xff1a;本人初玩树莓派opencv&#xff0c;使用的是树莓派Raspberry Pi OS 11&#xff0c;系统若不一致请慎用&#xff0c;本文主要记录在树莓派上通过Opencv打开摄像头的经验。 1、系统版本 进入树莓派&#xff0c;打开终端输入以下代码&#xff08;查看系统的版本&…...

国产ChatGPT命名图鉴

很久不见这般热闹的春天。 随着ChatGPT的威名席卷全球&#xff0c;大洋对岸的中国厂商也纷纷亮剑&#xff0c;各式本土大模型你方唱罢我登场&#xff0c;声势浩大的发布会排满日程表。 有趣的是&#xff0c;在这些大模型产品初入历史舞台之时&#xff0c;带给世人的第一印象其…...

操作系统——进程管理

0.关注博主有更多知识 操作系统入门知识合集 目录 0.关注博主有更多知识 4.1进程概念 4.1.1进程基本概念 思考题&#xff1a; 4.1.2进程状态 思考题&#xff1a; 4.1.3进程控制块PCB 4.2进程控制 思考题&#xff1a; 4.3线程 思考题&#xff1a; 4.4临界资源与临…...

第四十一章 Unity 输入框 (Input Field) UI

本章节我们学习输入框 (Input Field)&#xff0c;它可以帮助我们获取用户的输入。我们点击菜单栏“GameObject”->“UI”->“Input Field”&#xff0c;我们调整一下它的位置&#xff0c;效果如下 我们在层次面板中发现&#xff0c;这个InputField UI元素包含两个子元素&…...

10.集合

1.泛型 1.1泛型概述 泛型的介绍 ​ 泛型是JDK5中引入的特性&#xff0c;它提供了编译时类型安全检测机制 泛型的好处 把运行时期的问题提前到了编译期间避免了强制类型转换 泛型的定义格式 <类型>: 指定一种类型的格式.尖括号里面可以任意书写,一般只写一个字母.例如:…...

强化学习p3-策略学习

Policy Network (策略网络) 我们无法知道策略函数 π \pi π所以要做函数近似&#xff0c;求一个近似的策略函数 使用策略网络 π ( a ∣ s ; θ ) \pi(a|s;\theta) π(a∣s;θ) 去近似策略函数 π ( a ∣ s ) \pi(a|s) π(a∣s) ∑ a ∈ A π ( a ∣ s ; θ ) 1 \sum_{a\in …...