当前位置: 首页 > news >正文

基于集成学习的用户流失预测并利用shap进行特征解释

基于集成学习的用户流失预测并利用shap进行特征解释

小P:小H,如果我只想尽可能的提高准确率,有什么好的办法吗?

小H:优化数据、调参侠、集成学习都可以啊

小P:什么是集成学习啊,听起来就很厉害的样子

小H:集成学习就类似于【三个臭皮匠顶个诸葛亮】,将一些基础模型组合起来使用,以期得到更好的结果

集成学习实战

数据准备

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
import warnings
warnings.filterwarnings('ignore')from scipy import stats
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE 
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, KFold
from sklearn.feature_selection import RFE 
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, VotingClassifier, ExtraTreesClassifier
import xgboost as xgb
from sklearn.metrics import accuracy_score, auc, confusion_matrix, f1_score, \precision_score, recall_score, roc_curve  # 导入指标库
import prettytable
import sweetviz as sv # 自动eda
import toad 
from sklearn.model_selection import StratifiedKFold, cross_val_score  # 导入交叉检验算法# 绘图初始化
%matplotlib inline
pd.set_option('display.max_columns', None) # 显示所有列
sns.set(style="ticks")
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号# 导入自定义模块
import sys
sys.path.append("/Users/heinrich/Desktop/Heinrich-blog/数据分析使用手册")
from keyIndicatorMapping import *

上述自定义模块keyIndicatorMapping如果有需要的同学可关注公众号HsuHeinrich,回复【数据挖掘-自定义函数】自动获取~

以下数据如果有需要的同学可关注公众号HsuHeinrich,回复【数据挖掘-集成学习】自动获取~

# 读取数据
raw_data = pd.read_csv('classification.csv')
raw_data.head()

image-20230206151936701

# 缺失值填充,SMOTE方法限制非空
raw_data=raw_data.fillna(raw_data.mean())
# 数据集分割
X = raw_data[raw_data.columns.drop('churn')]
y = raw_data['churn']
# 标准化
scaler = StandardScaler() 
scale_data = scaler.fit_transform(X)  
X = pd.DataFrame(scale_data, columns = X.columns)
# 划分训练测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3, random_state=0)
# 过采样
model_smote = SMOTE(random_state=0)  # 建立SMOTE模型对象
X_train, y_train = model_smote.fit_resample(X_train, y_train) 

模型对比

%%time
# 初选分类模型
model_names = ['LR', 'SVC', 'RFC', 'XGBC']  # 不同模型的名称列表
model_lr = LogisticRegression(random_state=10) # 建立逻辑回归对象
model_svc = SVC(random_state=0, probability=True) # 建立支持向量机分类对象
model_rfc = RandomForestClassifier(random_state=10) # 建立随机森林分类对象
model_xgbc = xgb.XGBClassifier(use_label_encoder=False, eval_metric='auc', random_state=10) # 建立XGBC对象# 模型拟合结果
model_list = [model_lr, model_svc, model_rfc, model_xgbc]  # 不同分类模型对象的集合
pre_y_list = [model.fit(X_train, y_train).predict(X_test) for model in model_list]  # 各个回归模型预测的y值列表
CPU times: user 2.49 s, sys: 125 ms, total: 2.62 s
Wall time: 843 ms
# 核心评估指标
metrics_dic = {'model_names':[],'auc':[],'ks':[],'accuracy':[],'precision':[],'recall':[],'f1':[]}
for model_name, model, pre_y in zip(model_names, model_list, pre_y_list):y_prob = model.predict_proba(X_test)  # 获得决策树的预测概率,返回各标签(即0,1)的概率fpr, tpr, thres = roc_curve(y_test, y_prob[:, 1])  # ROC y_score[:, 1]取标签为1的概率,这样画出来的roc曲线为正metrics_dic['model_names'].append(model_name)metrics_dic['auc'].append(auc(fpr, tpr))  # AUCmetrics_dic['ks'].append(max(tpr - fpr)) # KS值metrics_dic['accuracy'].append(accuracy_score(y_test, pre_y))metrics_dic['precision'].append(precision_score(y_test, pre_y))metrics_dic['recall'].append(recall_score(y_test, pre_y))metrics_dic['f1'].append(f1_score(y_test, pre_y))
pd.DataFrame(metrics_dic)

image-20230206152007352

集成学习

%%time
# 建立组合评估器列表 均衡稳定性和准确性 这里只是演示,就将所有模型都纳入了
estimators = [('SVC', model_svc), ('RFC', model_rfc), ('XGBC', model_xgbc), ('LR', model_lr)]  
model_vot = VotingClassifier(estimators=estimators, voting='soft', weights=[1.1, 1.1, 0.9, 1.2],n_jobs=-1)  # 建立组合评估模型
cv = StratifiedKFold(5)  # 设置交叉检验方法 分类算法常用交叉检验方法
cv_score = cross_val_score(model_vot, X_train, y_train, cv=cv, scoring='accuracy')  # 交叉检验
print('{:*^60}'.format('Cross val scores:'),'\n',cv_score) # 打印每次交叉检验得分
print('Mean scores is: %.2f' % cv_score.mean())  # 打印平均交叉检验得分
*********************Cross val scores:********************** [0.73529412 0.7745098  0.85294118 0.85294118 0.87745098]
Mean scores is: 0.82
CPU times: user 2.38 s, sys: 432 ms, total: 2.81 s
Wall time: 5 s
# 模型训练
model_vot.fit(X_train, y_train)  # 模型训练
VotingClassifier(estimators=[('SVC', SVC(probability=True, random_state=0)),('RFC', RandomForestClassifier(random_state=10)),('XGBC',XGBClassifier(base_score=0.5, booster='gbtree',colsample_bylevel=1,colsample_bynode=1,colsample_bytree=1,eval_metric='rmse', gamma=0,gpu_id=-1, importance_type='gain',interaction_constraints='',learning_rate=0.300000012,max...min_child_weight=1, missing=nan,monotone_constraints='()',n_estimators=100, n_jobs=8,num_parallel_tree=1,random_state=10, reg_alpha=0,reg_lambda=1, scale_pos_weight=1,subsample=1, tree_method='exact',use_label_encoder=False,validate_parameters=1,verbosity=None)),('LR', LogisticRegression(random_state=10))],n_jobs=-1, voting='soft', weights=[1.1, 1.1, 0.9, 1.2])
model_confusion_metrics(model_vot, X_test, y_test, 'test')
model_core_metrics(model_vot, X_test, y_test, 'test')
confusion matrix for test+----------+--------------+--------------+
|          | prediction-0 | prediction-1 |
+----------+--------------+--------------+
| actual-0 |      53      |      31      |
| actual-1 |      37      |     179      |
+----------+--------------+--------------+
core metrics for test+-------+----------+-----------+--------+-------+-------+
|  auc  | accuracy | precision | recall |   f1  |   ks  |
+-------+----------+-----------+--------+-------+-------+
| 0.805 |  0.773   |   0.589   | 0.631  | 0.609 | 0.504 |
+-------+----------+-----------+--------+-------+-------+

可以看到集成学习的各项指标表现均优异,只有召回率低于LR

利用shap进行模型解释

shap作为一种经典的事后解释框架,可以对每一个样本中的每一个特征变量,计算出其重要性值,达到解释的效果。该值在shap中被专门称为Shapley Value。

该系列以应用为主,对于具体的理论只会简单的介绍它的用途和使用场景。这里的shap相关知识 可以参考黑盒模型事后归因解析:SHAP方法、SHAP知识点全汇总

学无止境,且学且珍惜~

# pip install shap
import shap   
# 初始化
shap.initjs()  
# 通过采样提高计算效率,但会导致准确率降低。表现在base_value与mean(model.predict_proba(X))存在差异,不建议K太小
# X_test_summary = shap.sample(X_test, 200)
# X_test_summary = shap.kmeans(X_test, 150)
explainer = shap.KernelExplainer(model_vot.predict_proba, X_test)
shap_values = explainer.shap_values(X_test, nsamples = 10)
Using 300 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
  • 单样本查看
# 单样本查看-1概率较高的样本 # 208
shap.force_plot(base_value=explainer.expected_value[1],shap_values=shap_values[1][208],features = X_test.iloc[208,:])

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AxXaXK0k-1679902455430)(null)]

  • base_value:所有样本预测值的均值,即base_value=model_vot.predict_proba(X_test)[:,1].mean()

    ⚠️注意:当进行采样或者kmean加速计算时,会损失一定准确度。即explainer带入的是X_test_summary

  • f(x):预测的实际值model_vot.predict_proba(X_test)[:,1]

  • data:样本特征值

  • shap_values:f(x)-base_value;shap值越大越红,越小越蓝

# 验证base_value
print('所有样本预测标签1的概率均值:',model_vot.predict_proba(X_test)[:,1].mean())
print('base_value:',explainer.expected_value[1])
所有样本预测标签1的概率均值: 0.3519852365700774
base_value: 0.35198523657007774

经验证,base_value计算逻辑正确

# 验证单一样本
i=208
fx=model_vot.predict_proba(X_test)[:,1][i]
da=X_test.iloc[i,:]
sv=fx-explainer.expected_value[1]
sv_val=shap_values[1][i].sum()
print('f(x):',fx)
print('shap_values:',sv,sv_val)
f(x): 0.9264517406651224
shap_values: 0.5744665040950446 0.5744665040950446

经验证,shap_values计算逻辑正确

  • 特征重要性
# 特征重要程度
shap.summary_plot(shap_values[1],X_test,max_display=10,plot_type="bar")

  • 蜂窝图体现特征重要性
# 特征与样本蜂窝图
shap.summary_plot(shap_values[1],X_test,max_display=10)

output_38_0

retention_days越大,蓝色的样本越多,表明较高的retention_days有助于缓减流失

  • 特征的shap值
# 单特征预测结果
shap.dependence_plot("retention_days", shap_values[1], X_test, interaction_index=None)

output_41_0

retention_days低的shape值较大,上面讲到shap越大越红,对于y起到提高作用。即retention_days与流失负相关

# 双特征交叉影响
shap.dependence_plot("retention_days", shap_values[1], X_test, interaction_index='level')

output_43_0

  • 在较低的retention_days(如-1.5),高level(level=1.0)的shepae值较高(红色点),在0.2附近
  • 在较高的retention_days(如1.5),高level(level=1.0)的shepae值较低(红色点),在-0.2附近

总结

集成学习能有效地提高模型的预测性能,但是使得模型内部结构更为复杂,无法直观理解。好在可以借助shap进行常见的特征重要性解释等。

共勉~

相关文章:

基于集成学习的用户流失预测并利用shap进行特征解释

基于集成学习的用户流失预测并利用shap进行特征解释 小P:小H,如果我只想尽可能的提高准确率,有什么好的办法吗? 小H:优化数据、调参侠、集成学习都可以啊 小P:什么是集成学习啊,听起来就很厉害的…...

【Java版oj 】 day17杨辉三角形的变形、计算某字符出现次数

目录 一、杨辉三角形的变形 (1)原题再现 (2)问题分析 (3)完整代码 二、计算某字符出现次数 (1)原题再现 (2)问题分析 (3)完整代…...

智能驾驶芯片赛道混战:如何看待5类玩家的竞争格局?

智能驾驶芯片赛道,一直是业内关注的焦点。 高工智能汽车注意到,针对L0-L2,业内基本采用智能前视一体机(IFC)方案;要实现高速NOA、城市NOA等更为高阶的智驾功能等,则基本采用域控制器方案。从IF…...

vue antd table表格的增删改查(三)input输入框根据关键字模糊查询【后台管理系统 使用filter与indexOf嵌套】

vue antd table表格的增删改查(三)input输入框根据关键字查询【后台管理系统filter与indexOf嵌套】知识回调场景复现利用filter和indexOf方法实现模糊查询1.查询对象为单层的数组元素2.查询对象为多层的数组元素(两层为例)3.查询对…...

【计组】性能指标——速度

衡量计算机性能的指标之一——速度,是指计算机执行完所有指令所耗费时间的长短。 一、概念: 引出了如下概念:机器字长:指计算机一次能处理的二进制位数,也就是我们通常说的32位64位计算机中的位。 机器字长决定了计算…...

【PC自动化测试-4】inspect.exe 详解

1,inspect.exe图解" 检查 "窗口有几个主要部分:● 标题栏。 显示" 检查 HWND (窗口句柄) 。● 菜单栏。 提供对 检查功能 的访问权限。● 工具 栏。 提供对 检查功能 的访问权限。● 树视图。 将 UI 元素的层次结构呈现为树视图控件&…...

比肩ChatGPT的国产AI:文心一言——有话说

🔗 运行环境:chatGPT,文心一言 🚩 撰写作者:左手の明天 🥇 精选专栏:《python》 🔥 推荐专栏:《算法研究》 #### 防伪水印——左手の明天 #### 💗 大家好&am…...

【第13届蓝桥杯】C/C++组B组省赛题目+详解

A.九进制转十进制 题目描述 九进制正整数(2022)9转换成十进制等于多少? 解: 2*9^02*9^12*9^321814581478; B.顺子日期 题目描述 小明特别喜欢顺子。顺子指的就是连续的三个数字:123、456等。顺子日期指的就是在日期的yyyymmdd表示法中&a…...

STM32 KEI 调试新手注意事项

记录一下解决问题的经过:1,用STM32 cubeMX 生成的MKD工程,默认的代码优化级别是level3 , 这个级别 会把一些代码给优化掉,造成一些意想不到的结果,最直观的就是 被优化的语句不能打断点调试,当你打了断点 ,…...

Windows权限提升—令牌窃取、UAC提权、进程注入等提权

Windows权限提升—令牌窃取、UNC提权、进程注入等提权1. 前言2. at本地命令提权2.1. 适用范围2.2. 命令使用2.3. 操作步骤2.3.1. 模拟提权2.3.2. at配合msf提权2.3.2.1. 生成木马文件2.3.2.2. 设置监听2.3.2.3. 设置反弹2.3.2.4. 查看反弹效果3. sc本地命令提权3.1. 适用范围3.…...

不做孔乙己也不做骆驼祥子

对教书育人的探讨前言一、为什么要“育人”1.育人为先2.育人是快乐的二、怎么“育人”前言 借着本次师德师风建设的主题,跟各位老师谈一谈对于“育人”的一些观点,和教育的一些看法。本文仅代表自己的观点,有不到位的地方,大家可以…...

ChatGPT原理解析

文章目录Transformer模型结构构成组件整体流程GPT预训练微调模型GPT2GPT3局限性GPT4相关论文Transformer Transformer,这是一种仅依赖于注意力机制而不使用循环或卷积的简单模型,它简单而有效,并且在性能方面表现出色。 在时序模型中&#…...

常用算法实现【必会】:sort/bfs/dfs

文章目录常用排序算法实现(Go版本)BFS 广度优先遍历,利用queueDFS 深度优先遍历,利用stack前序遍历(根 左 右)中序遍历(左根右)后序遍历(左 右 根)BFS/DFS 总…...

瑟瑟发抖吧——用了这款软件,我的开发效率提升了50%

一、前言 开发中,一直听到有人讨论是否需要重复造轮子,我觉得有能力的人,轮子得造。但是往往开发周期短,用轮子所节省的时间去更好的理解业务,应用到业务中,也能清晰发现轮子的利弊,一定意义上…...

笔记本只使用Linux是什么体验?

个人主页:董哥聊技术我是董哥,嵌入式领域新星创作者创作理念:专注分享高质量嵌入式文章,让大家读有所得!近期,也有朋友问我,笔记本只安装Linux怎么样,刚好我也借此来表达一下我的感受…...

pipeline业务发布

业务环境介绍公司当前业务上线流程首先是通过nginx灰度,dubbo-admin操作禁用,然后发布上线主机,发布成功后,dubbo-admin启用,nginx启用主机;之前是通过手动操作,很不方便,本次优化为…...

【巨人的肩膀】JAVA面试总结(七)

💪MyBatis 1、谈谈你对MyBatis的理解 Mybatis是一个半ORM(对象关系映射)框架,它内部封装了JDBC,加载驱动、创建连接、创建statement等繁杂的过程,开发者开发时只需要关注如何编写SQL语句,可以…...

Python满屏表白代码

目录 前言 爱心界面 无限弹窗 前言 人生苦短,我用Python!又是新的一周啦,本期博主给大家带来了一个全新的作品:满屏表白代码,无限弹窗版!快快收藏起来送给她吧~ 爱心界面 def Heart(): roottk.Tk…...

Spring学习流程介绍

Spring学习流程介绍 Spring技术是JavaEE开发必备技能,企业开发技术选型命中率>90%; Spring有下面两大优势: 简化开发: 降低企业级开发的复杂性 框架整合: 高效整合其他技术,提高企业级应用开发与运行效率 Spring官网: https://spring.io/ Spring发展…...

杭银消金基于 Apache Doris 的统一数据查询网关改造

导读: 随着业务量快速增长,数据规模的不断扩大,杭银消金早期的大数据平台在应对实时性更强、复杂度更高的的业务需求时存在瓶颈。为了更好的应对未来的数据规模增长,杭银消金于 2022 年 10 月正式引入 Apache Doris 1.2 对现有的风…...

Flink学习笔记(六)Time详解

一、Flink中Time的三种类型: Stream数据中的Time(时间)分为以下3种: 1.Event Time(事件产生的时间): 事件的时间戳,通常是生成事件的时间。Event time 是事件本身的时间&#xff0c…...

「Vue面试题」在项目中你是如何解决跨域的?

文章目录一、跨域是什么二、如何解决CORSProxy一、跨域是什么 跨域本质是浏览器基于同源策略的一种安全手段 同源策略(Sameoriginpolicy),是一种约定,它是浏览器最核心也最基本的安全功能 所谓同源(即指在同一个域&…...

java八股文--数据库

数据库1.索引的基本原理2.聚簇和非聚簇索引的区别3.mysql索引的数据结构以及各自的优劣4.索引的设计原则5.事务的基本特性和隔离级别6.mysql主从同步原理7.简述MyISAM和InnoDB的区别8.简述mysql中索引类型及对数据库性能的影响9.Explain语句结果中各个字段分别表示什么10.索引覆…...

vue中名词解释

No名称略写作用应用场景其他1 单页面应用 (Single-page application) SPA 1,控制整个页面 2,抓取更新数据 3,无需加载,进行页面切换 丰富的交互,复杂的业务逻辑的web前端一般要求后端提供api数据…...

基于Java+SSM+Vue的旅游资源网站设计与实现【源码(完整源码请私聊)+论文+演示视频+包运行成功】

博主介绍:专注于Java技术领域和毕业项目实战 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇🏻 不然下次找不到哟 Java项目精品实战案例(200套) 目录 一、效果演示 二、…...

用于人工智能研究的开源Python微电网模拟器pymgrid(入门篇)

pymgrid是一个开源Python库,用于模拟微型电网的三级控制,允许用户创建或自行选择的微电网。并可以使用自定义的算法或pymgrid中包含的控制算法之一来控制这些微电网(基于规则的控制和模型预测控制)。 pymgrid还提供了与OpenAI Gy…...

运算放大器:电压比较器、电压跟随器、同相比例放大器

目录一、单限电压比较器二、滞回电压比较器三、窗口电压比较器四、正点原子直流电机驱动器电路分析实战1、电压采集电路2、电流采集电路3、过流检测电路Ⅰ、采用分压后的输入电压:Ⅱ、采用理想电压源的输入电压:Ⅲ、同相输入电压采用的是非理想电压源&am…...

Vector - CAPL - 实时时间on *(续2)

继续继续。。。四、键盘事件这个键盘事件是我个人起的名字,为了方便与其他事件进行区分,为什么要把这一个单独拉出来说呢,因为它的用处实在是太广泛了,基本只要是使用CANoe做一些基本的自动化测试小工具,都会用到它&am…...

数据质量管理的四个阶段

然而,我们需要按照什么流程来对数据质量进行有效的管控,从而提升数据质量,释放数据价值?一般来讲,数据质量控制流程分为4个阶段:启动、执行、检查、处理。在管控过程中这4个阶段需不断循环,螺旋…...

Spring源码面试最难问题——循环依赖

前言 问:Spring 如何解决循环依赖? 答:Spring 通过提前曝光机制,利用三级缓存解决循环依赖(这原理还是挺简单的,参考:三级缓存、图解循环依赖原理) 再问:Spring 通过提前…...

wordpress摘要添加省略号/宁波seo排名费用

Java NIO Path基本概念Path的创建创建绝对路径Path创建相对路径PathPath类的方法normalize基本概念 Path接口在java.nio.file包下在Java中 ,Path表示文件系统的路径,可以指向文件或者文件夹,有绝对路径和相对路径之分java.nio.file.Path接口和操作系统的path环境变量没有任何关…...

wordpress s7/cps游戏推广平台

文章目录前言一、MHA 概述1.1、MHA 是什么1.2、MHA 的组成1.3、MHA 的特点二、MHA 实验2.1、案例环境2.2、拓扑图2.3、实验目的2.4、实验过程2.4.1、主从复制调整2.4.2、安装 MHA 软件2.4.3、配置节点间SSH面交互无密码认证2.4.4、配置 MHA2.4.5、测试 ssh 无密码认证2.4.6、测…...

网站怎么做排名/百度推广怎么操作流程

什么是过滤器?有什么用?过滤器JavaWeb三大组件之一,它与Servlet很相似。不过滤器是用来拦截请求的,而不是处理请求的。过滤,顾名思义,就是留下我们想要的,丢掉我们不需要的。例如:某…...

新疆生产建设兵团第二中学网站/seo门户 site

给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个结点 p、q,最近公共祖先表示为一个结点 x,满足 x 是 p、q 的祖先且 x 的深度尽可能大(一个节点也可以是它自己的祖…...

网站建设项目描述范文/百度如何收录网站

解决微信浏览器内video全屏问题参考文章: (1)解决微信浏览器内video全屏问题 (2)https://www.cnblogs.com/phpjinggege/p/8270742.html 备忘一下。...

怎么做网站和服务器吗/年轻人不要做网络销售

一、进入dos命令行 按下菜单键windowsR弹出运行框,然后输入cmd,并回车。就会弹出dos命令行: 直接回车; 然后再退出数据库:exit; 然后输入mysqladmin -u root -p password 回车 第一个是密码有密码就打没…...