当前位置: 首页 > news >正文

24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)

知识要点

keras 保存成hdf5文件, 1.保存模型和参数, 2.只保存参数

  • 1.保存模型和参数
    • save_model
    • callback ModelCheckpoint
  • 2. 只保存参数
    • save_weights
    • callback ModelCheckpoint save_weights_only = True

保存模型:

  • 案例数据: Fashion-MNIST总共有十个类别的图像
  • model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5'))      # 保存参数的方法
  • 加载参数: model.load_weights(os.path.join(logdir, 'fashion_mnist_weight.h5'))
  • 保存模型: model.save(os.path.join(logdir, 'fashion_mnist_model.h5'))
  • 加载模型: model2 = keras.models.load_model(os.path.join(logdir, 'fashion_mnist_model.h5'))
  • 把keras模型保存成savedmodel格式: tf.saved_model.save(model, './keras_saved_model')


一 模型保存和部署

  • TFLite是为了将深度学习模型部署在移动端和嵌入式设备的工具包,可以把训练好的TF模型通过转化、部署和优化三个步骤,达到提升运算速度,减少内存、显存占用的效果。
  • TFlite主要由Converter和Interpreter组成。Converter负责把TensorFlow训练好的模型转化,并输出为.tflite文件(FlatBuffer格式)。转化的同时,还完成了对网络的优化,如量化。Interpreter则负责把.tflite部署到移动端,嵌入式(embedded linux device)和microcontroller,并高效地执行推理过程,同时提供API接口给Python,Objective-C,Swift,Java等多种语言。简单来说,Converter负责打包优化模型,Interpreter负责高效易用地执行推理
  • Fashion-MNIST总共有十个类别的图像。每一个类别由训练数据集6000张图像和测试数据集1000张图像。所以训练集和测试集分别包含60000张和10000张。测试训练集用于评估模型的性能。

  • 每一个输入图像的高度和宽度均为28像素。数据集由灰度图像组成。Fashion-MNIST,中包含十个类别,分别是t-shirt,trouser,pillover,dress,coat,sandal,shirt,sneaker,bag,ankle boot。

1.1 模型创建

  • 导包
# 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
  • 时尚数据导入
# 时尚数据导入
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
  • 标准化
# 标准化
from sklearn.preprocessing import StandardScaler    # preprocessing 预处理
scaler = StandardScaler()x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1, 784))
x_valid_scaled = scaler.fit_transform(x_valid.astype(np.float32).reshape(-1, 784))
x_test_scaled = scaler.fit_transform(x_test.astype(np.float32).reshape(-1, 784))
  • 创建模型
# 创建模型
model = keras.models.Sequential([keras.layers.Dense(512, activation = 'relu', input_shape = (784, )),keras.layers.Dense(256, activation = 'relu'),keras.layers.Dense(128, activation = 'relu'),keras.layers.Dense(10, activation = 'softmax')])model.compile(loss = 'sparse_categorical_crossentropy',optimizer = 'adam',metrics = ['accuracy'])

1.2 保存模型

# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
callbacks = [keras.callbacks.TensorBoard(logdir),  # 保存地址# 保存效果最好的模型: save_best_onlykeras.callbacks.ModelCheckpoint(output_model_file,    save_best_only = True, save_weights_only = True),keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]   history = model.fit(x_train_scaled, y_train, epochs = 10,validation_data= (x_valid_scaled, y_valid),callbacks = callbacks)

  • 保存模型
# 保存模型
output_model_file2 = os.path.join(logdir, 'fashion_mnist_model.h5')
model.save(output_model_file2)
  •  保存参数

# 另一种保存参数的方法
model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5'))
  • 模型评估
# evaluate 评估
model.evaluate(x_valid_scaled, y_valid)   # [0.35909169912338257, 0.88919997215271]
  • 模型加载
# 加载模型
model2 = keras.models.load_model(output_model_file2)
model2.evaluate(x_valid_scaled, y_valid)  # [0.35909169912338257, 0.88919997215271]

二 保存模型为savemodel格式

# 把keras模型保存成savedmodel格式
tf.saved_model.save(model, './keras_saved_model')
  •  读取模型
# 加载savedmodel模型
loaded_saved_model = tf.saved_model.load('./keras_saved_model')
loaded_saved_model

 2.1 另一种保存

# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
model.load_weights(output_model_file)

三 tflite_interpreter 的使用

  • 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import os
with open('./tflite_models/concrete_func_tf_lite', 'rb') as f:concrete_func_tflite = f.read()
  • 创建interpreter
# 创建interpreter
interpreter = tf.lite.Interpreter(model_content = concrete_func_tflite)
# 分配内存
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
  • 预测数值
input_data = tf.constant(np.ones(input_details[0]['shape'], dtype = np.float32))
# 传入预测数据
interpreter.set_tensor(input_details[0]['index'], input_data)# 执行预测
interpreter.invoke()# 获取输出
output_results = interpreter.get_tensor(output_details[0]['index'])
print(output_results)

四 to_concrete_function

  • 加载文件
# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model(np.ones((1, 784)))

  • 把keras模型转化为concrete function
# 把keras模型转化为concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,loaded_keras_model.inputs[0].dtype))
# 使用
keras_concrete_func(tf.constant(np.ones((1, 784), dtype = np.float32)))

五 to_quantized_tflite

5.1 keras to tflite

# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model
# lite 精简版模型   # 创建转化器
keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)
keras_to_tflite_converter
# 给converter添加量化的优化  # 把32位的浮点数变成8位整数
keras_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]# 执行转化
keras_tflite = keras_to_tflite_converter.convert()
# 写入指定文件
import os
if not os.path.exists('./tflite_models'):os.mkdir('./tflite_models')with open('./tflite_models/quantized_keras_tflite', 'wb') as f:f.write(keras_tflite)

5.2 concrete function to tflite

# 把keras模型转化成concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,loaded_keras_model.inputs[0].dtype))
concrete_func_to_tflite_converter = tf.lite.TFLiteConverter.from_concrete_functions([keras_concrete_func])
concrete_func_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
concrete_func_tflite = concrete_func_to_tflite_converter.convert()
with open('./tflite_models/quantized_concrete_func_tf_lite', 'wb') as f:f.write(concrete_func_tflite)

5.3 saved_model to tflite

saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model('./keras_saved_model/')
saved_model_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
saved_model_tflite = saved_model_to_tflite_converter.convert()
with open('./tflite_models/quantized_saved_model_tflite', 'wb') as f:f.write(saved_model_tflite)

相关文章:

24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)

知识要点 keras 保存成hdf5文件, 1.保存模型和参数, 2.只保存参数 1.保存模型和参数 save_modelcallback ModelCheckpoint2. 只保存参数 save_weightscallback ModelCheckpoint save_weights_only True 保存模型: 案例数据: Fashion-MNIST总共有十个类别的图像model.save_w…...

【Echarts图例点击事件】自定义Echarts图例legend点击事件(已解决)

目录先睹为快(效果)1、实现Echarts多条曲线2、点击echarts触发接口请求2.1 先默认隐藏部分数据2.2 自定义legend图例点击事件3、源码下载地址(解压即用)**【写在前面】**这下我又不得不说了,还是客户现场使用时想查询一…...

uniapp-首页配置

为了获取到后台服务器发来的数据,需要配置相应的网络地址。位置在main.js入口文件中。 import { $http } from escook/request-miniprogramuni.$http $http // 配置请求根路径 $http.baseUrl https://api-hmugo-web.itheima.net// 请求开始之前做一些事情 $http.…...

支持DDR5,超频更简单,小雕够给力,技嘉B760M小雕WIFI主板上手

目前13代酷睿已经全员集结了,其中全新的i5 13490F应该依然会备受欢迎,当然了,刚上市不久的13代酷睿价格方面还不是很有吸引力,好在12代酷睿在新一代主板上面依然可用,所以预算有限的朋友,完全可用继续使用1…...

fengMap 自定义dom 偏离实际位置;缩放时飘出地图所在区域

目录 一、问题 二、原因及解决方法 三、总结 一、问题 1.前人写了一份代码,很奇怪。使用 new fengmap.FMCompositeMarker添加的复合覆盖物位置是正常的,缩放的时候也是正常的,仍然处于地图内部;但是new fengmap.FMDomMarker添加…...

TryHackMe-黑我杯

黑我杯 相信我们大家在TryHackMe的日积月累都学到了不少东西,从纯萌新到oscp再到更高 我很高兴能将国内各thm玩家聚集到一起,构建一个更好的学习环境和氛围 本次娱乐分两场: Offensive Pentesting — 中等难度Junior Penetration — 容易难…...

【JAVA程序设计】【C00109】基于SSM(非maven)的员工工资管理系统

基于SSM(非maven)的员工工资管理系统项目简介项目获取开发环境项目技术运行截图项目简介 基于ssm框架非maven开发的企业工资管理系统共分为二个角色:系统管理员、员工 管理员角色包含以下功能: 系统后台登陆、管理员管理、员工信…...

《计算机原理》——HelloWorld.cpp如何运行的

学校《计算机原理》开课啦!特此开辟专栏,将一些知识作为笔记,记录下来。 前言 本篇博客知识点来源于educoder的相关题目 1. 相关知识 1.1 计算机语言 计算机语言是人与计算机之间通讯的语言,计算机语言包括编写计算机程序的字符…...

【面试题】在JS循环中使用await会怎么样?

前言这个问题是这样产生的?某天,在学习异步的知识遇到这样一道题:使用Promise的方式,每隔一秒输出数组中一个值const arr [1, 2, 3] ​ arr.reduce((pre, cur) > {return pre.then(() > {returnnewPromise((resolve, rejec…...

Qt QMessageBox详解

文章目录一.QMessageBox介绍枚举属性函数二.QMessageBox的用法1.导入QMessage库2.弹窗提示3.提供选项的弹窗提示4.作为提示,报警,报错提示窗口一.QMessageBox介绍 文本消息显示框(message box)向用户发出情况警报信息并进一步解释警报或向用户提问&…...

Flutter之beamer路由入门指南

beamer路由入门指南 前言使用方法1、路由配置方式1路由配置方式2路由跳转测试现象前言 Beamer是一个很好用的路由组件,本文以beamer1.5.0版本进行说明,前面博主也介绍了其他路由组件 Flutter实战之go_router路由组件入门指南 、 Flutter之Fluro路由组件入门指南 Flutter之Ge…...

「基础篇」机器学习概览

文章目录1. 什么是机器学习2. 引入机器学习3. 应用场景4. 机器学习分类4.1. 有无人类监督4.2. 是否增量学习4.3. 泛化方式5. 主要挑战6. 测试与验证1. 什么是机器学习 机器学习(Machine Learning,ML)是一个研究领域,让计算机无需…...

揭秘可视化图探索工具 NebulaGraph Explore 是如何实现图计算的

前言 在可视化图探索工具 NebulaGraph Explorer 3.1.0 版本中加入了图计算工作流功能,针对 NebulaGraph 提供了图计算的能力,同时可以利用工作流的 nGQL 运行能力支持简单的数据读取,过滤及写入等数据处理功能。 本文将简单分享下 NebulaGr…...

移动架构43_什么是Jetpack

Android移动架构汇总​​​​​​​ 文章目录一 Android 开发框架演变1 MVC2 MVP3 MVVM二 什么是JetPack三 如何构建支持Jetpack项目一 Android 开发框架演变 1 MVC Model-View-Controller,模型-视图-控制器,Model负责数据管理,View负责UI显…...

TiDB的分布式事务原理探究

事务开启 获取全局授时作为startTS构建一个tikvTxn对象(包括snapshot)。 事务写 txn.Set方法本质上将kv值写入了一个内存缓存(即kv/memdb_buffer.go中的memDbBuffer)中。该内存kv数据库利用的是golevel提供的功能。 事务回滚 直接将tikvTxn的valid字段…...

【C语言】函数指针和指针函数

文章目录[TOC](文章目录)前言概述函数指针定义:使用:回调函数指针函数前言 今天学一下函数指针 提示:以下是本篇文章正文内容,下面案例可供参考 概述 函数指针:是一个指向函数的指针,在内存空间中存放的…...

Nodejs中npx简介和作用

一、npx简介npm从5.25.2版开始,增加了 npx 命令。方便了我在项目中使用全局包。二、安装Node安装后自带npm模块,可以直接使用npx命令。如果不能使用用,就要手动安装一下。npm install -g npx三、使用npx想要解决的主要问题,就是调…...

Matplotlib精品学习笔记001——绘制3D图形详解+实例讲解

3D图片更生动,或许在时间序列数据的展示上更胜一筹 想法: 学习3D绘图的想法来自科研绘图中。我从事的专业是古植物学,也就是和植物化石打交道。化石有三大信息:1.物种信息,也就是它的分类学价值;2.时间信息…...

学习ifconfig实战技巧,成为网络管理高手

文章目录前言一. ifconfig 命令介绍二. 语法格式及常用选项三. 参考案例3.1 显示网络设备信息3.2 启动和关闭指定的网卡3.3 对指定的网卡设备执行修改IP地址操作3.4 启动和关闭ARP协议3.5 使用ifconfig添加网卡总结前言 大家好,又见面了,我是沐风晓月&a…...

day38|70. 爬楼梯(进阶)、322. 零钱兑换、279.完全平方数

70. 爬楼梯(进阶) 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? 示例 1: 输入:n 2 输出:2 解释:有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2…...

SpringBoot全局异常处理

一、目的 当客户端/前端向服务端发送一个请求后,这个请求并不是每次都能完全正确的处理,比如出现一些资源不存在、参数错误或者内部错误等信息的时候,就需要将异常反馈给客户端或者前端。那么这就需要程序有完整的异常处理机制。 在 Java 中所…...

SpringBoot异常处理

目录 一、 错误处理 1. 默认规则 2. 定制错误处理逻辑 二、自定义异常处理 1. 实现 ErrorController 2. RestControllerAdvice/ControllerAdvice ExceptionHandler 实现自定义异常 3. 新建 UserController.class 测试 3 种不同异常的处理 4. 最终效果如下 补充 1. 参…...

《C++ Primer Plus》(第6版)第8章编程练习

《C Primer Plus》(第6版)第8章编程练习《C Primer Plus》(第6版)第8章编程练习1. 打印字符串2. CandyBar3. 将string对象的内容转换为大写4. 设置并打印字符串5. max5()6. maxn()7. SumArray()《C Primer Plus》(第6版…...

RAD Studio 11.3 Alexandria Crack

RAD Studio 11.3 Alexandria Crack 瞄准最新平台版本-此版本增加了对Android 13和Apple macOS Ventura的官方支持。它还支持Ubuntu 22 LTS和Microsoft Windows Server 2022。 使用生物特征认证-New为FireMonkey移动应用程序提供了新的移动生物特征认证组件。 部署嵌入式InterBa…...

Stm32 iic 协议使用

/* 第1个参数为I2C操作句柄 第2个参数为从机设备地址 第3个参数为从机寄存器地址 第4个参数为从机寄存器地址长度 第5个参数为发送的数据的起始地址 第6个参数为传输数据的大小 第7个参数为操作超时时间 */ HAL_I2C_Mem_Write(&hi2c2,salve_add,0,0,PA_BUFF,sizeof(PA_BUFF…...

Malware Dev 02 - Windows SDDL 后门利用之 SCManager

写在最前 如果你是信息安全爱好者,如果你想考一些证书来提升自己的能力,那么欢迎大家来我的 Discord 频道 Northern Bay。邀请链接在这里: https://discord.gg/9XvvuFq9Wb我拥有 OSCP,OSEP,OSWE,OSED&…...

每日一题29——山峰数组的顶部

符合下列属性的数组 arr 称为 山峰数组&#xff08;山脉数组&#xff09; &#xff1a; arr.length > 3 存在 i&#xff08;0 < i < arr.length - 1&#xff09;使得&#xff1a; arr[0] < arr[1] < ... arr[i-1] < arr[i] arr[i] > arr[i1] > ... &g…...

Linux- 系统随你玩之--好用到炸裂的系统级监控、诊断工具

文章目录1、前言2、lsof介绍2.1、问题来了&#xff1a; 所有用户都可以采用该命令吗&#xff1f;3、 服务器安装lsof3.1、安装3.2、检查安装是否正常。4、lsof 命令4.1、常用功能选项4.2、输出内容4.2.1 、FD和 TYPE列5、 lsof 命令实操常见用法6 、常用组合命令7、 结语1、前言…...

第十三节 继承

什么是继承&#xff1f; java中提供一个关键字extends&#xff0c;用这个关键字&#xff0c;我们可以让一个类和另一个类建立父子关系。 public class Student extends People{} student为子类&#xff08;派生类&#xff09;&#xff0c;people为父类&#xff08;基类或者超类…...

【优化】性能优化Springboot 项目配置内置Tomcat使用Http11AprProtocol(AIO)

Springboot 项目配置内置tomcat使用Http11AprProtocol(AIO) Windows版本 1.下载Springboot对应版本tomcat包 下载地址 Apache Tomcat - Apache Tomcat 9 Software Downloads 找到bin目录下 tcnative-1.dll 文件 2 放到jdk的bin目录下 Linux版本 在Springboot中内嵌的Tomcat默…...

wordpress 双侧边栏/黄页88网推广服务

我的SQL SEVER 2005是2010年12装的 那时候一直在用&#xff0c;后来自己的PC不想用来编程了所以就把 PC上面的SQL服务都给禁掉了。现在的情况是怎么都连不上本地数据库sqlexpress倒是可以连接上。有没有什么办法能 解决这个问题&#xff0c; 请大家帮忙解决一下。标题: 连接到…...

东莞专业做淘宝网站/怎么做好网络营销

在CNN(1)和CNN(2)两篇文章中&#xff0c;主要说明的是CNN的基本架构和权值共享(Weight Sharing)&#xff0c;本文则重点介绍卷积的部分。 首先&#xff0c;在卷积之前&#xff0c;我们的数据是4D的tensor&#xff08;width,height,channels,batch&#xff09;&#xff0c;在CNN…...

衡阳市建设局网站/品牌营销推广代运营

https://github.com/qiu8310/crontab http://blog.csdn.net/spring21st/article/details/52789045 Windows 版 Crontab [JAVA] 关于文件编码&#xff0c;由于需要结合 windows 系统的命令行&#xff0c;所以源文件和配置文件都是GBK编码的&#xff0c;请注意。 Crontab 是 Li…...

药物研发网站怎么做/网络整合营销方案

题目链接&#xff1a;http://www.51nod.com/onlineJudge/questionCode.html#!problemId1049 题意&#xff1a;又是仲文题诶&#xff5e; 思路&#xff1a;暴力会超时&#xff0c;又好像没什么专门的算法&#xff0c;自己yy了一个&#xff0c;算是贪心吧&#xff5e; 暴力的话就…...

网站怎么做图片搜索/百度快速排名化

下面的文章讲解了如何为Cassandra设置登录密码&#xff1a;(操作的环境为Linux 下的 Cassandra , Windows 下面存在一些问题, 不知道怎么解决, cqlsh 不能登录) 0.修改配置文件 cassandra.yaml&#xff0c;把 authenticator: AllowAllAuthenticator改为authenticator: Passwor…...

做临时网站/百度推广没有一点效果

本文为大家带来CF延迟133ms怎么解决133ms延迟解决方法。玩家们反应的就是这种情况&#xff0c;非常稳定的133ms延迟&#xff0c;点击开始游戏之后&#xff0c;就一直卡在黑色的界面&#xff0c;无法登陆游戏。Q&#xff1a;cf延迟133ms进不去是为什么&#xff1f;A&#xff1a;…...