当前位置: 首页 > news >正文

【人脸识别】CurricularFace:自适应课程学习人脸识别损失函数

论文题目:《CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition》
论文地址:https://arxiv.org/pdf/2004.00288v1.pdf
代码地址:https://github.com/HuangYG123/CurricularFace

建议先了解下这篇文章:MV-softmax

1.背景

       人脸识别中常用损失函数主要包括两类,基于间隔和难样本挖掘,这两种方法损失函数的训练策略都存在缺陷。前一种方法是对所有样本都采用一个固定的间隔值,没有充分利用每个样本自身的难易信息,这可能导致在使用大边际时出现收敛问题;后一种方法则在整个网络训练周期都强调难样本,可能出现网络无法收敛问题。在本论文中,提出了一种新的自适应课程学习损失函数,称为CurricularFace,它能够很好地解决上述两类损失函数存在的问题。
       下图是CurricularFace跟ArcFace和 MV-Arc-Softmax两种方法的对比,可以看到CurricularFace的优势还是很明显的,通过自适应的方式实现,在早期突出易样本的作用(红色虚线),而在晚期突出难样本的作用(红色实线)

在这里插入图片描述

注:Curriculum Learning即课程学习,它是由Montreal大学的Bengio教授团队在2009年的ICML上提出的,其主要思想是模仿人类学习的特点,按照从简单到困难的程度来学习课程,这样容易使模型找到更好的局部最优,同时加快训练速度。

– MV-Sotamax存在的问题:从training起始阶段就开始强调semi-hard/hard-sample,可能会导致模型的收敛问题!

easy sample first, hard sample later!

2.方法

       论文中提出的一种新的自适应课程学习损失CurricularFace,是将课程学习的思想嵌入到损失函数中,以实现一种新的深度人脸识别训练策略。该策略主要针对早期训练阶段的易样本和后期训练阶段的难样本,使其在不同的训练阶段,通过一个课程表自适应地调整简单和困难样本的相对重要性。也就是说,在每个阶段,不同的样本根据其相应的困难程度被赋予不同的重要性。
       由于人类学习的本质是先易后难,CurricularFace是以一种适应性的方式将课程学习的理念融入到人脸识别中,这与传统的认知有两处明显不同:
       1)首先,课程设计的自适应性。在传统的课程学习中,样本是按照相应的难易程度排序的,这些难易程度往往是由先验知识定义的,然后固定下来建立课程。而在CurricularFace中,做法是由每个Batch随机抽取样本,通过在线挖掘难样本自适应地建立课程
       2)其次,难样本的重要性是自适应的。一方面,易样本和难样本的相对重要性是动态的,可以在不同的训练阶段进行调整。另一方面,当前Batch中每一个难样本的重要性取决于其自身的难易程度。
       具体来看,文中选择Batch中的被误分类样本作为难样本,通过调整样本与假类别中心向量之间的余弦相似度的调制系数来加权。为了在整个训练过程中实现自适应课程学习的目标,论文设计了一种新的系数函数,该函数包括以下两个因子:
       1)自适应估计参数t,该参数利用样本和其真类别间的Positive余弦相似度的移动平均值来实现自适应,以消除人工调整的负担。
       2)余弦角度参数,该参数定义难样本实现自适应分配的的难易性。

       上面介绍完了CurricularFace的基本原理,我们来看下其损失函数是如何定义的,如下:
在这里插入图片描述
其中,T(cos(θ_y)) = cos(θ_y + m), I (t, cos(θ_j))表示样本的权重函数,N(t, cos(θ_j))定义如下:

在这里插入图片描述

Adaptive Estimation of t.
       在不同的训练阶段决定一个恰当的t的值是十分重要的。理想情况下,t的值能够指示模型的训练阶段。我们通过经验发现正cosine相似度的平均值是一个好的指示器。可是min-batch的基于统计的方法往往面临一个问题:当许多极端数据被采样到一个mini-batch时,统计可能是一个很大的噪声,估计值可能很不稳定。Exponential Moving Average (EMA)方法是一个常用的解决该问题的方法,假设r(k)是第k个batch的正cosine相似度的平均值,r^(0) = 0,即:

在这里插入图片描述
则有(t^(k)随着k的增加,会呈现出单调递增的趋势):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
Note : (a, b), a表示在训练过程中[某个时刻] curricular_loss和arcface-loss的比值;b表示max {cos(θ_j), j ≠ yi}

3.训练

3.1.训练步骤

在这里插入图片描述

3.2.训练曲线

在这里插入图片描述
在这里插入图片描述
1.x-axis : iterations, y-axis : 难样本的调整系数
2. t:adaptive parameter; M : MV-Arc-Softmax; M(ours) : gradient modulation coefficients
3.在训练早期,t --> 0,模型可以利用easy-sample加速收敛;在训练中后期t不断增大使得I(t, cos(θ_j)) > 1,这样模型可以更多地关注hard-smaples.

4.实验

       从Figure 4中可以看到,在整个训练阶段,CurricularFace对于难样本的决策边界从训练早期到后期自适应性的变化。
在这里插入图片描述
       最终,与其它方法相比,CurricularFace下的人脸识别效果得到明显改善(如Table4与Table6)
在这里插入图片描述
在这里插入图片描述

5.结论

       论文提出的自适应课程学习损失CurricularFace,将自适应课程学习的思想嵌入到人脸识别中。该方法易于实现,收敛性强,能够明显的提升人脸识别的准确率,而且它解决的是经常在训练过程中出现的问题(如:大边际和难样本),因而具备很高的实用价值。

pytorch代码:

class CurricularFace(nn.Module):"""Implementation for "CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition"."""def __init__(self, in_features, out_features, device_id=None, m = 0.5, s = 64., fp16 = False):super(CurricularFace, self).__init__()self.device_id = device_idself.fp16 = fp16self.m = mself.s = sself.cos_m = math.cos(m)self.sin_m = math.sin(m)self.threshold = math.cos(math.pi - m)self.mm = math.sin(math.pi - m) * mself.kernel = Parameter(torch.FloatTensor(out_features, in_features))self.register_buffer('t', torch.zeros(1))nn.init.xavier_uniform_(self.kernel)     #self.kernel = Parameter(torch.Tensor(in_features, out_features))#self.register_buffer('t', torch.zeros(1))#nn.init.normal_(self.kernel, std=0.01)def forward(self, feats, labels):#kernel_norm = F.normalize(self.kernel, dim=0)#feats = F.normalize(feats)#cos_theta = torch.mm(feats, kernel_norm)sub_weights = torch.chunk(self.kernel, len(self.device_id), dim=0)temp_x = feats.cuda(self.device_id[0])weight = sub_weights[0].cuda(self.device_id[0])cos_theta = F.linear(F.normalize(temp_x), F.normalize(weight))for i in range(1, len(self.device_id)):temp_x = x.cuda(self.device_id[i])weight = sub_weights[i].cuda(self.device_id[i])cos_theta = torch.cat((cos_theta, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)cos_theta = cos_theta.clamp(-1.0, 1.0)  # for numerical stabilitywith torch.no_grad():origin_cos = cos_theta.clone()target_logit = cos_theta[torch.arange(0, temp_x.size(0)), labels].view(-1, 1)sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m #cos(target+margin)mask = cos_theta > cos_theta_mif self.fp16:cos_theta_m = cos_theta_m.half()final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)hard_example = cos_theta[mask]with torch.no_grad():self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.tif self.fp16:self.t = self.t.half()cos_theta[mask] = hard_example * (self.t + hard_example)if self.device_id != None:cos_theta = cos_theta.cuda(self.device_id[0])cos_theta.scatter_(1, labels.view(-1, 1).long(), final_target_logit)output = cos_theta * self.sreturn output

相关文章:

【人脸识别】CurricularFace:自适应课程学习人脸识别损失函数

论文题目:《CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition》 论文地址:https://arxiv.org/pdf/2004.00288v1.pdf 代码地址:https://github.com/HuangYG123/CurricularFace 建议先了解下这篇文章&#xff1a…...

springmvc之rest风格(RESTFUL)

目录 一、介绍 1.什么是REST? 2.REST的实质 3.REST风格的优点 4.REST风格的缺点 3.什么是RESTful? 二、代码理解 一、介绍 1.什么是REST? 答:REST(Representational State Transfer) ,表现形式转…...

django项目实战十四(django+bootstrap实现增删改查)进阶混合数据使用modelform上传

目录 一、启用media 1、URL设置 2、settings.py配置 二、url 三、upload.py 新增upload_modelform方法 四、form.py新增UpModelForm 五、创建city表 六、创建city_list.html 接上一篇《django项目实战十三(djangobootstrap实现增删改查)进阶混合数据f…...

2023年CDGA考试模拟题库(1-100)

2023年CDGA考试模拟题库(1-100) 1.以下哪种活动中 ,混淆是不足以保护数据 的?[1分] A.数据共享 B.数据转换 C.数据脱敏 D.以上都正确 答案C 2.关于受控词表描述不正确的是?[1分] A.系统地组织文件档案和内容离不开受控词表 B.受控词表的一个例子是用于出版物分类的都…...

HTML常用基础内容总结

文章目录一、对HTML的感性认知前置知识什么是web前端,什么是web后端前端技术栈、后端技术栈开发与运行的区别浏览器的功能是什么简介写一个简单可运行的的html代码前端开发方式二、VSCode的简单使用三、常用的HTML标签最最基本的HTML结构HTML代码特点注释标签标题标…...

Gorm-学习笔记

1 基本使用 2 创建数据 2.1 如何使用Upsert 使用clause.OnConflict处理数据冲突 2.2 如何使用默认值 通过使用default标签为字段定义默认值 3 查询数据 3.1 First与Find 使用First时,需要注意查询不到数据会返回ErrRecordNotFound。 使用Find查询多条数据&#x…...

【Neo4j】图数据库CypherQueryLanguage随笔

CQL语言随笔 一、Cyther关系描述 如图&#xff1a;唐僧&#xff0c;孙悟空&#xff0c;白骨精三者的关系图&#xff1a; Cypher语言描述他们的关系&#xff1a; (孙悟空)<-[:赶走]-(唐僧)-[:被骗]->(白骨精)-[:被打死]->(孙悟空) 二、CQL语言的使用案例 创建结点…...

STM32Cube串口USART发送接收数据

本文代码使用 HAL 库。 文章目录前言一、USART 同步/异步串行接收/发送器二、USART 原理图三、CubeMX 创建工程四、usart.c 文件解析五&#xff0c;设计实验&#xff1a;在 串口输入字符点亮led实验现象&#xff1a;总结前言 这篇文章介绍 实现 USART 异步模式下 通过 串口助手…...

OpenFeign详解

OpenFeign是什么&#xff1f; OpenFeign&#xff1a; OpenFeign是Spring Cloud 在Feign的基础上支持了SpringMVC的注解&#xff0c;如RequesMapping等等。OpenFeign的FeignClient可以解析SpringMVC的RequestMapping注解下的接口&#xff0c;并通过动态代理的方式产生实现类&am…...

python多线程网络编程

背景 使用过flask框架后&#xff0c;我对request这个全局实例非常感兴趣。它在客户端发起请求后会保存着所有的客户端数据&#xff0c;例如用户上传的表单或者文件等。那么在很多客户端发起请求时&#xff0c;服务器是怎么去区分不同的request对象呢&#xff1f;当查看了大量的…...

BFS-走迷宫

题目描述 给定一个 NM 的网格迷宫 G。G 的每个格子要么是道路,要么是障碍物(道路用 1 表示,障碍物用 0 表示)。 已知迷宫的入口位置为 (x1,y1),出口位置为 (x2...

【蓝牙mesh】Lower协议层介绍

【蓝牙mesh】Lower协议层介绍 Lower层简介 Lower协议层用于处理网络层以下的功能&#xff0c;包括节点的广播、重传、路由和网络拓扑等&#xff0c;是实现蓝牙mesh网络的关键协议之一。其中Lower协议层中最主要的一部分工作就是mesh数据的分片和组包。 Lower层是将Upper层发过…...

Java-重排序,happens-before 和 as-if-serial 语义

目录1. 如何解决重排序带来的问题2. happens-before1. 如何解决重排序带来的问题 对于编译器&#xff0c;JMM 的编译器重排序规则会禁止特定类型的编译器重排序。对于处理器重排序&#xff0c;JMM 的处理器重排序规则会要求编译器在生成指令序列时&#xff0c;插入特定类型的内…...

Nginx安装及介绍

前言&#xff1a;传统结构上(如下图所示)我们只会部署一台服务器用来跑服务&#xff0c;在并发量小&#xff0c;用户访问少的情况下基本够用但随着用户访问的越来越多&#xff0c;并发量慢慢增多了&#xff0c;这时候一台服务器已经不能满足我们了&#xff0c;需要我们增加服务…...

【华为OD机试模拟题】用 C++ 实现 - 寻找路径 or 数组二叉树(2023.Q1)

最近更新的博客 【华为OD机试模拟题】用 C++ 实现 - 获得完美走位(2023.Q1) 文章目录 最近更新的博客使用说明寻找路径 or 数组二叉树题目输入输出描述示例一输入输出示例二输入输出Code使用说明 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,通过…...

LINUX学习记录

回顾系列&#xff1a;两天的时间&#xff08;2023.2.24-2023.2.25&#xff09;重新学了遍Linux基础课&#xff0c;收获非常多&#xff0c;以前只会一些简单的Linux命令&#xff0c;对shell&#xff0c;git&#xff0c;管道&#xff0c;复杂Linux命令都不熟悉&#xff0c;学完之…...

华为OD机试用Python实现 -【狼羊过河 or 羊、狼、农夫过河】(2023-Q1 新题)

华为OD机试题 华为OD机试300题大纲狼羊过河 or 羊、狼、农夫过河题目描述输入描述输出描述说明示例一输入输出说明Python 代码实现代码实现思路华为OD机试300题大纲 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,通过率才会高。 华为 OD 清单查看地址…...

【SAP Abap】X-DOC:SAP ABAP 语法更新之Open SQL

SAP ABAP 语法更新之Open SQL1、前言2、演示1、前言 自从 SAP 推出 SAP ON HANA&#xff0c;与之相随的 AS ABAP NW 7.40 版本以后&#xff0c;ABAP 语法也有了较多的更新&#xff0c;本篇对 Open Sql的语法更新部分做一个DEMO演示。 NW 7.40 以前 OpenSQL 的限制&#xff1a…...

leetcode 困难 —— 数组中的逆序对(分治法)

题目&#xff1a; 在数组中的两个数字&#xff0c;如果前面一个数字大于后面的数字&#xff0c;则这两个数字组成一个逆序对。输入一个数组&#xff0c;求出这个数组中的逆序对的总数。 题解&#xff1a; ① 我最开始想的蠢方法&#xff08;会超时&#xff0c;可跳过&#xff…...

02.24:图片的风格转换

Github网址&#xff1a;https://github.com/lengstrom/fast-style-transfer 在anaconda prompt中切换环境命令&#xff1a;activate 环境名 列出所有环境名&#xff1a;conda info --envs 安装环境&#xff1a;conda create -n 环境名 pythonx.x.x 删除某环境&#xff1a;co…...

[SSD综述 1.3] SSD及固态存储技术半个世纪发展史

在我们今天看来&#xff0c;SSD已不再是个新鲜事物。这多亏了存储行业的前辈们却摸爬滚打了将近半个世纪&#xff0c;才有了SSD的繁荣&#xff0c; 可惜很多前辈都没有机会看到。所有重大的技术革新都是这样&#xff0c;需要长期的技术积累&#xff0c;一代一代的工程师们默默的…...

PAT 1023 组个最小数(分数20)题目有bug

目录 题目描述&#xff1a; 题目讲解&#xff1a; 框架构建&#xff1a; 代码部分&#xff1a; 一个bug&#xff1a; 题目描述&#xff1a; 给定数字 0-9 各若干个。你可以以任意顺序排列这些数字&#xff0c;但必须全部使用。目标是使得最后得到的数尽可能小&#xff08;…...

QML 中的 5 大布局

作者: 一去、二三里 个人微信号: iwaleon 微信公众号: 高效程序员 在 QML 中,可以通过多种方式对元素进行布局 - 手动定位、坐标绑定定位、锚定位(anchors)、定位器和布局管理器。 说到 anchors,可能很多人都不太了解,它是 QML 中一个非常重要的概念,主要提供了一种相…...

使用Python进行数据分析——线性回归分析

大家好&#xff0c;线性回归是确定两种或两种以上变量之间互相依赖的定量关系的一种统计分析方法。根据自变量的个数&#xff0c;可以将线性回归分为一元线性回归和多元线性回归分析。一元线性回归&#xff1a;就是只包含一个自变量&#xff0c;且该自变量与因变量之间的关系是…...

我眼中的柔宇科技

关注、星标公众号&#xff0c;直达精彩内容来源&#xff1a;技术让梦想更伟大作者&#xff1a;李肖遥很早就知道了柔宇科技&#xff0c;当时是因为知道创始人刘自鸿&#xff0c;23岁清华本硕毕业&#xff0c;26岁获斯坦福大学电子工程博士学位&#xff0c;历时不超过3年&#x…...

Allegro如何快速把视图居中显示操作指导

Allegro如何快速把视图居中显示操作指导 用Allegro进行PCB设计的时候,为了方便检查和设计,时常需要将视图居中显示。一般地,会使用鼠标的中键进行放大和缩小,或者使用Zoom in和Zoom out来调整视图 Allegro还支持快速将视图居中 具体操作如下 点击View...

搜索相关功能

一、进入搜索页面 1.1 在pages下创建搜索页面为&#xff1a;search 1.2 在index.vue中点击进入搜素页面 onNavigationBarButtonTap(e){if(e.floatleft){uni.navigateTo({url:/pages/search/search})}},1.3 在pages.json中配置搜索页面头部 {"path" : "pages/…...

【从零开始制作 bt 下载器】一、了解 torrent 文件

【从零开始制作 bt 下载器】一、了解 torrent 文件写作背景了解 torrent 文件认识 bencodepython 解析 torrent 文件解密 torrent 文件结尾写作背景 最先开始是朋友向我诉说使用某雷下载结果显示因为版权无法下载&#xff0c;找其他的下载器有次数限制&#xff0c;于是来询问我…...

SystemVerilog-时序逻辑建模(5)多个时钟和时钟域交叉

数字硬件建模SystemVerilog-时序逻辑建模&#xff08;5&#xff09;多个时钟和时钟域交叉数字门级电路可分为两大类&#xff1a;组合逻辑和时序逻辑。锁存器是组合逻辑和时序逻辑的一个交叉点&#xff0c;在后面会作为单独的主题处理。组合逻辑描述了门级电路&#xff0c;其中逻…...

基本中型网络的仿真(RYU+Mininet的SDN架构)-以校园为例

目录 ​​​​​​​具体问题可以私聊博主 一、设计目标 1.1应用场景介绍 1.2应用场景设计要求 网络配置方式 网络技术要求 网络拓扑要求 互联互通 二、课程设计内容与原理 &#xff08;1&#xff09;预期网络拓扑结构和功能 &#xff08;1&#xff09;网络设备信息 …...