当前位置: 首页 > news >正文

20- widedeep及函数式构建模型 (TensorFlow系列) (深度学习)

知识要点

  • wide&deep: 模型构建中, 卷积后数据和原始数据结合进行输出.
  • fetch_california_housing:加利福尼亚的房价数据,总计20640个样本,每个样本8个属性表示,以及房价作为target,所有属性值均为number,详情可调用fetch_california_housing()['DESCR']了解每个属性的具体含义;目标值为连续值
  • wide&deep结合: concat = keras.layers.concatenate([input, hidden2])   # 将卷积后的结果和原始的输入值进行结合
  • mse: 均方误差
  • 多输入wide&deep模型: concat = keras.layers.concatenate([input_wide, hidden2])  # 定义两个输入创建模型,  然后其中一个进行深度卷积, 另一个直接用来结合卷积后的结果. 同时注意需要对输入特征数据进行调整.
  • model = keras.models.Model(inputs=[input_wide,input_deep],outputs =[output,output2]# 多输入输出
  • 定义模型回调函数:     # log_dir 文件夹目录
callbacks = [keras.callbacks.TensorBoard(log_dir),keras.callbacks.ModelCheckpoint(output_model_file, save_best_only = True),keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]
  • 函数式实现wide&deep的方法:
# 子类API的写法, pytch
class WideDeepModel(keras.models.Model):def __init__(self):'''定义模型的层次'''super().__init__()self.hidden1 = keras.layers.Dense(32, activation = 'relu')self.hidden2 = keras.layers.Dense(32, activation = 'relu')self.output_layer = keras.layers.Dense(1)def call(self, input):'''完成模型的正向传播'''hidden1 = self.hidden1(input)hidden2 = self.hidden2(hidden1)# 拼接concat = keras.layers.concatenate([input, hidden2])output = self.output_layer(concat)return output'''定义实例对象'''
model = WideDeepModel()
model.build(input_shape = (None, 8))


1 wide and deep模型

1.1 背景

Wide and deep 模型是 TensorFlow 在 2016 年 6 月左右发布的一类用于分类和回归的模型,并应用到了 Google Play 的应用推荐中。wide and deep 模型的核心思想是结合线性模型的记忆能力(memorization)和 DNN 模型的泛化能力(generalization),在训练过程中同时优化 2 个模型的参数,从而达到整体模型的预测能力最优

记忆(memorization)即从历史数据中发现item或者特征之间的相关性

泛化(generalization)即相关性的传递,发现在历史数据中很少或者没有出现的新的特征组合。

1.2 网络结构原理

1.3 稀疏特征

离散值特征: 只能从N个值中选择一个

  • 比如性别, 只能是男女

  • one-hot编码表示的离散特征, 我们就认为是稀疏特征.

  • Eg: 专业= {计算机, 人文, 其他}, 人文 = [0, 1, 0]

  • Eg: 词表 = {人工智能,深度学习,你, 我, 他 , ..} 他= [0, 0, 0, 0, 1, 0, ...]

  • 叉乘 = {(计算机, 人工智能), (计算机, 你)...}

  • 叉乘可以用来精确刻画样本, 实现记忆效果.

  • 优点:

    • 有效, 广泛用于工业界, 比如广告点击率预估(谷歌, 百度的主要业务), 推荐算法.

  • 缺点:

    • 需要人工设计.

    • 叉乘过度, 可能过拟合, 所有特征都叉乘, 相当于记住了每一个样本.

    • 泛化能力差, 没出现过就不会起效果

密集特征

  • 向量表达

    • Eg: 词表 = {人工智能, 我们, 他}

    • 他 = [0.3, 0.2, 0.6, ...(n维向量)]

    • 每个词都可以用一个密集向量表示, 那么词和词之间就可以计算距离.

  • Word2vec工具可以方便的将词语转化为向量.

    • 男 - 女 = 国王 - 王后

  • 优点:

    • 带有语义信息, 不同向量之间有相关性.

    • 兼容没有出现过的特征组合.

    • 更少人工参与

  • 缺点:

    • 过度泛化, 比如推荐不怎么相关的产品.

1.4 简单神经网络实现回归任务  (加利福尼亚州房价数据)

  • concat = keras.layers.concatenate([input, hidden2])

1.4.1 导包

from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

1.4.2 加利福尼亚州房价数据导入

from sklearn.datasets import fetch_california_housing
housing = fetch_california_housing()x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data,housing.target,random_state= 7)
x_train, x_valid, y_train, y_valid = train_test_split(x_train_all, y_train_all,random_state = 11)

1.4.3 标准化数据

from sklearn.preprocessing import StandardScaler, MinMaxScaler
scaler =StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)

1.4.4 基础神经网络 实现回归任务

# 定义网络
model = keras.models.Sequential([# input_dim是传入数据, input_shape一定要是元组keras.layers.Dense(128, activation = 'relu', input_shape = x_train.shape[1:]),keras.layers.Dense(64, activation = 'tanh'),keras.layers.Dense(1)])

 1.4.5 配置和训练模型

# 配置
model.compile(loss = 'mean_squared_error', optimizer = 'sgd', metrics = ['mse'])
# epochs 迭代次数
history = model.fit(x_train_scaled, y_train, validation_data = (x_valid_scaled, y_valid), epochs = 30)

 1.4.6 图文显示

# 定义画图函数, 看是否过拟合
def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize = (8, 5))plt.grid(True)plt.gca().set_ylim(0, 1)plt.show()
plot_learning_curves(history)

1.5 定义回调函数

  • 回调函数中添加保存最佳参数的模型
  • 定义提前停止的条件    # 连续多少次变化幅度小于某值时停止训练
log_dir = './callback'
if not os.path.exists(log_dir):  # 如果没有直接创建os.mkdir(log_dir)# 模型文件保存格式, 一般为h5, 会保存层级
output_model_file = os.path.join(log_dir, 'model.h5')  callbacks = [keras.callbacks.TensorBoard(log_dir),keras.callbacks.ModelCheckpoint(output_model_file, save_best_only = True),keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]
# epochs 迭代次数
history = model.fit(x_train_scaled,y_train,validation_data =(x_valid_scaled,y_valid),epochs = 50, callbacks = callbacks)

2 多输入 wide&deep模型

2.1 wide&deep模型 (内部进行结合)

input = keras.layers.Input(shape = x_train.shape[1:])    # (11610, 8)
hidden1 = keras.layers.Dense(32, activation = 'relu')(input)
hidden2 = keras.layers.Dense(32, activation = 'relu')(hidden1)concat = keras.layers.concatenate([input, hidden2])
output = keras.layers.Dense(1)(concat)
model = keras.models.Model(inputs = [input], outputs = output)

model.compile(loss = 'mean_squared_error', optimizer = 'Adam', metrics= ['mse'])
history = model.fit(x_train_scaled, y_train,validation_data = (x_valid_scaled, y_valid),epochs= 20)

import pandas as pd
def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize = (8, 5))plt.grid(True)plt.gca().set_ylim(0, 1)plt.show()
plot_learning_curves(history)

2.2 wide&deep方式二 (多输入)

  • 定义两个输入创建模型
# 多输入
# 定义两个输入
input_wide = keras.layers.Input(shape = [5])
input_deep = keras.layers.Input(shape = [6])hidden1 = keras.layers.Dense(30, activation = 'relu')(input_deep)
hidden2 = keras.layers.Dense(30, activation = 'relu')(hidden1)
concat = keras.layers.concatenate([input_wide, hidden2])
output = keras.layers.Dense(1)(concat)model = keras.models.Model(inputs = [input_wide, input_deep], outputs =[output])

# 对输入数据进行修改
x_train_scaled_wide = x_train_scaled[:, :5]
x_train_scaled_deep = x_train_scaled[:, 2:]x_valid_scaled_wide = x_valid_scaled[:, :5]
x_valid_scaled_deep = x_valid_scaled[:, 2:]x_test_scaled_wide = x_test_scaled[:, :5]
x_test_scaled_deep = x_test_scaled[:, 2:]history = model.fit([x_train_scaled_wide, x_train_scaled_deep],y_train, validation_data = ([x_valid_scaled_wide, x_valid_scaled_deep], y_valid), epochs= 20)

 2.3 wide&deep方式三  (多输出)

  • 双输入双输出
# 多输出 # 定义两个输入
input_wide = keras.layers.Input(shape = [5])
input_deep = keras.layers.Input(shape = [6])hidden1 = keras.layers.Dense(30, activation = 'relu')(input_deep)
hidden2 = keras.layers.Dense(30, activation = 'relu')(hidden1)
concat = keras.layers.concatenate([input_wide, hidden2])output = keras.layers.Dense(1)(concat)
output2 = keras.layers.Dense(1)(hidden2)model = keras.models.Model(inputs=[input_wide,input_deep],outputs =[output,output2])

# 对输入数据进行修改
x_train_scaled_wide = x_train_scaled[:, :5]
x_train_scaled_deep = x_train_scaled[:, 2:]x_valid_scaled_wide = x_valid_scaled[:, :5]
x_valid_scaled_deep = x_valid_scaled[:, 2:]x_test_scaled_wide = x_test_scaled[:, :5]
x_test_scaled_deep = x_test_scaled[:, 2:]history = model.fit([x_train_scaled_wide, x_train_scaled_deep],[y_train, y_train], validation_data = ([x_valid_scaled_wide, x_valid_scaled_deep], [y_valid, y_valid]), epochs= 20)

  •  在该模型的效果一般

3 子类API 实现wide&deep模型

3.1 函数构建模型

  • 卷积后的结果结合原始输入进行运算
# 子类API的写法, pytch
class WideDeepModel(keras.models.Model):def __init__(self):'''定义模型的层次'''super().__init__()self.hidden1 = keras.layers.Dense(32, activation = 'relu')self.hidden2 = keras.layers.Dense(32, activation = 'relu')self.output_layer = keras.layers.Dense(1)def call(self, input):'''完成模型的正向传播'''hidden1 = self.hidden1(input)hidden2 = self.hidden2(hidden1)# 拼接concat = keras.layers.concatenate([input, hidden2])output = self.output_layer(concat)return output'''定义实例对象'''
model = WideDeepModel()
model.build(input_shape = (None, 8))

# 配置
model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mse'])
history = model.fit(x_train_scaled, y_train, validation_data = (x_valid_scaled, y_valid), epochs= 20)

相关文章:

20- widedeep及函数式构建模型 (TensorFlow系列) (深度学习)

知识要点 wide&deep: 模型构建中, 卷积后数据和原始数据结合进行输出.fetch_california_housing:加利福尼亚的房价数据,总计20640个样本,每个样本8个属性表示,以及房价作为target,所有属性值均为number&#xff0…...

大家一起做测试的,凭什么你现在拿20k,我却还只有10k?...

最近我发现一个神奇的事情,我一个97年的朋友居然已经当上了测试项目组长,据我所知他去年还是在深圳的一家创业公司做苦逼的测试狗,短短8个月,到底发生了什么? 于是我立刻私聊他八卦一番。 原来他所在的公司最近正在裁…...

>>数据管理:DAMA简介「考试和续期」

关于DAMA,这里就不再多做描述,可以参考以前写的一些简介或官方介绍。下面就考试再做一些详细介绍。 1 区别 CDGA:数据治理工程师(Certified Data Governance Associate),“DAMA中国”组织的数据治理方面的职业认证考试。 CDGP:数据治理专家(Certified Data Governa…...

React的生命周期详细讲解

什么是生命周期? 所谓的React生命周期,就是指组件从被创建出来,到被使用,最后被销毁的这么一个过程。而在这个过程中,React提供了我们会自动执行的不同的钩子函数,我们称之为生命周期函数。**组件的生命周期…...

蓝蓝算法二期工程day3,一万年太久,只争朝夕

思路: 最好想的是用hashmap,当然用c的话也可以用两个数组,一个数组用于存放字符串,自动对应ACSII码,一个将对应ACSII码的数字对应其下标,当然这也是用的映射的思想。 import java.util.*;public class Cac…...

程序代码的自动化生成方案设计

程序设计就能够适用这种代码自动化生成方法的前提是:PLC 程序代码具有高度重复性,执行的是相同数据处理或者逻辑判断,而相关变量组 是离 散 的,没 有规 律 可循 。以 I/O 变量和中间 变量的地 址 映 射 程序为例 ,程序代码为赋 值 语 句 ,高度重复;IO 变量和与 其 对应 的中间 …...

Go 稀疏数组学习与实现

仍然还是一个数组 基本介绍 一般就是指二维以上的数组 当一个数组中大部分元素是0 ,或者为同一个值的数组时,可以使用系数数组来保存该数组. 稀疏数组的处理方法: 记录数组一共有几行几列,有多少个不同的值把具有不同值的元素的行列及值记录在一个小规模的数组中,从而缩小程…...

MySQL 学习笔记(借鉴黑马程序员MySQL)

MySQL视频课链接 MySQL概述 数据库相关概念 数据库是存储数据的仓库,数据是有组织的进行存储(DataBase) 数据库管理系统是操纵和管理数据库的大型软件(DataBase Management System) SQL是操作关系型数据库的编程语…...

中级工程师职称申报到底需要参加答辩不?

获得中级工程师职称的方式有认定、评审、考试这几种形式。 甘建二老师先来简单说一下关于认定和考试这两种: 1.认定:中级职称认定一般是根据各地职称认定政策,如果你想走认定渠道,首先本人简历条件、业绩、奖项等非常优秀&#…...

MM32开发教程(LED灯)

文章目录前言一、MM32介绍和STM32的区别二、板载LED灯原理图三、代码编写总结前言 今天将为大家介绍一款性能高体积小的MM32,这款开发板出自百问网团队。他就是灵动的MM32F3273,他体积非常小便于携带。 有128KB的SRAM、512KB的Flash、而且还支持双TypeC…...

win10安装docker

1.win10安装docker,前提必须是要安装WSL2。 现在Docker Desktop默认使用WSL 2来运行,而不是以前的Hyper-V。 WSL2 全称是Windows Subsystem on Linux。意思是,在win10,可以直接启动一个Linux。因为docker依赖Linux内核。 可查看…...

设计模式系列 - 代理模式及动态代理详解

定义 为其他对象提供一种代理以控制对这个对象的访问。在某些情况下,一个对象不适合或者不能直接引用另一个对象,而代理对象可以在客户端和目标对象之间起到中介的作用。 结构 抽象角色:通过接口或抽象类声明真实角色实现的业务方法。 代…...

【分享】订阅集简云畅捷通T+cloud连接器自动同步财务费用单至畅捷通

方案场景 伴随公司发展和数字化水平提高,大量的财务单据需要手动审核和录入,这些重复机械的操作占据大量人力,同时极容易出现数据出错或丢失等情况,严重影响着企业经营效率。 使用集简云提供服务的畅捷通TCloud钉钉连接器完成财…...

GPT的发展历程

GPT是当前最火的人工智能技术之一,自推出以来就广受关注。但大家对这个技术了解多少,又知道它经历了什么? GPT的诞生离不开谷歌在人工智能领域的努力和研究。2004年,谷歌成立了人工智能实验室(现已成为谷歌 AI实验室&…...

iOS开发笔记之九十八——关于Memory Leak总结笔记

*****阅读完此文,大概需要3分钟******关于Memory leak(内存泄漏)的问题,如果是面试被问这个问题以及此类问题,主要涉及下面3个方面:内存泄漏的常见场景有哪些,列举几个常见的例子?开…...

HTML基础语法

一 前端简介构成语言说明结构HTML页面元素和内容表现CSS网页元素的外观和位置等页面样式(美化)行为JavaScript网页模型的定义和页面交互二 HTML1.简介HTML(Hyper Text Markup Language):超文本标记语言。网页结构整体&…...

微软新版必应gpt人工智能体验教程

大家好,我是雄雄,欢迎关注微信公众号:** 雄雄的小课堂 ** 现在是:2023年2月28日18:35:02 前言 前几天,发了一篇文章,主要介绍了如何申请新必应的内测名单,其实一共也就那几步,然后等着就行: 文章连接:new bing如何快速申请内测资格,从而体验人工智能? 今天,终于…...

你问我答|虚拟机、容器和无服务器,怎么选?

在新技术层出不穷的当下,每家企业都希望不断降低成本,并提高运营效率,一个方法就是寻找不同的技术方案来优化运营。      例如,曾经一台服务器只能运行一个应用(裸机);接着,一台服务器的资源可以划分为多个块,从而运行多个应用(虚拟化);再到后来,应用越来越多,为了方便它们…...

某建筑设计研究院“综合布线管理软件”应用实践

某建筑设计研究院有限公司(简称“某院”)隶属于国务院国资委直属的大型骨干科技型中央企业。“某院”前身为中央直属设计公司,创建于1952年。成立近70年来,始终秉承优良传统,致力于推进国内勘察设计产业的创新发展&…...

R语言绘制SCI论文中常见的箱线散点图,并自动进行方差分析计算显著性水平

显著性标记箱线散点图 本篇笔记的内容是在R语言中利用ggplot2,ggsignif,ggsci,ggpubr等包制作箱线散点图,并计算指定变量之间的显著性水平,对不同分组进行特异性标记,最终效果如下。 加载R包 library(ggplo…...

后进先出(LIFO)详解

LIFO 是 Last In, First Out 的缩写,中文译为后进先出。这是一种数据结构的工作原则,类似于一摞盘子或一叠书本: 最后放进去的元素最先出来 -想象往筒状容器里放盘子: (1)你放进的最后一个盘子&#xff08…...

深入剖析AI大模型:大模型时代的 Prompt 工程全解析

今天聊的内容,我认为是AI开发里面非常重要的内容。它在AI开发里无处不在,当你对 AI 助手说 "用李白的风格写一首关于人工智能的诗",或者让翻译模型 "将这段合同翻译成商务日语" 时,输入的这句话就是 Prompt。…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

React hook之useRef

React useRef 详解 useRef 是 React 提供的一个 Hook&#xff0c;用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途&#xff0c;下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

visual studio 2022更改主题为深色

visual studio 2022更改主题为深色 点击visual studio 上方的 工具-> 选项 在选项窗口中&#xff0c;选择 环境 -> 常规 &#xff0c;将其中的颜色主题改成深色 点击确定&#xff0c;更改完成...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

渲染学进阶内容——模型

最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...

相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)

【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...

【C语言练习】080. 使用C语言实现简单的数据库操作

080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”&#xff0c;物流的终极形态正在诞生 想象这样的场景&#xff1a; 凌晨3点&#xff0c;某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径&#xff1b;AI视觉系统在0.1秒内扫描包裹信息&#xff1b;数字孪生平台正模拟次日峰值流量压力…...