基于迁移学习的手势分类模型训练
1、基本原理介绍
这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略),是在一个已经有过训练基础的模型上,用自己的数据集,进一步训练,使得这个模型能够完成我们需要的任务。
这么做有有这样几个显而易见的好处:
※ 因为模型之前被训练过,所以初始参数不会是0,这样能够加速模型训练
※ 因为预训练模型(什么是预训练模型下文会讲到)在其他数据集上训练过,而其他数据集往往和我们用的数据集存在一定的区别,所以这可以提高模型的泛化能力
※ 通过迁移学习,可以将来自大规模数据的优势转移到小规模或新任务上,提高模型的表现和效果
2、预训练模型
在进行迁移学习时,我们要先找到一个预训练模型。在分类任务领域,比较流行的如resnet系列、mobilenet系列(更轻量化)、vgg(系列)、efficientnet(系列)等等网络,都是比较常用且容易获得的预训练模型,这些模型都能够通过python直接下载。
而且由于上述模型基本都是在ImageNet这一大规模,多分类类别的数据集上进行过训练的,所以对于简单的二分类等少数类别分类,能有较好的效果。
3、训练流程
迁移学习完整的训练流程和一般搭建神经网络的训练模型的流程基本类似:数据预处理->数据集的切分->加载预训练模型(搭建神经网络)->设置超参数/损失函数/优化器等->训练模型
3.1 模型训练
下面的代码是一个利用mobilenet网络训练得到的手势分类模型,该模型能够较准确的分类不同类别手势。
相关解释已在代码中注释说明。
from torchvision.models import mobilenet_v2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation# 定义数据预处理和增强器
transform = Compose([RandomHorizontalFlip(), # 随机水平翻转RandomRotation(10), # 随机旋转10度Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集并应用预处理和增强器
dataset = ImageFolder(root='data', transform=transform)
# 这里由于数据比较少,将所有数据集全部用来训练,得到的模型直接拿来用了,这其实不算是非常规范的操作,仅供参考# 定义网络结构
model = mobilenet_v2(pretrained=True) # 加载预训练模型,也可以试试其他模型,效果差别挺大的
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 5) # 假设是5分类问题,具体几分类,改这里的参数就行了# 将模型移动到设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 定义优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()# 定义训练循环
def train_model(model, criterion, optimizer, num_epochs, train_loader):for epoch in range(num_epochs):model.train() # 设置模型为训练模式train_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()epoch_loss = train_loss / totalepoch_acc = 100. * correct / totalprint(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 创建训练集的DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)# 开始训练模型
train_model(model, criterion, optimizer, num_epochs=15, train_loader=train_loader)
torch.save(model, 'my_model(1).pth')
3.2 数据集文件结构
当然,你也可以自己定义读取数据集的data_loader类。
3.3 模型推理
这段代码是用训练得到的模型对一张图片进行推理测试的,如果需要对系列图片进行推理,评估模型效果,可自行修改,调用对应函数即可。
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
def predict_image(image_path, model_path='my_model(1).pth'):image = Image.open(image_path).convert("RGB")# 对测试的图片进行预处理,需要和训练时处理的方式一样transform = Compose([Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image_tensor = transform(image).unsqueeze(0)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')image_tensor = image_tensor.to(device)model = torch.load(model_path,map_location=device)model.eval()with torch.no_grad():output = model(image_tensor)_, predicted = torch.max(output.data, 1) # 获得分类标记return predicted.item()
if __name__=="__main__":image_path = "test2/6.jpg"print(predict_image(image_path))
3.4 整体项目文件
4、补充说明
这种利用迁移学习策略,进行少类别,不同类别特征差距小的任务需求来说,效果一般来说是比较好的。因为之前做过相关实验,准确率90%以上是很容易的,所以这里没有模型评估,生成混淆矩阵等过程。对于多类别分类,建议有完整的评估体系。
上述使用的方法仅适用于分类任务,对于真正的目标检测如手势识别,直接使用该模型的问题是:由于无法定位手势的位置,所以导致识别不准确。
本实验数据集是不同类别手势图片,为自制,不开源。
相关文章:
基于迁移学习的手势分类模型训练
1、基本原理介绍 这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略)&#x…...
个性化音频生成GPT-SoVits部署使用和API调用
一、训练自己的音色模型步骤 1、准备好要训练的数据,放在Data文件夹中,按照文件模板中的结构进行存放数据 2、双击打开go-webui.bat文件,等待页面跳转 3、页面打开后,开始训练自己的模型 (1)、人声伴奏分…...
MFC列表框示例
本文仅供学习交流,严禁用于商业用途,如本文涉及侵权请及时联系本人将于及时删除 目录 1.示例内容 2.程序步骤 3.运行结果 4.代码全文 1.示例内容 编写一个对话框应用程序CMFC_Li6_4_学生信息Dlg,对话框中有一个列表框,当用户…...
Android TabLayout的简单用法
TabLayout 注意这里添加tab,使用binding.tabLayout.newTab()进行创建 private fun initTabs() {val tab binding.tabLayout.newTab()tab.text "模板库"binding.tabLayout.addTab(tab)binding.tabLayout.addOnTabSelectedListener(object : TabLayout.On…...
基于vite + pnpm monorepo 实现一个UI组件库
基于vite pnpm monorepo的vue组件库 仓库地址 思路 好多文章都是直接咔咔咔的上代码。跟着做也没问题,但总觉得少了些什么。下次做的时候还要找文章参考。。 需求有三个模块,那么就需要三个包。使用monorepo进行分包管理。 a. 组件库 b. 组件库文档…...
FDM3D打印系列——Luck13关节可动模型打印和各种材料的尝试
luck13可动关节模型FDM3D打印制作过程 大家好,我是阿赵。 最近我沉迷于打印一个叫做Luck13的关节超可动人偶。 首先说明一下,这个模型是分为了外甲和骨骼两个部分的。 为什么我会打印了这么多个呢? 一、第一次尝试——PLATPU 刚开始…...
windows10 获取磁盘类型
powershell Get-PhysicalDisk | Select FriendlyName, MediaType FriendlyName MediaType ------------ --------- NVMe PC SN740 NVMe WD 256GB SSD WDC WD10EZEX-75WN4A1 HDD 适用场景 SSD: 适合需要快速访问速度和较高响…...
数据库之运算符
目录 一、算数运算符 二、比较运算符 1.常用比较运算符 2.实现特殊功能的比较运算符 三、逻辑运算符 1.逻辑与运算符(&&或者AND) 2.逻辑或运算符(||或者OR) 3.逻辑非运算符(!或者NOT&#…...
【自动化机器学习AutoML】AutoML工具和平台的使用
自动化机器学习AutoML:AutoML工具和平台的使用 目录 引言什么是AutoMLAutoML的优势常见的AutoML工具和平台 Google Cloud AutoMLH2O.aiAuto-sklearnTPOTMLBox AutoML的基本使用 Google Cloud AutoML使用示例Auto-sklearn使用示例 AutoML的应用场景结论 引言 自动…...
【每日一练】python求最后一个单词的长度
""" 求某变量中最后一个单词的长度 例如s"Good morning, champ! Youre going to rock this day" 分析思路: 遇到字符串问题,经常和列表结合使用来解决, 可以先用列表的.split()分割方法进行单词分割, 再…...
[红明谷CTF 2021]write_shell 1
目录 代码审计check()$_GET["action"] ?? "" 解题 代码审计 <?php error_reporting(0); highlight_file(__FILE__); function check($input){if(preg_match("/| |_|php|;|~|\\^|\\|eval|{|}/i",$input)){// if(preg_match("/| |_||p…...
【Go - sync.once】
sync.Once 是 Go 语言标准库中的一个结构体,它的作用是确保某个操作在全局范围内只被执行一次。这对于实现单例模式或需要一次性初始化资源的场景非常有用。 典型用法 sync.Once 提供了一个方法 Do(f func()),该方法接收一个没有参数和返回值的函数 f …...
Spark RPC框架详解
文章目录 前言Spark RPC模型概述RpcEndpointRpcEndpointRefRpcEnv 基于Netty的RPC实现NettyRpcEndpointRefNettyRpcEnv消息的发送消息的接收RpcEndpointRef的构造方式直接通过RpcEndpoint构造RpcEndpointRef通过消息发送RpcEndpointRef Endpoint的注册Dispatcher消息的投递消息…...
win10安装ElasticSearch7.x和分词插件
说明: 以下内容整理自网络,格式调整优化,更易阅读,希望能对需要的人有所帮助。 一 安装 Java环境 ElasticSearch使用Java开发的,依赖Java环境,安装 ElasticSearch 7.x 之前,需要先安装jdk-8。…...
Linux中,MySQL的用户管理
MySQL库中的表及其作用 user表 User表是MySQL中最重要的一个权限表,记录允许连接到服务器的帐号信息,里面的权限是全局级的。 db表和host表 db表和host表是MySQL数据中非常重要的权限表。db表中存储了用户对某个数据库的操作权限,决定用户…...
个人电脑网络安全 之 防浏览器和端口溢出攻击 和 权限对系统的重要性
防浏览器和端口溢出攻击 该如何防 很多人都不明白 我相信很多人只知道杀毒软件 却不知道网络防火墙 防火墙分两种 : 1、 病毒防火墙 也就是我们说的杀毒软件 2、 网络防火墙 这是用来防软件恶意通信的 使用防火墙 有两种 1、 半开式规则…...
美食聚焦 -- 仿大众点评项目技术难点总结
1 实现点赞功能显示哪些用户点赞过并安装时间顺序排序 使用sort_set 进行存储,把博客id作为key,用户id作为value,时间戳作为score 但存储成功之后还是没有成功按照时间顺序排名,因为sql语句,比如最后in(5…...
拓扑图:揭示复杂系统背后的结构与逻辑
在现代软件开发和运维中,图形化的表示方式越来越重要。拓扑图,作为一种关键的可视化工具,不仅能够帮助我们理解系统的结构和组件间的关系,还能提升系统的可维护性和可扩展性。 什么是拓扑图? 拓扑图是一种展示系统或网络中各个节点(如服务器、交换机、数据库等)及其连…...
Java面试八股之什么是spring boot starter
什么是spring boot starter Spring Boot Starter是Spring Boot项目中的一个重要概念。它是一种依赖管理机制,用于简化Maven或Gradle配置文件中的依赖项声明。Spring Boot Starter提供了一组预定义的依赖关系,这些依赖关系被封装在一个单一的包中&#x…...
探究项目未能获得ASPICE 1、2级能力的原因及改进策略
项目整体未能获得ASPICE 1、2级能力的原因可能涉及多个方面,以下是基于参考文章中的信息和可能的情境进行的分析: 1.过程成熟度不足:ASPICE(Automotive Software Process Improvement and Capability Determination)是…...
WHAT - 不同 HTTP Methods 使用场景、使用方法和可能遇到的问题
目录 前言基本介绍具体介绍前置知识:幂等和非幂等幂等操作非幂等操作幂等性和非幂等性的应用场景总结 1. GET2. POST3. PUT4. PATCH1. 确保操作是幂等的2. 使用版本控制或条件更新3. 全量更新部分属性4. 使用特定操作指令5. 幂等标识符示例代码总结 5. DELETE6. HEA…...
Pytorch使用教学4-张量的索引
1 张量的符号索引 张量也是有序序列,我们可以根据每个元素在系统内的顺序位置,来找出特定的元素,也就是索引。 1.1 一维张量的索引 一维张量由零维张量构成 一维张量索引与Python中的索引一样是是从左到右,从0开始的ÿ…...
【Git多人协作开发】同一分支下的多人协作开发模式
目录 0.前言场景 1.开发者1☞完成准备工作&协作开发 1.1创建dev分支开发 1.2拉取远程dev分支至本地 1.3查看分支情况和分支联系情况 1.4创建本地dev分支且与远程dev分支建立联系 1.5在本地dev分支上开发file.txt 1.6推送push至远程仓库 2.开发者2☞完成准备工作&…...
Vue使用FullCalendar实现日历/周历/月历
Vue使用FullCalendar实现日历/周历/月历 需求背景:项目上遇到新需求,要求实现工单以日/周/月历形式展示。而且要求不同工单根据状态显示不同颜色,一个工单内部,需要以不同颜色显示三个阶段。 效果图 日历 周历 月历 安装插件…...
社交圈子聊天交友系统搭建社交app开发:陌生交友发布动态圈子单聊打招呼群聊app介绍
系统概述 社交圈子部天交友系统是一个集成即时通讯、社区互动、用户管理等功能的在线社交平台。它支持用户创建个人资料,加入兴趣围子,通过文字、图片、语音、视频等多种方式进行交流,满足用户在不同场景下的社交需求 核心功能 -,…...
【微信小程序实战教程】之微信小程序原生开发详解
微信小程序原生开发详解 微信小程序的更新迭代非常频繁,几乎每个月都会有新版本发布,这就会让初学者感觉到学习的压力和难度。其实,我们小程序的每次版本迭代都是在现有小程序架构基础之上进行更新的,如果想要学好小程序开发技术&…...
PHP身份证实名认证接口集成守护电商购物
在这个万物互联的世界里,网购已成为日常生活中不可或缺的一部分。然而,随着线上交易的增加,如何保护消费者和商家免受欺诈,确保每一笔交易的安全,成了亟待解决的难题。这时,身份证实名认证接口应运而生&…...
为什么有了MAC还需要IP?
目录 MAC地址(Media Access Control Address)IP地址(Internet Protocol Address)为什么需要两者? IP地址和MAC地址在网络通信中扮演着不同的角色,它们各自有独特的功能和用途。下面是它们的主要区别和为什么…...
SpringBoot中如何使用RabbitMq
一,RabbitMQ简介和基本概念 RabbitMQ 是一个开源的消息中间件,基于 AMQP(高级消息队列协议)实现。 它由 Erlang 语言开发,并且支持多种编程语言,包括 Java、Python、Ruby、PHP 和 C# 等, 下载…...
LangChain自定义Embedding封装 之 ERNIE Bot
LangChain自定义Embedding封装 之 ERNIE Bot 百度飞浆平台的 ERNIE Bot 导入下面方法 和 环境 ,即可验证 embedding ERNIE_Bot_embedding() class ERNIE_Bot_embedding(BaseModel, Embeddings):client: Anyroot_validator()def validate_environment(cls, value…...
Git 安装教程
1、登录git 官方网站:https://git-scm.com/ 点击左边的 Downloads 或者 右边标识的下载标志,它根据电脑操作系统自动匹配版本 Downloads for Windows 2、以 windows 为例下载对应版本 网络有时可能不大好,阿里镜像下载超快。 下载好以后&a…...
Lua 类管理器
Lua 类管理器 -- ***** Class Manager 类管理*****‘local ClassManager {}local this ClassManagerfunction ClassManager.Class(className, ...)print(ClassManager::Class)--print(className)-- 构建类local cls {__className className}--print(cls)-- 父类集合local …...
实现领域驱动设计(DDD)系列详解:领域模型的持久化
领域驱动设计主要通过限界上下文应对复杂度,它是绑定业务架构、应用架构和数据架构的关键架构单元。设计由领域而非数据驱动,且为了保证定义了领域模型的应用架构和定义了数据模型的数据架构的变化方向相同,就应该在领域建模阶段率先定义领域…...
配置sublime的中的C++编译器(.sublime-build),实现C++20
GCC 4.8: 支持 C11 (部分) GCC 4.9: 支持 C11 和 C14 (部分) GCC 5: 完全支持 C14 GCC 6: 支持 C14 和 C17 (部分) GCC 7: 支持 C17 (大部分) GCC 8: 完全支持 C17,部分支持 C20 GCC 9: 支持更多的 C20 特性 GCC 10: 支持大部分 C20 特性 GCC 11: 更全面地支持 C20 …...
Android14 - 前台Service、图片选择器 、OpenJDK 17、其他适配
前台服务 1. 指定前台服务类型 以 Android 14(API 级别 34)或更高版本为目标平台的应用,需要为应用中的每项前台服务指定服务类型,因为系统需要特定类型的前台服务满足特定用例。具体介绍如下: 在Android 10 在 <service> 元素内引入了 android:foregroundServiceT…...
数据恢复教程:如何从硬盘、SD存储卡、数码相机中恢复误删除数据。
您正在摆弄 Android 设备。突然,您意外删除了一张或多张图片。不用担心,您总能找到一款价格实惠的数据恢复应用。这款先进的软件可帮助 Android 用户从硬盘、安全数字 (SD) 或存储卡以及数码相机中恢复已删除的数据。 Android 上数据被删除的主要原因 在…...
谷粒商城实战笔记-47-商品服务-API-三级分类-网关统一配置跨域
文章目录 一,跨域问题1,跨域问题产生的原因2,预检请求3,跨域解决方案3.1 CORS (Cross-Origin Resource Sharing)后端配置示例(Spring Boot) 3.2 JSONP (JSON with Padding)3.3 代理服务器Nginx代理配置示例…...
stm32平台为例的软件模拟时间,代替RTC调试
stm32平台为例的软件模拟时间,代替RTC调试 我们在开发项目的时候,如果用到RTC,如果真正等待RTC到达指定的时间,那调试时间就太长了。 比如每隔半个小时,存储一次数据,如果要观察10次存储的效果࿰…...
《设计模式之美》读书笔记2
从Linux学习应对大型复杂项目的方法: 1、封装与抽象:封装了不同类型设备的访问细节,抽象为统一的文件访问方式,更高层的代码就能基于统一的访问方式,来访问底层不同类型的设备。这样做的好处是,隔离底层设备…...
C++ STL set_difference 用法
一:功能 给定两个集合A,B;计算集合的差集,即计算出那些只包含在A中而不包含在B中的元素。 二:用法 #include <vector> #include <algorithm> #include <iostream>int main() {std::vector<int&…...
【基础算法总结】优先级队列
优先级队列 1.最后一块石头的重量2.数据流中的第 K 大元素4.前K个高频单词4.数据流的中位数 点赞👍👍收藏🌟🌟关注💖💖 你的支持是对我最大的鼓励,我们一起努力吧!😃😃 1…...
python-绝对值排序(赛氪OJ)
[题目描述] 输入 n 个整数,按照绝对值从大到小排序后输出。保证所有整数的绝对值不同。输入格式: 输入数据有多组,每组占一行,每行的第一个数字为 n ,接着是 n 个整数, n0 表示输入数据的结束,不做处理。输…...
成功者的几个好习惯,你具备了几个
每个人都想成为自己领域的佼佼者,然而,成功并非偶然,它往往与一系列良好的习惯紧密相连。这些习惯如同灯塔,指引着成功者在波涛汹涌的大海中稳健前行。 一、设定明确目标 没有明确的目标,就如同航海没有指南针&#…...
centos中zabbix安装、卸载及遇到的问题
目录 Zabbix简介Zabbix5.0和Zabbix7.0的区别监控能力方面模板和 API 方面性能、速度方面 centos7安装Zabbix(5.0)安装zabbix遇到的问题卸载Zabbix Zabbix简介 Zabbix 是一个基于 WEB 界面的提供分布式系统监视以及网络监视功能的企业级的开源解决方案。zabbix 能监视各种网络参…...
php编译安装
一、基础环境准备 # php使用www用户 useradd -s /sbin/nologin -M www二、下载php包 # 下载地址 https://www.php.net/downloads wget https://www.php.net/distributions/php-8.3.9.tar.gz三、配置编译安装 编译安装之前需要处理必要的依赖,在编译配置安装&…...
[K8S] K8S资源控制器Controller Manager(4)
文章目录 1. 常见的Pod控制器及含义2. Replication Controller控制器2.1 部署ReplicaSet 3. Deployment3.1部署Deployment3.2 运行Deployment3.3 镜像更新方式3.4 Deployment扩容3.5 滚动更新3.6 金丝雀发布(灰度发布)3.7 Deployment版本回退3.8 Deployment 更新策略 4. Daemon…...
C#,.NET常见算法
1.递归算法 1.1.C#递归算法计算阶乘的方法 using System;namespace C_Sharp_Example {public class Program{/// <summary>/// 阶乘:一个正整数的阶乘Factorial是所有小于以及等于该数的正整数的积,0的阶乘是1,n的阶乘是n࿰…...
KubeSphere介绍及一键安装k8s
KubeSphere介绍 官网地址:https://kubesphere.io/zh/ KubeSphere愿景是打造一个以 Kubernetes 为内核的云原生分布式操作系统,它的架构可以非常方便地使第三方应用与云原生生态组件进行即插即用(plug-and-play)的集成࿰…...
Spring 系列
SpringBoot 实体类(Entity)层 实体类(Entity)通常属于模型层(Model Layer)或领域层(Domain Layer)。它们代表应用程序中的核心业务数据结构,与数据库表结构紧密对应。在…...
基于opencv[python]的人脸检测
1 图片爬虫 这里的代码转载自:http://t.csdnimg.cn/T4R4F # 获取图片数据 import os.path import fake_useragent import requests from lxml import etree# UA伪装 head {"User-Agent": fake_useragent.UserAgent().random}pic_name 0 def request_pic…...