24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)
知识要点
keras 保存成hdf5文件, 1.保存模型和参数, 2.只保存参数
- 1.保存模型和参数
- save_model
- callback ModelCheckpoint
- 2. 只保存参数
- save_weights
- callback ModelCheckpoint save_weights_only = True
保存模型:
- 案例数据: Fashion-MNIST总共有十个类别的图像
- model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5')) # 保存参数的方法
- 加载参数: model.load_weights(os.path.join(logdir, 'fashion_mnist_weight.h5'))
- 保存模型: model.save(os.path.join(logdir, 'fashion_mnist_model.h5'))
- 加载模型: model2 = keras.models.load_model(os.path.join(logdir, 'fashion_mnist_model.h5'))
- 把keras模型保存成savedmodel格式: tf.saved_model.save(model, './keras_saved_model')
一 模型保存和部署
- TFLite是为了将深度学习模型部署在移动端和嵌入式设备的工具包,可以把训练好的TF模型通过转化、部署和优化三个步骤,达到提升运算速度,减少内存、显存占用的效果。
- TFlite主要由Converter和Interpreter组成。Converter负责把TensorFlow训练好的模型转化,并输出为.tflite文件(FlatBuffer格式)。转化的同时,还完成了对网络的优化,如量化。Interpreter则负责把.tflite部署到移动端,嵌入式(embedded linux device)和microcontroller,并高效地执行推理过程,同时提供API接口给Python,Objective-C,Swift,Java等多种语言。简单来说,Converter负责打包优化模型,Interpreter负责高效易用地执行推理。
-
Fashion-MNIST总共有十个类别的图像。每一个类别由训练数据集6000张图像和测试数据集1000张图像。所以训练集和测试集分别包含60000张和10000张。测试训练集用于评估模型的性能。
-
每一个输入图像的高度和宽度均为28像素。数据集由灰度图像组成。Fashion-MNIST,中包含十个类别,分别是t-shirt,trouser,pillover,dress,coat,sandal,shirt,sneaker,bag,ankle boot。
1.1 模型创建
- 导包
# 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
- 时尚数据导入
# 时尚数据导入
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
- 标准化
# 标准化
from sklearn.preprocessing import StandardScaler # preprocessing 预处理
scaler = StandardScaler()x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1, 784))
x_valid_scaled = scaler.fit_transform(x_valid.astype(np.float32).reshape(-1, 784))
x_test_scaled = scaler.fit_transform(x_test.astype(np.float32).reshape(-1, 784))
- 创建模型
# 创建模型
model = keras.models.Sequential([keras.layers.Dense(512, activation = 'relu', input_shape = (784, )),keras.layers.Dense(256, activation = 'relu'),keras.layers.Dense(128, activation = 'relu'),keras.layers.Dense(10, activation = 'softmax')])model.compile(loss = 'sparse_categorical_crossentropy',optimizer = 'adam',metrics = ['accuracy'])
1.2 保存模型
# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
callbacks = [keras.callbacks.TensorBoard(logdir), # 保存地址# 保存效果最好的模型: save_best_onlykeras.callbacks.ModelCheckpoint(output_model_file, save_best_only = True, save_weights_only = True),keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)] history = model.fit(x_train_scaled, y_train, epochs = 10,validation_data= (x_valid_scaled, y_valid),callbacks = callbacks)
- 保存模型
# 保存模型
output_model_file2 = os.path.join(logdir, 'fashion_mnist_model.h5')
model.save(output_model_file2)
-
保存参数
# 另一种保存参数的方法
model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5'))
- 模型评估
# evaluate 评估
model.evaluate(x_valid_scaled, y_valid) # [0.35909169912338257, 0.88919997215271]
- 模型加载
# 加载模型
model2 = keras.models.load_model(output_model_file2)
model2.evaluate(x_valid_scaled, y_valid) # [0.35909169912338257, 0.88919997215271]
二 保存模型为savemodel格式
# 把keras模型保存成savedmodel格式
tf.saved_model.save(model, './keras_saved_model')
- 读取模型
# 加载savedmodel模型
loaded_saved_model = tf.saved_model.load('./keras_saved_model')
loaded_saved_model
2.1 另一种保存
# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
model.load_weights(output_model_file)
三 tflite_interpreter 的使用
- 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import os
with open('./tflite_models/concrete_func_tf_lite', 'rb') as f:concrete_func_tflite = f.read()
- 创建interpreter
# 创建interpreter
interpreter = tf.lite.Interpreter(model_content = concrete_func_tflite)
# 分配内存
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
- 预测数值
input_data = tf.constant(np.ones(input_details[0]['shape'], dtype = np.float32))
# 传入预测数据
interpreter.set_tensor(input_details[0]['index'], input_data)# 执行预测
interpreter.invoke()# 获取输出
output_results = interpreter.get_tensor(output_details[0]['index'])
print(output_results)
四 to_concrete_function
- 加载文件
# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model(np.ones((1, 784)))
- 把keras模型转化为concrete function
# 把keras模型转化为concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,loaded_keras_model.inputs[0].dtype))
# 使用
keras_concrete_func(tf.constant(np.ones((1, 784), dtype = np.float32)))
五 to_quantized_tflite
5.1 keras to tflite
# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model
# lite 精简版模型 # 创建转化器
keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)
keras_to_tflite_converter
# 给converter添加量化的优化 # 把32位的浮点数变成8位整数
keras_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]# 执行转化
keras_tflite = keras_to_tflite_converter.convert()
# 写入指定文件
import os
if not os.path.exists('./tflite_models'):os.mkdir('./tflite_models')with open('./tflite_models/quantized_keras_tflite', 'wb') as f:f.write(keras_tflite)
5.2 concrete function to tflite
# 把keras模型转化成concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,loaded_keras_model.inputs[0].dtype))
concrete_func_to_tflite_converter = tf.lite.TFLiteConverter.from_concrete_functions([keras_concrete_func])
concrete_func_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
concrete_func_tflite = concrete_func_to_tflite_converter.convert()
with open('./tflite_models/quantized_concrete_func_tf_lite', 'wb') as f:f.write(concrete_func_tflite)
5.3 saved_model to tflite
saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model('./keras_saved_model/')
saved_model_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
saved_model_tflite = saved_model_to_tflite_converter.convert()
with open('./tflite_models/quantized_saved_model_tflite', 'wb') as f:f.write(saved_model_tflite)
相关文章:

24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)
知识要点 keras 保存成hdf5文件, 1.保存模型和参数, 2.只保存参数 1.保存模型和参数 save_modelcallback ModelCheckpoint2. 只保存参数 save_weightscallback ModelCheckpoint save_weights_only True 保存模型: 案例数据: Fashion-MNIST总共有十个类别的图像model.save_w…...

【Echarts图例点击事件】自定义Echarts图例legend点击事件(已解决)
目录先睹为快(效果)1、实现Echarts多条曲线2、点击echarts触发接口请求2.1 先默认隐藏部分数据2.2 自定义legend图例点击事件3、源码下载地址(解压即用)**【写在前面】**这下我又不得不说了,还是客户现场使用时想查询一…...

uniapp-首页配置
为了获取到后台服务器发来的数据,需要配置相应的网络地址。位置在main.js入口文件中。 import { $http } from escook/request-miniprogramuni.$http $http // 配置请求根路径 $http.baseUrl https://api-hmugo-web.itheima.net// 请求开始之前做一些事情 $http.…...

支持DDR5,超频更简单,小雕够给力,技嘉B760M小雕WIFI主板上手
目前13代酷睿已经全员集结了,其中全新的i5 13490F应该依然会备受欢迎,当然了,刚上市不久的13代酷睿价格方面还不是很有吸引力,好在12代酷睿在新一代主板上面依然可用,所以预算有限的朋友,完全可用继续使用1…...

fengMap 自定义dom 偏离实际位置;缩放时飘出地图所在区域
目录 一、问题 二、原因及解决方法 三、总结 一、问题 1.前人写了一份代码,很奇怪。使用 new fengmap.FMCompositeMarker添加的复合覆盖物位置是正常的,缩放的时候也是正常的,仍然处于地图内部;但是new fengmap.FMDomMarker添加…...

TryHackMe-黑我杯
黑我杯 相信我们大家在TryHackMe的日积月累都学到了不少东西,从纯萌新到oscp再到更高 我很高兴能将国内各thm玩家聚集到一起,构建一个更好的学习环境和氛围 本次娱乐分两场: Offensive Pentesting — 中等难度Junior Penetration — 容易难…...

【JAVA程序设计】【C00109】基于SSM(非maven)的员工工资管理系统
基于SSM(非maven)的员工工资管理系统项目简介项目获取开发环境项目技术运行截图项目简介 基于ssm框架非maven开发的企业工资管理系统共分为二个角色:系统管理员、员工 管理员角色包含以下功能: 系统后台登陆、管理员管理、员工信…...

《计算机原理》——HelloWorld.cpp如何运行的
学校《计算机原理》开课啦!特此开辟专栏,将一些知识作为笔记,记录下来。 前言 本篇博客知识点来源于educoder的相关题目 1. 相关知识 1.1 计算机语言 计算机语言是人与计算机之间通讯的语言,计算机语言包括编写计算机程序的字符…...
【面试题】在JS循环中使用await会怎么样?
前言这个问题是这样产生的?某天,在学习异步的知识遇到这样一道题:使用Promise的方式,每隔一秒输出数组中一个值const arr [1, 2, 3] arr.reduce((pre, cur) > {return pre.then(() > {returnnewPromise((resolve, rejec…...

Qt QMessageBox详解
文章目录一.QMessageBox介绍枚举属性函数二.QMessageBox的用法1.导入QMessage库2.弹窗提示3.提供选项的弹窗提示4.作为提示,报警,报错提示窗口一.QMessageBox介绍 文本消息显示框(message box)向用户发出情况警报信息并进一步解释警报或向用户提问&…...
Flutter之beamer路由入门指南
beamer路由入门指南 前言使用方法1、路由配置方式1路由配置方式2路由跳转测试现象前言 Beamer是一个很好用的路由组件,本文以beamer1.5.0版本进行说明,前面博主也介绍了其他路由组件 Flutter实战之go_router路由组件入门指南 、 Flutter之Fluro路由组件入门指南 Flutter之Ge…...

「基础篇」机器学习概览
文章目录1. 什么是机器学习2. 引入机器学习3. 应用场景4. 机器学习分类4.1. 有无人类监督4.2. 是否增量学习4.3. 泛化方式5. 主要挑战6. 测试与验证1. 什么是机器学习 机器学习(Machine Learning,ML)是一个研究领域,让计算机无需…...

揭秘可视化图探索工具 NebulaGraph Explore 是如何实现图计算的
前言 在可视化图探索工具 NebulaGraph Explorer 3.1.0 版本中加入了图计算工作流功能,针对 NebulaGraph 提供了图计算的能力,同时可以利用工作流的 nGQL 运行能力支持简单的数据读取,过滤及写入等数据处理功能。 本文将简单分享下 NebulaGr…...

移动架构43_什么是Jetpack
Android移动架构汇总 文章目录一 Android 开发框架演变1 MVC2 MVP3 MVVM二 什么是JetPack三 如何构建支持Jetpack项目一 Android 开发框架演变 1 MVC Model-View-Controller,模型-视图-控制器,Model负责数据管理,View负责UI显…...
TiDB的分布式事务原理探究
事务开启 获取全局授时作为startTS构建一个tikvTxn对象(包括snapshot)。 事务写 txn.Set方法本质上将kv值写入了一个内存缓存(即kv/memdb_buffer.go中的memDbBuffer)中。该内存kv数据库利用的是golevel提供的功能。 事务回滚 直接将tikvTxn的valid字段…...

【C语言】函数指针和指针函数
文章目录[TOC](文章目录)前言概述函数指针定义:使用:回调函数指针函数前言 今天学一下函数指针 提示:以下是本篇文章正文内容,下面案例可供参考 概述 函数指针:是一个指向函数的指针,在内存空间中存放的…...
Nodejs中npx简介和作用
一、npx简介npm从5.25.2版开始,增加了 npx 命令。方便了我在项目中使用全局包。二、安装Node安装后自带npm模块,可以直接使用npx命令。如果不能使用用,就要手动安装一下。npm install -g npx三、使用npx想要解决的主要问题,就是调…...

Matplotlib精品学习笔记001——绘制3D图形详解+实例讲解
3D图片更生动,或许在时间序列数据的展示上更胜一筹 想法: 学习3D绘图的想法来自科研绘图中。我从事的专业是古植物学,也就是和植物化石打交道。化石有三大信息:1.物种信息,也就是它的分类学价值;2.时间信息…...

学习ifconfig实战技巧,成为网络管理高手
文章目录前言一. ifconfig 命令介绍二. 语法格式及常用选项三. 参考案例3.1 显示网络设备信息3.2 启动和关闭指定的网卡3.3 对指定的网卡设备执行修改IP地址操作3.4 启动和关闭ARP协议3.5 使用ifconfig添加网卡总结前言 大家好,又见面了,我是沐风晓月&a…...
day38|70. 爬楼梯(进阶)、322. 零钱兑换、279.完全平方数
70. 爬楼梯(进阶) 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? 示例 1: 输入:n 2 输出:2 解释:有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘
美国西海岸的夏天,再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至,这不仅是开发者的盛宴,更是全球数亿苹果用户翘首以盼的科技春晚。今年,苹果依旧为我们带来了全家桶式的系统更新,包括 iOS 26、iPadOS 26…...
连锁超市冷库节能解决方案:如何实现超市降本增效
在连锁超市冷库运营中,高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术,实现年省电费15%-60%,且不改动原有装备、安装快捷、…...

[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明
AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...
土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等
🔍 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术,可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势,还能有效评价重大生态工程…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南
🚀 C extern 关键字深度解析:跨文件编程的终极指南 📅 更新时间:2025年6月5日 🏷️ 标签:C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言🔥一、extern 是什么?&…...

在WSL2的Ubuntu镜像中安装Docker
Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包: for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...

嵌入式学习笔记DAY33(网络编程——TCP)
一、网络架构 C/S (client/server 客户端/服务器):由客户端和服务器端两个部分组成。客户端通常是用户使用的应用程序,负责提供用户界面和交互逻辑 ,接收用户输入,向服务器发送请求,并展示服务…...
Java求职者面试指南:计算机基础与源码原理深度解析
Java求职者面试指南:计算机基础与源码原理深度解析 第一轮提问:基础概念问题 1. 请解释什么是进程和线程的区别? 面试官:进程是程序的一次执行过程,是系统进行资源分配和调度的基本单位;而线程是进程中的…...

力扣热题100 k个一组反转链表题解
题目: 代码: func reverseKGroup(head *ListNode, k int) *ListNode {cur : headfor i : 0; i < k; i {if cur nil {return head}cur cur.Next}newHead : reverse(head, cur)head.Next reverseKGroup(cur, k)return newHead }func reverse(start, end *ListNode) *ListN…...