深度学习应用技巧4-模型融合:投票法、加权平均法、集成模型法
大家好,我是微学AI,今天给大家介绍一下,深度学习中的模型融合。它是将多个深度学习模型或其预测结果结合起来,以提高模型整体性能的一种技术。
深度学习中的模型融合技术,也叫做集成学习,是指同时使用多个模型来进行预测或分类,将它们的结果结合起来,从而获得更准确、更鲁棒的结果。这种方法能够弥补单个模型的不足之处,提高模型的性能。
常见的深度学习模型包括卷积神经网络(CNN)、循环神经网络(RNN)等,在实际应用中,通常会使用多个模型来解决同一个任务。然而,单独使用每个模型可能会存在过拟合、欠拟合、训练时间长等一些问题。这时,模型融合技术就派上用场了。 对于模型融合技术,其主要的思想是结合多个模型的优点,减少缺点,从而提高整体的性能。

一、模型融合技巧主要包括以下几个方面:
1. 投票法: 对多个相同类型的模型进行训练,最后通过投票的方式选择输出结果最多的类别作为最终的预测结果。在实践中,通常会使用奇数个模型,以避免出现相同数量的投票结果。
2. 加权平均法: 对多个相同类型的模型的输出结果进行加权平均。采用加权平均法融合的模型,可以根据效果不同,分配不同的权重。
3. 集成多种不同类型的模型: 在深度学习中,常常会使用不同类型的模型,如 CNN、 RNN 、 LSTM 等,将它们进行集成,综合利用不同模型的优点,进一步提高系统的性能。
4. 提前停止模型训练: 训练多个模型时,如果其中一个模型已经达到了最优的状态,可以停止继续训练,以达到快速训练过程和提高融合效果的目的。
模型融合主要提高了深度学习模型的表现力和泛化性能,更好的解决了过拟合等问题。在选择模型融合技巧时,可以根据具体实际应用选择不同的融合方法,灵活运用各种方法,从而得到更好的模型效果。
应用场景想象:
想象你正在参加一个模型比赛,需要对一些数据进行预测。你仅使用了一个模型进行预测,但是你发现这个模型可能存在某些缺陷,导致预测结果不够准确,于是你想到使用模型融合技术。 你开始集成三个不同的模型,每一个模型都有自己的特点和优缺点。为了得到集成模型的预测结果,你可以采用堆叠法,即把三个模型的预测结果作为输入特征,再训练一个新的模型进行预测。通过堆叠,你得到了一个更加准确的预测结果。 如果你觉得模型之间存在评估的差异,你也可以使用加权平均法来集成模型。你可以根据每一个模型的表现,为它们分配一定的权重,然后根据这些权重,对它们输出的结果进行加权平均,从而得到更加精确的预测结果。 最后,你也可以使用投票法,这种方法是将多个模型的预测结果进行投票,选择获得最多票数的结果作为最终预测结果。它会适用于模型数量较多的情况,即使其中某个模型出现了不准确的情况,也不会对最终结果有太大的影响。 通过这些集成模型的方法,你可以在深度学习中获得更为准确、稳定的预测结果,从而提升模型的性能和可靠性。
二、模型融合代码样例:
下面我将利用投票法、加权平均法和集成模型法三种方式对CNN网络进行融合。
import numpy as np
from sklearn.metrics import accuracy_score
from tensorflow import keras
from tensorflow.keras import layers# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()# 将像素值缩放到 0-1 范围内
train_images = train_images.astype('float32') / 255.
test_images = test_images.astype('float32') / 255.# 将标签转换为 one-hot 编码
train_labels = keras.utils.to_categorical(train_labels, 10)
test_labels = keras.utils.to_categorical(test_labels, 10)# 定义 CNN 模型
def create_model():model = keras.Sequential()model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Conv2D(64, (3, 3), activation='relu'))model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Conv2D(64, (3, 3), activation='relu'))model.add(layers.Flatten())model.add(layers.Dense(64, activation='relu'))model.add(layers.Dense(10, activation='softmax'))return model# 创建多个 CNN 模型
models = []
for i in range(3):model = create_model()model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])model.fit(train_images, train_labels, epochs=5, batch_size=128, verbose=1)models.append(model)
1.投票法:对测试集进行预测,并取预测结果的众数作为最终预测结果
predictions = []
for model in models:predictions.append(model.predict(test_images))
y_pred = np.argmax(np.round(np.mean(predictions, axis=0)), axis=1)
y_true = np.argmax(test_labels, axis=1)# 计算准确率
acc = accuracy_score(y_true, y_pred)
print("使用投票法进行模型融合的准确率:", acc)
2.加权平均法:为每个模型定义一个权重,并将预测结果加权平均
weights = [0.2, 0.3, 0.5]
predictions = []
for i, model in enumerate(models):prediction = model.predict(test_images)predictions.append(weights[i] * prediction)
y_pred = np.argmax(np.sum(predictions, axis=0), axis=1)
y_true = np.argmax(test_labels, axis=1)# 计算准确率
acc = accuracy_score(y_true, y_pred)
print("使用加权平均法进行模型融合的准确率:", acc)
3.模型集成法:将多个 CNN 模型堆叠在一起,将其看作一个更强大的模型,并对测试集进行预测
inputs = keras.Input(shape=(28, 28, 1))
outputs = [model(inputs) for model in models]
y = layers.Average()(outputs)
ensemble_model = keras.Model(inputs=inputs, outputs=y)
ensemble_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练集成模型
ensemble_model.fit(train_images, train_labels, epochs=5, batch_size=64)# 测试集成模型
predictions = ensemble_model.predict(test_images)
y_pred = np.argmax(predictions, axis=1)
y_true = np.argmax(test_labels, axis=1)# 计算准确率
acc = accuracy_score(y_true, y_pred)
print("使用模型集成方法进行模型融合的准确率:", acc)
最后我们可以看三种方法的准确率,在实战案例中根据业务需求进行方法的选择。
相关文章:
深度学习应用技巧4-模型融合:投票法、加权平均法、集成模型法
大家好,我是微学AI,今天给大家介绍一下,深度学习中的模型融合。它是将多个深度学习模型或其预测结果结合起来,以提高模型整体性能的一种技术。 深度学习中的模型融合技术,也叫做集成学习,是指同时使用多个…...
【并发编程】深入理解Java内存模型及相关面试题
文章目录优秀引用1、引入2、概述3、JMM内存模型的实现3.1、简介3.2、原子性3.3、可见性3.4、有序性4、相关面试题4.1、你知道什么是Java内存模型JMM吗?4.2、JMM和volatile他们两个之间的关系是什么?4.3、JMM有哪些特性/能说说JMM的三大特性吗?…...
C++编程语言STL之queue介绍
本文主要介绍C编程语言的STL(Standard Template Library)中queue(队列)的相关知识,同时通过示例代码介绍queue的常见用法。1 概述适配器(adaptor)是STL中的一个通用概念。容器、迭代器和函数都有…...
ACO优化蚁群算法
%% 蚁群算法(ant colony optimization,ACO) %清空变量 clear close all clc [ graph ] createGraph(); figure subplot(1,3,1) drawGraph( graph); %% 初始化参数 maxIter 100; antNo 50; tau0 10 * 1 / ( graph.n * mean( graph.edges(:) …...
SwiftUI 常用组件和属性(SwiftUI初学笔记)
本文为初学SwiftUI笔记。记录SwiftUI常用的组件和属性。 组件 共有属性(View的属性) Image("toRight").resizable().background(.red) // 背景色.shadow(color: .black, radius: 2, x: 9, y: 15) //阴影.frame(width: 30, height: 30) // 宽高 可以只设置宽或者高.…...
Centos 中设置代理的两种方法
Centos 中设置代理的两种方法 在使用局域网时,有时在局域网内只有一台电脑可以进行上网,其他电脑只能通过配置代理的方式来上网,在Windows系统中设置代理上网相对简单,如果只需上网的话,只需在浏览器中找到网络连接&am…...
高速PCB设计指南系列(一)
第一篇 PCB布线 在PCB设计中,布线是完成产品设计的重要步骤,可以说前面的准备工作都是为它而做的, 在整个PCB中,以布线的设计过程限定最高,技巧最细、工作量最大。PCB布线有单面布线、 双面布线及多层布线。布线的方…...
云端IDE:TitanIDE v2.6.0 正式发布
原文作者:行云创新技术总监 邓冰寒 引言 云原生集成开发环境 TitanIDE v2.6.0 正式发布了,一起来看看都有那些全新的体验吧! TitanIDE 是一款云IDE, 也称 CloudIDE,作为数字化时代研发体系不可或缺的一环,和企业建设…...
【Python】tqdm 模块
import mathfrom tqdm import tqdm, trange# 计算阶乘 results_1 []for i in range(6666):results_1.append(math.factorial(i))这是一个循环计算阶乘的程序,我们不知道程序运行的具体情况,如果能加上一个程序运行过程的进度条,那可就太有趣…...
论文阅读:Adversarial Cross-Modal Retrieval对抗式跨模式检索
Adversarial Cross-Modal Retrieval 对抗式跨模式检索 跨模态检索研究的核心是学习一个共同的子空间,不同模态的数据可以直接相互比较。本文提出了一种新的对抗性跨模态检索(ACMR)方法,它在对抗性学习的基础上寻求有效的共同子空间…...
计算机网络复习
什么是DHCP和DNS DNS(Domain Name System,域名系统),因特网上作为域名和IP地址相互映射的一个分布式数据库,能够使用户更方便的访问互联网,而不用去记住能够被机器直接读取的IP数串。通过主机名,最终得到该主机名对应的…...
unity动画--动画绑定,转换,用脚本触发
文章目录如何制作和添加动画大概过程示例图将多组图片转化为动画放在对象身上实现动画之间的切换使用脚本触发Parameters(Trigger)如何制作和添加动画 大概过程示例图 将多组图片转化为动画放在对象身上 首先,我们要为我们要对象添加animator 然后我们要设置对应的…...
车载汽车充气泵PCBA方案
汽车为什么会需要充气泵呢?其实是由于乘用车中没有供气源,所以就必需充气泵来给避震器供气。充气泵是为了保障汽车车胎对汽车的行驶安全所配备的,防止遇上紧急问题时没有解决方案,同时也可以检测轮胎胎压。现阶段的充气泵方案&…...
Android 连接 MySQL 数据库教程
在 Android 应用程序中连接 MySQL 数据库可以帮助开发人员实现更丰富的数据管理功能。本教程将介绍如何在 Android 应用程序中使用低版本的 MySQL Connector/J 驱动程序来连接 MySQL 数据库。 步骤一:下载 MySQL Connector/J 驱动程序 首先,我们需要下…...
tmall.item.update.schema.get( 天猫编辑商品规则获取 )
¥开放平台免费API必须用户授权 Schema方式编辑天猫商品时,编辑商品规则获取 公共参数 请求地址: HTTP地址 http://gw.api.taobao.com/router/rest 公共请求参数: 公共响应参数: 点击获取key和secret 请求示例 TaobaoClient client new DefaultTaobao…...
Leetcode 2379. 得到 K 个黑块的最少涂色次数
目录 一、题目内容和对应链接 1.题目对应链接 2.题目内容 二、我的想法 三、其他人的题解 一、题目内容和对应链接 1.题目对应链接 Leetcode 2379. 得到 K 个黑块的最少涂色次数 2.题目内容 给你一个长度为 n 下标从 0 开始的字符串 blocks ,blocks[i] 要…...
[深入理解SSD系列 闪存实战2.1.3] 固态硬盘闪存的物理学原理_NAND Flash 的读、写、擦工作原理
2.1.3.1 Flash 的物理学原理与发明历程 经典物理学认为 物体越过势垒,有一阈值能量;粒子能量小于此能量则不能越过,大于此能 量则可以越过。例如骑自行车过小坡,先用力骑,如果坡很低,不蹬自行车也能 靠惯性过去。如果坡很高,不蹬自行车,车到一半就停住,然后退回去。 …...
总结:Linux内核相关
一、介绍看eBPF和Cilium相关内容时,碰到Cilium是运行在第 3/4 层,不明白怎么做到的,思考原理的时候就想到了内容,本文记录下内核相关知识。https://www.oschina.net/p/cilium?hmsraladdin1e1二、Linux内核主要由哪几个部分组成Li…...
flutter工程创建过程中遇到一些问题。
安装环境版本:JDK7.-JDK 8 Andriod SDK 10 flutter 版本 3.0 1.当创建完后flutter工程后会遇到 run gradle task assemlble Debug 的问题,需要设置远程仓库,共需要修改三个地方build.gradle两处以及flutter 下面的D:\FVM\versions\3.0.0\pac…...
记录实现操作系统互斥锁的一次思考
今天实现操作系统互斥锁的时候遇到一个有趣的问题。 场景 有两个进程分别名为 taskA,taskB,采取时间片轮转的方式交替运行——也即维护了一个 ready_queue,根据时钟中断来 FIFO 地调度任务。它们的任务是无限循环调用 sys_print() 来打印自…...
vscode里如何用git
打开vs终端执行如下: 1 初始化 Git 仓库(如果尚未初始化) git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...
K8S认证|CKS题库+答案| 11. AppArmor
目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...
理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端
🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...
Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility
Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...
如何为服务器生成TLS证书
TLS(Transport Layer Security)证书是确保网络通信安全的重要手段,它通过加密技术保护传输的数据不被窃听和篡改。在服务器上配置TLS证书,可以使用户通过HTTPS协议安全地访问您的网站。本文将详细介绍如何在服务器上生成一个TLS证…...
【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)
骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
解决:Android studio 编译后报错\app\src\main\cpp\CMakeLists.txt‘ to exist
现象: android studio报错: [CXX1409] D:\GitLab\xxxxx\app.cxx\Debug\3f3w4y1i\arm64-v8a\android_gradle_build.json : expected buildFiles file ‘D:\GitLab\xxxxx\app\src\main\cpp\CMakeLists.txt’ to exist 解决: 不要动CMakeLists.…...
WebRTC调研
WebRTC是什么,为什么,如何使用 WebRTC有什么优势 WebRTC Architecture Amazon KVS WebRTC 其它厂商WebRTC 海康门禁WebRTC 海康门禁其他界面整理 威视通WebRTC 局域网 Google浏览器 Microsoft Edge 公网 RTSP RTMP NVR ONVIF SIP SRT WebRTC协…...
怎么开发一个网络协议模块(C语言框架)之(六) ——通用对象池总结(核心)
+---------------------------+ | operEntryTbl[] | ← 操作对象池 (对象数组) +---------------------------+ | 0 | 1 | 2 | ... | N-1 | +---------------------------+↓ 初始化时全部加入 +------------------------+ +-------------------------+ | …...
