pytorch基于ray和accelerate实现多GPU数据并行的模型加速训练
在pytorch的DDP原生代码使用的基础上,ray和accelerate两个库对于pytorch并行训练的代码使用做了更加友好的封装。
以下为极简的代码示例。
ray
ray.py
#coding=utf-8
import os
import sys
import time
import numpy as np
import torch
from torch import nn
import torch.utils.data as Data
import ray
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
import onnxruntime# bellow code use AI model to simulate linear regression, formula is: y = x1 * w1 + x2 * w2 + b
# --- DDP RAY --- # # model structure
class LinearNet(nn.Module):def __init__(self, n_feature):super(LinearNet, self).__init__()self.linear = nn.Linear(n_feature, 1)def forward(self, x):y = self.linear(x)return y# whole train task
def train_task():print("--- train_task, pid: ", os.getpid())# device settingdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("device:", device)device_ids = torch._utils._get_all_device_indices()print("device_ids:", device_ids)if len(device_ids) <= 0:print("invalid device_ids, exit")return# prepare datanum_inputs = 2num_examples = 1000true_w = [2, -3.5]true_b = 3.7features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b + torch.tensor(np.random.normal(0, 0.01, size=num_examples), dtype=torch.float)# load databatch_size = 10dataset = Data.TensorDataset(features, labels)data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)for X, y in data_iter:print(X, y)breakdata_iter = ray.train.torch.prepare_data_loader(data_iter)# model define and initmodel = LinearNet(num_inputs)ddp_model = ray.train.torch.prepare_model(model)print(ddp_model)# cost functionloss = nn.MSELoss()# optimizeroptimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.03)# trainnum_epochs = 6for epoch in range(1, num_epochs + 1):batch_count = 0sum_loss = 0.0for X, y in data_iter:output = ddp_model(X)l = loss(output, y.view(-1, 1))optimizer.zero_grad()l.backward()optimizer.step()batch_count += 1sum_loss += l.item()print('epoch %d, avg_loss: %f' % (epoch, sum_loss / batch_count))# save modelprint("save model, pid: ", os.getpid())torch.save(ddp_model.module.state_dict(), "ddp_ray_model.pt")def ray_launch_task():num_workers = 2scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=True)trainer = TorchTrainer(train_loop_per_worker=train_task, scaling_config=scaling_config)results = trainer.fit()def predict_task():print("--- predict_task")# prepare datanum_inputs = 2num_examples = 20true_w = [2, -3.5]true_b = 3.7features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b + torch.tensor(np.random.normal(0, 0.01, size=num_examples), dtype=torch.float)model = LinearNet(num_inputs)model.load_state_dict(torch.load("ddp_ray_model.pt"))model.eval()x, y = features[6], labels[6]pred_y = model(x)print("x:", x)print("y:", y)print("pred_y:", y)if __name__ == "__main__":print("==== task begin ====")print("python version:", sys.version)print("torch version:", torch.__version__)print("model name:", LinearNet.__name__)ray_launch_task()# predict_task()print("==== task end ====")
accelerate
acc.py
#coding=utf-8
import os
import sys
import time
import numpy as np
from accelerate import Accelerator
import torch
from torch import nn
import torch.utils.data as Data
import onnxruntime# bellow code use AI model to simulate linear regression, formula is: y = x1 * w1 + x2 * w2 + b
# --- accelerate --- # # model structure
class LinearNet(nn.Module):def __init__(self, n_feature):super(LinearNet, self).__init__()self.linear = nn.Linear(n_feature, 1)def forward(self, x):y = self.linear(x)return y# whole train task
def train_task():print("--- train_task, pid: ", os.getpid())# device settingdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("device:", device)device_ids = torch._utils._get_all_device_indices()print("device_ids:", device_ids)if len(device_ids) <= 0:print("invalid device_ids, exit")return# prepare datanum_inputs = 2num_examples = 1000true_w = [2, -3.5]true_b = 3.7features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b + torch.tensor(np.random.normal(0, 0.01, size=num_examples), dtype=torch.float)# load databatch_size = 10dataset = Data.TensorDataset(features, labels)data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)for X, y in data_iter:print(X, y)break# model define and initmodel = LinearNet(num_inputs)# cost functionloss = nn.MSELoss()# optimizeroptimizer = torch.optim.SGD(model.parameters(), lr=0.03)accelerator = Accelerator()model, optimizer, data_iter = accelerator.prepare(model, optimizer, data_iter) # automatically move model and data to gpu as config# trainnum_epochs = 3for epoch in range(1, num_epochs + 1):batch_count = 0sum_loss = 0.0for X, y in data_iter:output = model(X)l = loss(output, y.view(-1, 1))optimizer.zero_grad()accelerator.backward(l)optimizer.step()batch_count += 1sum_loss += l.item()print('epoch %d, avg_loss: %f' % (epoch, sum_loss / batch_count))# save modeltorch.save(model, "acc_model.pt")def predict_task():print("--- predict_task")# prepare datanum_inputs = 2num_examples = 20true_w = [2, -3.5]true_b = 3.7features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b + torch.tensor(np.random.normal(0, 0.01, size=num_examples), dtype=torch.float)model = torch.load("acc_model.pt")model.eval()x, y = features[6], labels[6]pred_y = model(x)print("x:", x)print("y:", y)print("pred_y:", y)if __name__ == "__main__":# launch method: use command line# for example# accelerate launch ACC.py print("python version:", sys.version)print("torch version:", torch.__version__)print("model name:", LinearNet.__name__)train_task()predict_task()print("==== task end ====")
相关文章:
pytorch基于ray和accelerate实现多GPU数据并行的模型加速训练
在pytorch的DDP原生代码使用的基础上,ray和accelerate两个库对于pytorch并行训练的代码使用做了更加友好的封装。 以下为极简的代码示例。 ray ray.py #codingutf-8 import os import sys import time import numpy as np import torch from torch import nn im…...
[蓝帽杯 2022 初赛]domainhacker
打开流量包,追踪TCP流,看到一串url编码 放到瑞士军刀里面解密 最下面这一串会觉得像base64编码 删掉前面两个字符就可以base64解码 依次类推,提取到第13个流,得到一串编码其中里面有密码 导出http对象 发现最后有个1.rar文件 不出…...
在 Pytorch 中使用 TensorBoard
机器学习的训练过程中会产生各类数据,包括 “标量scalar”、“图像image”、“统计图diagram”、“视频video”、“音频audio”、“文本text”、“嵌入Embedding” 等等。为了更好地追踪和分析这些数据,许多可视化工具应运而生,比如之前介绍的…...
Grafana Dashboard 备份方案
文章目录 Grafana Dashboard 备份方案引言工具简介支持的组件要求配置备份安装使用 pypi 安装grafana备份工具配置环境变量使用Grafana Backup Tool 进行备份恢复备份 Grafana Dashboard恢复 Grafana Dashboard结论Grafana Dashboard 备份方案 引言 每个使用 Grafana 的同学都…...
opencv-疲劳检测-眨眼检测
#导入工具包 from scipy.spatial import distance as dist from collections import OrderedDict import numpy as np import argparse import time import dlib import cv2FACIAL_LANDMARKS_68_IDXS OrderedDict([("mouth", (48, 68)),("right_eyebrow",…...
2023-08-24力扣每日一题
链接: 1267. 统计参与通信的服务器 题意: 同行同列可以发生通信,求能发生通信的机器数量 解: 标记每行/每列的机器个数即可 实际代码: #include<bits/stdc.h> using namespace std; class Solution { pub…...
蚂蚁数科持续发力PaaS领域,SOFAStack布局全栈软件供应链安全产品
8月18日,记者了解到,蚂蚁数科再度加码云原生PaaS领域,SOFAStack率先完成全栈软件供应链安全产品及解决方案的布局,包括静态代码扫描Pinpoint、软件成分分析SCA、交互式安全测试IAST、运行时防护RASP、安全洞察Appinsight等&#x…...
Java后端开发面试题——消息中间篇
RabbitMQ-如何保证消息不丢失 交换机持久化: Bean public DirectExchange simpleExchange(){// 三个参数:交换机名称、是否持久化、当没有queue与其绑定时是否自动删除 return new DirectExchange("simple.direct", true, false); }队列持久化…...
C++ Windows API IsDebuggerPresent的作用
IsDebuggerPresent 是 Windows API 中的一个函数,它用于检测当前运行的程序是否正在被调试。当程序被如 Visual Studio 这样的调试器附加时,此函数会返回 TRUE;否则,它会返回 FALSE。 这个函数经常被用在一些安全相关的场景或是防…...
【JVM 内存结构 | 程序计数器】
内存结构 前言简介程序计数器定义作用特点示例应用场景 主页传送门:📀 传送 前言 Java 虚拟机的内存空间由 堆、栈、方法区、程序计数器和本地方法栈五部分组成。 简介 JVM(Java Virtual Machine)内存结构包括以下几个部分&#…...
华为云Stack的学习(一)
一、华为云Stack架构 1.HCS 物理分散、逻辑统一、业务驱动、运管协同、业务感知 2.华为云Stack的特点 可靠性 包括整体可靠性、数据可靠性和单一设备可靠性。通过云平台的分布式架构,从整体系统上提高可靠性,降低系统对单设备可靠性的要求。 可用性…...
人类反馈强化学习RLHF;微软应用商店推出AI摘要功能
🦉 AI新闻 🚀 微软应用商店推出AI摘要功能,快速总结用户对App的评价 摘要:微软应用商店正式推出了AI摘要功能,该功能能够将数千条在线评论总结成一段精练的文字,为用户选择和下载新应用和游戏提供参考。该…...
day1:前端缓存问题
❝ 「目标」: 持续输出!每日分享关于web前端常见知识、面试题、性能优化、新技术等方面的内容。篇幅不会过长,方便理解和记忆。 ❞ ❝ 「主要面向群体:」前端开发工程师(初、中、高级)、应届、转行、培训等同学 ❞ Day…...
学习网络编程No.4【socket编程实战】
引言 北京时间:2023/8/19/23:01,耍了好几天,主要归咎于《我欲封天》这本小说,听了几个晚上之后逐渐入门,在闲暇时间又看了一下,小高潮直接来临,最终在三个昼夜下追完了,哈哈哈&…...
HarmonyOS学习路之方舟开发框架—学习ArkTS语言(状态管理 四)
Observed装饰器和ObjectLink装饰器:嵌套类对象属性变化 上文所述的装饰器仅能观察到第一层的变化,但是在实际应用开发中,应用会根据开发需要,封装自己的数据模型。对于多层嵌套的情况,比如二维数组,或者数…...
arcgis--坐标系
1、arcgis中,投影坐标系的y坐标一定是7位数,X坐标有两种:6位和8位。 6位:省略带号,这是中央经线形式的投影坐标,一般投影坐标中会带CM字样;8位:包括带号,一般投影坐标中…...
LFS学习系列 第5章. 编译交叉工具链(1)
5.1 介绍 本章介绍如何构建交叉编译器及其相关工具。尽管这里的交叉编译是“伪造”、“假装”的,但其原理与真正的交叉工具链相同。 本章中编译的程序将安装在$LFS/tools目录下,以使它们与以下章节中安装的文件分离。而另一方面,库被安装到…...
网络互联与互联网 - TCP 协议详解
文章目录 1 概述2 TCP 传输控制协议2.1 报文格式2.2 三次握手,建立连接2.3 四次挥手,释放连接 3 扩展3.1 实验演示3.2 网工软考 1 概述 在 TCP/IP 协议簇 中有两个传输协议 TCP:Transmission Control Protocol,传输控制协议&…...
开源在线图片设计器,支持PSD解析、AI抠图等,基于Puppeteer生成图片
Github 开源地址: palxiao/poster-design 项目速览 git clone https://github.com/palxiao/poster-design.git cd poster-design npm run prepared # 快捷安装依赖指令 npm run serve # 本地运行将同时运行前端界面与图片生成服务(3000与7001端口),合成图片时…...
在Linux系统上安装和配置Redis数据库,无需公网IP即可实现远程连接的详细解析
文章目录 1. Linux(centos8)安装redis数据库2. 配置redis数据库3. 内网穿透3.1 安装cpolar内网穿透3.2 创建隧道映射本地端口 4. 配置固定TCP端口地址4.1 保留一个固定tcp地址4.2 配置固定TCP地址4.3 使用固定的tcp地址连接 Redis作为一款高速缓存的key value键值对的数据库,在…...
跨平台图表:ChartDirector for .NET 7.1 Crack
什么是新的 ChartDirector for .NET 7.0 支持跨平台使用,但仅限于 .NET 6。这是因为在 .NET 7 中,Microsoft 停止了用于非 Windows 使用的 .NET 图形库 System.Drawing.Common。由于 ChartDirector for .NET 7.0 依赖于该库,因此它不再支持 .…...
【unity数据持久化】XML数据管理器知识点
👨💻个人主页:元宇宙-秩沅 👨💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨💻 本文由 秩沅 原创 👨💻 收录于专栏:Uni…...
Linux——Shell常用运算符
运算符说明举例-eq检测两个数是否相等,相等返回 true。[ $a -eq $b ] 返回 false。-ne检测两个数是否不相等,不相等返回 true。[ $a -ne $b ] 返回 true。-gt检测左边的数是否大于右边的,如果是,则返回 true。[ $a -gt $b ] 返回 …...
C++(4)C++内存管理和命名空间
内存管理 new/delete C语言 malloc free完成对堆内存的申请和释放。 C new delete 类 new:动态申请存储空间的运算符,返回值为申请空间的对应数据类型的地址 int *p new int(10); 申请了一个初始值为10的整型数据 int *p new int[10]; 申…...
一网打尽java注解-克隆-面向对象设计原则-设计模式
文章目录 注解内置注解元注解 对象克隆为什么要克隆?如何克隆浅克隆深克隆 Java设计模式什么是设计模式?为什么要学习设计模式? 建模语言类接口类之间的关系依赖关系关联关系聚合关系组合关系继承关系实现关系 面向对象设计原则单一职责开闭原…...
k8s-statefulset部署myql-Nodeport方式
目录 1、部署openebs(Elastic Block Store) 1.下载镜像(针对k8s1.19) 2.加载镜像(所有节点包括master) 3.下载yaml文件并部署 4.设置默认storageclass 2、编写相关yaml文件 1.编写secret 2.编写state…...
MySQL双主架构、主从架构
为什么要对数据库做优化? MySQL官方说法: 单表2000万数据就达到瓶颈了。所以为了保证查询效率,要让每张表的大小得到控制。 MySQL主主架构 主数据库都负责增删改查。 比如有1000W的数据,有两个主数据库,就将数据分流给…...
基于微信小程序的物流管理系统3txar
在此基础上,结合现有物流管理体系的特点,运用新技术,构建了以 springboot为基础的物流信息化管理体系。首先,以需求为依据,对目前传统物流管理基础业务进行了较为详尽的了解和分析。根据需求分析结果进行了系统的设计&…...
Maven 一键部署到 SSH 服务器
简介 利用 Maven Mojo 功能一键部署 jar 包或 war 包到远程服务器上。 配置 在 maven 的setting.xml 配置服务器 SSH 账号密码。虽然可以在工程的 pom.xml 直接配置,但那样不太安全。 <servers><server><id>iq</id><configuration&…...
docker搭建owncloud,Harbor,构建镜像
1、使用mysql:5.6和 owncloud 镜像,构建一个个人网盘。 拉取镜像 docker pull owncloud docker pull mysql:5.6 2、安装搭建私有仓库 Harbor 1.下载docker-compose 2.安装harbor 3.编辑 harbor.yml文件 使用./intall.sh安装 4.登录 3、编写Dockerfile制作Web应用系…...
建筑工程网校官网/谷歌seo外包
动态路由中的链路状态算法和距离矢量算法有什么区别? OSPF在点到点网络如何构建拓扑和计算路由? OSPF在广播型网络如何构建拓扑和计算路由? 两种动态路由算法 动态路由算法两大类,距离矢量和链路状态。 距离矢量算法是谣传性…...
网站如何做信誉认证/青岛网站关键词排名优化
Seen.js 渲染3D场景为 SVG 或者 HTML5 画布。Seen.js 包含对于 SVG 和 HTML5 Canvas 元素的图形功能的最简单的抽象。http://seenjs.io/...
wordpress 文章 视频/搜索引擎优化的实验结果分析
一,启动jar包命令 1. java -jar Zhglpt.Boot-1.0-SNAPSHOT.jar 格式: java -jar 包名.jar 描述:java -jar 包名.jar 这个命令在linux,window 操作中,一般关闭window 窗口,或者关闭ssh 连接,都…...
网站建设套餐报价/如何写软文
记得第一次接触数据透视表还是在2013年在名古屋SAP项目组,觉得有点难以驾驭(没用过),旁边的SAP顾问似乎有点鄙视我的感觉。 后来,工作中需要维护透视表,慢慢的多了一些了解。 数据透视表(Pivo…...
大网站建设/企业软文营销发布平台
解析语法查询就是调用方法查询的原始查询 例如: 查询所有的查询器的语法为:*:*,因为lucene查询是根据term来做的,既是:key:value类型。*:*表示所有域中的所有值。 api调用语法解析 pom.xml 必须…...
网站百度收录秒收方法/抖音网络营销案例分析
高并发架构 消息队列搜索引擎缓存分库分表读写分离设计高并发系统 高并发架构部分内容 缓存: Redis高可用: 高并发系统设计: 分布式系统 分布式业务系统,就是把原来用 Java 开发的一个大块系统,给拆分成多个子系统&a…...