当前位置: 首页 > news >正文

Keras使用sklearn中的交叉验证和网格搜索

Keras是Python在深度学习领域非常受欢迎的第三方库,但Keras的侧重点是深度学习,而不是所以的机器学习。事实上,Keras力求极简主义,只专注于快速、简单地定义和构建深度学习模型所需要的内容。Python中的scikit-learn是非常受欢迎的机器学习库,它基于Scipy,用于高效的数值计算。scikit-learn是一个功能齐全的通用机器学习库,并提供了许多在开发深度学习过程中非常有帮助的方法。例如scikit-learn提供了很多用于选择模型和对模型调参的方法,这些方法同样适用于深度学习。

Keras提供了一个Wrapper,将Keras的深度学习模型包装成scikit-learn中的分类模型或回归模型,以便于使用scikit-learn中的方法和函数。对于深度学习模型的包装是通过KerasClassifier(分类模型)和KerasRegressor(回归模型)来实现的。KerasClassifier和KerasRegressor类使用参数build_fn,指定用来创建模型的函数的名称。

Keras的一般构建流程:

model = Sequential() # 定义模型
model.add(Dense(units=64, activation='relu', input_dim=100)) # 定义网络结构
#第一层网络:输出尺寸64,输入尺寸100,activation激活函数relu
model.add(Dense(units=10, activation='softmax')) # 定义网络结构
#第二层网络:输出尺寸10,输入是上一层的输出尺寸64,activation激活函数softmax
model.compile(loss='categorical_crossentropy', # 定义loss函数、优化方法、评估标准optimizer='sgd',metrics=['accuracy'])
#输入训练样本和标签,迭代5次,每次迭代32个数据
model.fit(x_train, y_train, epochs=5, batch_size=32) # 训练模型
loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128) # 评估模型
classes = model.predict(x_test, batch_size=128) # 使用训练好的数据进行预测

参数意义:

keras.layers.Dense(units, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None)

units: 正整数,输出空间维度。
activation: 激活函数。 若不指定,则不使用激活函数 (即,「线性」激活: a(x) = x)。
use_bias: 布尔值,该层是否使用偏置向量。
kernel_initializer: kernel 权值矩阵的初始化器。
bias_initializer: 偏置向量的初始化器。
kernel_regularizer: 运用到 kernel 权值矩阵的正则化函数 。
bias_regularizer: 运用到偏置向的的正则化函数 。
activity_regularizer: 运用到层的输出的正则化函数 。
kernel_constraint: 运用到 kernel 权值矩阵的约束函数 。
bias_constraint: 运用到偏置向量的约束函数。

Keras调用scikit-learn实现交叉验证:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import cross_val_score, KFold
from keras.wrappers.scikit_learn import KerasClassifierdef creat_model():# 构建模型model = Sequential()model.add(Dense(units=12, input_dim=11, activation='relu'))model.add(Dense(units=8, activation='relu'))model.add(Dense(units=1, activation='sigmoid'))# 模型编译model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])return model# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')# 删除id列
data.drop('客户编号',axis=1,inplace=True)X, Y = data.values[:,:-1], data.values[:,-1] # Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, epochs=150, batch_size=10, verbose=0)# 10折交叉验证
kfold = KFold(n_splits=10, shuffle=True, random_state=10)
result = cross_val_score(model, X, Y, cv=kfold)

 Keras调用scikit-learn实现模型调参

在构建深度学习模型时,如何配置一个最优模型一直是进行一个项目的重点。在机器学习中,可以通过算法自动调优这些配置参数,在这里将通过Keras的包装类,借助scikit-learn的网格搜索算法评估神经网络模型的不同配置,并找到最佳评估性能的参数组合。creat_model()函数被定义为具有两个默认值的参数(optimizer和init)的函数,创建模型后,定义要搜索的参数的数值数组,包括优化器(optimizer)、权重初始化方案(init)、epochs和batch_size。

在scikit-learn中的GridSearchCV需要一个字典类型的字段作为需要调整的参数,默认采用3折交叉验证来评估算法,由于4个参数需要进行调参,因此将会产生4✖️3个模型。
Keras调用scikit-learn实现GridSearchCV网格搜索:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasClassifierdef creat_model(optimizer='adam,init='glorot_uniform'):# 构建模型model = Sequential()model.add(Dense(units=12, input_dim=11,kernel_initializer=init, activation='relu'))model.add(Dense(units=8, kernel_initializer=init, activation='relu'))model.add(Dense(units=1, kernel_initializer=init, activation='sigmoid'))# 模型编译model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])return model# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')# 删除id列
data.drop('客户编号',axis=1,inplace=True)X, Y = data.values[:,:-1], data.values[:,-1] # Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, verbose=0)# 构建需要调整的参数
param_gird = {}
param_grid['optimizer'] = ['rmsprop','adam']
param_grid['init'] = ['glorot_uniform', 'normal', 'uniform']
param_gird['epochs'] = [50, 100, 150, 200]
param_gird['batch_size'] = [5, 10, 20]# 调参
grid = GridSearchCV(estimator=model, param_gird=param_grid)
result = grid.fit(X, Y)# 输出结果
print('Best: %f using %s' % (result.best_score_, result.best_params_))

关于Epochs和batch_size的解释?

Epochs是神经网络训练过程中的一个重要超参数,定义为向前和向后传播中所有批次的单次训练迭代。简单说,一个Epoch是将所有的数据输入网络完成一次向前计算及反向传播。在训练过程中,数据会被“轮”多少次,即应当完整遍历数据集多少次(一次为一个Epoch)。如果Epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果Epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。所以,选择适当的Epoch数量需要在充分训练和避免过拟合之间找到平衡。
 

假设我们有1000个数据样本,每次我们送入10个数据进行训练(也就是batch_size为10)。那么完成一个Epoch,我们需要进行100次迭代(也就是100次前向传播和100次反向传播)。具体来说,我们需要将所有的数据都送入神经网络进行一次前向传播和反向传播,所以一次Epoch相当于所有数据集/batch size=N次迭代。
 

相关文章:

Keras使用sklearn中的交叉验证和网格搜索

Keras是Python在深度学习领域非常受欢迎的第三方库,但Keras的侧重点是深度学习,而不是所以的机器学习。事实上,Keras力求极简主义,只专注于快速、简单地定义和构建深度学习模型所需要的内容。Python中的scikit-learn是非常受欢迎的…...

docker--Prometheus、Grafana、node_exporter的安装配置及Springboot集成Prometheus示例

1. 安装Prometheus Prometheus一个系统和服务监控系统。它以给定的时间间隔从配置的目标收集指标,计算规则表达式,显示结果,并在观察到某些条件为真时触发警报。 可观察性侧重于根据系统产生的数据了解系统的内部状态,这有助于确定基础设施是否健康。Prometheus是用于监视…...

数据结构和算法笔记2:二分法

二分法网上有两种写法&#xff0c;一种左闭右闭&#xff0c;一种左闭右开&#xff0c;个人习惯左闭右闭的写法&#xff0c; 有序数组查找数 这是标准二分法&#xff0c;对应力扣的704. 二分查找&#xff1a; 求值为target的索引 int search(vector<int>& nums, i…...

Mybatis3系列课程8-带参数查询

简介 上节课内容中讲解了查询全部, 不需要带条件查, 这节我们讲讲 带条件查询 目标 1. 带一个条件查询-基本数据类型 2.带两个条件查询-连个基本数据类型 3.带一个对象类型查询 为了实现目标, 我们要实现 按照主键 查询某个学生信息, 按照姓名和年级编号查询学生信息 按照学生…...

IDEA shorten command line介绍和JAR manifest 导致mybatis找不到接口类处理

如果类路径太长&#xff0c;或者有许多VM参数&#xff0c;程序就无法启动。原因是大多数操作系统都有命令行长度限制。在这种情况下&#xff0c;IntelliJIDEA将试图缩短类路径。最好选中 classpath file模式。 shorten command line 选项提供三种选项缩短类路径。 none&#x…...

泽攸科技SEM台式扫描电子显微镜

泽攸科技是一家国产的科学仪器公司&#xff0c;专注于研发、生产和销售原位电镜解决方案、扫描电镜整机、台阶仪、探针台等仪器。目前台式扫描电镜分为三个系列&#xff1a;ZEM15、ZEM18、ZEM20。 ZEM15台式扫描电镜&#xff1a; ZEM18台式扫描电镜&#xff1a; ZEM20台式扫描…...

华为交换机配置BGP的基本示例

BGP简介 定义 边界网关协议BGP&#xff08;Border Gateway Protocol&#xff09;是一种实现自治系统AS&#xff08;Autonomous System&#xff09;之间的路由可达&#xff0c;并选择最佳路由的距离矢量路由协议。早期发布的三个版本分别是BGP-1&#xff08;RFC1105&#xff0…...

数据分析基础之《numpy(4)—ndarry运算》

一、逻辑运算 当我们要操作符合某一条件的数据时&#xff0c;需要用到逻辑运算 1、运算符 满足条件返回true&#xff0c;不满足条件返回false # 重新生成8只股票10个交易日的涨跌幅数据 stock_change np.random.normal(loc0, scale1, size(8, 10))# 获取前5行前5列的数据 s…...

分享一个项目——Sambert UI 声音克隆

文章目录 前言一、运行ipynb二、数据标注三、训练四、生成总结 前言 原教程视频 项目链接 运行一个ipynb&#xff0c;就可操作 总共四步 1&#xff09;运行ipynb 2&#xff09;数据标注 3&#xff09;训练 4&#xff09;生成 一、运行ipynb 等运行完毕后&#xff0c;获得该…...

ES6 语法精粹简读

本文旨在记录 ES6 的核心常用语法,略去一些细节。 文章目录 1 var 函数作用域与 let/const 块作用域2 解构赋值数组结构赋值对象结构赋值3 ES6 中字符串的新语法模板字符串模板编译标签模板4 ES6 中的函数默认值rest 参数箭头函数this 指向问题部署管道机制尾调用优化...

uniapp整合echarts(目前性能最优、渲染最快方案)

本文echarts示例如上图,可扫码体验渲染速度及loading效果,下文附带本小程序uniapp相关代码 实现代码 <template><view class="source...

解决Electron应用中的白屏问题的实用方法

在使用Electron构建应用程序时&#xff0c;一些开发者可能会面临窗口加载过程中出现的白屏问题。这种问题主要分为两个方面&#xff1a; Electron未加载完毕HTML&#xff1a; 这时Electron自身产生的白色背景可能导致用户在启动应用时看到一片空白。HTML加载渲染过程中的短暂白…...

大数据---34.HBase数据结构

一、HBase简介 HBase是一个开源的、分布式的、版本化的NoSQL数据库&#xff08;即非关系型数据库&#xff09;&#xff0c;依托Hadoop分布式文件系统HDFS提供分布式数据存储&#xff0c;利用MapReduce来处理海量数据&#xff0c;用Zookeeper作为其分布式协同服务&#xff0c;一…...

【工具使用-有道云笔记】如何在有道云笔记中插入目录

一&#xff0c;简介 本文主要介绍如何在有道云笔记中插入目录&#xff0c;方便后续笔记的查看&#xff0c;供参考。 二&#xff0c;具体步骤 分为两个步骤&#xff1a;1&#xff0c;设置标题格式&#xff1b;2&#xff0c;插入标题。非常简单~ 2.1 设置标题格式 鼠标停在标…...

用户管理第2节课-idea 2023.2 后端一删除表,从零开始---【本人】

一、清空model文件夹下&#xff0c;所有文件 1.1.1效果如下&#xff1a; 1.1代码内容 package com.daisy.usercenter.model;import lombok.Data;Data public class User {private Long id;private String name;private Integer age;private String email; }二、清空mapper文件…...

如何添加jar包到本地Maven项目中

在 Maven 中添加一个外部 JAR 包的依赖&#xff0c;你需要使用 Maven 的 <dependency> 元素来指定该 JAR 包的坐标信息。以下是具体的步骤&#xff1a; 将 JAR 包手动添加到 Maven 本地仓库&#xff1a; 首先&#xff0c;确保将外部 JAR 包手动添加到 Maven 本地仓库。可…...

智能优化算法应用:基于学校优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于学校优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于学校优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.学校优化算法4.实验参数设定5.算法结果6.…...

【MATLAB第85期】基于MATLAB的2023年智能进化算法/元启发式算法合集(持续更新)

【MATLAB第85期】基于MATLAB的2023年智能进化算法/元启发式算法合集&#xff08;持续更新&#xff09; 1.海象进化算法&#xff08;Walrus Optimization Algorithm&#xff09; 作者&#xff1a;Pavel Trojovsk and Mohammad Dehghani 2.暴龙优化算法&#xff08;Tyrannosa…...

[Realtek sdk-3.4.14b]RTL8197FH-VG+RTL8812F WiFi使用功率限制功能使用说明

sdk说明 ** Gateway/AP firmware v3.4.14b – Aug 26, 2019**  Wireless LAN driver changes as:  Refine WiFi Stability and Performance  Add 8812F MU-MIMO  Add 97G/8812F multiple mac-clone  Add 97G 2T3R antenna diversity  Fix 97G/8812F/8814B MP issu…...

Vue中为什么data属性是一个函数而不是一个对象?(看完就会了)

文章目录 一、实例和组件定义data的区别二、组件data定义函数与对象的区别三、原理分析四、结论 一、实例和组件定义data的区别 vue实例的时候定义data属性既可以是一个对象&#xff0c;也可以是一个函数 const app new Vue({el:"#app",// 对象格式data:{foo:&quo…...

Linux中一些知识积累(持续补充)

如何安装Eigen3库&#xff1f; 在linux中直接命令安装。Eigen/Dense 是 Eigen 库中的一个模块&#xff0c;提供了对密集矩阵&#xff08;Dense Matrix&#xff09;的支持。 sudo apt install libeigen3-devLinux 中VScode中运行C时&#xff0c;gdb 的Launch与Attach有什么区别…...

内网渗透基础

内网 内网指的是内部局域网&#xff0c;常说的LAN&#xff08;local area network&#xff09;。常见家庭wifi网络和小型的企业网络&#xff0c;通常内部计算机直接访问路由器设备&#xff0c;路由器设备接入移动电信的光纤实现上网。 内部局域网可以通过交换机/防火墙组成多个…...

【2023年网络安全优秀创新成果大赛专刊】银行数据安全解决方案(天空卫士)

在2023年网络安全优秀创新成果大赛&#xff0c;成都分站中&#xff0c;天空卫士银行数据安全方案获得优秀解决方案奖。与此同时&#xff0c;天空卫士受信息安全杂志邀请&#xff0c;编写《银行数据安全解决方案》。12月6日&#xff0c;天空卫士编写的《银行数据安全解决方案》做…...

嵌入式串口输入详细实例

学习目标 掌握串口初始化流程掌握串口输出单个字符掌握串口输出字符串掌握通过串口printf熟练掌握串口开发流程学习内容 需求 串口循环输出内容到PC机。 串口数据发送 添加Usart功能。 首先,选中Firmware,鼠标右键,点击Manage Project Items 接着,将gd32f4xx_usart.c添…...

springboot(ssm智慧生活商城系统 网上购物系统Java系统

springboot(ssm智慧生活商城系统 网上购物系统Java系统 开发语言&#xff1a;Java 框架&#xff1a;ssm/springboot vue JDK版本&#xff1a;JDK1.8&#xff08;或11&#xff09; 服务器&#xff1a;tomcat 数据库&#xff1a;mysql 5.7&#xff08;或8.0&#xff09; 数…...

Peter算法小课堂—贪心与二分

太戈编程655题 题目描述&#xff1a; 有n辆车大甩卖&#xff0c;第i辆车售价a[i]元。有m个人带着现金来申请购买&#xff0c;第i个到现场的人带的现金为b[i]元&#xff0c;只能买价格不超过其现金额的车子。你是大卖场总经理&#xff0c;希望将车和买家尽量多地进行一对一配对…...

搭建Vue前端项目的流程

1、安装nodejs 测试安装是否成功 $ npm -v 6.14.16 $ node -v v12.22.122、全局安装npm install -g vue/cli&#xff0c;后续会使用到vue命令 $ vue --version vue/cli 5.0.8使用vue create demo_project_fe命令创建项目&#xff0c;使用箭头键来选择&#xff0c;确认使用回车…...

1.使用 Blazor 利用 ASP.NET Core 生成第一个 Web 应用

参考 https://dotnet.microsoft.com/zh-cn/learn/aspnet/blazor-tutorial/create 1.使用vs2022创建新项目 选择 C# -> Windows -> Blzxor Server 应用模板 2.项目名称BlazorApp下一步 3.选择 .NET6.0 或 .NET7.0 或 .NET8.0 创建 4.运行BlazorApp 5.全部选择是。 信…...

如何入门 GPT 并快速跟上当前的大语言模型 LLM 进展?

入门GPT 首先说第一个问题&#xff1a;如何入门GPT模型&#xff1f; 最直接的方式当然是去阅读官方的论文。GPT模型从2018年的GPT-1到现在的GPT-4已经迭代了好几个版本&#xff0c;通过官方团队发表的论文是最能准确理清其发展脉络的途径&#xff0c;其中包括GPT模型本身和一…...

【pentaho】kettle读取Hive表不支持bigint和timstamp类型解决。

一、bigint类型 报错: Unable to get value BigNumber(16) from database resultset显示kettle认为此应该是decimal类型(kettle中是TYPE_BIGNUMBER或称BigNumber)&#xff0c;但实际hive数据库中是big类型。 修改kettle源码解决&#xff1a; kettle中java.sql.Types到kettle…...

什么程序做教育网站好/网站提交收录软件

本博文介绍如何使用 UTL_SMTP来发送邮件UTL_SMTP是基于SMTP协议来发送邮件1、需要安装UTL_SMTP包要利用oracle的系统包实现发送邮件的功能&#xff0c;必须先以sys用户登录执行以下两个脚本&#xff1a;$ORACLE_HOME/rdbms/admin/utlsmtp.sql$ORACLE_HOME/rdbms/admin/utltcp.s…...

中山移动网站建设报价/世界杯大数据

最近在看李沐的实用机器学习课程&#xff0c;讲到regression问题的loss的时候有弹幕问&#xff1a;“为什么要平方&#xff1f;”如果是几年前学生问我这个问题&#xff0c;我会回答&#xff1a;“因为做回归的时候的我们的残差有正有负&#xff0c;取个平方求和以后可以很简单…...

永州公司做网站/河南关键词排名顾问

接口callable <V> 类型参数 V-call方法的结构类型 public interface Callable<V> 返回结果并且可能抛出的异常的任务。实现者定义一个不带任何参数的的call()方法&#xff0c; Callable 接口类似于Runnable ,两者都是为了哪些真实实例可能被另一个线程执行的类…...

陕西省人民政府采购网/曲靖seo

本文同步发布于 个人博客 前言 上周周赛因为忘记起床导致没打TAT 本次周赛战绩: rk5&#xff0c;总完成时间20min&#xff0c;还有奖品&#xff0c;好耶&#xff01; A 2129.将标题首字母大写 题意 给出一个包含若干个单词的句子&#xff0c;把所有字母变为小写字母&#…...

什么是网络营销策略?/seo指的是什么意思

&#xff08;本文发表于《程序员》2010年3月刊&#xff09; 借鉴丰田方法对大型软件组织进行敏捷改造 &#xff08;上&#xff09; 本文以 ThoughtWorks 中国公司与 某大型 电 信 设备 提供商 合作的 咨询项目 案例 为 背景 &#xff0c; 介 绍 如何采用丰…...

企业网站建设的好处/网站流量数据分析

【H5】 svg画扇形饼图 效果图如下&#xff1a; 封装代码如下&#xff1a; 代码内有详细注解哦&#xff01; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widt…...