16- 梯度提升分类树GBDT (梯度下降优化) (算法)
-
梯度提升算法
from sklearn.ensemble import GradientBoostingClassifier
clf = GradientBoostingClassifier(subsample=0.8,learning_rate = 0.005)
clf.fit(X_train,y_train)
1、交叉熵
1.1、信息熵
- 构建好一颗树,数据变的有顺序了(构建前,一堆数据,杂乱无章;构建一颗,整整齐齐,顺序),用什么度量衡表示,数据是否有顺序:信息熵
- 物理学,热力学第二定律(熵),描述的是封闭系统的混乱程度
- 信息熵,和物理学中熵类似的
1.2、交叉熵
由信息熵可以引出交叉熵!
小明在学校玩王者荣耀被发现了,爸爸被叫去开家长会,心里悲屈的很,就想法子惩罚小明。到家后,爸爸跟小明说:既然你犯错了,就要接受惩罚,但惩罚的程度就看你聪不聪明了。这样吧,我们俩玩猜球游戏,我拿一个球,你猜球的颜色,我可以回答你任何问题,你每猜一次,不管对错,你就一个星期不能玩王者荣耀,当然,猜对,游戏停止,否则继续猜。当然,当答案只剩下两种选择时,此次猜测结束后,无论猜对猜错都能100%确定答案,无需再猜一次,此时游戏停止。
1.2.1、题目一
爸爸拿来一个箱子,跟小明说:里面有橙、紫、蓝及青四种颜色的小球任意个,各颜色小球的占比不清楚,现在我从中拿出一个小球,你猜我手中的小球是什么颜色?
为了使被罚时间最短,小明发挥出最强王者的智商,瞬间就想到了以最小的代价猜出答案,简称策略1,小明的想法是这样的。
1.2.2、题目二
爸爸还是拿来一个箱子,跟小明说:箱子里面有小球任意个,但其中1/2是橙色球,1/4是紫色球,1/8是蓝色球及1/8是青色球。我从中拿出一个球,你猜我手中的球是什么颜色的?
小明毕竟是最强王者,仍然很快得想到了答案,简称策略2,他的答案是这样的。
这就需要引入交叉熵,其用来衡量在给定的真实分布下,使用非真实分布所指定的策略消除系统的不确定性所需要付出的努力的大小。
1.3、sigmoid
后面算法推导过程中都会使用到上面的基本方程,因此先对以上概念公式,有基本了解!
2、GBDT分类树
2.1、梯度提升分类树概述
GBDT分类树 sigmoid + 决策回归树 一一> 概率问题!
-
损失函数是交叉熵
-
概率计算使用sigmoid
-
使用 mse 作为分裂标准(同梯度提升回归树)
2.2、梯度提升分类树应用
1、加载数据
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifierX,y = datasets.load_iris(return_X_y = True)
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state = 1124)
2、普通决策树表现
model = DecisionTreeClassifier()
model.fit(X_train,y_train)
model.score(X_test,y_test) # 输出:0.8421052631578947
3、梯度提升分类树表现
from sklearn.ensemble import GradientBoostingClassifier
clf = GradientBoostingClassifier(subsample=0.8,learning_rate = 0.005)
clf.fit(X_train,y_train)
clf.score(X_test,y_test) # 输出:0.9473684210526315
3、GBDT分类树算例演示
3.1、算法公式
-
概率计算(sigmoid函数)
-
函数初始值(这个函数即是sigmoid分母中的F(x),用于计算概率)
逻辑回归中的函数是线性函数,GBDT中的函数不是线性函数,但是作用类似!
-
计算残差公式
-
均方误差(根据均方误差,筛选最佳裂分条件)
-
决策树叶节点预测值(相当于负梯度)
-
梯度提升
-
根据以上公式,即可进行代码演算了~
3.2、算例演示
3.2.1、创建数据
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
from sklearn import tree
import graphviz
X = np.arange(1,11).reshape(-1,1)
y = np.array([0,0,0,1,1]*2)
display(X,y)
3.2.2、构造GBDT训练预测
# 默认情况下,损失函数就是Log-loss == 交叉熵!
clf = GradientBoostingClassifier(n_estimators=3,learning_rate=0.1,max_depth=1)
clf.fit(X,y)
y_ = clf.predict(X)
print('真实的类别:',y)
print('算法的预测:',y_)
proba_ = clf.predict_proba(X)
print('预测概率是:\n',proba_)
3.2.3、GBDT可视化
第一棵树
dot_data = tree.export_graphviz(clf[0,0],filled = True)
graph = graphviz.Source(dot_data)
graph
第二棵树
dot_data = tree.export_graphviz(clf[1,0],filled = True)
graph = graphviz.Source(dot_data)
graph
第三棵树
dot_data = tree.export_graphviz(clf[2,0],filled = True)
graph = graphviz.Source(dot_data)
graph
每棵树,根据属性进行了划分,每棵树的叶节点都有预测值,这些具体都是如何计算的呢?且看,下面详细的计算工程~
3.2.4、计算步骤
首先,计算初始值 :
F0 = np.log(y.sum()/(1-y).sum())
F0 # 输出结果:-0.40546510810816444
# 此时未裂分,所有的数据都是F0
F0 = np.array([F0]*10)
# 然后,计算残差
# 残差,F0带入sigmoid计算的即是初始概率
residual0 = y - 1/(1 + np.exp(-F0))
residual0
# 输出:array([-0.4, -0.4, -0.4, 0.6, 0.6, -0.4, -0.4, -0.4, 0.6, 0.6])
3.2.5、拟合第一棵树
根据残差的mse,计算最佳分裂条件
lower_mse = ((residual0 - residual0.mean())**2).mean()
best_split = {}
# 分裂标准 mse
for i in range(0,10):if i == 9:mse = ((residual0 - residual0.mean())**2).mean()else:left_mse = ((residual0[:i+1] - residual0[:i+1].mean())**2).mean()right_mse = ((residual0[i+1:] - residual0[i+1:].mean())**2).mean()mse = left_mse*(i+1)/10 + right_mse*(10-i-1)/10if lower_mse > mse:lower_mse = msebest_split.clear()best_split['X[0] <= '] = X[i:i + 2].mean() print('从第%d个进行分裂'%(i + 1),np.round(mse,4))
# 从第八个样本这里进行分类,最优的选择,和算法第一颗画图的结果一致
print('最小的mse是:',lower_mse)
print('最佳裂分条件是:',best_split)
现在我们知道了,分裂条件是:X[0] <= 8.5!然后计算决策树叶节点预测值(相当于负梯度),其中的 就是残差residual0
3.2.6、拟合第二棵树
第一棵树的负梯度(预测值)
# 第一棵预测的结果,负梯度
gamma = np.array([gamma1]*8 + [gamma2]*2)
gamma '''输出:array([-0.625, -0.625, -0.625, -0.625, -0.625, -0.625,-0.625, -0.625, 2.5 , 2.5 ])'''
梯度提升
# F(x) 随着梯度提升树,提升,发生变化
learning_rate = 0.1
F1 = F0 + gamma*learning_rate
F1 ''' 输出 array([-0.46796511, -0.46796511, -0.46796511, -0.46796511,
-0.46796511, -0.46796511, -0.46796511, -0.46796511, -0.15546511, -0.15546511])'''
根据 F1 计算残差
residual1 = y - 1/(1 + np.exp(-F1))
residual1 '''array([-0.38509799, -0.38509799, -0.38509799, 0.61490201,
0.61490201, -0.38509799, -0.38509799, -0.38509799, 0.53878818, 0.53878818])'''
根据新的残差residual1的mse,计算最佳分裂条件
lower_mse = ((residual1 - residual1.mean())**2).mean()
best_split = {}
# 分裂标准 mse
for i in range(0,10):if i == 9:mse = ((residual1 - residual1.mean())**2).mean()else:left_mse = ((residual1[:i+1] - residual1[:i+1].mean())**2).mean()right_mse = ((residual1[i+1:] - residual1[i+1:].mean())**2).mean()mse = left_mse*(i+1)/10 + right_mse*(10-i-1)/10if lower_mse > mse:lower_mse = msebest_split.clear()best_split['X[0] <= '] = X[i:i + 2].mean() print('从第%d个进行分裂'%(i + 1),np.round(mse,4))
# 从第八个样本这里进行分类,最优的选择,和算法第一颗画图的结果一致
print('最小的mse是:',lower_mse)
print('最佳裂分条件是:',best_split)
现在我们知道了,第二棵树分裂条件是:X[0] <= 8.5 !然后计算决策树叶节点预测值(相当于负梯度),其中的 就是残差residual1
3.2.7、拟合第三棵树
第二棵树的负梯度
# 第二棵树预测值
gamma = np.array([gamma1]*8 + [gamma2]*2)
gamma
梯度提升
# F(x) 随着梯度提升树,提升,发生变化
learning_rate = 0.1
F2 = F1 + gamma*learning_rate
F2
根据 F2 计算残差
residual2 = y - 1/(1 + np.exp(-F2))
residual2
根据新的残差residual2的 mse,计算最佳分裂条件
lower_mse = ((residual2 - residual2.mean())**2).mean()
best_split = {}
# 分裂标准 mse
for i in range(0,10):if i == 9:mse = ((residual2 - residual2.mean())**2).mean()else:left_mse = ((residual2[:i+1] - residual2[:i+1].mean())**2).mean()right_mse = ((residual2[i+1:] - residual2[i+1:].mean())**2).mean()mse = left_mse*(i+1)/10 + right_mse*(10-i-1)/10if lower_mse > mse:lower_mse = msebest_split.clear()best_split['X[0] <= '] = X[i:i + 2].mean() print('从第%d个进行分裂'%(i + 1),np.round(mse,4))
# 从第八个样本这里进行分类,最优的选择,和算法第一颗画图的结果一致
print('最小的mse是:',lower_mse)
print('最佳裂分条件是:',best_split)
现在我们知道了,第三棵树分裂条件是:X[0] <= 3.5!然后计算决策树叶节点预测值(相当于负梯度),其中的 就是残差residual2
# 计算第三颗树的预测值
# 前三个是一类
# 后七个是一类
# 左边分支
gamma1 = residual2[:3].sum()/((y[:3] - residual2[:3])*(1 - y[:3] +
residual2[:3])).sum()
print('第三棵树左边决策树分支,预测值:',gamma1)# 右边分支
gamma2 =residual2[3:].sum()/((y[3:] - residual2[3:])*(1 - y[3:] +
residual2[3:])).sum()
print('第三棵树右边决策树分支,预测值:',gamma2)
3.2.8、预测概率计算
计算第三棵树的F3(x)
# 第三棵树预测值
gamma = np.array([gamma1]*3 + [gamma2]*7)
# F(x) 随着梯度提升树,提升,发生变化
learning_rate = 0.1
F3 = F2 + gamma*learning_rate
概率公式如下:
proba = 1/(1 + np.exp(-F3))
# 类别:0,1,如果这个概率大于等于0.5类别1,小于0.5类别0
display(proba)
# 进行转换,类别0,1的概率都展示
np.column_stack([1- proba,proba])
# 算法预测概率
clf.predict_proba(X)
结论:
-
手动计算的概率和算法预测的概率完全一样!
-
GBDT分类树,计算过程原理如上
4、GBDT分类树原理推导
4.1、损失函数:
-
定义交叉熵为函数
其中 ,即sigmoid函数
- 表示决策回归树 DecisionTreeRegressor F(x) 表示每一轮决策树的value,即负梯度
4.2、损失函数化简
-
损失函数化简:
-
-
化简过程
4.3、损失函数求导
将F(x)看成整体变量,进行求导
一阶导数:
4.4、初始值 计算
4.4.1、初始值方程构建
之前的GBDT回归树,初始值是多少:平均值
现在的GBDT分类树 ,计算初始值 ,令
5、GBDT二分类步骤总结
Step - 1:
Step - 2:for i in range(M):
a.
b. 根据残差 ,寻找最小 mse 裂分条件
c.
d.
相关文章:
16- 梯度提升分类树GBDT (梯度下降优化) (算法)
梯度提升算法 from sklearn.ensemble import GradientBoostingClassifier clf GradientBoostingClassifier(subsample0.8,learning_rate 0.005) clf.fit(X_train,y_train) 1、交叉熵 1.1、信息熵 构建好一颗树,数据变的有顺序了(构建前,…...
SpringCloud+Nacos+Gateway
SpringCloudNacosGatewaySpringBoot整合GatewayNacos一. 环境准备1. 版本环境2. 服务环境二. 实战1.创建用户服务2.创建订单服务3.创建网关服务4.测试三. 避坑指南问题1--503问题问题2--网关服务启动报错SpringBoot整合GatewayNacos 本篇文章只演示通过gateway网关服务访问其他…...
高通开发系列 - linux kernel内核升级msm-3.18升至msm-4.9(2)
By: fulinux E-mail: fulinux@sina.com Blog: https://blog.csdn.net/fulinus 喜欢的盆友欢迎点赞和订阅! 你的喜欢就是我写作的动力! 目录 返回高通开发系列 - 总目录 前面我们升级了msm-4.9内核系统正常启动了,文件系统也正常工作,但那是使用了老基线的文件系统,其yocto…...
Spring依赖注入与反转控制到底是个啥?
目录 1. 引言 2. 管中窥豹 3.1 Spring 依赖注入 3.2 Bean 的依赖注入方式有两种 4. 总结 1. 引言 此文目的是用通俗易懂的语言讲清楚什么是依赖注入与反转控制,在看了大量的博客文章后归纳总结,便于后续巩固!我相信,大多数…...
Linux Shell脚本讲解
目录 Shell脚本基础 Shell脚本组成 Shell脚本工作方式 编写简单的Shell脚本 Shell脚本参数 Shell脚本接收参数 Shell脚本判断用户参数 文件测试与逻辑测试语句 整数测试比较语句 字符串比较语句 Shell流程控制 if条件判断语句 单分支 双分支 多分支 for循环语句…...
Linux:用户空间非法指针coredump简析
1. 前言 限于作者能力水平,本文可能存在谬误,因此而给读者带来的损失,作者不做任何承诺。 2. 背景 本文分析基于 ARM32 架构,Linux-4.14 内核代码。 3. 问题分析 3.1 测试范例 void main(void) {*(int *)0 8; }运行程序会 …...
带你玩转Jetson之Deepstream简明教程(四)DeepstreamApp如何使用以及用于工程验证。
1.DeepstreamApp是什么? 如果你安装完毕deepstream整体框架,会在你的系统执行目录内有可执行文件,文件名字是deepstream-app。这是一个可执行脚本文件,通过deepstream框架中的代码在安装的时候编译后install到系统根目录内。 此脚…...
快速搭建个人在线书库,随时随地畅享阅读!
前边我们利用NAS部署了个人的导航页、小说站、云笔记,今天,我们再看看怎么部署一个个人的在线书库。 相信很多朋友都在自己的电脑中收藏了大量的PDF、MOBI等格式的电子书籍,但是一旦换了一台设备,要么是无法翻阅,要么…...
电子纸墨水屏的现实应用场景
电子纸挺好个东西,大家都把注意力集中在商超场景 其实还有更多有趣的场景方案可用,价值也不小,比如: 一、仓库场景 通过亮灯拣选,提高仓库作业效率 二、仓库循环使用标签 做NFC类发卡式应用,替代传统纸…...
常量const、引用、指针的大杂烩
文章目录1 普通引用1.1 对普通值的普通引用1.2 对常量值的普通引用1.3 对普通指针的普通引用1.4 对常量指针的普通引用1.5 对指针常量的普通引用1.6 对指向常量的指针常量的普通引用2 常量引用2.1 对普通值的常量引用2.2 对常量值的常量引用2.3 对普通指针的常量引用2.4 对常量…...
宝塔搭建实战php开源likeadmin通用管理移动端uniapp源码(四)
大家好啊,我是测评君,欢迎来到web测评。 上一期给大家分享了pc端的部署方式,今天来给大家分享uniapp端在本地搭建,与打包发布到宝塔的方法。感兴趣的朋友可以自行下载学习。 技术架构 vscode node16 vue3 uniapp vite types…...
Hive的分区表与分桶表内部表外部表
文章目录1 Hive分区表1.1 Hive分区表的概念?1.1.1 分区表注意事项1.2 分区表物理存储结构1.3 分区表使用场景1.4 静态分区表是什么?1.4.1 静态分区表案例1.4.2 分区表练习一1.4.3 分区操作1.5 动态分区表是什么?1.5.1 动态态分区表案例&#…...
和数集团打造《神念无界:源起山海》,诠释链游领域创新与责任
首先,根据网上资料显示,一部《传奇》,二十年热血依旧。 《传奇》所缔造的成绩,承载的是多少人的青春回忆,《传奇》无疑已经在游戏史上写下了浓墨重彩的一笔。 相比《传奇》及背后的研发运营公司娱美德名声大噪&#x…...
小白入门模拟IC设计,如何快速学习?
众所周知,模拟电路很难学。以最普遍的晶体管来说,我们分析它的时候必须首先分析直流偏置,其次在分析交流输出电压。可以说,确定工作点就是一项相当麻烦的工作(实际中来说),晶体管的参数多、参数…...
51单片机——中断系统之外部中断实验,小白讲解,相互学习
中断介绍 中断是为使单片机具有对外部或内部随机发生的事件实时处理而设置的,中断功能的存在,很大程度上提高了单片机处理外部或内部事件的能力。它也是单片机最重要的功能之一,是我们学些单片机必须要掌握的。 为了更容易的理解中断概念&…...
如何设计一个秒杀系统
秒杀系统要如何设计? 前言 高并发下如何设计秒杀系统?这是一个高频面试题。这个问题看似简单,但是里面的水很深,它考查的是高并发场景下,从前端到后端多方面的知识。 秒杀一般出现在商城的促销活动中,指定…...
厄瓜多尔公司注册方案
简介: 经济概况与商机 厄瓜多尔是世界上第74大国家,是南美西部国家,与哥伦比亚,秘鲁和太平洋接壤。厄瓜多尔地处世界中心,地理位置优越,地理位置优越-赤道线零纬度,使其成为通往太平洋的理想枢…...
安全渗透环境准备(工具下载)
数据来源 01 一些VM虚拟机的安装 攻击机kali: kali官网 渗透测试工具Kali Linux安装与使用 kali汉化 虚拟机网络建议设置成NAT模式,桥接有时不稳定。 靶机OWASP_Broken_Web_Apps: 迅雷下载 网盘下载 安装教程 开机之后需要登录&am…...
118.(leaflet篇)leaflet空间判断-点与geojson面图层的空间关系(turf实现)
听老人家说:多看美女会长寿 地图之家总目录(订阅之前建议先查看该博客) 文章末尾处提供保证可运行完整代码包,运行如有问题,可“私信”博主。 效果如下所示: 下面献上完整代码,代码重要位置会做相应解释 <!DOCTYPE html> <html>...
目标检测与目标跟踪算法技术汇总
现如今chatgpt的爆火,我也使用了一段时间,问了许多关于人工智能技术的问题,基本是它能够回答了大部分的原理的,至于其人工智能涉及到的算法以及网络,考虑到也没有图,可能在给出这类回答上,是不太…...
Linux 系统启动过程
过去几十年,公用事业行业发生了重大变化。能源需求的转变导致企业利润率的波动,但不是运营成本的波动。 许多公用事业公司通过后勤部门流程自动化来削减成本,比如招采流程自动化。 在招采活动中,人工招采会产生盲点。由于公共事业…...
【每日一题Day118】LC1124表现良好的最长时间段 | 前缀和+单调栈/哈希表
表现良好的最长时间段【LC1124】 给你一份工作时间表 hours,上面记录着某一位员工每天的工作小时数。 我们认为当员工一天中的工作小时数大于 8 小时的时候,那么这一天就是「劳累的一天」。 所谓「表现良好的时间段」,意味在这段时间内&#…...
vue使用nprogress(进度条)
目录 1.安装 2.引入 3.配置 4.使用 5.使用场景 6.改变颜色 1.安装 npm install --save nprogress2.引入 import NProgress from nprogress import nprogress/nprogress.css3.配置 NProgress.configure({easing: ease, // 动画方式,和css动画属性一样&#…...
@NotNull 、@NotBlank、@NotEmpty区别和使用
引言 今天在使用validation校验的时候,发现了使用校验不起作用,一时间有点摸不到头绪,就看了一下同事提交的代码,发现了问题在用NotNull用法,用的有些错误,所以在这里讲一下NotNull、NotBlank、NotEmpty区…...
Nacos——Nacos简介以及Nacos Server安装
资料来源:02-Nacos配置管理-什么是配置中心_哔哩哔哩_bilibili nacos记得下载2.x版本的,负责以后新建配置的时候会出现“发布错误,请检查参数是否正确”错误!!!! 目录 一、Nacos简介 1.1 四…...
Presto 文档和笔记
1. Presto Presto 官网 Presto 文档 2. 配置 3.1 node 配置 cat etc/node.properties # Generated by Apache Ambari. Fri Feb 10 14:52:10 2023node.data-dir/mnt/bmr/presto/data node.environmentproduction node.idbmr-master-4b7cbaa3.2 jvm 配置 cat etc/jvm.confi…...
大尺度衰落与小尺度衰落
一. 大尺度衰落 无线电磁波信号在收发天线长距离(远大于传输波长)或长时间范围发生的功率变化,称为大尺度衰落,一般可以用路径损耗模型来描述,路径损耗是由发射功率在空间中的辐射扩散造成的,根据功率传输…...
完美解决:重新安装VMware Tools灰色。以及共享文件夹的创建(centos8)
解决:重新安装VMware Tools灰色问题:重新安装VMware Tools灰色解决方案-挂载VMware中的linux.iso1. vmtools的linux.iso挂载及安装2. 共享文件夹的创建及配置问题:重新安装VMware Tools灰色 发现一个小问题,我的vm虚拟机安装后发…...
达梦数据库作业管理
一、基本功能 作业系统大致包含作业,警报,操作员三部分。 作业可运行DMPL/SQL脚本,定期备份数据库,检查等。可定时执行,也可通过警报触发执行,可产生警报通知用户状态。一个作业由多个步骤组成,…...
数据结构-考研难点代码突破(C++实现树型查找-二叉搜索树(二叉排序树))
文章目录1.二叉搜索树基本操作二叉搜索树的效率分析2. C实现1.二叉搜索树基本操作 二叉排序树是具有下列特性的二叉树: 若左子树非空,则左子树上所有结点的值均小于根结点的值。若右子树非空,则右子树上所有结点的值均大于根结点的值。左、…...
如何做公众号小说网站赚钱/俄罗斯引擎搜索
CNC加工中心的高精高效,安全是前提。安全生产离不开优秀的车间管理,设备的精良保养以及丰富的加工经验。 1.预先开机 正式加工前可以进行开机空转,让CNC加工中心主轴空转几分钟,可以让主轴的轴承充分润滑,减少加工误…...
烟台网站建设推荐企汇互联见效付款/百度客服怎么转人工
当用到了java.sql.Date时间等非内置对象时,如果对象为null则会出现此异常。最简单的方法就是保证非内置对象不为null。 在项目业务中随着需求的变化而变化,并不能保证内置对象都不为null,因此有必要对此异常进行解决,以达到通用的…...
国外优惠卷网站怎么做/推介网
目前ROS1.0版本更新已经接近尾声,后续的ROS版本更新将集中在ROS2.0版本,但很多人可能不想卸载ROS1.0,因此本文介绍一种如何比较方便的在ROS1.0和ROS2.0两个版本之间切换的方法熟悉ROS1的都知道,打开新终端后需要重新source才能使用…...
wordpress注册添加算术验证/成品短视频app下载有哪些
面试一个公司,和那边技术负责人讨论了一个关于启动一个Service,首先问了我一下有什么启动方式,然后各自生命周期是怎样的,都一一回答,接着就问如果一个Service通过两种方式使用,然后怎么进行关闭࿰…...
网站设计与制作平台/网站开发费用
MySQL查询超时问题是什么原因呢?应该如何解决呢?下面就为您详细介绍MySQL查询超时问题的解决方法,希望可以帮助到您。mysql>show variables like %timeout;打印结果如下:-----------------------------------| Variable_name |…...
网站里面嵌入的地图是怎么做的/新软件推广
缓存大量小文件?Redis是首选! 2011-07-14 13:41 QLeelulu cnblogs 我要评论(0) 字号:T | T缓存文件,我们可以选择用Web、文件系统或数据库来做,比如本文中列出的Nginx、MooseFS以及Redis。作者需要将3KW条小数据做缓存…...