华为开源自研AI框架昇思MindSpore应用案例:人体关键点检测模型Lite-HRNet
如果你对MindSpore感兴趣,可以关注昇思MindSpore社区
一、环境准备
1.进入ModelArts官网
云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网
选择下方CodeLab立即体验
等待环境搭建完成
2.使用CodeLab体验Notebook实例
下载NoteBook样例代码,Lite-HRNet实现人体关键点检测 ,.ipynb
为样例代码
选择ModelArts Upload Files上传.ipynb
文件
选择Kernel环境
切换至GPU环境,切换成第一个限时免费
进入昇思MindSpore官网,点击上方的安装
获取安装命令
回到Notebook中,在第一块代码前加入命令
conda update -n base -c defaults conda
安装MindSpore 2.0 GPU版本
conda install mindspore=2.0.0a0 -c mindspore -c conda-forge
安装mindvision
pip install mindvision
安装下载download
pip install download
人体关键点检测模型Lite-HRNet
人体关键点检测是计算机视觉的基本任务之一,在许多应用场景诸如自动驾驶、安防等有着重要的地位。可以发现,在这些应用场景下,深度学习模型可能需要部署在IoT设备上,这些设备算力较低,存储空间有限,无法支撑太大的模型,因此轻量但不失高性能的人体关键点检测级模型将极大降低模型部署难度。Lite-HRNet便提供了一轻量级神经网络骨干,通过接上不同的后续模型可以完成不同的任务,其中便包括人体关键点检测,在配置合理的情况下,Lite-HRNet可以以大型神经网络数十分之一的参数量及计算量达到相近的性能。
模型简介
Lite-HRNet由HRNet(High-Resolution Network)改进而来,HRNet的主要思路是在前向传播过程中通过维持不同分辨率的特征,使得最后生成的高阶特征既可以保留低分辨率高阶特征中的图像语义信息,也可以保留高分辨率高阶特征中的物体位置信息,进而提高在分辨率敏感的任务如语义分割、姿态检测中的表现。Lite-HRNet是HRNet的轻量化改进,改进了HRNet中的卷积模块,将HRNet中的参数量从28.5M降低至1.1M,计算量从7.1GFLOPS降低至0.2GFLOPS,但AP75仅下降了7%。
综上,Lite-HRNet具有计算量、参数量低,精度可观的优点,有利于部署在物联网低算力设备上服务于各个应用场景。
数据准备
本案例使用COCO2017数据集作为训练、验证数据集,请首先安装Mindspore Vision套件,并确保安装的Mindspore是GPU版本,随后请在https://cocodataset.org/ 上下载好2017 Train Images、2017 Val Images以及对应的标记2017 Train/Val Annotations,并解压至当前文件夹,文件夹结构下表所示
Lite-HRNet/├── imgs├── src├── annotations├──person_keypoints_train2017.json└──person_keypoints_train2017.json├── train2017└── val2017
训练、测试原始图片如下所示,图片中可能包含多个人体,且包含的人体不一定包含COCO2017中定义的17个关键点,标注中有每个人体的边框、关键点信息,以便处理图像后供模型训练。
数据预处理
src/mindspore_coco.py中定义了供mindspore模型训练、测试的COCO数据集接口,在加载训练数据集时只需指定所用数据集文件夹位置、输入图像的尺寸、目标热力图的尺寸、以及手动设置对训练图像采用的变换即可
import mindspore as ms
import mindspore.dataset as dataset
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.nn as nn
from mindspore.dataset.transforms.py_transforms import Composefrom src.configs.dataset_config import COCOConfig
from src.dataset.mindspore_coco import COCODatasetcfg = COCOConfig(root="./", output_dir="outputs/", image_size=[192, 256], heatmap_size=[48, 64])
trans = Compose([py_vision.ToTensor(),py_vision.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_ds = COCODataset(cfg, "../", "train2017", True, transform=trans)
train_loader = dataset.GeneratorDataset(train_ds, ["data", "target", "weight"])
构建网络
Lite-HRNet网络骨干大体结构如下图所示:
网络中存在不同的分辨率分支,网络主干上维持着较高分辨率、较少通道数的输出特征,网络分支上延展出较低分辨率、较多通道数的输出特征,且这些不同分辨率的特征之间通过上采样、下采样的卷积层进行交互、融合。Stage内的Cross Channel Weighting(CCW)则是网络实现轻量化的精髓,它将原HRNet中复杂度较高的1*1卷积以更低复杂度的Spatial Weighting等方法替代,从而实现降低网络参数、计算量的效果。CCW的结构如下图所示
值得注意的是,除了骨干网络,作者在论文中同时也给出了所使用的检测头即SimpleBaseline,为了简洁起见,在本次的Lite-HRNet的Mindspore实现中,检测头(代码中包括IterativeHeads和LiteTopDownSimpleHeatMap)已集成至骨干网络之后,作为整体模型的一部分,直接调用模型即可得到热力图预测输出。
损失函数
此处使用损失函数为JointMSELoss,即关节点的均方差误差损失函数,其源码如下所示,总体流程即计算每个关节点预测热力图与实际热力图的均方差,其中target是根据关节点的人工标注坐标,通过二维高斯分布生成的热力图,target_weight用于指定参与计算的关节点,若某关节点对应target_weight取值为0,则表明该关节点在输入图像中未出现,不参与计算。
"""JointMSELoss"""
import mindspore.nn as nn
import mindspore.ops as opsclass JointsMSELoss(nn.Cell):"""Joint MSELoss"""def __init__(self, use_target_weight):"""JointMSELoss"""super(JointsMSELoss, self).__init__()self.criterion = nn.MSELoss(reduction='mean')self.use_target_weight = use_target_weightdef construct(self, output, target, weight):"""construct"""target = targettarget_weight = weightbatch_size = output.shape[0]num_joints = output.shape[1]spliter = ops.Split(axis=1, output_num=num_joints)mul = ops.Mul()heatmaps_pred = spliter(output.reshape((batch_size, num_joints, -1)))heatmaps_gt = spliter(target.reshape((batch_size, num_joints, -1)))loss = 0for idx in range(num_joints):heatmap_pred = heatmaps_pred[idx].squeeze()heatmap_gt = heatmaps_gt[idx].squeeze()if self.use_target_weight:heatmap_pred = mul(heatmap_pred, target_weight[:, idx])heatmap_gt = mul(heatmap_gt, target_weight[:, idx])loss += 0.5 * self.criterion(heatmap_pred,heatmap_gt)else:loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)return loss/num_joints
模型实现与训练
在实现模型时,需指定模型内部结构,在src/net_configs中已指定原论文中10种结构配置,在训练样例种取Lite_18_coco作为模型结构,此处作为案例,仅设置epoch数量为1,在实际训练中可以设置为200,并且可以加入warmup。由于mindspore的训练接口默认数据集中每条数据只有两列(即训练数据和标签),所以这里需自定义Loss Cell。值得注意的是loss在训练前后变化并不会十分大,训练好的模型的loss为0.0004左右
class CustomWithLossCell(nn.Cell):def __init__(self,net: nn.Cell,loss_fn: nn.Cell):super(CustomWithLossCell, self).__init__()self.net = netself._loss_fn = loss_fndef construct(self, img, target, weight):""" build network """heatmap_pred = self.net(img)return self._loss_fn(heatmap_pred,target,weight)
from src.configs.net_configs import get_netconfig
from mindspore.train.callback import LossMonitor
from src.backbone import LiteHRNetext = get_netconfig("extra_lite_18_coco")
net = LiteHRNet(ext)
criterion = JointsMSELoss(use_target_weight=True)train_loader = train_loader.batch(64)
optim = nn.Adam(net.trainable_params(), learning_rate=2e-3)
loss = JointsMSELoss(use_target_weight=True)
net_with_loss = CustomWithLossCell(net, loss)model = ms.Model(network=net_with_loss, optimizer=optim)
epochs = 1
#Start Training
model.train(epochs, train_loader, callbacks=[LossMonitor(100)], dataset_sink_mode=False)
模型评估
模型评估过程中使用AP、AP50、AP75以及AR50、AR75作为评价指标,val2017作为评价数据集,pycocotool包中已实现根据评价函数,且src/mindspore_coco.py中的evaluate函数也实现了调用该评价函数的接口,只需提供预测关键点坐标等信息即可获得评价指标。此处载入Lite_18_coco的预训练模型进行评价。
from mindspore import load_checkpoint
from mindspore import load_param_into_netfrom src.utils.utils import get_final_preds
import numpy as npdef evaluate_model(model, dataset, output_path):"""Evaluate"""num_samples = len(dataset)all_preds = np.zeros((num_samples, 17, 3),dtype=np.float32)all_boxes = np.zeros((num_samples, 6))image_path = []for i, data in enumerate(dataset):input_data, target, meta = data[0], data[1], data[3]input_data = ms.Tensor(input_data[0], ms.float32).reshape(1, 3, 256, 192)shit = model(input_data).asnumpy()target = target.reshape(shit.shape)c = meta['center'].reshape(1, 2)s = meta['scale'].reshape(1, 2)score = meta['score']preds, maxvals = get_final_preds(shit, c, s)all_preds[i:i + 1, :, 0:2] = preds[:, :, 0:2]all_preds[i:i + 1, :, 2:3] = maxvals# double check this all_boxes partsall_boxes[i:i + 1, 0:2] = c[:, 0:2]all_boxes[i:i + 1, 2:4] = s[:, 0:2]all_boxes[i:i + 1, 4] = np.prod(s*200, 1)all_boxes[i:i + 1, 5] = scoreimage_path.append(meta['image'])dataset.evaluate(0, all_preds, output_path, all_boxes, image_path)net_dict = load_checkpoint("./ckpt/litehrnet_18_coco_256x192.ckpt")
load_param_into_net(net, net_dict)eval_ds = COCODataset(cfg, "./", "val2017", False, transform=trans)
evaluate_model(net, eval_ds, "./result")
模型推理
- Lite-HRNet是关键点检测模型,所以输入待推理图像应为包含单个人体的图像,作者在论文中提及在coco test 2017测试前已使用SimpleBaseline生成的目标检测Bounding Box处理图像,所以待推理图像应仅包含单个人体。
- 网络的输入为(1,3,256,192),所以在输入图像前应先将其变换成网络可处理的形式。
import cv2
from src.utils.utils import get_max_preds
origin_img = cv2.imread("./imgs/man.jpg")
origin_h, origin_w, _ = origin_img.shape
scale_factor = [origin_w/192, origin_h/256]# resize to (112 112 3) and convert to tensor
img = cv2.resize(origin_img, (192, 256))
print(img.shape)
img = trans(img)
# img = np.expand_dims(img, axis=0)
img = ms.Tensor(img)
print(img.shape)# Infer
heatmap_pred = net(img).asnumpy()
pred, _ = get_max_preds(heatmap_pred)# Postprocess
pred = pred.reshape(pred.shape[0], -1, 2)
print(pred[0])
pre_landmark = pred[0] * 4 * scale_factor
# Draw points
for (x, y) in pre_landmark.astype(np.int32):cv2.circle(origin_img, (x, y), 3, (255, 255, 255), -1)# Save image
cv2.imwrite("./imgs/man_infer.jpg", origin_img)
可以看到模型基本正确标注出了关键点的位置\
算法基本流程
- 获取原始数据
- 从数据集的标注json文件中得到各个图像bbox以及关键点坐标信息
- 根据bbox裁剪图像,并放缩至指定尺寸,如果是训练还可以作适当数据增强,生成指定尺寸的目标热力图
- 指定尺寸的输入经过网络前向传播后得到预测的关键点热力图
- 经过处理后取热力图中的最大值坐标作为关键点的预测坐标
相关文章:
华为开源自研AI框架昇思MindSpore应用案例:人体关键点检测模型Lite-HRNet
如果你对MindSpore感兴趣,可以关注昇思MindSpore社区 一、环境准备 1.进入ModelArts官网 云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpo…...
每日OJ题_牛客_天使果冻_递推_C++_Java
目录 牛客_天使果冻_递推 题目解析 C代码 Java代码 牛客_天使果冻_递推 天使果冻 描述: 有 n 个果冻排成一排。第 i 个果冻的美味度是 ai。 天使非常喜欢吃果冻,但她想把最好吃的果冻留到最后收藏。天使想知道前 x个果冻中,美味…...
独立站干货:WordPress主机推荐
WordPress作为全球最受欢迎的独立站建设平台,提供了灵活性和强大的功能,使得建站变得简单而高效。本文将为您详细介绍WordPress建站的流程,并推荐几款实测后觉得好用的主机商。 WordPress建站流程 域名注册 首先需要注册一个域名,…...
支持多种快充协议和支持多种功能的诱骗取电协议芯片
汇铭达XSP15是一款应用于手持电动工具、智能家居、显示器、音箱等充电方案的大功率快充协议芯片,支持最大功率100W给设备快速充电,大大缩短了充电时间。芯片支持通过UART串口发送电压/电流消息供其它芯片读取。支持自动识别连接的是电脑或是充电器。支持…...
Android中常见内存泄漏的场景和解决方案
本文讲解Android 开发中常见内存泄漏场景及其解决方案,内容包括代码示例、原因分析以及最佳实践建议。 1. 静态变量导致的内存泄漏 静态变量的生命周期与应用进程一致,如果静态变量持有了对 Activity 或其他大对象的引用,就可能导致内存泄漏…...
MyBatis Plus中的@TableId注解
TableId 注解用于将某个成员变量指定为数据表主键,以下为使用示例: import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import lo…...
java基础概念33:常见API-Objects工具类
一、使用场景 二、成员方法 2-1、equals方法 源码: 2-2、isNull方法、nonNull方法 三、小结...
脚手架vue-cli,webpack模板
先安装node.js,它是服务器端,用于给页面提供服务。前端学习不需要会node.js,只需要学会node.js衍生出来的npm命令即可。 npm 是node.js的一个工具,作用是进行包管理,npm是node.js的包管理器。 接着安装脚手架ÿ…...
什么是React Native?
写在前面 React Native (RN) 是一个由 Facebook 开发的开源框架,用于构建跨平台的移动应用程序。它允许开发者使用 JavaScript 和 React 来创建原生 iOS 和 Android 应用。RN 的出现极大地简化了移动应用的开发过程,使得开发者可以更快速、更高效地构建…...
Three.js LOD(Level of Detail)通过根据视距调整渲染细节的技术
在 Three.js 中,LOD(Level of Detail)技术是一种通过根据视距调整渲染细节的技术,旨在提高渲染性能并优化用户体验。LOD 技术尤其在处理复杂场景或高多边形模型时显得尤为重要。在这篇博客中,我们将详细介绍 LOD 的概念…...
Vulnhub靶场案例渗透[12]-Grotesque: 1.0.1
文章目录 一、靶场搭建1. 靶场描述2. 下载靶机环境3. 靶场搭建 二、渗透靶场1. 确定靶机IP2. 探测靶场开放端口及对应服务3. 目录扫描4. 敏感信息获取5. 反弹shell6. 权限提升 一、靶场搭建 1. 靶场描述 get flags difficulty: medium about vm: tested and exported from vi…...
招聘和面试
本篇内容是根据2019年4月份#82 Hiring and job interviews音频录制内容的整理与翻译 小组成员 Mat Ryer、Ashley McNamara、Johnny Boursiquot 和 Carmen Andoh 讨论了受聘、雇用和工作面试的过程。如果人是团队中最重要的部分,我们如何选择与谁一起工作࿱…...
Gin 框架入门(GO)-1
解决安装包失败问题(*) go env -w GO111MODULE=on go env -w GOPROXY=https://goproxy.cn,direct 1 介绍 Gin 是一个 Go (Golang) 编写的轻量级 http web 框架,运行速度非常快,Gin 最擅长的就是 Api 接口的高并发。 2 Gin 环境搭建 1.下载并安装 gin go get -u github.…...
LeetCode:700. 二叉搜索树中的搜索
目录 题目描述: 代码: 题目描述: 给定二叉搜索树(BST)的根节点 root 和一个整数值 val。 你需要在 BST 中找到节点值等于 val 的节点。 返回以该节点为根的子树。 如果节点不存在,则返回 null 。 示例 1: 输入:root [4,2,7,1,3…...
用邻接矩阵实现图的深度优先遍历
问题描述 给定一个无向图,用邻接矩阵作为图的存储结构,输出指定顶点出发的深度优先遍历序列。在深度优先遍历的过程中,如果同时出现多个待访问的顶点,则优先选择编号最小的一个进行访问。 输入描述 第一行输入三个正整数&#…...
vue2中实现token的无感刷新
后端配置 设置Token过期时间:在后端(如服务器或网关)配置access_token和refresh_token的过期时间。通常,access_token的过期时间较短,而refresh_token的过期时间较长。提供刷新Token接口:后端需要提供一个…...
无需Photoshop即可在线裁剪和调整图像大小的工具
Bitmind是一个灵活且易于使用的批量图像本地化处理器,经过抓包看,这个工具在浏览器本地运行,不会上传图片到服务器,所以安全性完全有保证。 它可以将图像调整到任何特定尺寸,并在必要时按比例裁剪。 这是一个在线工具…...
云安全之法律和合规
0x00 前言 本文主要内容是从法律,合同,电子举证,以及合规和审计这五个部分来记录一下相关的云安全内容 0x01 法律 受法律约束的影响因素 云服务所在的地区云用户所在的区域数据主体所在的区域 GDPR:通用数据保护法案…...
倒计时功能分享
今天想要分享的是一个面试题,也是一个我们在项目中常用的功能:倒计时。 首先我们在写倒计时的时候必须要考虑到是:准确性、性能。接下来我们一步一步实现这个完美地倒计时功能。 setInterval 先来简单实现一个倒计时的函数: func…...
【论文分享】使用多源数据识别建筑功能:以中国三大城市群为例
建筑功能对城市规划至关重要,而利用多源数据进行建筑功能分类有助于支持城市规划政策。本研究通过分析建筑特征和POI密度,识别了中国三个城市群的建筑功能,并使用XGBoost模型验证了其在大规模映射中的高准确性和有效性。研究强调了建筑环境对…...
华为手机启用ADB无线调试功能
打开开发者模式,勾选USB调试,和“仅充电”模式下允许ADB调试 确认 设置添加adb路径到PATH变量 使用adb查看安卓设置 切换为无线模式: 查看手机IP...
云原生之Kubernetes集群搭建
1、Kubernetets基础概念 传统的服务器架构演进,现在基于docker容器化应用可以完成快速部署,但是对于大型的应用,有可能出现成百上千个容器化应用,一个挂了需要人工管理是相当麻烦,因此急需一个大规模容器编排系统。 Kubernetes Kubernetes 是一个可移植、可扩展的开源平…...
STM32单片机CAN总线汽车线路通断检测
目录 目录 前言 一、本设计主要实现哪些很“开门”功能? 二、电路设计原理图 1.电路图采用Altium Designer进行设计: 2.实物展示图片 三、程序源代码设计 四、获取资料内容 前言 随着汽车电子技术的不断发展,车辆通信接口在汽车电子控…...
大连理工大学概率上机作业免费下载
大连理工大学概率论与数理统计上机资源 本资源库收录了大连理工大学概率论与数理统计课程的上机作业范例代码,旨在通过实际操作加深学生对概率统计概念的理解,帮助学生更好地理解和掌握知识点。 作业内容概览 第一题:随机变量关系探索 数…...
Tomcat 如何管理 Session
Tomcat 如何管理 Session 我们知道,Tomcat 中每一个 Context 容器对应一个 Web 应用,而 Web 应用之间的 Session 应该是独立的,因此 Session 的管理肯定是 Context 级的,也就是一个 Context 一定关联多个 Session。 Tomcat 中主…...
stm32启动过程解析startup启动文件
1.STM32的启动过程模式 1.1 根据boot引脚决定三种启动模式 复位后,在 SYSCLK 的第四个上升沿锁存 BOOT 引脚的值。BOOT0 为专用引脚,而 BOOT1 则与 GPIO 引脚共用。一旦完成对 BOOT1 的采样,相应 GPIO 引脚即进入空闲状态,可用于…...
SystemVerilog学习——构造函数new
一、概述 在 SystemVerilog 中,new 是一个构造函数,用于创建类的实例(即对象)。它在面向对象编程(OOP)中起着重要作用,负责实例化一个对象并进行初始化。与传统编程语言(如 C 或 Jav…...
力扣题目总结
1.游戏玩法分析IV AC: select IFNULL(round(count(distinct(Result.player_id)) / count(distinct(Activity.player_id)), 2), 0) as fraction from (select Activity.player_id as player_idfrom (select player_id, DATE_ADD(MIN(event_date), INTERVAL 1 DAY) as second_da…...
Java API 进阶指南:从核心API到高级应用的全面提升
文章目录 Java API 进阶学习指南1. 深入理解核心API1.1 集合框架(Collections Framework)1.2 输入输出流(I/O Streams)1.3 并发编程(Concurrency)1.4 反射(Reflection)1.5 泛型&…...
esp32c3开发板通过micropython的ubluetooth库连蓝牙设备
ESP32-C3开发板是一款高性能、低功耗的微控制器,搭载了Espressif自家的RISC-V处理器。通过MicroPython,一种面向微控制器的精简版Python编程语言,开发者可以轻松地为ESP32-C3编写代码。MicroPython的ubluetooth库使得ESP32-C3能够通过蓝牙与各…...
新兴网站建设/新品怎么刷关键词
2019-06-19 22:46:12 特别感谢一位我是我们村的希望 学长 我用的是免费版的 打开cmd(W R) 输入: ssh ubuntu[你的IP地址(公网)] 输入你的密码 下载 点击圈住的地方 点击connect,accepted&&save 方框左边是你的文件 方框右边是…...
做动态网站需要什么/网络推广团队哪家好
2016-09-01 01:18齐晓庆 客户经理在加减乘除运算中,运算结果的类型和运算量的类型相同,由于类型不同,所以出错,要达到你目的可以1000\15 1000\25,由于1.5 2.5是小数所以强制转换时出错2016-09-01 01:15齐敦益 客户经理同构数是会出现在它的平…...
大型大型网站建设方案/网络广告策划与制作
小编典典问了几天问题后,我发现MediaInfo可以提供有关视频或音频文件的许多技术和标签信息。我发现subs4me的源代码树中有一个用于MediaInfo的JNI包装器,我认为它非常有用。以下是一些代码片段,显示了如何从媒体文件中提取一些信息࿱…...
汕头网站排名优化报价/如何做一个自己的网站呢
20.7.1. Macro Name http://nagios.sourceforge.net/docs/3_0/macrolist.html 20.7.2. 插件开发手册 https://nagios-plugins.org/doc/guidelines.html#THRESHOLDFORMAT 原文出处:Netkiller 系列 手札 本文作者:陈景峯 转载请与作者联系,同时…...
做网站怎么切psd图/千锋教育地址
应表哥要求,写一个简单的TTS软件,他们单位上用于广播通知用。源码如下:简单说明:public partial class frmMain : Form{public frmMain(){InitializeComponent();comboBox1.DropDownStyle ComboBoxStyle.DropDownList;}SpVoiceUt…...
中组部两学一做网站/百度搜索收录入口
位、字节、字、KB、MB 位:“位(bit)”是电子计算机中最小的数据单位。每一位的状态只能是0或1。 字节:8个二进制位构成1个“字节(Byte)”,它是存储空间的基本计量单位。1个字节可以储存1个英文字母或者半个汉字,换句话说…...