当前位置: 首页 > news >正文

21- 神经网络模型_超参数搜索 (TensorFlow系列) (深度学习)

知识要点

  • fetch_california_housing:加利福尼亚的房价数据,总计20640个样本,每个样本8个属性表示,以及房价作为target

  • 超参数搜索的方式: 网格搜索, 随机搜索, 遗传算法搜索, 启发式搜索

  • 超参数训练后用: gv.estimator调取最佳模型

  • 函数式添加神经网络:

    • model.add(keras.layers.Dense(layer_size, activation = 'relu'))

    • model.compile(loss = 'mse', optimizer = optimizer)    # optimizer = keras.optimizers.SGD (learning_rate)

    • sklearn_model = KerasRegressor(build_fn = build_model)

from tensorflow.keras.wrappers.scikit_learn import KerasRegressor   # 回归神经网络
# 搜索最佳学习率
def build_model(hidden_layers = 1, layer_size = 30, learning_rate = 3e-3):model = keras.models.Sequential()model.add(keras.layers.Dense(layer_size, activation = 'relu', input_shape = x_train.shape[1:]))for _ in range(hidden_layers - 1):model.add(keras.layers.Dense(layer_size, activation = 'relu'))model.add(keras.layers.Dense(1))optimizer = keras.optimizers.SGD(learning_rate)model.compile(loss = 'mse', optimizer = optimizer)# model.summary()return model
sklearn_model = KerasRegressor(build_fn = build_model)
  • callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]  # 回调函数设置

  • gv = GridSearchCV(sklearn_model, param_grid = params, n_jobs = 1, cv= 5,verbose = 1)  # 找最佳参数

  • gv.fit(x_train_scaled, y_train)


1 导包

from tensorflow import keras
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
cpu=tf.config.list_physical_devices("CPU")
tf.config.set_visible_devices(cpu)
print(tf.config.list_logical_devices())

2 导入数据

from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housinghousing = fetch_california_housing()
x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data,housing.target,random_state= 7)
x_train, x_valid, y_train, y_valid = train_test_split(x_train_all, y_train_all,random_state = 11)

3 标准化处理数据

from sklearn.preprocessing import StandardScaler, MinMaxScalerscaler =StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)

4 函数式定义模型

from tensorflow.keras.wrappers.scikit_learn import KerasRegressor   # 回归神经网络
# 搜索最佳学习率
def build_model(hidden_layers = 1, layer_size = 30, learning_rate = 3e-3):model = keras.models.Sequential()model.add(keras.layers.Dense(layer_size, activation = 'relu', input_shape = x_train.shape[1:]))for _ in range(hidden_layers - 1):model.add(keras.layers.Dense(layer_size, activation = 'relu'))model.add(keras.layers.Dense(1))optimizer = keras.optimizers.SGD(learning_rate)model.compile(loss = 'mse', optimizer = optimizer)# model.summary()return model
sklearn_model = KerasRegressor(build_fn = build_model)

 

5 模型训练

callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]
history = sklearn_model.fit(x_train_scaled, y_train, epochs = 10,validation_data = (x_valid_scaled, y_valid), callbacks = callbacks)

 6 超参数搜索

超参数搜索的方式:

  • 网格搜索

    • 定义n维方格

    • 每个方格对应一组超参数

    • 一组一组参数尝试

  • 随机搜索

  • 遗传算法搜索

    • 对自然界的模拟

    • A: 初始化候选参数集合 --> 训练---> 得到模型指标作为生存概率

    • B: 选择 --> 交叉--> 变异 --> 产生下一代集合

    • C: 重新到A, 循环.

  • 启发式搜索

    • 研究热点-- AutoML的一部分

    • 使用循环神经网络来生成参数

    • 使用强化学习来进行反馈, 使用模型来训练生成参数.

# 使用sklearn 的网格搜索, 或者随机搜索
from sklearn.model_selection import GridSearchCV, RandomizedSearchCVparams = {'learning_rate' : [1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2],'hidden_layers': [2, 3, 4, 5], 'layer_size': [20, 60, 100]}gv = GridSearchCV(sklearn_model, param_grid = params, n_jobs = 1, cv= 5,verbose = 1)
gv.fit(x_train_scaled, y_train)
  • 输出最佳参数
# 最佳得分
print(gv.best_score_)    # -0.47164334654808043
# 最佳参数
print(gv.best_params_)  # {'hidden_layers': 5,'layer_size': 100,'learning_rate':0.01}
# 最佳模型
print(gv.estimator)
'''<keras.wrappers.scikit_learn.KerasRegressor object at 0x0000025F5BB12220>'''
gv.score

7 最佳参数建模

model = keras.models.Sequential()
model.add(keras.layers.Dense(100, activation = 'relu', input_shape = x_train.shape[1:]))
for _ in range(4):model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(1))
optimizer = keras.optimizers.SGD(0.01)
model.compile(loss = 'mse', optimizer = optimizer)
model.summary()

callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]
history = model.fit(x_train_scaled, y_train, epochs = 10,validation_data = (x_valid_scaled, y_valid), callbacks = callbacks)

 8 手动实现超参数搜索

  • 根据参数进行多次模型的训练, 然后记录 loss
# 搜索最佳学习率
learning_rates = [1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2]
histories = []
for lr in learning_rates:model = keras.models.Sequential([keras.layers.Dense(30, activation = 'relu', input_shape = x_train.shape[1:]),keras.layers.Dense(1)])optimizer = keras.optimizers.SGD(lr)model.compile(loss = 'mse', optimizer = optimizer, metrics = ['mse'])callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-2)]history = model.fit(x_train_scaled, y_train, validation_data = (x_valid_scaled, y_valid), epochs = 100, callbacks = callbacks)histories.append(history)

 

# 画图
import pandas as pd
def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize = (8, 5))plt.grid(True)plt.gca().set_ylim(0, 1)plt.show()for lr, history in zip(learning_rates, histories): print(lr)plot_learning_curves(history)   

相关文章:

21- 神经网络模型_超参数搜索 (TensorFlow系列) (深度学习)

知识要点 fetch_california_housing&#xff1a;加利福尼亚的房价数据&#xff0c;总计20640个样本&#xff0c;每个样本8个属性表示&#xff0c;以及房价作为target 超参数搜索的方式: 网格搜索, 随机搜索, 遗传算法搜索, 启发式搜索 超参数训练后用&#xff1a; gv.estimat…...

《NFL橄榄球》:芝加哥熊·橄榄1号位

芝加哥熊&#xff08;英语&#xff1a;Chicago Bears&#xff09;是一支职业美式橄榄球球队。位于伊利诺伊州的芝加哥。现时为全国橄榄球联盟的国家联盟北区的球队。他们曾经赢出九次美式橄榄球比赛的冠军&#xff0c;分别为八次旧制全国橄榄球联盟和一次超级碗冠军&#xff08…...

【ES】Elasticsearch核心基础概念:文档与索引

es的核心概念主要是&#xff1a;index(索引)、Document(文档)、Clusters(集群)、Node(节点)与实例&#xff0c;下面我们先来了解一下Document与Index。 RESTful APIs 在讲解Document与Index概念之前&#xff0c;我们先来了解一下RESTful APIs&#xff0c;因为下面讲解Documen…...

实时手势识别(C++与python都可实现)

一、前提配置&#xff1a; Windows&#xff0c;visual studio 2019&#xff0c;opencv&#xff0c;python10&#xff0c;opencv-python&#xff0c;numpy&#xff0c;tensorflow&#xff0c;mediapipe&#xff0c;math 1.安装python环境 这里我个人使用的安装python10&#…...

15个Spring扩展点,一般人知道的不超过5个!

Spring的核心思想就是容器&#xff0c;当容器refresh的时候&#xff0c;外部看上去风平浪静&#xff0c;其实内部则是一片惊涛骇浪&#xff0c;汪洋一片。Spring Boot更是封装了Spring&#xff0c;遵循约定大于配置&#xff0c;加上自动装配的机制。很多时候我们只要引用了一个…...

Elasticsearch:以 “Painless” 方式保护你的映射

Elasticsearch 是一个很棒的工具&#xff0c;可以从各种来源收集日志和指标。 它为我们提供了许多默认处理&#xff0c;以便提供最佳用户体验。 但是&#xff0c;在某些情况下&#xff0c;默认处理可能不是最佳的&#xff08;尤其是在生产环境中&#xff09;&#xff1b; 因此&…...

js几种对象创建方式

适用于不确定对象内部数据方式一&#xff1a;var p new Object(); p.name TOM; p.age 12 p.setName function(name) {this.name name; }// 测试 p.setName(jack) console.log(p.name,p.age)方式二&#xff1a; 对象字面量模式套路&#xff1a;使用{}创建对象&#xff0c;同…...

阿里云服务器ECS适用于哪些应用场景?

云服务器ECS具有广泛的应用场景&#xff0c;既可以作为Web服务器或者应用服务器单独使用&#xff0c;又可以与其他阿里云服务集成提供丰富的解决方案。 云服务器ECS的典型应用场景包括但不限于本文描述&#xff0c;您可以在使用云服务器ECS的同时发现云计算带来的技术红利。 阿…...

Ajax学习笔记01

引入 翻译成中文就是“异步的Javascript和XML”。即使用Javascript语言与服务器进行异步交互&#xff0c;传输的数据为XML&#xff08;当然&#xff0c;传输的数据不只是XML&#xff09;。 AJAX 不是新的编程语言&#xff0c;而是一种使用现有标准的新方法。 AJAX 最大的优点…...

Jinja2----------过滤器的使用、控制语句

目录 1.过滤器的使用 1.过滤器和测试器 2.过滤器 templates/filter.html app.py 效果 3.自定义过滤器 app.py templates/filter.html 效果 2.控制语句 1.if app.py templates/control.html 2.for app.py templates/control.htm 1.过滤器的使用 1.过滤器和测…...

面试了1个自动化测试,开口40W年薪,只能说痴人做梦...

公司前段缺人&#xff0c;也面了不少测试&#xff0c;结果竟然没有一个合适的。一开始瞄准的就是中级的水准&#xff0c;也没指望来大牛&#xff0c;提供的薪资在10-20k&#xff0c;面试的人很多&#xff0c;但平均水平很让人失望。看简历很多都是3年工作经验&#xff0c;但面试…...

冲鸭!33% 程序员月薪达到 5 万元以上~

2023年&#xff0c;随着互联网产业的蓬勃发展&#xff0c;程序员作为一个自带“高薪多金”标签的热门群体&#xff0c;被越来越多的人所关注。在过去充满未知的一年中&#xff0c;他们的职场现状发生了一定的改变。那么&#xff0c;程序员岗位的整体薪资水平、婚恋现状、职业方…...

【RSA】HTTPS中SSL/TLS握手时RSA前后端加密流程

SSL/TLS层的位置 SSL/TLS层在网络模型的位置&#xff0c;它属于应用层协议。接管应用层的数据加解密&#xff0c;并通过网络层发送给对方。 SSL/TLS协议分握手协议和记录协议&#xff0c;握手协议用来协商会话参数&#xff08;比如会话密钥、应用层协议等等&#xff09;&…...

clion在linux设置桌面启动图标(jetbrains全家桶均适用)

clion在linux设置桌面启动图标&#xff08;jetbrains全家桶均适用&#xff09; 网上大部分步骤都只是pycharm的教程&#xff0c;其实对于jetbrains全家桶都适合&#xff0c;vs code编辑器也可以这样。 刚开始是使用pycharm在linux设置的教程&#xff0c;参照&#xff1a;http…...

Java数据结构LinkedList单链表和双链表模拟实现及相关OJ题秒AC总结知识点

本篇文章主要讲述LinkedList链表中从初识到深入相关总结&#xff0c;常见OJ题秒AC&#xff0c;望各位大佬喜欢 一、单链表 1.1链表的概念及结构 1.2无头单向非循环链表模拟实现 1.3测试模拟代码 1.4链表相关面试OJ题 1.4.1 删除链表中等于给定值 val 的所有节点 1.4.2 反转…...

立创EDA 学习 day01 应用下载安装,基本使用的操作

1.下载网站 1.链接&#xff1a;立创EDA下载-立创EDA官方版-PC下载网 (pcsoft.com.cn) 2.安装立创EDA 1.直接 next &#xff08;简单的操作&#xff09; 3.注册账号 1. 最好注册一个账号&#xff0c;等下在原理图转PCB 板的时候要登录&#xff0c;才可以。 4.新建工程 1.新…...

华为OD机试真题Python实现【火星文计算】真题+解题思路+代码(20222023)

火星文计算 题目 已经火星人使用的运算符号为# $ 其与地球人的等价公式如下 x#y=2*x+3*y+4 x$y=3*x+y+2 x y是无符号整数 地球人公式按照 c 语言规则进行计算 火星人公式中$符优先级高于#相同的运算符按从左到右的顺序运算 🔥🔥🔥🔥🔥👉👉👉👉👉👉 华…...

yolov8 修改类别 自定义数据集

yolov8 加载yolo网络模型 yolov8n.yaml nc: 80 # number of classes 分类数量 depth_multiple: 0.33 # scales module repeats 重复规模 width_multiple: 0.25 # scales convolution channels 缩放卷积通道 backbone head 指定配置 coco128.yaml path: ../datasets/coco128 # d…...

Linux环境下验证python项目

公司大佬开发的python rpa跑数项目&#xff0c;Windows运行没问题后&#xff0c;需要搭建一个linux环境进行验证&#xff0c;NOW START&#xff01; Install VMware官网 下载好之后打开按步骤安装 最后一步会让填许可证&#xff08;密钥&#xff09;&#xff0c;这里自行百…...

MAC开发使用技巧

1. 查看所有安装的程序 您可以通过以下步骤在 macOS 中查看所有已安装的程序&#xff1a; 点击屏幕左上角的苹果图标&#xff0c;选择“关于本机”。 在打开的窗口中&#xff0c;选择“系统报告”。 在系统报告窗口中&#xff0c;选择“软件”选项卡&#xff0c;然后选择“安…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制&#xff0c;因此这个了16进制的数据既可以翻译成为这个机器码&#xff0c;也可以翻译成为这个国标码&#xff0c;所以这个时候很容易会出现这个歧义的情况&#xff1b; 因此&#xff0c;我们的这个国…...

Prompt Tuning、P-Tuning、Prefix Tuning的区别

一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)

CSI-2 协议详细解析 (一&#xff09; 1. CSI-2层定义&#xff08;CSI-2 Layer Definitions&#xff09; 分层结构 &#xff1a;CSI-2协议分为6层&#xff1a; 物理层&#xff08;PHY Layer&#xff09; &#xff1a; 定义电气特性、时钟机制和传输介质&#xff08;导线&#…...

Qt Widget类解析与代码注释

#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码&#xff0c;写上注释 当然可以&#xff01;这段代码是 Qt …...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试&#xff0c;通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小&#xff0c;增大可提高计算复杂度duration: 测试持续时间&#xff08;秒&…...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”&#xff0c;物流的终极形态正在诞生 想象这样的场景&#xff1a; 凌晨3点&#xff0c;某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径&#xff1b;AI视觉系统在0.1秒内扫描包裹信息&#xff1b;数字孪生平台正模拟次日峰值流量压力…...

Java + Spring Boot + Mybatis 实现批量插入

在 Java 中使用 Spring Boot 和 MyBatis 实现批量插入可以通过以下步骤完成。这里提供两种常用方法&#xff1a;使用 MyBatis 的 <foreach> 标签和批处理模式&#xff08;ExecutorType.BATCH&#xff09;。 方法一&#xff1a;使用 XML 的 <foreach> 标签&#xff…...

七、数据库的完整性

七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...

第7篇:中间件全链路监控与 SQL 性能分析实践

7.1 章节导读 在构建数据库中间件的过程中&#xff0c;可观测性 和 性能分析 是保障系统稳定性与可维护性的核心能力。 特别是在复杂分布式场景中&#xff0c;必须做到&#xff1a; &#x1f50d; 追踪每一条 SQL 的生命周期&#xff08;从入口到数据库执行&#xff09;&#…...

Spring Boot + MyBatis 集成支付宝支付流程

Spring Boot MyBatis 集成支付宝支付流程 核心流程 商户系统生成订单调用支付宝创建预支付订单用户跳转支付宝完成支付支付宝异步通知支付结果商户处理支付结果更新订单状态支付宝同步跳转回商户页面 代码实现示例&#xff08;电脑网站支付&#xff09; 1. 添加依赖 <!…...