当前位置: 首页 > news >正文

Pytorch单机多卡分布式训练

Pytorch单机多卡分布式训练

数据并行:

DP和DDP

这两个都是pytorch下实现多GPU训练的库,DP是pytorch以前实现的库,现在官方更推荐使用DDP,即使是单机训练也比DP快。

  1. DataParallel(DP)

    • 只支持单进程多线程,单一机器上进行训练。
    • 模型训练开始的时候,先把模型复制到四个GPU上面,然后把数据分配给四个GPU进行前向传播,前向传播之后再汇总到卡0上面,然后在卡0上进行反向传播,参数更新,再将更新好的模型复制到其他几张卡上。

    在这里插入图片描述

  2. DistributedDataParallel(DDP)

    • 支持多线程多进程,单一或者多个机器上进行训练。通常DDP比DP要快。

    • 先把模型载入到四张卡上,每个GPU上都分配一些小批量的数据,再进行前向传播,反向传播,计算完梯度之后再把所有卡上的梯度汇聚到卡0上面,卡0算完梯度的平均值之后广播给所有的卡,所有的卡更新自己的模型,这样传输的数据量会少很多。

      在这里插入图片描述

DDP代码写法

  1. 初始化

    import torch.distributed as dist
    import torch.utils.data.distributed# 进行初始化,backend表示通信方式,可选择的有nccl(英伟达的GPU2GPU的通信库,适用于具有英伟达GPU的分布式训练)、gloo(基于tcp/ip的后端,可在不同机器之间进行通信,通常适用于不具备英伟达GPU的环境)、mpi(适用于支持mpi集群的环境)
    # init_method: 告知每个进程如何发现彼此,默认使用env://
    dist.init_process_group(backend='nccl', init_method="env://")
    
  2. 设置device

    device = torch.device(f'cuda:{args.local_rank}')	# 设置device,local_rank表示当前机器的进程号,该方式为每个显卡一个进程
    torch.cuda.set_device(device)	# 设定device
    
  3. 创建dataloader之前要加一个sampler

    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)	# 加一个sampler
    data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)
    
  4. torch.nn.parallel.DistributedDataParallel包裹模型(先to(device)再包裹模型)

    net = torchvision.models.resnet101(num_classes=10)
    net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])	# 包裹模型
    
  5. 真正训练之前要set_epoch(),否则将不会shuffer数据

    for epoch in range(10):train_sampler.set_epoch(epoch)		# set_epochfor step, data in enumerate(data_loader_train):images, labels = dataimages, labels = images.to(device), labels.to(device)opt.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()opt.step()if step % 10 == 0:print("loss: {}".format(loss.item()))
    
  6. 模型保存

    if args.local_rank == 0:		# local_rank为0表示master进程torch.save(net, "my_net.pth")
    
  7. 运行

    if __name__ == "__main__":parser = argparse.ArgumentParser()# local_rank参数是必须的,运行的时候不必自己指定,DDP会自行提供parser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()main(args)
    
  8. 运行命令

    python -m torch.distributed.launch --nproc_per_node=2 多卡训练.py	# --nproc_per_node=2表示当前机器上有两个GPU可以使用
    

完整代码

import os
import argparse
import torch
import torchvision
import torch.distributed as dist
import torch.utils.data.distributedfrom torchvision import transforms
from torch.multiprocessing import Processdef main(args):# nccl: 后端基于NVIDIA的GPU-to-GPU通信库,适用于具有NVIDIA GPU的分布式训练# gloo: 后端是一个基于TCP/IP的后端,可在不同机器之间进行通信,通常适用于不具备NVIDIA GPU的环境。# mpi: 后端使用MPI实现,适用于具备MPI支持的集群环境。# init_method: 告知每个进程如何发现彼此,如何使用通信后端初始化和验证进程组。 默认情况下,如果未指定 init_method,PyTorch 将使用环境变量初始化方法 (env://)。dist.init_process_group(backend='nccl', init_method="env://") # nccl比较推荐device = torch.device(f'cuda:{args.local_rank}')torch.cuda.set_device(device)trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)net = torchvision.models.resnet101(num_classes=10)net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)net = net.to(device)net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])criterion = torch.nn.CrossEntropyLoss()opt = torch.optim.Adam(params=net.parameters(), lr=0.001)for epoch in range(10):train_sampler.set_epoch(epoch)for step, data in enumerate(data_loader_train):images, labels = dataimages, labels = images.to(device), labels.to(device)opt.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()opt.step()if step % 10 == 0:print("loss: {}".format(loss.item()))if args.local_rank == 0:torch.save(net, "my_net.pth")if __name__ == "__main__":parser = argparse.ArgumentParser()# must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDPparser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()main(args)

参考:

https://zhuanlan.zhihu.com/p/594046884
https://zhuanlan.zhihu.com/p/358974461

相关文章:

Pytorch单机多卡分布式训练

Pytorch单机多卡分布式训练 数据并行: DP和DDP 这两个都是pytorch下实现多GPU训练的库,DP是pytorch以前实现的库,现在官方更推荐使用DDP,即使是单机训练也比DP快。 DataParallel(DP) 只支持单进程多线程…...

asp.net coremvc+efcore增删改查

下面是一个使用 EF Core 在 ASP.NET Core MVC 中完成增删改查的示例&#xff1a; 创建一个新的 ASP.NET Core MVC 项目。 安装 EF Core 相关的 NuGet 包。在项目文件 (.csproj) 中添加以下依赖项&#xff1a; <ItemGroup><PackageReference Include"Microsoft…...

Java基础面试,什么是面向对象,谈谈你对面向对象的理解

前言 马上就要找工作了&#xff0c;从今天开始一天准备1~2道面试题&#xff0c;来打基础&#xff0c;就从Java基础开始吧。 什么是面向对象&#xff0c;谈谈你对面向对象的理解&#xff1f; 谈到面向对象&#xff0c;那就不得不谈到面向过程。面向过程更加注重的是完成一个任…...

Ubuntu系统初始设置

更换国内源 安装截图工具 安装中文输入法 安装QQ 参考&#xff1a; 安装双系统win10Ubuntu20.04LTS&#xff08;详细到我自己都害怕&#xff09; 引导方式磁盘分区方法UEFIGPTLegancyMBR 安装网络助手 sudo apt install net-tools 安装VS Code 使用从官网下载.deb安装包…...

焕新古文化传承之路,AI为古彝文识别赋能

目录 1 古彝文与古典保护 2 古文识别的挑战 2.1 西文与汉文OCR 2.2 古彝文识别难点 3 合合信息&#xff1a;古彝文保护新思路 3.1 图像矫正 3.2 图像增强 3.3 语义理解 3.4 工程技巧 4 总结 1 古彝文与古典保护 彝文指的是云南、贵州、四川等地的彝族人使用的文字&am…...

毛玻璃动画交互效果

效果展示 页面结构组成 从上述的效果展示页面结构来看&#xff0c;页面布局都是比较简单的&#xff0c;只是元素的动画交互比较麻烦。 第一个动画交互是两个圆相互交错来回运动。第二个动画交互是三角绕着圆进行 360 度旋转。 CSS 知识点 animationanimation-delay绝对定位…...

Audio2Face的工作原理

预加载一个3D数字人物模型(Digital Mark),该模型可以通过音频驱动进行面部动画。 用户上传音频文件作为输入。 将音频输入馈送到预训练的深度神经网络中。 Audio2Face加载预制的3d人头mesh 3D数字人物面部模型由大量顶点组成,每个顶点都有xyz坐标。 深度神经网络输入音频特征,…...

【面试题】2023前端面试真题之JS篇

前端面试题库 &#xff08;面试必备&#xff09; 推荐&#xff1a;★★★★★ 地址&#xff1a;前端面试题库 表妹一键制作自己的五星红旗国庆头像&#xff0c;超好看 世界上只有一种真正的英雄主义&#xff0c;那就是看清生活的真相之后&#xff0c;依然热爱生活。…...

Mysql 分布式序列算法

接上文 Mysql分库分表 1.分布式序列简介 在分布式系统下&#xff0c;怎么保证ID的生成满足以上需求&#xff1f; ShardingJDBC支持以上两种算法自动生成ID。这里&#xff0c;使用ShardingJDBC让主键ID以雪花算法进行生成&#xff0c;首先配置数据库&#xff0c;因为默认的注…...

Windows/Linux双系统卸载Ubuntu

参考&#xff1a;双系统下完全卸载ubuntu...

asp.net core mvc 视图组件viewComponents

ASP.NET Core MVC 视图组件&#xff08;View Components&#xff09;是一种可重用的 UI 组件&#xff0c;用于在视图中呈现某些特定的功能块&#xff0c;例如导航菜单、侧边栏、用户信息等。视图组件提供了一种将视图逻辑与控制器解耦的方式&#xff0c;使视图能够更加灵活、可…...

如何保持终身学习

文章目录 2.1. 了解你的大脑2.2 学习是对神经元网络的塑造2.3 大脑的一生 3.学习的心里基础3.1 固定思维与成长思维3.2 我们为什么要学习 4. 学习路径4.1 构建知识模块4.2 大脑是如何使用注意力的4.3 提高专注力4.4 放松一下&#xff0c;学的更好4.5 巩固你的学习痕迹4.6 被动学…...

【RV1103】RTL8723bs (SD卡形状模块)驱动开发

文章目录 前言硬件分析Luckfox Pico的SD卡接口硬件原理图LicheePi zero WiFiBT模块总结 正文Kernel WiFi驱动支持Kernel 设备树支持修改一&#xff1a;修改二&#xff1a; SDK全局配置支持 wifi全局编译脚本支持编译逻辑拷贝rtl8723bs的固件到文件系统的固定目录里面去 上电后手…...

LeetCode 周赛上分之旅 #49 再探内向基环树

⭐️ 本文已收录到 AndroidFamily&#xff0c;技术和职场问题&#xff0c;请关注公众号 [彭旭锐] 和 BaguTree Pro 知识星球提问。 学习数据结构与算法的关键在于掌握问题背后的算法思维框架&#xff0c;你的思考越抽象&#xff0c;它能覆盖的问题域就越广&#xff0c;理解难度…...

kubernetes-v1.23.3 部署 kafka_2.12-2.3.0

文章目录 [toc]构建 debian 基础镜像部署 zookeeper配置 namespace配置 gfs 的 endpoints配置 pv 和 pvc配置 configmap配置 service配置 statefulset 部署 kafka配置 configmap配置 service配置 statefulset 这里采用的部署方式如下&#xff1a; 使用自定义的 debian 镜像作为…...

位置编码器

目录 1、位置编码器的作用 2、代码演示 &#xff08;1&#xff09;、使用unsqueeze扩展维度 &#xff08;2&#xff09;、使用squeeze降维 &#xff08;3&#xff09;、显示张量维度 &#xff08;4&#xff09;、随机失活张量中的数值 3、定义位置编码器类&#xff0c;我…...

Lua多脚本执行

--全局变量 a 1 b "123"for i 1,2 doc "Holens" endprint(c) print("*************************************1")--本地变量&#xff08;局部变量&#xff09; for i 1,2 dolocal d "Holens2"print(d) end print(d)function F1( ..…...

Spirng Cloud Alibaba Nacos注册中心的使用 (环境隔离、服务分级存储模型、权重配置、临时实例与持久实例)

文章目录 一、环境隔离1. Namespace&#xff08;命名空间&#xff09;&#xff1a;2. Group&#xff08;分组&#xff09;&#xff1a;3. Services&#xff08;服务&#xff09;&#xff1a;4. DataId&#xff08;数据ID&#xff09;&#xff1a;5. 实战演示&#xff1a;5.1 默…...

26663-2011 大型液压安全联轴器 课堂随笔

声明 本文是学习GB-T 26663-2011 大型液压安全联轴器. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了大型液压安全联轴器的分类、技术要求、试验方法及检验规则等。 本标准适用于联接两同轴线的传动轴系&#xff0c;可起到限制…...

ChatGPT架构师:语言大模型的多模态能力、幻觉与研究经验

来源 | The Robot Brains Podcast OneFlow编译 翻译&#xff5c;宛子琳、杨婷 9月26日&#xff0c;OpenAI宣布ChatGPT新增了图片识别和语音能力&#xff0c;使得ChatGPT不仅可以进行文字交流&#xff0c;还可以给它展示图片并进行互动&#xff0c;这是一次ChatGPT向多模态进化的…...

二、VXLAN BGP EVPN基本原理

VXLAN BGP EVPN基本原理 1、BGP EVPN2、BGP EVPN路由2.1、Type2路由——MAC/IP路由2.2、Type3路由——Inclusive Multicast路由2.3、Type5路由——Inclusive Multicast路由 ————————————————————————————————————————————————…...

Evil.js

Evil.js install npm i lodash-utils什么&#xff1f;黑心996公司要让你体统跑路了&#xff1f; 想在离开前给你们的项目留点小礼物&#xff1f; 偷偷地把本项目引入你们的项目吧&#xff0c;你们的项目会有但不仅限于如下的神奇效果&#xff1a; 仅在周日时&#xff1a; 当…...

使用sqlmap的 ua注入

文章目录 1.使用sqlmap自带UA头的检测2.使用sqlmap随机提供的UA头3.使用自己写的UA头4.调整level检测 测试环境&#xff1a;bWAPP SQL Injection - Stored (User-Agent) 1.使用sqlmap自带UA头的检测 python sqlmap.py -u http://127.0.0.1:9004/sqli_17.php --cookie“BEEFHOO…...

华为云云耀云服务器L实例评测 | 实例评测使用之体验评测:华为云云耀云服务器管理、控制、访问评测

华为云云耀云服务器L实例评测 &#xff5c; 实例评测使用之体验评测&#xff1a;华为云云耀云服务器管理、控制、访问评测 介绍华为云云耀云服务器 华为云云耀云服务器 &#xff08;目前已经全新升级为 华为云云耀云服务器L实例&#xff09; 华为云云耀云服务器是什么华为云云耀…...

resultmap

自定义映射resultMap resultMap处理字段和属性的映射关系 若字段名和实体类中的属性名称不一致&#xff0c;则可以通过resultMap设置自定义映射 建moudel项目【实现多对一、一对多的表操作demo】 temp员工表、dept部门表 导入依赖【mysql驱动、junit、mybatis、日志依赖log4…...

宽带光纤接入网中影响家宽业务质量的常见原因有哪些

1 引言 虽然家宽业务质量问题约60%发生在家庭网&#xff08;见《家宽用户家庭网的主要质量问题是什么&#xff1f;原因有哪些》一文&#xff09;&#xff0c;但在用户的眼里&#xff0c;所有家宽业务质量问题都是由运营商的网络质量导致的&#xff0c;用户也因此对不同运营商家…...

C++ - 封装 unordered_set 和 unordered_map - 哈希桶的迭代器实现

前言 unordered_set 和 unordered_map 两个容器的底层是哈希表实现的&#xff0c;此处的封装使用的 上篇博客当中的哈希桶来进行封装&#xff0c;相当于是在 哈希桶之上在套上了 unordered_set 和 unordered_map 。 哈希桶的逻辑实现&#xff1a; C - 开散列的拉链法&…...

gradle中主模块/子模块渠道对应关系通过配置实现

前言&#xff1a; 我们开发过程中&#xff0c;经常会面对针对不同的渠道&#xff0c;要产生差异性代码和资源的场景。目前谷歌其实为我们提供了一套渠道包的方案&#xff0c;这里简单描述一下。 比如我主模块依赖module1和module2。如果主模块中声明了2个渠道A和B&#xff0c…...

28383-2012 卷筒料凹版印刷机 学习笔记

声明 本文是学习GB-T 28383-2012 卷筒料凹版印刷机. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了卷筒料凹版印刷机的型式、基本参数、要求、试验方法、检验规则、标志、包装、运输与 贮存。 本标准适用于机组式的卷筒料凹版…...

stable diffusion学习笔记【2023-10-2】

L1&#xff1a;界面 CFG Scale&#xff1a;提示词相关性 denoising&#xff1a;重绘幅度 L2&#xff1a;文生图 女性常用的负面词 nsfw,NSFW,(NSFW:2),legs apart, paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, (…...

女生做网站编辑好吗/网站页面禁止访问

C代码管理没有规矩很难管理&#xff0c;不像类管理那样方便&#xff0c;我们一遍一遍的写代码但是重复的很多&#xff0c;如何管理自己写的C程序以供以后使用方便&#xff1f;因此好用的C代码管理小工具出现了&#xff0c;可用于代码管理&#xff0c;将计算机里面的代码导入到数…...

衢州网站建设有限公司/公司网络营销实施计划

转自&#xff1a;https://x264-settings.wikispaces.com/x264settings 264参数设置 本页面介绍x264参数的用法与目的。参数解释的顺序对应以下帮助内容中的参数出现顺序。x264 --fullhelp 参见 x264 Stats Output , x264_Stats_File , x264 Encoding Suggestions .Table of C…...

vr超市门户网站建设/淘宝交易指数换算工具

ZMap 类 功能介绍 ZMap 是学习百度地图 api 接口&#xff0c;开发基本功能后整的一个脚本类&#xff0c;本类方法功能大多使用 prototype 原型 实现&#xff1b; 包含的功能有&#xff1a;轨迹回放&#xff0c;圈画区域可编辑&#xff0c;判断几个坐标是否在一个圆圈内&#xf…...

接给别人做网站的活/手机做网页的软件

前言 接着上一篇 根据 InstanceKlass 查找 vtable 的数据, 其中留下了一些 存在疑问的地方, 呵呵 本文主要就是探讨这几点疑问的地方 同样适用上面的方式, 可以很轻松的定位到 vtable 的最后一个元素 对应的类是 Test02LoopUpVTable 方法名称的 常量池索引是 74??, 参照最…...

网站建设报告实训步骤/百度客服电话号码

统一项目管理平台&#xff08;UMPlatForm.NET&#xff09; 4.10 用户权限管理模块 统一项目管理平台&#xff08;UMPlatForm.NET&#xff09;,基于.NET的快速开发、整合框架。 4.10用户权限管理模块 在实际应用中我们会发现&#xff0c;权限控制会经常变动&#xff0c;如&#…...

网站服务器站点是什么意思/网站关键词排名优化客服

1&#xff09;允许Client1访问Server1的Web服务 2&#xff09;允许Client1访问网络192.168.2.0/24 3&#xff09;禁止Client1访问其它网络 搭建实验环境 实现此案例需要按照如下步骤进行。 [AR1]ip route-static 0.0.0.0 0.0.0.0 192.168.12.2 [AR3]ip route-static 0.0.0.0 0…...