当前位置: 首页 > news >正文

DLA 神经网络的极限训练方法:gradient checkpointing

gradient checkpointing

        一般来说,训练的过程需要保存中间结果(不管是GPU还是CPU)。前向传播根据输入(bottom_data)计算输出(top_data),后向传播由top_diff计算bottom_diff(如果某个变量打开梯度进行训练的话)。top和bottom是包含数据和梯度的两个结构体,整个网络的每层top和bottom在训练的过程中都会保存,这消耗了大量的内存。

        如果不保存这些变量,每次传播时重新分配和计算,会大大减少内存的使用量,但是也会使得网络的训练时间无限延长。为了平衡这两个矛盾,论文Training Deep Nets with Sublinear Memory Cost 使用亚线性内存成本训练深度网络:我们提出了一种系统方法来减少深度的内存消耗 神经网络训练。具体来说,我们设计了一种成本高昂的算法 O(sqrt(n)) 内存来训练 n 层网络,只需计算成本 每个小批量的额外前向传递。每隔 sqrt(n)保留一个检查点的feature map。

CODE

  • https://pytorch.org/docs/stable/checkpoint.html
// https://discuss.pytorch.org/t/trying-to-understand-torch-utils-checkpoint/95224
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.notebook import tqdmfrom torch import optim
import torchvision.models as models
from torch import nnCHECKPOINT = True
BATCH_SIZE = 32
dev = "cuda:0"class ImageDataset(Dataset):def __init__(self,length = 100000,size = 244):self.length = lengthself.size = 244def __len__(self):return self.lengthdef __getitem__(self,idx,display = False):return torch.from_numpy(np.random.randn(2,3,self.size,self.size))
train = ImageDataset()
trainloader = DataLoader(train,batch_size = BATCH_SIZE,num_workers = 24,pin_memory = True
)resnet = models.resnet50(pretrained = False)class MODEL(nn.Module):def __init__(self,model):super(MODEL,self).__init__()self.model = modelself.LR = nn.Linear(1000,1000)def forward(self,x):if CHECKPOINT == False:o1 = self.model(x[:,0])o2 = self.model(x[:,1])else:o1 = torch.utils.checkpoint.checkpoint(self.model,x[:,0])o2 = torch.utils.checkpoint.checkpoint(self.model,x[:,1])return torch.mean((self.LR(o1)-o2)**2)resnet = MODEL(resnet).to(dev)optimizer = optim.SGD(resnet.parameters(),lr = .001)for T in tqdm(trainloader):out = torch.mean(resnet(T.float().to(dev)))optimizer.zero_grad()out.backward()optimizer.step()

CG

在这里插入图片描述

  • https://github.com/merrymercy/dtr-prototype

ZeRO-Offload

  • https://arxiv.org/pdf/2101.06840.pdf 大规模模型训练一直是少数人的比赛场地 需要复杂的模型重构和访问昂贵的 GPU 集群。ZeRO-Offload 通过使 几乎每个人都可以访问大型模型训练。它可以训练模型 单个 GPU 上超过 13 亿个参数,与 GPU 相比,大小增加了 10 倍 流行的框架,如PyTorch,它不需要任何模型就可以做到这一点。 从数据科学家改变或牺牲计算效率。 ZeRO-卸载通过卸载数据和计算来实现大型模型训练 中央处理器。为了保持计算效率,它旨在最大限度地减少数据 移入/移出 GPU,减少 CPU 计算时间,同时最大化内存 节省 GPU 成本。因此,ZeRO-Offload可以在单个上实现40 TFlops / GPU。 NVIDIA V100 GPU 用于 10B 参数模型,与单独使用 PyTorch 的 30TF 相比 对于 1.4B 参数模型,可以训练而不会耗尽的最大参数模型 的记忆。ZeRO-Offload 还设计为在以下情况下在多个 GPU 上进行扩展 可用,可在多达 128 个 GPU 上提供近乎线性的加速。此外,它可以 与模型并行性协同工作,训练超过 70 亿的模型 单个 DGX-2 盒子上的参数,与模型尺寸相比增加了 4.5 倍 单独使用模型并行性。通过将计算和内存效率与 易于使用,ZeRO-Offload 使大规模模型训练民主化,使其成为 即使是数据科学家也可以访问,只需访问一个 GPU。

梯度累积

        训练时大的batch一般能得到更稳定的训练效果,梯度累积训练方法是一种用于训练深度神经网络的技术,旨在减少显存需求并提高训练效果。在传统的训练方法中,模型的参数是通过单个批次(batch)的数据计算得到的梯度平均值进行更新。但在梯度累积训练中,模型的参数更新是通过多个批次的梯度累积得到的。

以下是梯度累积训练的基本步骤:

  1. 设置梯度累积步数(accumulation steps),它决定了要累积多少个批次的梯度。

  2. 初始化模型的参数。

  3. 对于每个训练批次(batch):

    • 使用当前批次的数据进行前向传播计算损失。
    • 对损失进行反向传播计算梯度。
    • 累积当前批次的梯度到之前的梯度值上。
  4. 当累积达到设置的步数时,将累积的梯度应用于模型参数的更新:

    • 通过将累积的梯度平均化来获得参数的更新值。
    • 使用更新值来更新模型的参数。
  5. 重复步骤3和4,直到完成所有的训练批次。

梯度累积训练的主要优势在于能够降低每个批次所需的显存量,允许在具有有限显存的硬件上训练更大的模型。此外,梯度累积还可以改善模型的收敛性,提高模型的性能和泛化能力。

相关文章:

DLA 神经网络的极限训练方法:gradient checkpointing

gradient checkpointing 一般来说,训练的过程需要保存中间结果(不管是GPU还是CPU)。前向传播根据输入(bottom_data)计算输出(top_data),后向传播由top_diff计算bottom_diff(如果某个变量打开梯度进行训练的话&#xff…...

python excel 操作

excel文件内容如下: 一、xlrd 读Excel 操作 1、打开Excel文件读取数据 filexlrd.open_workbook(filename)#文件名以及路径,如果路径或者文件名有中文给前面加一个 r 2、常用函数 (1)获取一个sheet工作表 table file.sheets(…...

记一次Linux启动Mysql异常解决

文章目录 第一步: netstat -ntlp 查看端口情况2、启动Mysql3、查看MySQL日志 tail -100f /var/log/mysqld.log4、查看磁盘占用情况:df -h5、思路小结 第一步: netstat -ntlp 查看端口情况 并没有发现3306数据库端口 2、启动Mysql service …...

ATFX汇市:美联储年内或仍将加息依次,美指向下空间不大

环球汇市行情摘要—— 昨日,美元指数上涨0.08%,收盘在102.08点, 欧元贬值0.07%,收盘价1.1003点; 日元贬值0.51%,收盘价142.47点; 英镑升值0.28%,收盘价1.2784点; 瑞…...

【博客687】k8s informer的list-watch机制剖析

k8s informer的list-watch机制剖析 1、list-watch场景: client-go中的reflector模块首先会list apiserver获取某个资源的全量信息,然后根据list到的rv来watch资源的增量信息。希望使用client-go编写的控制器组件在与apiserver发生连接异常时&#xff0c…...

用Python获取链家二手房房源数据,做可视化图分析数据

前言 数据采集的步骤是固定: 发送请求, 模拟浏览器对于url地址发送请求获取数据, 获取网页数据内容 --> 请求那个链接地址, 返回服务器响应数据解析数据, 提取我们需要的数据内容保存数据, 保存本地文件 所需模块 win R 输入cmd 输入安装命令 pip install 模块名 (如果你…...

Yield Guild Games:社区更新 — 2023 年第二季度

本文重点介绍了 Yield Guild Games (YGG) 2023 年第二季度社区更新中涵盖的关键主题,包括公会发展计划 (GAP) 第 3 季的总结、YGG 领导团队的新成员以及 YGG 的最新消息地区公会网络和广泛的游戏合作伙伴生态系统。 在 YGG 品牌焕然一新的基础上,第二季…...

Stable Diffusion - 运动服 (Gymwear Leggings) 风格服装与背景的 LoRA 配置

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/132179050 测试模型:DreamShaper 8 运动裤 (Gymwear Leggings) 是紧身的裤子,通常用于健身、瑜伽、跑步等运动。运动裤的…...

js-7:javascript原型、原型链及其特点

1、原型 JavaScript常被描述为一种基于原型的语言-每个对象拥有一个原型对象。 当试图访问一个对象的属性时,它不仅仅在该对象上搜寻,还会搜寻该对象的原型,以及该对象的原型的原型,依次层层向上搜索,直到找到一个名字…...

无涯教程-Perl - continue 语句函数

可以在 while 和 foreach 循环中使用continue语句。 continue - 语法 带有 while 循环的 continue 语句的语法如下- while(condition) {statement(s); } continue {statement(s); } 具有 foreach 循环的 continue 语句的语法如下- foreach $a (listA) {statement(s); } co…...

【贪心算法】leetcode刷题

贪心算法无固定套路。 核心思想:先找局部最优,再扩展到全局最优。 455.分发饼干 两种思路: 1、从大到小。局部最优就是大饼干喂给胃口大的,充分利用饼干尺寸喂饱一个,全局最优就是喂饱尽可能多的小孩。先遍历的胃口&a…...

PyMySQL库版本引起的python执行sql编码错误

前言 长话短说,之前在A主机(centos7.9)上运行的py脚本拿到B主机上(centos7.9)运行报错: UnicodeEncodeError: latin-1 codec cant encode characters in position 265-266: ordinal not in range(256)两个…...

第二章-算法

第二章-算法 数据结构和算法的关系 算法是解决特定问题求解步骤的描述,在计算机中表现为指令的有限序列,并且每条指令表示一个或多个操作。 算法的特性 算法有五个基本特征:输入、输出、有穷性、确定性和可行性。 输入:算法具…...

‘vue’不是内部或外部命令,也不是可运行的程序或批处理文件的原因及解决方法

今天我在用node.js的时候,结果出现如下错误: C:\Users\xiesj> vue -v vue不是内部或外部命令,也不是可运行的程序或批处理文件。 原因: 1、确定npm是否已正确安装? 2、确定vue以及vue-cli已正确安装?…...

HBase API

我们之后的实际开发中不可能在服务器那边直接使用shell命令一直敲的&#xff0c;一般都是通过API进行操作的。 环境准备 新建Maven项目&#xff0c;导入Maven依赖 <dependencies><dependency><groupId>org.apache.hbase</groupId><artifactId>…...

Qt6之QListWidget——Qt仿ToDesk侧边栏(1)

一、 QLitWidget概述 注意&#xff1a;本文不是简单翻译Qt文档或者接口函数&#xff0c;而侧重于无代码Qt设计器下演示使用。 QListWidget也称列表框类&#xff0c;它提供了一个类似于QListView提供的列表视图&#xff0c;但是它具有一个用于添加和删除项的经典的基于项的接口…...

Prometheus技术文档--基本安装-docker安装并挂载数据卷-《十分钟搭建》

一、查看可安装的版本 docker search prom/prometheus 二、拉取镜像 docker pull prom/prometheus 三、查看镜像 docker images 四、书写配置文件-以及创建挂载目录 宿主机挂载目录位置&#xff1a; 以及准备对应的挂载目录&#xff1a; /usr/local/docker/promethues/se…...

Android 数据库之GreenDAO

GreenDAO 是一款开源的面向 Android 的轻便、快捷的 ORM 框架&#xff0c;将 Java 对象映射到 SQLite 数据库中&#xff0c;我们操作数据库的时候&#xff0c;不再需要编写复杂的 SQL语句&#xff0c; 在性能方面&#xff0c;greenDAO 针对 Android 进行了高度优化&#xff0c;…...

kotlin 编写一个简单的天气预报app(六)使用recyclerView显示forecast内容

要使用RecyclerView显示天气预报的内容 先在grandle里添加recyclerView的引用 implementation androidx.recyclerview:recyclerview:1.3.1创建一个RecyclerView控件&#xff1a;在布局文件中&#xff0c;添加一个RecyclerView控件&#xff0c;用于显示天气预报的列表。 这是一…...

jpa Page 1 of 0 containing UNKNOWN instances错误关于like问题的解决记录

导致这个问题的原因很多&#xff0c;这里记录一下我碰到的问题和解决方法。 网上有说时 pageNo要从0开始&#xff0c;我的不是这个问题。 在使用springboot jpa时&#xff0c;发现使用 t.ip like %?5% 语句&#xff0c;如果数据库记录的ip is null时&#xff0c;将查询不到该…...

Python实战之使用Python进行数据挖掘详解

一、Python数据挖掘 1.1 数据挖掘是什么&#xff1f; 数据挖掘是从大量的、不完全的、有噪声的、模糊的、随机的实际应用数据中&#xff0c;通过算法&#xff0c;找出其中的规律、知识、信息的过程。Python作为一门广泛应用的编程语言&#xff0c;拥有丰富的数据挖掘库&#…...

scala 加载properties文件

利用java.util.Properties加载 import java.io.FileInputStream import java.util.Properties object LoadParameter {//动态获取properties文件可配置参数val props new Properties()def getParameter(s:String,filePath:String): String {props.load(new FileInputStream(f…...

备战秋招012(20230808)

文章目录 前言一、今天学习了什么&#xff1f;二、动态规划1.概念2.题目 总结 前言 提示&#xff1a;这里为每天自己的学习内容心情总结&#xff1b; Learn By Doing&#xff0c;Now or Never&#xff0c;Writing is organized thinking. 提示&#xff1a;以下是本篇文章正文…...

QT中定时器的使用

文章目录 概述步骤 概述 Qt中使用定时器大致有两种&#xff0c;本篇暂时仅描述使用QTimer实现定时器 步骤 // 1.创建定时器对象 QTimer *timer new QTimer(this);// 2.开启一个定时器&#xff0c;5秒触发一次 timer->start(5000); // 3.建立信号槽连接&am…...

【UE4】多人联机教程(重点笔记)

效果 1. 创建房间、搜索房间功能 2. 根据指定IP和端口加入游戏 步骤 1. 新建一个第三人称角色模板工程 2. 创建一个空白关卡&#xff0c;这里命名为“InitMap” 3. 新建一个控件蓝图&#xff0c;这里命名为“UMG_ConnectMenu” 在关卡蓝图中显示该控件蓝图 打开“UMG_Connec…...

【go】GIN参数重复绑定报错EOF问题

文章目录 1 问题描述2 解决&#xff1a;替换为ShouldBindBodyWith 1 问题描述 在 Gin 框架中&#xff0c;当多次调用 ShouldBind() 或 ShouldBindJSON() 方法时&#xff0c;会导致请求体的数据流被读取多次&#xff0c;从而出现 “EOF” 错误。 例如在api层绑定了参数&#x…...

关于MySQL中的binlog

介绍 undo log 和 redo log是由Inno DB存储引擎生成的。 在MySQL服务器架构中&#xff0c;分为三层&#xff1a;连接层、服务层&#xff08;server层&#xff09;、执行层&#xff08;存储引擎层&#xff09; bin log 是 binary log的缩写&#xff0c;即二进制日志。 MySQL…...

我维护电脑的方法

无论是学习还是工作&#xff0c;电脑都是IT人必不可少的重要武器&#xff0c;一台好电脑除了自身配置要经得起考验&#xff0c;后期主人对它的维护也是决定它寿命的重要因素&#xff01; 你日常是怎么维护你的“战友”的呢&#xff0c;维护电脑运行你有什么好的建议吗&#xff…...

AP51656 电流采样降压恒流驱动IC RGB PWM深度调光 LED电源驱动

产品描述 AP51656是一款连续电感电流导通模式的降压恒流源&#xff0c;用于驱动一颗或多颗串联LED 输入电压范围从 5 V 到 60V&#xff0c;输出电流 可达 1.5A 。根据不同的输入电压和 外部器件&#xff0c; 可以驱动高达数十瓦的 LED。 内置功率开关&#xff0c;采用电流采样…...

Python爬虫的解析(学习于b站尚硅谷)

目录 一、xpath  1.xpath插件的安装  2. xpath的基本使用  &#xff08;1&#xff09;xpath的使用方法与基本语法&#xff08;路径查询、谓词查询、内容查询&#xff08;使用text查看标签内容&#xff09;、属性查询、模糊查询、逻辑运算&#xff09;  &#xff08;2&a…...

深圳市住房和建设局网站变更/谷歌竞价广告

Django Web应用程序&#xff08;3&#xff09; 本文主要内容为对项目“学习笔记”设置样式并对其进行部署。 为设置样式&#xff0c;将使用Bootstrap库&#xff1b;另外&#xff0c;我们还将把项目部署到Heroku&#xff0c;这个网站能够让我们能够将项目推送到其服务器&#x…...

怎么弄网站/广告营销公司

org.hibernate.cfg.Configuration实例代表了应用程序到SQL数据库的映射配置&#xff0c; Configuration对象提供了一个buildSessionFactory方法&#xff0c;该方法可以产生一个不可变的SessionFactory对象。 也可以直接实例化Configuration来获取一个实例&#xff0c;并为它指…...

庆阳手机网站设计/游戏推广员是诈骗吗

《隐私计算&#xff1a;区块链从乌托邦走向现实的必由之路》的演讲。莫晓康表示&#xff0c;区块链就像一个美丽的梦想&#xff0c;但是它差了一步&#xff0c;这一步必须由隐私计算来完成。没有隐私计算&#xff0c;即使这个梦想非常美丽&#xff0c;也无法落地&#xff0c;因…...

广州网站开发费用/互联网营销策略有哪些

乱码有时候是一个非常让人头疼的问题&#xff0c;这里就总结一下常用的解决乱码的方法。只知道的用法&#xff0c;却不明白为什么这么用……一、在Java代码中&#xff1a;1 request.setCharacterEncoding("UTF-8");用在哪里&#xff0c;为什么这么用……二、String r…...

我有服务器和模板怎么做网站/白度

一、看门狗原理 在产品化的嵌入式系统中&#xff0c;为了使系统在异常情况下能自动复位&#xff0c;一般都需要引入看门狗。 看门狗其实就是一个可以在一定时间内被复位的计数器。当看门狗启动后&#xff0c;计数器开始自动计数&#xff0c;经过一定时间&#xff0c;如果没有…...

个人未授权做的网站/广州百度竞价外包

背景 我想使用带有Inception-Resnet_v2的keras来预测病理图像.我已经训练了模型并得到了.hdf5文件.由于病理图像非常大(例如&#xff1a;20,000 x 20,000像素),因此我必须扫描图像以获得用于预测的小补丁. 我想使用python2.7的多处理库来加速预测过程.主要思想是使用不同的子进…...