当前位置: 首页 > news >正文

PyTorch深度学习实战 | 基于线性回归、决策树和SVM进行鸢尾花分类

鸢尾花数据集是机器学习领域非常经典的一个分类任务数据集。它的英文名称为Iris Data Set,使用sklearn库可以直接下载并导入该数据集。数据集总共包含150行数据,每一行数据由4个特征值及一个标签组成。标签为三种不同类别的鸢尾花,分别为:Iris Setosa,Iris Versicolour,Iris Virginica。

对于多分类任务,有较多机器学习的算法可以支持。本文将使用决策树、线性回归、SVM等多种算法来完成这一任务,并对不同方法进行比较。

01、使用Logistic实现鸢尾花分类

在前面介绍过Logistic用于二分类任务,对其进行扩展也用于多分类任务。下面将使用sklearn库完成一个基于Logistic的鸢尾花分类任务。如代码清单1所示,首先是导入sklearn.datasets包从而加载数据集,并将数据集按照测试集占比0.2随机分为训练集和测试集。

代码清单1 导入包以及加载数据集

rom sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import label_binarize
from sklearn.metrics import confusion_matrix, precision_score, accuracy_score,recall_score, f1_score, roc_auc_score, \roc_curve
import matplotlib.pyplot as plt# 加载数据集
def loadDataSet():iris_dataset = load_iris()X = iris_dataset.datay = iris_dataset.target# 将数据划分为训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)return X_train, X_test, y_train, y_test

如代码清单2所示,编写函数训练Logistic模型。

代码清单2 训练Logistic模型

# 训练Logistic模性
def trainLS(x_train, y_train):# Logistic生成和训练clf = LogisticRegression()clf.fit(x_train, y_train)return clf

Logistic模型较为简单,不需要额外设置超参数即可开始训练。如代码清单3所示,初始化Logistic模型并将模型在训练集上训练,返回训练好的模型。

代码清单3 测试模型及打印各种评价指标

# 测试模型
def test(model, x_test, y_test):# 将标签转换为one-hot形式y_one_hot = label_binarize(y_test, np.arange(3))# 预测结果y_pre = model.predict(x_test)# 预测结果的概率y_pre_pro = model.predict_proba(x_test)# 混淆矩阵con_matrix = confusion_matrix(y_test, y_pre)print('confusion_matrix:\n', con_matrix)print('accuracy:{}'.format(accuracy_score(y_test, y_pre)))print('precision:{}'.format(precision_score(y_test, y_pre, average='micro')))print('recall:{}'.format(recall_score(y_test, y_pre, average='micro')))print('f1-score:{}'.format(f1_score(y_test, y_pre, average='micro')))# 绘制ROC曲线drawROC(y_one_hot, y_pre_pro)

在预测结果时,为了方便后面绘制ROC曲线,需要首先将测试集的标签转化为one-hot的形式,并得到模型在测试集上预测结果的概率值即y_pre_pro,从而传入drawROC函数完成ROC曲线的绘制。除此外,该函数实现了输出混淆矩阵以及计算准确率、精确率、查全率以及f1-score的功能。

代码清单4 绘制ROC曲线

def drawROC(y_one_hot, y_pre_pro):# AUC值auc = roc_auc_score(y_one_hot, y_pre_pro, average='micro')# 绘制ROC曲线fpr, tpr, thresholds = roc_curve(y_one_hot.ravel(), y_pre_pro.ravel())plt.plot(fpr, tpr, linewidth=2, label='AUC=%.3f' % auc)plt.plot([0, 1], [0, 1], 'k--')plt.axis([0, 1.1, 0, 1.1])plt.xlabel('False Postivie Rate')plt.ylabel('True Positive Rate')plt.legend()plt.show()

如代码清单4所示为绘制ROC曲线的代码实现。最后将加载数据集,训练模型,以及模型验证的整个流程连接起来从而实现main函数,如代码清单5所示。

代码清单5 main函数设置

if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainLS(X_train, y_train)test(model, X_test, y_test)

将上述所有代码放在同一py脚本文件中,如图1所示可得最终的输出结果为

图1 命令行打印的测试结果

绘制得到的ROC曲线如图2所示。

图2 ROC曲线

Logistic是一个较为简单的模型,参数量较少,一般也用于较为简单的分类任务中,当任务更为复杂时,可以选取更为复杂的模型获得更好的效果,下面将使用不同的模型从而验证同一任务在不同模型下的表现。

02、使用决策树实现鸢尾花分类

由于只改动了模型,加载数据集、模型评价等其他部分的代码不需要改动,如代码清单6所示,增加新的函数用于训练决策树模型。

代码清单6 使用决策树模型进行训练

from sklearn import tree
# 训练决策树模性
def trainDT(x_train, y_train):# DT生成和训练clf = tree.DecisionTreeClassifier(criterion="entropy")clf.fit(x_train, y_train)return clf

同时修改main函数中调用的训练函数如代码清单7所示。

代码清单7 修改main函数内容

if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainDT(X_train, y_train)test(model, X_test, y_test)

最后运行可得命令行输出如图3所示。

图3 决策树模型预测结果

以及ROC曲线如图4所示。

图4 决策树模型绘制ROC曲线

相比Logistic模型,决策树模型无论在哪一项指标上都得到了更高的评分,且决策树模型不会像Logistic模型一样受初始化的影响,多次运行程序均可获得相同的输出模型,而Logistic模型运行多次会发现评价指标会在某个范围内上下抖动。

03、使用SVM实现鸢尾花分类

到现在相信大家都已经非常熟悉如何继续修改代码从而实现SVM模型的预测,实现SVM模型的训练代码如代码清单8所示

代码清单8 使用SVM模型进行训练

# 训练SVM模性
from sklearn import svm
def trainSVM(x_train, y_train):# SVM生成和训练clf = svm.SVC(kernel='rbf', probability=True)clf.fit(x_train, y_train)return clf

同时修改main函数,如代码清单9所示。

代码清单9 修改main函数内容

if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainSVM(X_train, y_train)test(model, X_test, y_test)

程序运行输出如图5所示。

图5 使用SVM模型预测结果

绘制得到的ROC曲线如图6所示。

图6 使用SVM模型绘制的ROC曲线

可以发现,随着模型进一步变得复杂,最终预测的各项指标进一步上升,在三个模型中SVM模型的高斯核最终结果在测试集中表现得最好且没有发生过拟合的现象,因此可以选用SVM模型来完成鸢尾花分类这一任务。

相关文章:

PyTorch深度学习实战 | 基于线性回归、决策树和SVM进行鸢尾花分类

鸢尾花数据集是机器学习领域非常经典的一个分类任务数据集。它的英文名称为Iris Data Set,使用sklearn库可以直接下载并导入该数据集。数据集总共包含150行数据,每一行数据由4个特征值及一个标签组成。标签为三种不同类别的鸢尾花,分别为&…...

服务端接口优化方案

一、背景 针对老项目,去年做了许多降本增效的事情,其中发现最多的就是接口耗时过长的问题,就集中搞了一次接口性能优化。本文将给小伙伴们分享一下接口优化的通用方案。 二、接口优化方案总结 1. 批处理 批量思想:批量操作数据…...

【并发基础】Happens-Before模型详解

目录 一、Happens-Before模型简介 二、组成Happens-Before模型的八种规则 2.1 程序顺序规则(as-if-serial语义) 2.2 传递性规则 2.3 volatile变量规则 2.4 监视器锁规则 2.5 start规则 2.6 Join规则 一、Happens-Before模型简介 除了显示引用vo…...

Kubernetes系列---Kubernetes 理论知识 | 初识

Kubernetes系列---Kubernetes 理论知识 | 初识 1.K8s 是什么?2.K8s 特性3.小拓展(业务升级)4.K8s 集群架构与组件①架构拓扑图:②Master 组件③Node 组件 五 K8s 核心概念六 官方提供的三种部署方式总结 1.K8s 是什么&#xff1f…...

KingbaseES 原生XML系列三--XML数据查询函数

KingbaseES 原生XML系列三--XML数据查询函数(EXTRACT,EXTRACTVALUE,EXISTSNODE,XPATH,XPATH_EXISTS,XMLEXISTS) XML的简单使其易于在任何应用程序中读写数据,这使XML很快成为数据交换的一种公共语言。在不同平台下产生的信息,可以很容易加载XML数据到程序…...

【51单片机】点亮一个LED灯(看开发板原理图十分重要)

🎊专栏【51单片机】 🍔喜欢的诗句:更喜岷山千里雪 三军过后尽开颜。 🎆音乐分享【The Right Path】 🥰大一同学小吉,欢迎并且感谢大家指出我的问题🥰 目录 🍔基础内容 &#x1f3f3…...

数据可视化工具 - ECharts以及柱状图的编写

1 快速上手 引入echarts 插件文件到html页面中 <head><meta charset"utf-8"/><title>ECharts</title><!-- step1 引入刚刚下载的 ECharts 文件 --><script src"./echarts.js"></script> </head>准备一个…...

【AI绘画】——Midjourney关键词格式解析(常用参数分享)

目前在AI绘画模型中&#xff0c;Midjourney的效果是公认的top级别&#xff0c;但同时也是相对较难使用的&#xff0c;对小白来说比较难上手&#xff0c;主要就在于Mj没有webui&#xff0c;不能选择参数&#xff0c;怎么找到这些隐藏参数并且触发它是用好Mj的第一步。 今天就来…...

操作符知识点大全(简洁,全面,含使用场景,演示,代码)

目录 一.算术操作符 1.要点&#xff1a; 二.负数原码&#xff0c;反码&#xff0c;补码的互推 1.按位取反操作符&#xff1a;~&#xff08;二进制位&#xff09; 2.原反补互推演示 三.进制位的表示 1.不同进制位的特征&#xff1a; 2.二进制位表示 3.整型的二进制表…...

华工研究生语音课

这门课讲啥 语音蕴含的信息、语音识别的目的 语音的准平稳性、分帧、预加重、时域特征分析&#xff08;能量和过零率&#xff09;、端点检测&#xff08;双门限法&#xff09; 语音的基频及检测&#xff08;主要是自相关法、野点的处理&#xff09; 声音的产生过程&#xf…...

KingbaseES 原生XML系列二 -- XML数据操作函数

KingbaseES 原生XML系列二--XML数据操作函数(DELETEXML,APPENDCHILDXML,INSERTCHILDXML,INSERTCHILDXMLAFTER,INSERTCHILDXMLBEFORE,INSERTXMLAFTER,INSERTXMLBEFORE,UPDATEXML) XML的简单使其易于在任何应用程序中读写数据&#xff0c;这使XML很快成为数据交换的一种公共语言。…...

【Flink】DataStream API使用之源算子(Source)

源算子 创建环境之后&#xff0c;就可以构建数据的业务处理逻辑了&#xff0c;Flink可以从各种来源获取数据&#xff0c;然后构建DataStream进项转换。一般将数据的输入来源称为数据源&#xff08;data source&#xff09;&#xff0c;而读取数据的算子就叫做源算子&#xff08…...

树莓派硬件介绍及配件选择

目录 树莓派Datasheet下载地址&#xff1a; Raspberry 4B 外观图&#xff1a; 技术规格书&#xff1a; 性能介绍&#xff1a; 树莓派配件选用 电源的选用&#xff1a; 树莓派外壳选用&#xff1a; 内存卡/U盘选用 树莓派Datasheet下载地址&#xff1a; Raspberry Pi …...

O2OA (翱途) 平台 V8.0 发布新增数据台账能力

亲爱的小伙伴们&#xff0c;O2OA (翱途) 平台开发团队经过几个月的持续努力&#xff0c;实现功能的新增、优化以及问题的修复。2023 年度 V8.0 版本已正式发布。欢迎大家到 O2OA 的官网上下载进行体验&#xff0c;也希望大家在藕粉社区里多提宝贵建议。本篇我们先为大家介绍应用…...

数控解锁怎么解 数控系统解锁解密

Amazon Fargate 在中国区正式落地&#xff0c;因 数控解锁使用 Serverless 架构&#xff0c;更加适合对性能要求不敏感的服务使用&#xff0c;Pyroscope 是一款基于 Golang 开发的应用程序性能分析工具&#xff0c;Pyroscope 的服务端为无状态服务且性能要求不敏感&#xff0c;…...

3.0 响应式系统的设计与实现

1、Proxy代理对象 Proxy用于对一个普通对象代理&#xff0c;实现对象的拦截和自定义&#xff0c;如拦截其赋值、枚举、函数调用等。里面包含了很多组捕获器&#xff08;trap&#xff09;&#xff0c;在代理对象执行相应的操作时捕获&#xff0c;然后在内部实现自定义。 const…...

Rust 快速入门60分① 看完这篇就能写代码了

Rust 一门赋予每个人构建可靠且高效软件能力的语言https://hannyang.blog.csdn.net/article/details/130467813?spm1001.2014.3001.5502关于Rust安装等内容请参考上文链接&#xff0c;写完上文就在考虑写点关于Rust的入门文章&#xff0c;本专辑将直接从Rust基础入门内容开始讲…...

【5.JS基础-JavaScript的DOM操作】

1 认识DOM和BOM 所以我们学习DOM&#xff0c;就是在学习如何通过JavaScript对文档进行操作的&#xff1b; DOM Tree的理解 DOM的学习顺序 DOM的继承关系图 2 document对象 3 节点&#xff08;Node&#xff09;之间的导航&#xff08;navigator&#xff09; 4 元素&#xff0…...

【大数据之Hadoop】二十九、HDFS存储优化

纠删码和异构存储测试需要5台虚拟机。准备另外一套5台服务器集群。 环境准备&#xff1a; &#xff08;1&#xff09;克隆hadoop105为hadoop106&#xff0c;修改ip地址和hostname&#xff0c;然后重启。 vim /etc/sysconfig/network-scripts/ifcfg-ens33 vim /etc/hostname r…...

SuperMap GIS基础产品组件GIS FAQ集锦(2)

SuperMap GIS基础产品组件GIS FAQ集锦&#xff08;2&#xff09; 【iObjects for Spark】读取GDB参数该如何填写&#xff1f; 【解决办法】可参考以下示例&#xff1a; val GDB_params new util.HashMapString, java.io.Serializable GDB_params.put(FeatureRDDProviderParam…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

java_网络服务相关_gateway_nacos_feign区别联系

1. spring-cloud-starter-gateway 作用&#xff1a;作为微服务架构的网关&#xff0c;统一入口&#xff0c;处理所有外部请求。 核心能力&#xff1a; 路由转发&#xff08;基于路径、服务名等&#xff09;过滤器&#xff08;鉴权、限流、日志、Header 处理&#xff09;支持负…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

Java - Mysql数据类型对应

Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

用docker来安装部署freeswitch记录

今天刚才测试一个callcenter的项目&#xff0c;所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...

九天毕昇深度学习平台 | 如何安装库?

pip install 库名 -i https://pypi.tuna.tsinghua.edu.cn/simple --user 举个例子&#xff1a; 报错 ModuleNotFoundError: No module named torch 那么我需要安装 torch pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple --user pip install 库名&#x…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析

Linux 内存管理实战精讲&#xff1a;核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用&#xff0c;还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...

2025年渗透测试面试题总结-腾讯[实习]科恩实验室-安全工程师(题目+回答)

安全领域各种资源&#xff0c;学习文档&#xff0c;以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具&#xff0c;欢迎关注。 目录 腾讯[实习]科恩实验室-安全工程师 一、网络与协议 1. TCP三次握手 2. SYN扫描原理 3. HTTPS证书机制 二…...