当前位置: 首页 > news >正文

Keras使用sklearn中的交叉验证和网格搜索

Keras是Python在深度学习领域非常受欢迎的第三方库,但Keras的侧重点是深度学习,而不是所以的机器学习。事实上,Keras力求极简主义,只专注于快速、简单地定义和构建深度学习模型所需要的内容。Python中的scikit-learn是非常受欢迎的机器学习库,它基于Scipy,用于高效的数值计算。scikit-learn是一个功能齐全的通用机器学习库,并提供了许多在开发深度学习过程中非常有帮助的方法。例如scikit-learn提供了很多用于选择模型和对模型调参的方法,这些方法同样适用于深度学习。

Keras提供了一个Wrapper,将Keras的深度学习模型包装成scikit-learn中的分类模型或回归模型,以便于使用scikit-learn中的方法和函数。对于深度学习模型的包装是通过KerasClassifier(分类模型)和KerasRegressor(回归模型)来实现的。KerasClassifier和KerasRegressor类使用参数build_fn,指定用来创建模型的函数的名称。

Keras的一般构建流程:

model = Sequential() # 定义模型
model.add(Dense(units=64, activation='relu', input_dim=100)) # 定义网络结构
#第一层网络:输出尺寸64,输入尺寸100,activation激活函数relu
model.add(Dense(units=10, activation='softmax')) # 定义网络结构
#第二层网络:输出尺寸10,输入是上一层的输出尺寸64,activation激活函数softmax
model.compile(loss='categorical_crossentropy', # 定义loss函数、优化方法、评估标准optimizer='sgd',metrics=['accuracy'])
#输入训练样本和标签,迭代5次,每次迭代32个数据
model.fit(x_train, y_train, epochs=5, batch_size=32) # 训练模型
loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128) # 评估模型
classes = model.predict(x_test, batch_size=128) # 使用训练好的数据进行预测

参数意义:

keras.layers.Dense(units, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None)

units: 正整数,输出空间维度。
activation: 激活函数。 若不指定,则不使用激活函数 (即,「线性」激活: a(x) = x)。
use_bias: 布尔值,该层是否使用偏置向量。
kernel_initializer: kernel 权值矩阵的初始化器。
bias_initializer: 偏置向量的初始化器。
kernel_regularizer: 运用到 kernel 权值矩阵的正则化函数 。
bias_regularizer: 运用到偏置向的的正则化函数 。
activity_regularizer: 运用到层的输出的正则化函数 。
kernel_constraint: 运用到 kernel 权值矩阵的约束函数 。
bias_constraint: 运用到偏置向量的约束函数。

Keras调用scikit-learn实现交叉验证:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import cross_val_score, KFold
from keras.wrappers.scikit_learn import KerasClassifierdef creat_model():# 构建模型model = Sequential()model.add(Dense(units=12, input_dim=11, activation='relu'))model.add(Dense(units=8, activation='relu'))model.add(Dense(units=1, activation='sigmoid'))# 模型编译model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])return model# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')# 删除id列
data.drop('客户编号',axis=1,inplace=True)X, Y = data.values[:,:-1], data.values[:,-1] # Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, epochs=150, batch_size=10, verbose=0)# 10折交叉验证
kfold = KFold(n_splits=10, shuffle=True, random_state=10)
result = cross_val_score(model, X, Y, cv=kfold)

 Keras调用scikit-learn实现模型调参

在构建深度学习模型时,如何配置一个最优模型一直是进行一个项目的重点。在机器学习中,可以通过算法自动调优这些配置参数,在这里将通过Keras的包装类,借助scikit-learn的网格搜索算法评估神经网络模型的不同配置,并找到最佳评估性能的参数组合。creat_model()函数被定义为具有两个默认值的参数(optimizer和init)的函数,创建模型后,定义要搜索的参数的数值数组,包括优化器(optimizer)、权重初始化方案(init)、epochs和batch_size。

在scikit-learn中的GridSearchCV需要一个字典类型的字段作为需要调整的参数,默认采用3折交叉验证来评估算法,由于4个参数需要进行调参,因此将会产生4✖️3个模型。
Keras调用scikit-learn实现GridSearchCV网格搜索:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasClassifierdef creat_model(optimizer='adam,init='glorot_uniform'):# 构建模型model = Sequential()model.add(Dense(units=12, input_dim=11,kernel_initializer=init, activation='relu'))model.add(Dense(units=8, kernel_initializer=init, activation='relu'))model.add(Dense(units=1, kernel_initializer=init, activation='sigmoid'))# 模型编译model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])return model# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')# 删除id列
data.drop('客户编号',axis=1,inplace=True)X, Y = data.values[:,:-1], data.values[:,-1] # Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, verbose=0)# 构建需要调整的参数
param_gird = {}
param_grid['optimizer'] = ['rmsprop','adam']
param_grid['init'] = ['glorot_uniform', 'normal', 'uniform']
param_gird['epochs'] = [50, 100, 150, 200]
param_gird['batch_size'] = [5, 10, 20]# 调参
grid = GridSearchCV(estimator=model, param_gird=param_grid)
result = grid.fit(X, Y)# 输出结果
print('Best: %f using %s' % (result.best_score_, result.best_params_))

关于Epochs和batch_size的解释?

Epochs是神经网络训练过程中的一个重要超参数,定义为向前和向后传播中所有批次的单次训练迭代。简单说,一个Epoch是将所有的数据输入网络完成一次向前计算及反向传播。在训练过程中,数据会被“轮”多少次,即应当完整遍历数据集多少次(一次为一个Epoch)。如果Epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果Epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。所以,选择适当的Epoch数量需要在充分训练和避免过拟合之间找到平衡。
 

假设我们有1000个数据样本,每次我们送入10个数据进行训练(也就是batch_size为10)。那么完成一个Epoch,我们需要进行100次迭代(也就是100次前向传播和100次反向传播)。具体来说,我们需要将所有的数据都送入神经网络进行一次前向传播和反向传播,所以一次Epoch相当于所有数据集/batch size=N次迭代。
 

相关文章:

Keras使用sklearn中的交叉验证和网格搜索

Keras是Python在深度学习领域非常受欢迎的第三方库,但Keras的侧重点是深度学习,而不是所以的机器学习。事实上,Keras力求极简主义,只专注于快速、简单地定义和构建深度学习模型所需要的内容。Python中的scikit-learn是非常受欢迎的…...

docker--Prometheus、Grafana、node_exporter的安装配置及Springboot集成Prometheus示例

1. 安装Prometheus Prometheus一个系统和服务监控系统。它以给定的时间间隔从配置的目标收集指标,计算规则表达式,显示结果,并在观察到某些条件为真时触发警报。 可观察性侧重于根据系统产生的数据了解系统的内部状态,这有助于确定基础设施是否健康。Prometheus是用于监视…...

数据结构和算法笔记2:二分法

二分法网上有两种写法&#xff0c;一种左闭右闭&#xff0c;一种左闭右开&#xff0c;个人习惯左闭右闭的写法&#xff0c; 有序数组查找数 这是标准二分法&#xff0c;对应力扣的704. 二分查找&#xff1a; 求值为target的索引 int search(vector<int>& nums, i…...

Mybatis3系列课程8-带参数查询

简介 上节课内容中讲解了查询全部, 不需要带条件查, 这节我们讲讲 带条件查询 目标 1. 带一个条件查询-基本数据类型 2.带两个条件查询-连个基本数据类型 3.带一个对象类型查询 为了实现目标, 我们要实现 按照主键 查询某个学生信息, 按照姓名和年级编号查询学生信息 按照学生…...

IDEA shorten command line介绍和JAR manifest 导致mybatis找不到接口类处理

如果类路径太长&#xff0c;或者有许多VM参数&#xff0c;程序就无法启动。原因是大多数操作系统都有命令行长度限制。在这种情况下&#xff0c;IntelliJIDEA将试图缩短类路径。最好选中 classpath file模式。 shorten command line 选项提供三种选项缩短类路径。 none&#x…...

泽攸科技SEM台式扫描电子显微镜

泽攸科技是一家国产的科学仪器公司&#xff0c;专注于研发、生产和销售原位电镜解决方案、扫描电镜整机、台阶仪、探针台等仪器。目前台式扫描电镜分为三个系列&#xff1a;ZEM15、ZEM18、ZEM20。 ZEM15台式扫描电镜&#xff1a; ZEM18台式扫描电镜&#xff1a; ZEM20台式扫描…...

华为交换机配置BGP的基本示例

BGP简介 定义 边界网关协议BGP&#xff08;Border Gateway Protocol&#xff09;是一种实现自治系统AS&#xff08;Autonomous System&#xff09;之间的路由可达&#xff0c;并选择最佳路由的距离矢量路由协议。早期发布的三个版本分别是BGP-1&#xff08;RFC1105&#xff0…...

数据分析基础之《numpy(4)—ndarry运算》

一、逻辑运算 当我们要操作符合某一条件的数据时&#xff0c;需要用到逻辑运算 1、运算符 满足条件返回true&#xff0c;不满足条件返回false # 重新生成8只股票10个交易日的涨跌幅数据 stock_change np.random.normal(loc0, scale1, size(8, 10))# 获取前5行前5列的数据 s…...

分享一个项目——Sambert UI 声音克隆

文章目录 前言一、运行ipynb二、数据标注三、训练四、生成总结 前言 原教程视频 项目链接 运行一个ipynb&#xff0c;就可操作 总共四步 1&#xff09;运行ipynb 2&#xff09;数据标注 3&#xff09;训练 4&#xff09;生成 一、运行ipynb 等运行完毕后&#xff0c;获得该…...

ES6 语法精粹简读

本文旨在记录 ES6 的核心常用语法,略去一些细节。 文章目录 1 var 函数作用域与 let/const 块作用域2 解构赋值数组结构赋值对象结构赋值3 ES6 中字符串的新语法模板字符串模板编译标签模板4 ES6 中的函数默认值rest 参数箭头函数this 指向问题部署管道机制尾调用优化...

uniapp整合echarts(目前性能最优、渲染最快方案)

本文echarts示例如上图,可扫码体验渲染速度及loading效果,下文附带本小程序uniapp相关代码 实现代码 <template><view class="source...

解决Electron应用中的白屏问题的实用方法

在使用Electron构建应用程序时&#xff0c;一些开发者可能会面临窗口加载过程中出现的白屏问题。这种问题主要分为两个方面&#xff1a; Electron未加载完毕HTML&#xff1a; 这时Electron自身产生的白色背景可能导致用户在启动应用时看到一片空白。HTML加载渲染过程中的短暂白…...

大数据---34.HBase数据结构

一、HBase简介 HBase是一个开源的、分布式的、版本化的NoSQL数据库&#xff08;即非关系型数据库&#xff09;&#xff0c;依托Hadoop分布式文件系统HDFS提供分布式数据存储&#xff0c;利用MapReduce来处理海量数据&#xff0c;用Zookeeper作为其分布式协同服务&#xff0c;一…...

【工具使用-有道云笔记】如何在有道云笔记中插入目录

一&#xff0c;简介 本文主要介绍如何在有道云笔记中插入目录&#xff0c;方便后续笔记的查看&#xff0c;供参考。 二&#xff0c;具体步骤 分为两个步骤&#xff1a;1&#xff0c;设置标题格式&#xff1b;2&#xff0c;插入标题。非常简单~ 2.1 设置标题格式 鼠标停在标…...

用户管理第2节课-idea 2023.2 后端一删除表,从零开始---【本人】

一、清空model文件夹下&#xff0c;所有文件 1.1.1效果如下&#xff1a; 1.1代码内容 package com.daisy.usercenter.model;import lombok.Data;Data public class User {private Long id;private String name;private Integer age;private String email; }二、清空mapper文件…...

如何添加jar包到本地Maven项目中

在 Maven 中添加一个外部 JAR 包的依赖&#xff0c;你需要使用 Maven 的 <dependency> 元素来指定该 JAR 包的坐标信息。以下是具体的步骤&#xff1a; 将 JAR 包手动添加到 Maven 本地仓库&#xff1a; 首先&#xff0c;确保将外部 JAR 包手动添加到 Maven 本地仓库。可…...

智能优化算法应用:基于学校优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于学校优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于学校优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.学校优化算法4.实验参数设定5.算法结果6.…...

【MATLAB第85期】基于MATLAB的2023年智能进化算法/元启发式算法合集(持续更新)

【MATLAB第85期】基于MATLAB的2023年智能进化算法/元启发式算法合集&#xff08;持续更新&#xff09; 1.海象进化算法&#xff08;Walrus Optimization Algorithm&#xff09; 作者&#xff1a;Pavel Trojovsk and Mohammad Dehghani 2.暴龙优化算法&#xff08;Tyrannosa…...

[Realtek sdk-3.4.14b]RTL8197FH-VG+RTL8812F WiFi使用功率限制功能使用说明

sdk说明 ** Gateway/AP firmware v3.4.14b – Aug 26, 2019**  Wireless LAN driver changes as:  Refine WiFi Stability and Performance  Add 8812F MU-MIMO  Add 97G/8812F multiple mac-clone  Add 97G 2T3R antenna diversity  Fix 97G/8812F/8814B MP issu…...

Vue中为什么data属性是一个函数而不是一个对象?(看完就会了)

文章目录 一、实例和组件定义data的区别二、组件data定义函数与对象的区别三、原理分析四、结论 一、实例和组件定义data的区别 vue实例的时候定义data属性既可以是一个对象&#xff0c;也可以是一个函数 const app new Vue({el:"#app",// 对象格式data:{foo:&quo…...

基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容

基于 ​UniApp + WebSocket​实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配​微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版​分享

平时用 iPhone 的时候&#xff0c;难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵&#xff0c;或者买了二手 iPhone 却被原来的 iCloud 账号锁住&#xff0c;这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...

C# 表达式和运算符(求值顺序)

求值顺序 表达式可以由许多嵌套的子表达式构成。子表达式的求值顺序可以使表达式的最终值发生 变化。 例如&#xff0c;已知表达式3*52&#xff0c;依照子表达式的求值顺序&#xff0c;有两种可能的结果&#xff0c;如图9-3所示。 如果乘法先执行&#xff0c;结果是17。如果5…...

PHP 8.5 即将发布:管道操作符、强力调试

前不久&#xff0c;PHP宣布了即将在 2025 年 11 月 20 日 正式发布的 PHP 8.5&#xff01;作为 PHP 语言的又一次重要迭代&#xff0c;PHP 8.5 承诺带来一系列旨在提升代码可读性、健壮性以及开发者效率的改进。而更令人兴奋的是&#xff0c;借助强大的本地开发环境 ServBay&am…...

c# 局部函数 定义、功能与示例

C# 局部函数&#xff1a;定义、功能与示例 1. 定义与功能 局部函数&#xff08;Local Function&#xff09;是嵌套在另一个方法内部的私有方法&#xff0c;仅在包含它的方法内可见。 • 作用&#xff1a;封装仅用于当前方法的逻辑&#xff0c;避免污染类作用域&#xff0c;提升…...

stm32wle5 lpuart DMA数据不接收

配置波特率9600时&#xff0c;需要使用外部低速晶振...

Python网页自动化Selenium中文文档

1. 安装 1.1. 安装 Selenium Python bindings 提供了一个简单的API&#xff0c;让你使用Selenium WebDriver来编写功能/校验测试。 通过Selenium Python的API&#xff0c;你可以非常直观的使用Selenium WebDriver的所有功能。 Selenium Python bindings 使用非常简洁方便的A…...

C++实现分布式网络通信框架RPC(2)——rpc发布端

有了上篇文章的项目的基本知识的了解&#xff0c;现在我们就开始构建项目。 目录 一、构建工程目录 二、本地服务发布成RPC服务 2.1理解RPC发布 2.2实现 三、Mprpc框架的基础类设计 3.1框架的初始化类 MprpcApplication 代码实现 3.2读取配置文件类 MprpcConfig 代码实现…...

Monorepo架构: Nx Cloud 扩展能力与缓存加速

借助 Nx Cloud 实现项目协同与加速构建 1 &#xff09; 缓存工作原理分析 在了解了本地缓存和远程缓存之后&#xff0c;我们来探究缓存是如何工作的。以计算文件的哈希串为例&#xff0c;若后续运行任务时文件哈希串未变&#xff0c;系统会直接使用对应的输出和制品文件。 2 …...

初探用uniapp写微信小程序遇到的问题及解决(vue3+ts)

零、关于开发思路 (一)拿到工作任务,先理清楚需求 1.逻辑部分 不放过原型里说的每一句话,有疑惑的部分该问产品/测试/之前的开发就问 2.页面部分(含国际化) 整体看过需要开发页面的原型后,分类一下哪些组件/样式可以复用,直接提取出来使用 (时间充分的前提下,不…...