当前位置: 首页 > news >正文

PyTorch 基础学习(13)- 混合精度训练

系列文章:
《PyTorch 基础学习》文章索引

基本概念

混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16torch.bfloat16)数据类型的优势,提高计算效率和内存利用率。

  • 高精度(torch.float32:适合需要大动态范围的操作,如损失计算、缩减操作(如求和、平均)等。这些操作对数值稳定性要求较高,使用高精度能确保计算结果的准确性。

  • 低精度(torch.float16torch.bfloat16:适合计算密集型操作,如卷积和矩阵乘法。这些操作在低精度下可以显著提升计算速度,同时减少显存占用。

混合精度训练的核心思想是在模型中自动选择合适的数据类型,以在加速计算的同时,尽可能保持结果的准确性。PyTorch 提供了 torch.amp 模块,该模块封装了一些便捷的工具,使得混合精度的实现更加直观和高效。

重要方法及其作用

torch.autocast

torch.autocast 是混合精度训练中的核心工具。它是一个上下文管理器或装饰器,用于在代码的特定部分启用混合精度。在这些被启用的区域内,autocast 将根据操作的特性自动选择合适的数据类型。例如,卷积操作可以自动转换为 float16,而损失计算则保持为 float32

主要参数:

  • device_type:指定设备类型,如 cudacpuxpu
  • dtype:指定在 autocast 区域内使用的低精度数据类型。对于 CUDA 设备,默认是 torch.float16;对于 CPU 设备,默认是 torch.bfloat16
  • enabled:是否启用混合精度。默认为 True
  • cache_enabled:是否启用权重缓存。默认是 True,可以在某些场景下提高性能。

torch.amp.GradScaler

在低精度(如 float16)下,梯度值较小的操作可能会出现下溢现象,导致梯度值变为零,从而影响模型的训练。为了避免这种情况,PyTorch 提供了 GradScaler,它通过在反向传播之前动态缩放损失值,从而放大梯度值,使其在低精度下也能被有效表示。之后,优化器会在更新参数之前对梯度进行反缩放,以确保不会影响学习率。

主要参数:

  • init_scale:初始的缩放因子,默认是 65536.0
  • growth_factor:在没有发生下溢的情况下,缩放因子增长的倍数,默认是 2.0
  • backoff_factor:发生下溢时,缩放因子减少的倍数,默认是 0.5
  • growth_interval:在多少个步骤之后,如果没有下溢,缩放因子会增长,默认是 2000
  • enabled:是否启用梯度缩放,默认为 True

适用的场景

GPU 训练
在使用 CUDA 设备进行深度学习模型训练时,启用混合精度可以显著提升模型的训练速度。尤其是在使用大规模数据和复杂模型(如卷积神经网络、Transformer 模型)时,torch.autocast(device_type="cuda") 能够有效地减少 GPU 的计算负载,并提高吞吐量。

CPU 训练与推理
虽然 GPU 在深度学习中更常用,但在一些特定场景下(如低资源环境或需要在 CPU 上进行部署),混合精度在 CPU 上同样具有优势。使用 torch.autocast(device_type="cpu", dtype=torch.bfloat16) 可以在推理过程中降低计算复杂度,同时保持较高的精度。

3.3 自定义操作
在某些高级用例中,用户可能需要为自定义的自动微分函数实现混合精度支持。通过 torch.amp.custom_fwdtorch.amp.custom_bwd,用户可以定义在特定设备(如 cuda)上执行的前向和反向操作,并确保这些操作在混合精度模式下正常运行。

应用实例

以下是一个在 CUDA 设备上使用混合精度进行训练的完整示例,展示了如何在实践中应用 torch.autocasttorch.amp.GradScaler

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler# 定义简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型和优化器,使用默认精度(float32)
model = SimpleModel().cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 创建GradScaler
scaler = GradScaler()# 训练循环
for epoch in range(10):  # 假设有10个epochfor input, target in data_loader:  # 假设有一个data_loaderinput, target = input.cuda(), target.cuda()optimizer.zero_grad()# 在前向传播过程中启用自动混合精度with autocast(device_type="cuda"):output = model(input)loss = loss_fn(output, target)# 使用GradScaler进行反向传播scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch+1} completed.")

代码说明

  • 首先,我们定义了一个简单的神经网络模型,并将其放置在 CUDA 设备上。
  • 在每次训练循环中,我们使用 torch.autocast(device_type="cuda") 上下文管理器包裹前向传播过程,使得模型的计算自动使用混合精度。
  • 使用 GradScaler 对损失进行缩放,并在缩放后的损失上调用 backward() 进行反向传播。这一步骤有助于防止梯度下溢。
  • scaler.step(optimizer) 用于更新模型参数,scaler.update() 则是调整缩放因子。

这种方法既能提高训练速度,又能在较低精度下保持数值稳定性,是在实际项目中应用混合精度训练的有效方案。

注意事项

  • 弃用警告:从 PyTorch 1.10 开始,原有的 torch.cuda.amp.autocasttorch.cpu.amp.autocast 方法被弃用,推荐使用通用的 torch.autocast 代替。这不仅简化了接口,也为未来的设备扩展提供了灵活性。

  • 数据类型匹配:在使用 autocast 时,确保输入数据类型的一致性非常重要。如果在混合精度区域内生成的张量在退出后与其他不同精度的张量混合使用,可能会导致类型不匹配错误。因此,在必要时,需要手动将张量转换为 float32 或其他合适的精度。

  • GradScaler 的适用性:虽然 GradScaler 对大多数模型都有效,但在某些情况下(例如使用 bf16 预训练模型),可能会出现梯度溢出的情况。因此,在使用混合精度训练时,需要根据具体模型的特性进行调整。

通过对这些概念、方法、使用场景和实例的深入理解,您可以在实际项目中更好地应用混合精度训练,从而提升深度学习模型的训练效率和性能。

相关文章:

PyTorch 基础学习(13)- 混合精度训练

系列文章: 《PyTorch 基础学习》文章索引 基本概念 混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16 或 torch.bfloat16)数据类型的优势,…...

Mycat分片-垂直拆分

目录 场景 配置 测试 全局表配置 续接上篇:MySQ分库分表与MyCat安装配置-CSDN博客 续接下篇:Mycat分片-水平拆分-CSDN博客 场景 在业务系统中, 涉及以下表结构 ,但是由于用户与订单每天都会产生大量的数据, 单台服务器的数据 存储及处理能力是有限…...

一元四次方程求解-【附MATLAB代码】

目录 前言 求解方法 ​编辑 MATLAB验证 附:一元四次方程的故事 前言 最近在研究机器人的干涉(碰撞)检测,遇到了一个问题,就是在求椭圆到原点的最短距离时,构建的方程是一个一元四次方程。无论是高中的…...

【极限性能,尽在掌控】ROG NUC:游戏与创作的微型巨擘

初见ROG NUC,你或许会为它的小巧体型惊讶。然而,这看似不起眼的机身内,蕴藏着游戏、创意的强大能量。 掌中风暴,性能无界 ROG NUC搭载英特尔高性能处理器,配合高速NVMe SSD固态硬盘以及可选的高端独立显卡&#xff08…...

Ecosmos开启公测,将深度赋能CIOE中国光博会元宇宙参会新体验

如今,生成式AI技术的发展,极大地降低了3D数字资产的制作成本,元宇宙作为一种可以无缝将物理和数字资产进行融合的技术,在推动电子产业数字化进程、助力产业高质量发展的方面展现出了巨大的潜力。 当前,发展新质生产力是…...

【Kubernetes】k8s集群之包管理器Helm

目录 一.Helm概述 1.Helm的简介 2.Helm的三个重要概念 3.Helm2与Helm3的的区别 二.Helm 部署 1.安装 helm 2.使用 helm 安装 Chart 3.Helm 自定义模板 4.Helm 仓库 每个成功的软件平台都有一个优秀的打包系统,比如Debian、Ubuntu 的 apt,RedH…...

嵌入式linux系统镜像制作day3(构建镜像)

点击上方"蓝字"关注我们 01、上节回顾 嵌入式linux系统镜像制作day1嵌入式linux系统镜像制作day2提前下载好准备工具,不然失败了大眼瞪小眼。 02、构建 Poky 的 Sato 镜像1 环境: ubuntu18.04poky版本:Dizzy 工具git 在开始之前,针对不同的发行版,需要先执行…...

【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】

教师节中秋节国庆节车模特美女举牌生日视频制作教程AE模板改文字软件生成器素材 怎么如何做的【生日视频制作】教师节中秋节国庆节车模特美女举牌AE模板修改文字软件生成器教程特效素材【AE模板】 生日视频制作步骤: 安装AE软件下载AE模板把AE模板导入AE软件修改图…...

RongCallKit iOS 端本地私有 pod 方案

RongCallKit iOS 端本地私有 pod 方案 需求背景 适用于源码集成 CallKit 时,使用 pod 管理 RTC framework 以及源码。集成 CallKit 时,需要定制化修改 CallKit 的样式以及部分 UI 功能。适用于 CallKit 源码 Debug 调试便于定位相关问题。 解决方案 从…...

C++11:可变参数模板

目录 一、概述 二、场景 1.深拷贝的类 2.浅拷贝的类 C使用指南 一、概述 // Args是一个模板参数包&#xff0c;args是一个函数形参参数包 // 声明一个参数包Args...args&#xff0c;这个参数包中可以包含0到任意个模板参数。 template <class ...Args> void ShowList(…...

C++ 与 QML 之间进行数据交互的几种方法

https://www.cnblogs.com/jzcn/p/17774676.html 一、属性绑定 这是最简单的方式&#xff0c;可以在QML中直接绑定C 对象的属性。通过在C 对象中使用Q_PROPERTY宏定义属性&#xff0c;然后在QML中使用绑定语法将属性与QML元素关联起来。 1. person.h #include <QObject&g…...

Javaweb学习之Vue项目的创建(二)

学习资料 Vue.js - 渐进式 JavaScript 框架 | Vue.js (vuejs.org) 准备工作都做完了&#xff0c;接下来开始Vue的正式学习。 第一步&#xff0c;打开VS Code 在VS Code里&#xff0c;我们也需要使用到终端&#xff0c;如果不是以管理员身份打开&#xff0c;在新建Vue项目的时候…...

『深度长文』4种有效提高LLM输出质量的方法!

LLM&#xff0c;全称Large Language Model&#xff0c;意为大型语言模型&#xff0c;是一种基于深度学习的AI技术&#xff0c;能够生成、理解和处理自然语言文本&#xff0c;也因此成为当前大多数AI工具的核心引擎。LLM通过学习海量的文本数据&#xff0c;掌握了词汇、语法、语…...

【工业机器人】工业异常检测大模型AnomalyGPT

AnomalyGPT 工业异常检测视觉大模型AnomalyGPT AnomalyGPT: Detecting Industrial Anomalies using Large Vision-Language Models AnomalyGPT是一种基于大视觉语言模型&#xff08;LVLM&#xff09;的新型工业异常检测&#xff08;IAD&#xff09;方法。它利用LVLM的能力来理…...

【PGCCC】PostgreSQL案例:planning time超长问题分析#PG初级

在使用 PostgreSQL 时&#xff0c;查询的执行计划&#xff08;planning time&#xff09;有时会出现异常长的情况&#xff0c;这可能会影响数据库的整体性能。分析和解决这种问题可以从多个角度入手&#xff0c;以下是常见原因和相应的解决思路&#xff1a; 1. 统计信息不准确…...

【图文并茂】ant design pro 如何给后端发送 json web token - 请求拦截器的使用

上一节有讲过 【图文并茂】ant design pro 如何对接后端个人信息接口 还差一个东西&#xff0c;去获取个人信息的时候&#xff0c;是要发送 token 的&#xff0c;不然会报 403. 就是说在你登录之后才去获得个人信息。这样后端才能知道是谁的信息。 token 就代码了某个人。 …...

【微信小程序】自定义组件 - behaviors

1. 什么是 behaviors 2. behaviors 的工作方式 3. 创建 behavior 调用 Behavior(Object object) 方法即可创建一个共享的 behavior 实例对象&#xff0c;供所有的组件使用&#xff1a; 4. 导入并使用 behavior 5. behavior 中所有可用的节点 6. 同名字段的覆盖和组合规则* 关…...

Linux ubuntu 24.04 安装运行《帝国时代3》免安装绿色版游戏,解决 “Could not load DATAP.BAR”等问题

Linux ubuntu 24.04 安装运行《帝国时代3》游戏&#xff0c;解决 “Could not load DATAP.BAR" 等问题 《帝国时代 3》是一款比较经典的即时战斗游戏&#xff0c;伴随了我半个高中时代&#xff0c;周末有时间就去泡网吧&#xff0c;可惜玩的都是简单人机&#xff0c;高难…...

Springboot 图片

Springboot 图片 因为 server.servlet.context-path: /api 所以 url是这个的时候 http://127.0.0.1:9100/api/staticfiles/image/dd56a59d-da84-441a-8dac-1d97f9e42090.jpeg 配置代码的前面的 /api 是不要写的 package com.gk.study.config;import org.springframework.conte…...

LIMS实验室管理系统如何实现数据自动采集

随着科研技术的不断发展&#xff0c;LIMS实验室管理系统的应用也愈来愈广&#xff0c;已经成为现代化实验室管理不可或缺的工具。LIMS实验室管理系统未与仪器设备对接前&#xff0c;仪器设备产生的数据都是通过人工录入到系统中&#xff0c;再经过人工审核形成最终的数据报告。…...

全自动商用油炸锅介绍:

全自动商用油炸锅‌是一种专门为商业用途设计的厨房设备&#xff0c;旨在高效、节能、卫生地完成大量食品的油炸加工。这种设备通常采用油水混合技术&#xff0c;能够自动过滤残渣&#xff0c;延长换油周期&#xff0c;从而大大降低用油成本。全自动商用油炸锅适合中、小型油炸…...

CE修改器的简单使用

前言 这个系列目前是出于兴趣爱好&#xff0c;最终目的是为了可以用代码控制修改单机游戏。 这篇文章的对象是《植物大战僵尸杂交版》&#xff0c;其余游戏类似。 博客仅做技术研究使用&#xff0c;禁止用作商业用途。 1&#xff0c;安装CE修改器 到官网进行下载&#xff…...

element-plus el-cascader懒加载怎么指定对应的label和value。最后一级怎么判断?

<el-cascader:props"props"placeholder"请选择现地址所在地"v-model"currentaddress"ref"currentaddressRef"change"currentaddressChange"style"width:100%"clearable/> 懒加载需要用到props。 const pro…...

pdf查看密码

pdf有两种密码方式&#xff0c;一种是打开后进入文件内容页面后需要密码才能进行修改等操作&#xff0c;网上有很多方式进行移除密码操作&#xff0c;第二种是打开就需要密码&#xff0c;我这里简单记录一个暴力破解的方式&#xff0c;仅供参考 import PyPDF2 import itertools…...

从bbl和overleaf版本解决Arxiv提交后缺失参考文献Citation on page undefined on input line

debug 食用指南&#xff1a;框架/语言&#xff1a;问题描述&#xff1a;解决方案&#xff1a;问题原因&#xff1a;版本解决方案&#xff1a; 安利时间&#xff1a; 食用指南&#xff1a; 框架使用过程中的问题首先要注意版本发布时间造成方法弃用 当你在CSDN等网站查找不到最…...

Flutter【01】状态管理

声明式编程 Flutter 应用是 声明式 的&#xff0c;这也就意味着 Flutter 构建的用户界面就是应用的当前状态。 当你的 Flutter 应用的状态发生改变时&#xff08;例如&#xff0c;用户在设置界面中点击了一个开关选项&#xff09;你改变了状态&#xff0c;这将会触发用户界面…...

(转载)使用zed相机录制视频

参照下面这个链接 https://blog.csdn.net/peng_258/article/details/127457199?ops_request_misc&request_id&biz_id102&utm_termzed2%E5%BD%95%E5%88%B6%E6%95%B0%E6%8D%AE%E9%9B%86&utm_mediumdistribute.pc_search_result.none-task-blog-2~all~sobaiduweb…...

C/C++中奇妙的类型转换

1.引言 大家在学习C语言的时候&#xff0c;有没有遇见过类似于下面这样的代码呢&#xff1f; // 整形转bool int count 10; while(count--) {cout << count << endl; }// 指针转bool int* ptr cur; while(ptr) {//…… } 众所周知&#xff0c;while循环的判断…...

嵌入式AI快速入门课程-K510篇 (第三篇 环境搭建及开发板操作)

第三篇 环境搭建及开发板操作 文章目录 第三篇 环境搭建及开发板操作1.配置VMware使用桥接网卡1.1 vmware设置1.2 虚拟网络编辑器设置 2.安装软件2.2 安装 Windows 软件2.3 使用MobaXterm远程登录Ubuntu2.4 使用FileZilla在Windows和Ubuntu之间传文件2.5编程示例&#xff1a;Ub…...

C++第三十九弹---C++ STL中的无序容器:unordered_set与unordered_map使用详解

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】【C详解】 目录 1 unordered_set 1.1 unordered_set的接口说明 1.1.1 unordered_set的构造 1.1.2. unordered_set的容量 1.1.3. unordered_set的迭代器 1.1…...

flash代码做网站教程/上海公关公司

计算机编程术语&#xff1a; 参考网址&#xff1a;https://blog.csdn.net/linear_luo/article/details/52318820 application 应用程式 应用、应用程序 application framework 应用程式框架、应用框架 应用程序框架 architecture 架构、系统架构 体系结构 argument 引数&am…...

广州市口碑好的网站制作排名/百度指数怎么查询

在今天的互联网里&#xff0c;高并发、大数据量、大流量已经成为了代言词&#xff0c;那么我们的系统也承受着巨大的压力&#xff0c;首当其冲的解决方案就是redis。 那么redis使用不当就会产生雪崩、穿透、击穿等问题&#xff0c;这也是考验一个程序员技术能力的时刻。 当然面…...

网站更新的意义/百度竞价推广价格

需要准备的硬件 MC20开发板 1个https://item.taobao.com/item.htm?id562661881042GSM/GPRS天线 1根https://item.taobao.com/item.htm?id531979567261IPEX接口转SMA接口转接线 1根https://item.taobao.com/item.htm?id531979903836GPS有源天线 1根https://item.taobao.com/i…...

晋城企业网站建设价格/搜狐视频

大家好&#xff0c;我是老赵&#xff5e;我有一个朋友&#xff5e;做了一个小破站&#xff0c;现在要实现一个站内信web消息推送的功能&#xff0c;对&#xff0c;就是下图这个小红点&#xff0c;一个很常用的功能。不过他还没想好用什么方式做&#xff0c;这里我帮他整理了一下…...

苏州网站建设公司鹅鹅鹅/网址导航哪个好

2019独角兽企业重金招聘Python工程师标准>>> Swift目前已经支持keystone认证&#xff0c;不过官方的安装文档中还使用了TempAuth,这篇翻译&#xff0c;关于auth帮助我们更好的理解swift auth&#xff0c; The Auth System 认证系统 TempAuth Swift的认证系统松散的基…...

常州网站建设机构/常见的网络推广方法

获取或设置属性节点的值 操作标签属性 getAttribute(attrName) 作用 &#xff1a; 获取指定属性的值 attrName : 属性名setAttribute(attrName,value) 作用 &#xff1a;设置属性和对应的值 attrName : 属性名 value &#xff1a;属性值removeAttribute(attrName) 作用 &…...