【人脸识别】CurricularFace:自适应课程学习人脸识别损失函数
论文题目:《CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition》
论文地址:https://arxiv.org/pdf/2004.00288v1.pdf
代码地址:https://github.com/HuangYG123/CurricularFace
建议先了解下这篇文章:MV-softmax
1.背景
人脸识别中常用损失函数主要包括两类,基于间隔和难样本挖掘,这两种方法损失函数的训练策略都存在缺陷。前一种方法是对所有样本都采用一个固定的间隔值,没有充分利用每个样本自身的难易信息,这可能导致在使用大边际时出现收敛问题;后一种方法则在整个网络训练周期都强调难样本,可能出现网络无法收敛问题。在本论文中,提出了一种新的自适应课程学习损失函数,称为CurricularFace,它能够很好地解决上述两类损失函数存在的问题。
下图是CurricularFace跟ArcFace和 MV-Arc-Softmax两种方法的对比,可以看到CurricularFace的优势还是很明显的,通过自适应的方式实现,在早期突出易样本的作用(红色虚线),而在晚期突出难样本的作用(红色实线)

注:Curriculum Learning即课程学习,它是由Montreal大学的Bengio教授团队在2009年的ICML上提出的,其主要思想是模仿人类学习的特点,按照从简单到困难的程度来学习课程,这样容易使模型找到更好的局部最优,同时加快训练速度。
– MV-Sotamax存在的问题:从training起始阶段就开始强调semi-hard/hard-sample,可能会导致模型的收敛问题!
easy sample first, hard sample later!
2.方法
论文中提出的一种新的自适应课程学习损失CurricularFace,是将课程学习的思想嵌入到损失函数中,以实现一种新的深度人脸识别训练策略。该策略主要针对早期训练阶段的易样本和后期训练阶段的难样本,使其在不同的训练阶段,通过一个课程表自适应地调整简单和困难样本的相对重要性。也就是说,在每个阶段,不同的样本根据其相应的困难程度被赋予不同的重要性。
由于人类学习的本质是先易后难,CurricularFace是以一种适应性的方式将课程学习的理念融入到人脸识别中,这与传统的认知有两处明显不同:
1)首先,课程设计的自适应性。在传统的课程学习中,样本是按照相应的难易程度排序的,这些难易程度往往是由先验知识定义的,然后固定下来建立课程。而在CurricularFace中,做法是由每个Batch随机抽取样本,通过在线挖掘难样本自适应地建立课程。
2)其次,难样本的重要性是自适应的。一方面,易样本和难样本的相对重要性是动态的,可以在不同的训练阶段进行调整。另一方面,当前Batch中每一个难样本的重要性取决于其自身的难易程度。
具体来看,文中选择Batch中的被误分类样本作为难样本,通过调整样本与假类别中心向量之间的余弦相似度的调制系数来加权。为了在整个训练过程中实现自适应课程学习的目标,论文设计了一种新的系数函数,该函数包括以下两个因子:
1)自适应估计参数t,该参数利用样本和其真类别间的Positive余弦相似度的移动平均值来实现自适应,以消除人工调整的负担。
2)余弦角度参数,该参数定义难样本实现自适应分配的的难易性。
上面介绍完了CurricularFace的基本原理,我们来看下其损失函数是如何定义的,如下:

其中,T(cos(θ_y)) = cos(θ_y + m), I (t, cos(θ_j))表示样本的权重函数,N(t, cos(θ_j))定义如下:

Adaptive Estimation of t.
在不同的训练阶段决定一个恰当的t的值是十分重要的。理想情况下,t的值能够指示模型的训练阶段。我们通过经验发现正cosine相似度的平均值是一个好的指示器。可是min-batch的基于统计的方法往往面临一个问题:当许多极端数据被采样到一个mini-batch时,统计可能是一个很大的噪声,估计值可能很不稳定。Exponential Moving Average (EMA)方法是一个常用的解决该问题的方法,假设r(k)是第k个batch的正cosine相似度的平均值,r^(0) = 0,即:

则有(t^(k)随着k的增加,会呈现出单调递增的趋势):



Note : (a, b), a表示在训练过程中[某个时刻] curricular_loss和arcface-loss的比值;b表示max {cos(θ_j), j ≠ yi}
3.训练
3.1.训练步骤

3.2.训练曲线


1.x-axis : iterations, y-axis : 难样本的调整系数
2. t:adaptive parameter; M : MV-Arc-Softmax; M(ours) : gradient modulation coefficients
3.在训练早期,t --> 0,模型可以利用easy-sample加速收敛;在训练中后期t不断增大使得I(t, cos(θ_j)) > 1,这样模型可以更多地关注hard-smaples.
4.实验
从Figure 4中可以看到,在整个训练阶段,CurricularFace对于难样本的决策边界从训练早期到后期自适应性的变化。

最终,与其它方法相比,CurricularFace下的人脸识别效果得到明显改善(如Table4与Table6)


5.结论
论文提出的自适应课程学习损失CurricularFace,将自适应课程学习的思想嵌入到人脸识别中。该方法易于实现,收敛性强,能够明显的提升人脸识别的准确率,而且它解决的是经常在训练过程中出现的问题(如:大边际和难样本),因而具备很高的实用价值。
pytorch代码:
class CurricularFace(nn.Module):"""Implementation for "CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition"."""def __init__(self, in_features, out_features, device_id=None, m = 0.5, s = 64., fp16 = False):super(CurricularFace, self).__init__()self.device_id = device_idself.fp16 = fp16self.m = mself.s = sself.cos_m = math.cos(m)self.sin_m = math.sin(m)self.threshold = math.cos(math.pi - m)self.mm = math.sin(math.pi - m) * mself.kernel = Parameter(torch.FloatTensor(out_features, in_features))self.register_buffer('t', torch.zeros(1))nn.init.xavier_uniform_(self.kernel) #self.kernel = Parameter(torch.Tensor(in_features, out_features))#self.register_buffer('t', torch.zeros(1))#nn.init.normal_(self.kernel, std=0.01)def forward(self, feats, labels):#kernel_norm = F.normalize(self.kernel, dim=0)#feats = F.normalize(feats)#cos_theta = torch.mm(feats, kernel_norm)sub_weights = torch.chunk(self.kernel, len(self.device_id), dim=0)temp_x = feats.cuda(self.device_id[0])weight = sub_weights[0].cuda(self.device_id[0])cos_theta = F.linear(F.normalize(temp_x), F.normalize(weight))for i in range(1, len(self.device_id)):temp_x = x.cuda(self.device_id[i])weight = sub_weights[i].cuda(self.device_id[i])cos_theta = torch.cat((cos_theta, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)cos_theta = cos_theta.clamp(-1.0, 1.0) # for numerical stabilitywith torch.no_grad():origin_cos = cos_theta.clone()target_logit = cos_theta[torch.arange(0, temp_x.size(0)), labels].view(-1, 1)sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m #cos(target+margin)mask = cos_theta > cos_theta_mif self.fp16:cos_theta_m = cos_theta_m.half()final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)hard_example = cos_theta[mask]with torch.no_grad():self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.tif self.fp16:self.t = self.t.half()cos_theta[mask] = hard_example * (self.t + hard_example)if self.device_id != None:cos_theta = cos_theta.cuda(self.device_id[0])cos_theta.scatter_(1, labels.view(-1, 1).long(), final_target_logit)output = cos_theta * self.sreturn output
相关文章:
【人脸识别】CurricularFace:自适应课程学习人脸识别损失函数
论文题目:《CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition》 论文地址:https://arxiv.org/pdf/2004.00288v1.pdf 代码地址:https://github.com/HuangYG123/CurricularFace 建议先了解下这篇文章:…...
springmvc之rest风格(RESTFUL)
目录 一、介绍 1.什么是REST? 2.REST的实质 3.REST风格的优点 4.REST风格的缺点 3.什么是RESTful? 二、代码理解 一、介绍 1.什么是REST? 答:REST(Representational State Transfer) ,表现形式转…...
django项目实战十四(django+bootstrap实现增删改查)进阶混合数据使用modelform上传
目录 一、启用media 1、URL设置 2、settings.py配置 二、url 三、upload.py 新增upload_modelform方法 四、form.py新增UpModelForm 五、创建city表 六、创建city_list.html 接上一篇《django项目实战十三(djangobootstrap实现增删改查)进阶混合数据f…...
2023年CDGA考试模拟题库(1-100)
2023年CDGA考试模拟题库(1-100) 1.以下哪种活动中 ,混淆是不足以保护数据 的?[1分] A.数据共享 B.数据转换 C.数据脱敏 D.以上都正确 答案C 2.关于受控词表描述不正确的是?[1分] A.系统地组织文件档案和内容离不开受控词表 B.受控词表的一个例子是用于出版物分类的都…...
HTML常用基础内容总结
文章目录一、对HTML的感性认知前置知识什么是web前端,什么是web后端前端技术栈、后端技术栈开发与运行的区别浏览器的功能是什么简介写一个简单可运行的的html代码前端开发方式二、VSCode的简单使用三、常用的HTML标签最最基本的HTML结构HTML代码特点注释标签标题标…...
Gorm-学习笔记
1 基本使用 2 创建数据 2.1 如何使用Upsert 使用clause.OnConflict处理数据冲突 2.2 如何使用默认值 通过使用default标签为字段定义默认值 3 查询数据 3.1 First与Find 使用First时,需要注意查询不到数据会返回ErrRecordNotFound。 使用Find查询多条数据&#x…...
【Neo4j】图数据库CypherQueryLanguage随笔
CQL语言随笔 一、Cyther关系描述 如图:唐僧,孙悟空,白骨精三者的关系图: Cypher语言描述他们的关系: (孙悟空)<-[:赶走]-(唐僧)-[:被骗]->(白骨精)-[:被打死]->(孙悟空) 二、CQL语言的使用案例 创建结点…...
STM32Cube串口USART发送接收数据
本文代码使用 HAL 库。 文章目录前言一、USART 同步/异步串行接收/发送器二、USART 原理图三、CubeMX 创建工程四、usart.c 文件解析五,设计实验:在 串口输入字符点亮led实验现象:总结前言 这篇文章介绍 实现 USART 异步模式下 通过 串口助手…...
OpenFeign详解
OpenFeign是什么? OpenFeign: OpenFeign是Spring Cloud 在Feign的基础上支持了SpringMVC的注解,如RequesMapping等等。OpenFeign的FeignClient可以解析SpringMVC的RequestMapping注解下的接口,并通过动态代理的方式产生实现类&am…...
python多线程网络编程
背景 使用过flask框架后,我对request这个全局实例非常感兴趣。它在客户端发起请求后会保存着所有的客户端数据,例如用户上传的表单或者文件等。那么在很多客户端发起请求时,服务器是怎么去区分不同的request对象呢?当查看了大量的…...
BFS-走迷宫
题目描述 给定一个 NM 的网格迷宫 G。G 的每个格子要么是道路,要么是障碍物(道路用 1 表示,障碍物用 0 表示)。 已知迷宫的入口位置为 (x1,y1),出口位置为 (x2...
【蓝牙mesh】Lower协议层介绍
【蓝牙mesh】Lower协议层介绍 Lower层简介 Lower协议层用于处理网络层以下的功能,包括节点的广播、重传、路由和网络拓扑等,是实现蓝牙mesh网络的关键协议之一。其中Lower协议层中最主要的一部分工作就是mesh数据的分片和组包。 Lower层是将Upper层发过…...
Java-重排序,happens-before 和 as-if-serial 语义
目录1. 如何解决重排序带来的问题2. happens-before1. 如何解决重排序带来的问题 对于编译器,JMM 的编译器重排序规则会禁止特定类型的编译器重排序。对于处理器重排序,JMM 的处理器重排序规则会要求编译器在生成指令序列时,插入特定类型的内…...
Nginx安装及介绍
前言:传统结构上(如下图所示)我们只会部署一台服务器用来跑服务,在并发量小,用户访问少的情况下基本够用但随着用户访问的越来越多,并发量慢慢增多了,这时候一台服务器已经不能满足我们了,需要我们增加服务…...
【华为OD机试模拟题】用 C++ 实现 - 寻找路径 or 数组二叉树(2023.Q1)
最近更新的博客 【华为OD机试模拟题】用 C++ 实现 - 获得完美走位(2023.Q1) 文章目录 最近更新的博客使用说明寻找路径 or 数组二叉树题目输入输出描述示例一输入输出示例二输入输出Code使用说明 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,通过…...
LINUX学习记录
回顾系列:两天的时间(2023.2.24-2023.2.25)重新学了遍Linux基础课,收获非常多,以前只会一些简单的Linux命令,对shell,git,管道,复杂Linux命令都不熟悉,学完之…...
华为OD机试用Python实现 -【狼羊过河 or 羊、狼、农夫过河】(2023-Q1 新题)
华为OD机试题 华为OD机试300题大纲狼羊过河 or 羊、狼、农夫过河题目描述输入描述输出描述说明示例一输入输出说明Python 代码实现代码实现思路华为OD机试300题大纲 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,通过率才会高。 华为 OD 清单查看地址…...
【SAP Abap】X-DOC:SAP ABAP 语法更新之Open SQL
SAP ABAP 语法更新之Open SQL1、前言2、演示1、前言 自从 SAP 推出 SAP ON HANA,与之相随的 AS ABAP NW 7.40 版本以后,ABAP 语法也有了较多的更新,本篇对 Open Sql的语法更新部分做一个DEMO演示。 NW 7.40 以前 OpenSQL 的限制:…...
leetcode 困难 —— 数组中的逆序对(分治法)
题目: 在数组中的两个数字,如果前面一个数字大于后面的数字,则这两个数字组成一个逆序对。输入一个数组,求出这个数组中的逆序对的总数。 题解: ① 我最开始想的蠢方法(会超时,可跳过ÿ…...
02.24:图片的风格转换
Github网址:https://github.com/lengstrom/fast-style-transfer 在anaconda prompt中切换环境命令:activate 环境名 列出所有环境名:conda info --envs 安装环境:conda create -n 环境名 pythonx.x.x 删除某环境:co…...
第19节 Node.js Express 框架
Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...
应用升级/灾备测试时使用guarantee 闪回点迅速回退
1.场景 应用要升级,当升级失败时,数据库回退到升级前. 要测试系统,测试完成后,数据库要回退到测试前。 相对于RMAN恢复需要很长时间, 数据库闪回只需要几分钟。 2.技术实现 数据库设置 2个db_recovery参数 创建guarantee闪回点,不需要开启数据库闪回。…...
CTF show Web 红包题第六弹
提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框,很难让人不联想到SQL注入,但提示都说了不是SQL注入,所以就不往这方面想了 先查看一下网页源码,发现一段JavaScript代码,有一个关键类ctfs…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序
一、开发环境准备 工具安装: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 项目初始化: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...
2025盘古石杯决赛【手机取证】
前言 第三届盘古石杯国际电子数据取证大赛决赛 最后一题没有解出来,实在找不到,希望有大佬教一下我。 还有就会议时间,我感觉不是图片时间,因为在电脑看到是其他时间用老会议系统开的会。 手机取证 1、分析鸿蒙手机检材&#x…...
【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具
第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...
【开发技术】.Net使用FFmpeg视频特定帧上绘制内容
目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法,当前调用一个医疗行业的AI识别算法后返回…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
laravel8+vue3.0+element-plus搭建方法
创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...
招商蛇口 | 执笔CID,启幕低密生活新境
作为中国城市生长的力量,招商蛇口以“美好生活承载者”为使命,深耕全球111座城市,以央企担当匠造时代理想人居。从深圳湾的开拓基因到西安高新CID的战略落子,招商蛇口始终与城市发展同频共振,以建筑诠释对土地与生活的…...
