头歌——人工智能(机器学习 --- 决策树2)
文章目录
- 第5关:基尼系数
- 代码
- 第6关:预剪枝与后剪枝
- 代码
- 第7关:鸢尾花识别
- 代码
第5关:基尼系数
基尼系数
在ID3算法中我们使用了信息增益来选择特征,信息增益大的优先选择。在C4.5算法中,采用了信息增益率来选择特征,以减少信息增益容易选择特征值多的特征的问题。但是无论是ID3还是C4.5,都是基于信息论的熵模型的,这里面会涉及大量的对数运算。能不能简化模型同时也不至于完全丢失熵模型的优点呢?当然有!那就是基尼系数!
CART算法使用基尼系数来代替信息增益率,基尼系数代表了模型的不纯度,基尼系数越小,则不纯度越低,特征越好。这和信息增益与信息增益率是相反的(它们都是越大越好)。
从公式可以看出,相比于信息增益和信息增益率,计算起来更加简单。举个例子,还是使用第二关中提到过的数据集,第一列是编号,第二列是性别,第三列是活跃度,第四列是客户是否流失的标签(0表示未流失,1表示流失)。
代码
import numpy as npdef calcGini(feature, label, index):'''计算基尼系数:param feature:测试用例中字典里的feature,类型为ndarray:param label:测试用例中字典里的label,类型为ndarray:param index:测试用例中字典里的index,即feature部分特征列的索引。该索引指的是feature中第几个特征,如index:0表示使用第一个特征来计算信息增益。:return:基尼系数,类型float'''# 计算子集的基尼指数def calcGiniIndex(label_subset):total = len(label_subset)if total == 0:return 0label_counts = np.bincount(label_subset)probabilities = label_counts / totalgini = 1.0 - np.sum(np.square(probabilities))return gini# 将feature和label转为numpy数组f = np.array(feature)l = np.array(label)# 得到指定特征列的值的集合unique_values = np.unique(f[:, index])total_gini = 0total_samples = len(label)# 按照特征的每个唯一值划分数据集for value in unique_values:# 获取该特征值对应的样本索引subset_indices = np.where(f[:, index] == value)[0]# 获取对应的子集标签subset_label = l[subset_indices]# 计算子集的基尼指数subset_gini = calcGiniIndex(subset_label)# 加权计算总的基尼系数weighted_gini = (len(subset_label) / total_samples) * subset_ginitotal_gini += weighted_ginireturn total_gini
第6关:预剪枝与后剪枝
为什么需要剪枝
决策树的生成是递归地去构建决策树,直到不能继续下去为止。这样产生的树往往对训练数据有很高的分类准确率,但对未知的测试数据进行预测就没有那么准确了,也就是所谓的过拟合。
决策树容易过拟合的原因是在构建决策树的过程时会过多地考虑如何提高对训练集中的数据的分类准确率,从而会构建出非常复杂的决策树(树的宽度和深度都比较大)。在之前的实训中已经提到过,模型的复杂度越高,模型就越容易出现过拟合的现象。所以简化决策树的复杂度能够有效地缓解过拟合现象,而简化决策树最常用的方法就是剪枝。剪枝分为预剪枝与后剪枝。
预剪枝
预剪枝的核心思想是在决策树生成过程中,对每个结点在划分前先进行一个评估,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶结点。
想要评估决策树算法的泛化性能如何,方法很简单。可以将训练数据集中随机取出一部分作为验证数据集,然后在用训练数据集对每个结点进行划分之前用当前状态的决策树计算出在验证数据集上的正确率。正确率越高说明决策树的泛化性能越好,如果在划分结点的时候发现泛化性能有所下降或者没有提升时,说明应该停止划分,并用投票计数的方式将当前结点标记成叶子结点。
举个例子,假如上一关中所提到的用来决定是否买西瓜的决策树模型已经出现过拟合的情况,模型如下:
假设当模型在划分是否便宜这个结点前,模型在验证数据集上的正确率为0.81。但在划分后,模型在验证数据集上的正确率降为0.67。此时就不应该划分是否便宜这个结点。所以预剪枝后的模型如下:
从上图可以看出,预剪枝能够降低决策树的复杂度。这种预剪枝处理属于贪心思想,但是贪心有一定的缺陷,就是可能当前划分会降低泛化性能,但在其基础上进行的后续划分却有可能导致性能显著提高。所以有可能会导致决策树出现欠拟合的情况。
后剪枝
后剪枝是先从训练集生成一棵完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能够带来决策树泛化性能提升,则将该子树替换为叶结点。
后剪枝的思路很直接,对于决策树中的每一个非叶子结点的子树,我们尝试着把它替换成一个叶子结点,该叶子结点的类别我们用子树所覆盖训练样本中存在最多的那个类来代替,这样就产生了一个简化决策树,然后比较这两个决策树在测试数据集中的表现,如果简化决策树在验证数据集中的准确率有所提高,那么该子树就可以替换成叶子结点。该算法以bottom-up的方式遍历所有的子树,直至没有任何子树可以替换使得测试数据集的表现得以改进时,算法就可以终止。
从后剪枝的流程可以看出,后剪枝是从全局的角度来看待要不要剪枝,所以造成欠拟合现象的可能性比较小。但由于后剪枝需要先生成完整的决策树,然后再剪枝,所以后剪枝的训练时间开销更高。
代码
import numpy as np
from copy import deepcopyclass DecisionTree(object):def __init__(self):# 决策树模型self.tree = {}def calcInfoGain(self, feature, label, index):# 计算信息增益的代码def calcInfoEntropy(feature, label):# 计算信息熵label_set = set(label)result = 0for l in label_set:count = 0for j in range(len(label)):if label[j] == l:count += 1p = count / len(label)result -= p * np.log2(p)return resultdef calcHDA(feature, label, index, value):# 计算条件熵count = 0sub_feature = []sub_label = []for i in range(len(feature)):if feature[i][index] == value:count += 1sub_feature.append(feature[i])sub_label.append(label[i])pHA = count / len(feature)e = calcInfoEntropy(sub_feature, sub_label)return pHA * ebase_e = calcInfoEntropy(feature, label)f = np.array(feature)f_set = set(f[:, index])sum_HDA = 0for value in f_set:sum_HDA += calcHDA(feature, label, index, value)return base_e - sum_HDAdef getBestFeature(self, feature, label):max_infogain = 0best_feature = 0for i in range(len(feature[0])):infogain = self.calcInfoGain(feature, label, i)if infogain > max_infogain:max_infogain = infogainbest_feature = ireturn best_featuredef calc_acc_val(self, the_tree, val_feature, val_label):# 计算验证集准确率result = []def classify(tree, feature):if not isinstance(tree, dict):return treet_index, t_value = list(tree.items())[0]f_value = feature[t_index]if isinstance(t_value, dict):classLabel = classify(tree[t_index][f_value], feature)return classLabelelse:return t_valuefor f in val_feature:result.append(classify(the_tree, f))result = np.array(result)return np.mean(result == val_label)def createTree(self, train_feature, train_label):# 创建决策树if len(set(train_label)) == 1:return train_label[0]if len(train_feature[0]) == 1 or len(np.unique(train_feature, axis=0)) == 1:vote = {}for l in train_label:if l in vote.keys():vote[l] += 1else:vote[l] = 1max_count = 0vote_label = Nonefor k, v in vote.items():if v > max_count:max_count = vvote_label = kreturn vote_labelbest_feature = self.getBestFeature(train_feature, train_label)tree = {best_feature: {}}f = np.array(train_feature)f_set = set(f[:, best_feature])for v in f_set:sub_feature = []sub_label = []for i in range(len(train_feature)):if train_feature[i][best_feature] == v:sub_feature.append(train_feature[i])sub_label.append(train_label[i])tree[best_feature][v] = self.createTree(sub_feature, sub_label)return treedef post_cut(self, val_feature, val_label):# 剪枝相关代码def get_non_leaf_node_count(tree):non_leaf_node_path = []def dfs(tree, path, all_path):for k in tree.keys():if isinstance(tree[k], dict):path.append(k)dfs(tree[k], path, all_path)if len(path) > 0:path.pop()else:all_path.append(path[:])dfs(tree, [], non_leaf_node_path)unique_non_leaf_node = []for path in non_leaf_node_path:isFind = Falsefor p in unique_non_leaf_node:if path == p:isFind = Truebreakif not isFind:unique_non_leaf_node.append(path)return len(unique_non_leaf_node)def get_the_most_deep_path(tree):non_leaf_node_path = []def dfs(tree, path, all_path):for k in tree.keys():if isinstance(tree[k], dict):path.append(k)dfs(tree[k], path, all_path)if len(path) > 0:path.pop()else:all_path.append(path[:])dfs(tree, [], non_leaf_node_path)max_depth = 0result = Nonefor path in non_leaf_node_path:if len(path) > max_depth:max_depth = len(path)result = pathreturn resultdef set_vote_label(tree, path, label):for i in range(len(path)-1):tree = tree[path[i]]tree[path[len(path)-1]] = labelacc_before_cut = self.calc_acc_val(self.tree, val_feature, val_label)for _ in range(get_non_leaf_node_count(self.tree)):path = get_the_most_deep_path(self.tree)tree = deepcopy(self.tree)step = deepcopy(tree)for k in path:step = step[k]vote_label = sorted(step.items(), key=lambda item: item[1], reverse=True)[0][0]set_vote_label(tree, path, vote_label)acc_after_cut = self.calc_acc_val(tree, val_feature, val_label)if acc_after_cut > acc_before_cut:set_vote_label(self.tree, path, vote_label)acc_before_cut = acc_after_cutdef fit(self, train_feature, train_label, val_feature, val_label):# 训练决策树模型self.tree = self.createTree(train_feature, train_label)self.post_cut(val_feature, val_label)def predict(self, feature):# 预测函数result = []def classify(tree, feature):if not isinstance(tree, dict):return treet_index, t_value = list(tree.items())[0]f_value = feature[t_index]if isinstance(t_value, dict):classLabel = classify(tree[t_index][f_value], feature)return classLabelelse:return t_valuefor f in feature:result.append(classify(self.tree, f))return np.array(result)
第7关:鸢尾花识别
掌握如何使用sklearn提供的DecisionTreeClassifier
数据简介:
鸢尾花数据集是一类多重变量分析的数据集。通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类(其中分别用0,1,2代替)。
数据集中部分数据与标签如下图所示:
DecisionTreeClassifier
DecisionTreeClassifier的构造函数中有两个常用的参数可以设置:
criterion:划分节点时用到的指标。有gini(基尼系数),entropy(信息增益)。若不设置,默认为gini
max_depth:决策树的最大深度,如果发现模型已经出现过拟合,可以尝试将该参数调小。若不设置,默认为None
和sklearn中其他分类器一样,DecisionTreeClassifier类中的fit函数用于训练模型,fit函数有两个向量输入:
X:大小为[样本数量,特征数量]的ndarray,存放训练样本;
Y:值为整型,大小为[样本数量]的ndarray,存放训练样本的分类标签。
DecisionTreeClassifier类中的predict函数用于预测,返回预测标签,predict函数有一个向量输入:
X:大小为[样本数量,特征数量]的ndarray,存放预测样本。
DecisionTreeClassifier的使用代码如下:
from sklearn.tree import DecisionTreeClassifier
clf = tree.DecisionTreeClassifier()
clf.fit(X_train, Y_train)
result = clf.predict(X_test)
代码
#********* Begin *********#
import pandas as pd
from sklearn.tree import DecisionTreeClassifiertrain_df = pd.read_csv('./step7/train_data.csv').as_matrix()
train_label = pd.read_csv('./step7/train_label.csv').as_matrix()
test_df = pd.read_csv('./step7/test_data.csv').as_matrix()dt = DecisionTreeClassifier()
dt.fit(train_df, train_label)
result = dt.predict(test_df)result = pd.DataFrame({'target':result})
result.to_csv('./step7/predict.csv', index=False)
#********* End *********#
相关文章:
头歌——人工智能(机器学习 --- 决策树2)
文章目录 第5关:基尼系数代码 第6关:预剪枝与后剪枝代码 第7关:鸢尾花识别代码 第5关:基尼系数 基尼系数 在ID3算法中我们使用了信息增益来选择特征,信息增益大的优先选择。在C4.5算法中,采用了信息增益率…...
一七一、React性能优化方式
在 React 中进行性能优化可以通过多种手段来减少渲染次数、优化渲染效率并减少内存消耗。以下是常见的性能优化方法及示例: 1. shouldComponentUpdate shouldComponentUpdate 是类组件中的生命周期方法,它可以让组件在判断是否需要重新渲染时ÿ…...
编写dockerfile生成镜像,并且构建容器运行
编写dockerfile生成镜像,并且构建容器运行 目录 编写dockerfile生成镜像,并且构建容器运行 概述 一、dockerfile文件详解 Dockerfile的基本结构 Dockerfile的常用指令 二、构建过程 概述 随着微服务应用越来越多,大家需要尽快掌握dock…...
Java项目练习——学生管理系统
1. 整体结构 代码实现了基本的学生管理系统功能,包括登录、注册、忘记密码、添加、删除、修改和查询学生信息。 使用了ArrayList来存储用户和学生信息。 使用了Scanner类来处理用户输入。 2. 主要功能模块 登录 (logIn):验证用户名和密码,…...
sqlserver、达梦、mysql的差异
差异项sqlserver达梦mysql单行注释---- 1、-- ,--后面带个空格 2、# 包裹对象名称,如表、表字段等 [tableName] "tableName"tableName表字段自增IDENTITY(1, 1)IDENTITY(1, 1)AUTO_INCREMENT二进制数据类型IMAGEIMAGE、BLOBBLOB 存储一个汉字需…...
Spring AOP(定义、使用场景、用法、3种事务、事务失效场景及解决办法、面试题)
目录 1. AOP定义? 2.常见的AOP使用场景: 3.Spring AOP用法 3.1 Spring AOP中的几个核心概念 3.1.1 切面、切点、通知、连接点 3.1.2 切点表达式AspectJ 3.2 使用 Spring AOP 的步骤总结 3.2.1 添加依赖: 3.2.2 定义切面和切点(切点和…...
Flutter鸿蒙next 封装对话框详解
✅近期推荐:求职神器 https://bbs.csdn.net/topics/619384540 🔥欢迎大家订阅系列专栏:flutter_鸿蒙next 💬淼学派语录:只有不断的否认自己和肯定自己,才能走出弯曲不平的泥泞路,因为平坦的大路…...
【项目实战】通过LLaMaFactory+Qwen2-VL-2B微调一个多模态医疗大模型
前言 随着多模态大模型的发展,其不仅限于文字处理,更能够在图像、视频、音频方面进行识别与理解。医疗领域中,医生们往往需要对各种医学图像进行处理,以辅助诊断和治疗。如果将多模态大模型与图像诊断相结合,那么这会…...
SCSI驱动与 UFS 驱动交互概况
SCSI子系统概况 SCSI(Small Computer System Interface)子系统是 Linux 中的一个模块化框架,用于提供与存储设备的通用接口。通过 SCSI 子系统,可以支持不同类型的存储协议(如 UFS、SATA、SAS),…...
软件工程实践项目:人事管理系统
一、项目的需求说明 通过移动设备登录app提供简单、方便的操作。根据公司原来的考勤管理制度,为公司不同管理层次提供相应的权限功能。通过app上面的各种标准操作,考勤管理无纸化的实现,使公司的考勤管理更加科学规范,从而节省考…...
不使用三方软件,win系统下禁止单个应用联网能力的详细操作教程
本篇文章主要讲解,在win系统环境下,禁止某个应用联网能力的详细操作教程,通过本教程您可以快速掌握自定义对单个程序联网能力的限制和禁止。 作者:任聪聪 日期:2024年10月30日 步骤一、按下win按键(四个小方…...
近似线性可分支持向量机的原理推导
近似线性可分的意思是训练集中大部分实例点是线性可分的,只是一些特殊实例点的存在使得这种数据集不适用于直接使用线性可分支持向量机进行处理,但也没有到完全线性不可分的程度。所以近似线性可分支持向量机问题的关键就在于这些少数的特殊点。 相较于…...
Golang开发环境
Golang开发环境搭建 Go 语言开发包 国外:https://golang.org/dl/ 国内(推荐): https://golang.google.cn/dl/ 编辑器 Golang:https://www.jetbrains.com/go/ Visual Studio Code: https://code.visualstudio.com/ 搭建 Go 语言开发环境,需要…...
测试华为GaussDB(DWS)数仓,并通过APISQL快速将(表、视图、存储过程)发布为API
华为数据仓库服务 数据仓库服务(Data Warehouse Service,简称DWS)是一种基于公有云基础架构和平台的在线数据处理数据库,提供即开即用、可扩展且完全托管的分析型数据库服务。DWS是基于华为融合数据仓库GaussDB产品的云原生服务&a…...
使用GetX实现GetPage中间件
前言 GetX 中间件(Middleware)是 GetX 框架中的一种机制,用于在页面导航时对用户进行权限控制、数据预加载、页面访问条件设置等。通过使用中间件,可以有效地控制用户的访问流程,并在适当条件下引导用户到所需页面。 这…...
Navicat 17 功能简介 | SQL 预览
Navicat 17 功能简介 | SQL 预览 随着 17 版本的发布,Navicat 也带来了众多的新特性,包括兼容更多数据库、全新的模型设计、可视化智能 BI、智能数据分析、可视化查询解释、高质量数据字典、增强用户体验、扩展MongoDB 功能、轻松固定查询结果、便捷URI …...
ubuntu、Debian离线部署gitlab
一、软件包下载 gitlab安装包下载链接 ubuntu: ubuntu/focal 适用于 ubuntu20系列 ubuntu/bionic 适用于 ubuntu18 系列 Debian: debian/buster 适用于 Debian10系列 debian/bullseye 适用于 Debian11、12系列 二、安装gitlab ubuntu需要安装一些环境…...
数据库编程 SQLITE3 Linux环境
永久存储程序数据有两种方式: 用文件存储用数据库存储 对于多条记录的存储而言,采用文件时,插入、删除、查找的效率都会很差,为了提高这些操作的效率,有计算机科学家设计出了数据库存储方式 一、数据库 数据库的基本…...
独孤思维:总有一双眼睛默默观察你做副业
01 独孤昨天在陪伴群,分享了近期小白做副业的一些困扰。 并且以自己经历作为案例,分享了一些经验和方法。 最后顺势推出xx博主的关于365条赚钱信息小报童专栏。 订阅后,可以开拓副业赚钱思路,避免走一些弯路。 甚至于&#x…...
医院信息化与智能化系统(10)
医院信息化与智能化系统(10) 这里只描述对应过程,和可能遇到的问题及解决办法以及对应的参考链接,并不会直接每一步详细配置 如果你想通过文字描述或代码画流程图,可以试试PlantUML,告诉GPT你的文件结构,让他给你对应…...
基于YOLO11/v10/v8/v5深度学习的危险驾驶行为检测识别系统设计与实现【python源码+Pyqt5界面+数据集+训练代码】
《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…...
Flink CDC系列之:学习理解核心概念——Transform
Flink CDC系列之:学习理解核心概念——Transform Transform参数元数据字段函数比较函数逻辑函数字符串函数时间函数条件函数 示例添加计算列参考元数据列使用通配符投影所有字段添加过滤规则重新分配主键重新分配分区键指定表创建配置分类映射用户定义函数已知限制 …...
MyBatis-Plus:简化 CRUD 操作的艺术
一、关于MyBatis-Plus 1.1 简介 MyBatis-Plus 是一个基于 MyBatis 的增强工具,它旨在简化 MyBatis 的使用,提高开发效率。 关于Mybatis 简介 MyBatis 是一款流行的 Java 持久层框架,旨在简化 Java 应用程序与数…...
Windows on ARM编译安装openBLAS
Windows on ARM编译安装openBLAS 要求下载源码OpenBLAS可以使用LLVM工具链(clang-cl和flang)从源代码为Windows on ARM(WoA)进行构建。v0.3.24版本(预构建包)的构建和测试已通过。 要求 LLVM:版本需大于等于17.0.4 LLVM版本16及以下会生成冲突的符号(如_QQ*等)。 LL…...
FPGA编程语言VHDL与Verilog的比较分析!!!
VHDL(VHSIC硬件描述语言)和Verilog都是用于硬件描述和FPGA编程的工业标准语言。它们在语法和设计理念上存在一些差异,以下是两者的比较分析: 1. 历史背景 VHDL: 开发于1980年代初期,最初用于美国国防部的…...
C语言——八股文(笔试面试题)
1、 什么是数组指针,什么是指针数组? 数组指针:指向数组的指针 指针数组:数组中的元素都是指针 2、 什么是位段,什么是联合体 位段(Bit Field):在C语言中,允许在一个整数…...
解决 Oracle 数据库错误 ORA-12516:监听器无法找到匹配协议栈的处理程序
在使用 Oracle 数据库时,有时会遇到错误 ORA-12516,这个错误表明 Oracle 数据库的监听器无法为新的连接请求找到一个可用的处理程序,这通常是因为达到了连接数上限、配置问题或资源限制。本文将详细介绍如何解决这个问题。 一、错误描述 当…...
Flarum:简洁而强大的开源论坛软件
Flarum简介 Flarum是一款开源论坛软件,以其简洁、快速和易用性而闻名。它继承了esoTalk和FluxBB的优良传统,旨在提供一个不复杂、不臃肿的论坛体验。Flarum的核心优势在于: 快速、简单: Flarum使用PHP构建,易于部署&…...
方法+数组
1. 方法 1. 什么是方法 方法定义: // []表示可写可不写[public] [static] type name ( [type formal , type formal , ...]){方法体;[return value ;] }[修饰符] 返回值类型 方法名称([参数类型 形参 , 参数类型 形参 ...]){方法体代码;[return 返回值…...
驱动-----adc
在key1.c的基础上进行对adc1.c进行编写 首先将文件里面的key全部改为adc 再修改一下设备号 按键和adc的区别是什么,按键只需要按一下就触发了,并且不需要返回一个值出来, adc要初始化,启动,返回值 以下是裸机adc的代码: #include <s3c2440.h> #include "ad…...
门户网站的建设成果/seo网络公司
前言:本文章为慕课网上Java企业级电商项目架构演进之路Tomcat集群与Redis分布式的学习笔记.供本人复习之用. 目录 第一章 架构图 第二章 nginx配置介绍 第三章 实践 第一章 架构图 大致的架构图,可以看出nginx的配置是关键 第二章 nginx配置介绍 要配置一个upstream来指明…...
包头网站建设公司/网址域名注册信息查询
实战需求 SwiftUI 图表教程之 02 原生制作水平Bar Chart 北京全年日均最低气温 本文价值与收获 看完本文后,您将能够作出下面的界面 看完本文您将掌握的技能 使用VStack,HStack,Text,Rectangle,Capsule基础填补空白区域Spacer()数组转化为Double(item)/Double(alist.max()!…...
网页设计制作论文/西安seo王
Android学习CursorWrapper与Decorator模式 - Dufresne - 博客园 一 Decorator模式 意图: 动态的给一个对象添加一些额外的职责。就增加功能来说,Decorator模式相比生成子类更为灵活。 动态的给一个对象,而不是对整个类添加额外职责࿰…...
传奇游戏网站/网店网络推广方案
随时准备失去所拥有的,我可能才会倍加珍惜周围的一切! 作者:xhf55555 发表于2011-10-28 12:20:00 原文链接 阅读:69 评论:3 查看评论 转载于:https://www.cnblogs.com/loveyong/archive/2011/10/28/2330160.html...
广告网站建设流程/最厉害的搜索引擎
在用npm安装包的时候,一直出现errorno:4048的错误,提示为用管理员身份登录。找过很多帖子,多为用管理员身份打开cmd,或者npm cache clean --force,都试过了但是不起作用,因为我也没弄清楚原因(泪奔::>_<::&#…...
wordpress先使用/seo外包服务专家
DataTables是一款非常简单的前端表格展示插件,它支持排序,翻页,搜索以及在客户端和服务端分页等多种功能。官方介绍:DataTables is a plug-in for the jQuery Javascript library. It is a highly flexible tool, based upon the …...