在国产GPU寒武纪MLU上快速上手Pytorch使用指南
本文旨在帮助Pytorch使用者快速上手使用寒武纪MLU。以代码块为主,文字尽可能简洁,许多部分对标NVIDIA CUDA。不正确的地方请留言更正。本文不定期更新。
文章目录
- 前言
- Cambricon PyTorch的Python包torch_mlu导入
- 将模型加载到MLU上model.to('mlu')
- 定义损失函数,然后将其拷贝至MLU
- 将数据从CPU拷贝到MLU设备
- 以mnist.py为例的训练代码demo
- 参考引用
前言
大背景:信创改造、信创国产化、GPU国产化。
为使PyTorch支持寒武纪MLU,寒武纪对机器学习框架PyTorch进行了部分定制。若要在寒武纪MLU上运行PyTorch,需要安装并使用寒武纪定制的 Cambricon PyTorch
。
Cambricon PyTorch的Python包torch_mlu导入
Cambricon CATCH是寒武纪发布的一款Python包(包名torch_mlu),提供了在MLU设备上进行张量计算的能力。安装好Cambricon CATCH后,便可使用torch_mlu模块:
import torch # 需安装Cambricon PyTorch
import torch_mlu # 动态扩展MLU后端
附 Cambricon PyTorch源码编译安装
导入 torch 和 torch_mlu 后可以测试在MLU上完成加法运算:
t0 = torch.randn(2, 2, device='mlu') # 在MLU设备上生成Tensor
t1 = torch.randn(2, 2, device='mlu')
result = t0 + t1 # 在MLU设备上完成加法运算
将模型加载到MLU上model.to(‘mlu’)
以ResNet18为例,将模型加载到MLU上用 model.to('mlu')
,对标cuda的 model.to(device)
:
# 定义模型
model = models.__dict__["resnet50"]()
# 将模型加载到MLU上。
mlu_model = model.to('mlu')
定义损失函数,然后将其拷贝至MLU
# 构造损失函数
criterion = nn.CrossEntropyLoss()
# 将损失函数拷贝到MLU上
criterion.to('mlu')
将数据从CPU拷贝到MLU设备
x = torch.randn(1000000, dtype=torch.float)
x_mlu = x.to(torch.device('mlu'), non_blocking=True)
以mnist.py为例的训练代码demo
import torch # 导入原生 PyTorch
import torch_mlu # 导入 Cambricon PyTorch
from torch.utils.data import DataLoader
from torchvision.datasets import mnist
from torch import nn
from torch import optim
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)# 定义前向计算def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)output = F.log_softmax(x, dim=1)return output# 模型训练
def train(model, train_data, optimizer, epoch):model = model.train()for batch_idx, (img, label) in enumerate(train_data):img = img.mlu()label = label.mlu()optimizer.zero_grad()out = model(img)loss = F.nll_loss(out, label)# 反向计算loss.backward()# 梯度更新optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(img), len(train_data.dataset),100. * batch_idx / len(train_data), loss.item()))# 模型推理
def validate(val_loader, model):test_loss = 0correct = 0model.eval()with torch.no_grad():for images, target in val_loader:images = images.mlu()target = target.mlu()output = model(images)test_loss += F.nll_loss(output, target, reduction='sum').item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(val_loader.dataset)# 打印精度结果print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(val_loader.dataset),100. * correct / len(val_loader.dataset)))# 主函数
def main():# 定义预处理函数data_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.1307],[0.3081])])# 获取 MNIST 数据集train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True)test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)train_data = DataLoader(train_set, batch_size=64, shuffle=True)test_data = DataLoader(test_set, batch_size=1000, shuffle=False)net_orig = Net()# 模型拷贝到MLU设备net = net_orig.mlu()optimizer = optim.Adadelta(net.parameters(), 1)# 训练10个epochnums_epoch = 10# 训练完成后保存模型save_model = True# 学习率调整策略scheduler = StepLR(optimizer, step_size=1, gamma=0.7)for epoch in range(nums_epoch):train(net, train_data, optimizer, epoch)validate(test_data, net)scheduler.step()if save_model: # 将训练好的模型保存为model.pthif epoch == nums_epoch-1:checkpoint = {"state_dict":net.state_dict(), "optimizer":optimizer.state_dict(), "epoch": epoch}torch.save(checkpoint, 'model.pth')if __name__ == '__main__':main()
参考引用
寒武纪PyTorch v1.13.1用户手册
相关文章:
在国产GPU寒武纪MLU上快速上手Pytorch使用指南
本文旨在帮助Pytorch使用者快速上手使用寒武纪MLU。以代码块为主,文字尽可能简洁,许多部分对标NVIDIA CUDA。不正确的地方请留言更正。本文不定期更新。 文章目录 前言Cambricon PyTorch的Python包torch_mlu导入将模型加载到MLU上model.to(mlu)定义损失函…...
重生奇迹MU觉醒战士攻略
剑士连招技巧:生命之光:PK前起手式,增加血上限。 雷霆裂闪:眩晕住对手,剑士PK战士第一技能,雷霆裂闪是否使用好关系到胜负。 霹雳回旋斩:雷霆裂闪后可以选择用霹雳回旋斩跑出一定范围(因为对手…...
美颜技术详解:深入了解视频美颜SDK的工作机制
本文将深入探讨视频美颜SDK的工作机制,揭示其背后的科技奥秘和算法原理。 1.引言 视频美颜SDK作为一种集成到应用程序中的技术工具,通过先进的算法和图像处理技术,为用户提供令人印象深刻的实时美颜效果。 2.视频美颜SDK的基本工作原理 首…...
3D模型格式转换工具如何实现高性能数据转换?请看CAE系统开发实例!
客户背景 DP Technology是全球知名的CAM的供应商,在全球8个国家设有18个办事处。DP Technology提供的CAMESPRIT系统是一个用于数控编程,优化和仿真全方面的CAM系统。CAMESPRIT的客户来自多个行业,因此支持多种CAD工具和文件格式显得格外重…...
多级缓存:亿级流量的缓存方案
文章目录 一.多级缓存的引入二.JVM进程缓存三.Lua语法入门四.多级缓存1.OpenResty2.查询Tomcat3.Redis缓存预热4.查询Redis缓存5.Nginx本地缓存6.缓存同步 一.多级缓存的引入 传统缓存的问题 传统的缓存策略一般是请求到达Tomcat后,先查询Redis,如果未…...
C语言——高精度乘法
一、引子 高精度乘法相较于高精度加法和减法有更多的不同,加法和减法是一位对应一位进行操作的,而乘法是一个数的每一位对另一个数的每一位进行操作,需要的计算步骤更多。 二、核心算法 void Calculate(int num1[], int num2[], int numres…...
为什么C语言没有被C++所取代呢?
今日话题,为什么C语言没有被C所取代呢?虽然C是一个功能更强大的语言,但C语言在嵌入式领域仍然广泛使用,因为它更轻量级、更具可移植性,并且更适合在资源受限的环境中工作。这就是为什么C语言没有被C所取代的原因。如果…...
基于Spring的枚举类+策略模式设计(以实现多种第三方支付功能为例)
摘要 最近阅读《贯彻设计模式》这本书,里面使用一个更真实的项目来介绍设计模式的使用,相较于其它那些只会以披萨、厨师为例的设计模式书籍是有些进步。但这书有时候为了使用设计模式而强行朝着对应的 UML 图来设计类结构,并且对设计理念缺少…...
基于Linphone android sdk开发Android软话机
1.Linphone简介 1.1 简介 LinPhone是一个遵循GPL协议的开源网络电话或者IP语音电话(VOIP)系统,其主要如下。使用linphone,开发者可以在互联网上随意的通信,包括语音、视频、即时文本消息。linphone使用SIP协议&#…...
[论文分享]TimeDRL:多元时间序列的解纠缠表示学习
论文题目:TimeDRL: Disentangled Representation Learning for Multivariate Time-Series 论文地址:https://arxiv.org/abs/2312.04142 代码地址:暂无 关键要点:多元时间序列,自监督表征学习,分类和预测 摘…...
分享一个好看的vs主题
最近发现了一个很好看的vs主题(个人认为挺好看的),想要分享给大家。 主题的名字叫NightOwl,和vscode的主题颜色挺像的。操作方法也十分简单,首先我们先在最上面哪一行找到扩展。 然后点击管理扩展,再搜索栏…...
什么是云呼叫中心?
云呼叫中心作为一种高效的企业呼叫管理方案,越来越受到企业的青睐,常被用于管理客服和销售业务。那么,云呼叫中心到底是什么? 什么是云呼叫中心? 云呼叫中心是一种基于互联网的呼叫管理系统,与传统的呼叫…...
还在用nvm?来试试更快的node版本管理工具——fnm
前言 📫 大家好,我是南木元元,热衷分享有趣实用的文章,希望大家多多支持,一起进步! 🍅 个人主页:南木元元 目录 什么是node版本管理 常见的node版本管理工具 fnm是什么 安装fnm …...
【Hadoop精讲】HDFS详解
目录 理论知识点 角色功能 元数据持久化 安全模式 SecondaryNameNode(SNN) 副本放置策略 HDFS写流程 HDFS读流程 HA高可用 CPA原则 Paxos算法 HA解决方案 HDFS-Fedration解决方案(联邦机制) 理论知识点 角色功能 元数据持久化 另一台机器就…...
企业需要哪些数字化管理系统?
企业需要哪些数字化管理系统? ✅企业引进管理系统肯定是为了帮助整合和管理大量的数据,从而优化业务流程,提高工作效率和生产力。 ❌但是,如果各个系统之间不互通、无法互相关联数据的话,反而会增加工作量和时间成本…...
【vue】开发常见问题及解决方案
有一些问题不限于 Vue,还适应于其他类型的 SPA 项目。 1. 页面权限控制和登陆验证页面权限控制 页面权限控制是什么意思呢? 就是一个网站有不同的角色,比如管理员和普通用户,要求不同的角色能访问的页面是不一样的。如果一个页…...
飞天使-k8s知识点3-卸载yum 安装的k8s
要彻底卸载使用yum安装的 Kubernetes 集群,您可以按照以下步骤进行操作: 停止 Kubernetes 服务: sudo systemctl stop kubelet sudo systemctl stop docker 卸载 Kubernetes 组件: sudo yum remove -y kubelet kubeadm kubectl…...
ZooKeeper 集群搭建
文章目录 ZooKeeper 概述选举机制搭建前准备分布式配置分布式安装解压缩并重命名配置环境配置服务器编号配置文件 操作集群编写脚本运行脚本搭建过程中常见错误 ZooKeeper 概述 Zookeeper 是一个开源的分布式服务协调框架,由Apache软件基金会开发和维护。以下是对Z…...
Meson:现代的构建系统
Meson是一款现代化、高性能的开源构建系统,旨在提供简单、快速和可读性强的构建脚本。Meson被设计为跨平台的,支持多种编程语言,包括C、C、Fortran、Python等。其目标是替代传统的构建工具,如Autotools和CMake,提供更简…...
【大模型AIGC系列课程 5-2】视觉-语言大模型原理
重磅推荐专栏: 《大模型AIGC》;《课程大纲》 本专栏致力于探索和讨论当今最前沿的技术趋势和应用领域,包括但不限于ChatGPT和Stable Diffusion等。我们将深入研究大型模型的开发和应用,以及与之相关的人工智能生成内容(AIGC)技术。通过深入的技术解析和实践经验分享,旨在…...
震惊!难怪别人家的孩子越来越聪明,原来竟是因为它
前段时间工作调动给孩子换了个新学校,刚开始担心她不能适应新学校的授课方式,但任课老师对她评价很高,夸她上课很专注。 为了训练孩子的专注力,作为家长可没少下功夫,画画,下五子棋等益智游戏的兴趣班没少…...
Linux操作系统(UMASK+SUID+SGID+STICK)
UMASK反掩码 如何查看反掩码:直接在终端窗口运行 umask root用户反掩码:0022 普通用户反掩码:0002 UMASK的作用:确定目录,文件的缺省权限值 以root身份创建目录,观察目录的9位权限值 以root身份创建普通文件…...
Java 中单例模式的常见实现方式
目录 一、什么是单例模式? 二、单例模式有什么作用? 三、常见的创建单例模式的方式 1、饿汉式创建 2、懒汉式创建 3、DCL(Double Checked Lock)双检锁方式创建 3.1、synchronized 同步锁的基本使用 3.2、使用 DCL 中存在的疑…...
【C语言】自定义类型之联合和枚举
目录 1. 前言2. 联合体2.1 联合体类型的声明2.2 联合体的特点2.3 相同成员的结构体和联合体对比2.4 联合体大小的计算2.4 判断当前机器的大小端 3. 枚举3.1 枚举类型的声明3.2 枚举类型的优点3.3 枚举类型的使用 1. 前言 在之前的博客中介绍了自定义类型中的结构体,…...
使用Mosquitto/python3进行MQTT连接
一、简介 MQTT(消息队列遥测传输)是ISO 标准(ISO/IEC PRF 20922)下基于发布/订阅范式的消息协议。它工作在 TCP/IP协议族上,是为硬件性能低下的远程设备以及网络状况糟糕的情况下而设计的发布/订阅型消息协议,为此,它需要一个消息中间件。 …...
JavaWeb笔记之前端开发HTML
一、引言 1.1HTML概念 网页,是网站中的一个页面,通常是网页是构成网站的基本元素,是承载各种网站应用的平台。通俗的说,网站就是由网页组成的。通常我们看到的网页都是以htm或html后缀结尾的文件,俗称 HTML文件。 …...
通过IP地址定位解决被薅羊毛问题
随着互联网的普及,线上交易和优惠活动日益增多,这也为一些不法分子提供了可乘之机。他们利用技术手段,通过大量注册账号或使用虚假IP地址进行异常操作,以获取更多的优惠或利益,这种行为被称为“薅羊毛”。对于企业和平…...
Leetcode 122 买卖股票的最佳时机 II
题意理解: 已知:一个整数数组 prices ,其中 prices[i] 表示某支股票第 i 天的价格 如何哪个时间点买入,哪个时间点卖出,多次交易,能够收益最大化 目的:收益最大化 解题思路: 使用贪心…...
音频文件合成
音频文件合成 音频文件合成 http://ffmpeg.org/download.html https://blog.csdn.net/u013314786/article/details/89682800 http://www.360doc.com/content/19/0317/01/10519289_822112563.shtml https://chaijunkun.blog.csdn.net/article/details/116491526?spm1001.210…...
20231220将NanoPC-T4(RK3399)开发板的Android10的SDK按照Rockchip官方挖掘机开发板编译打包刷机之后启动跑飞
20231220将NanoPC-T4(RK3399)开发板的Android10的SDK按照Rockchip官方挖掘机开发板编译打包刷机之后启动跑飞 2023/12/20 17:19 简略步骤:rootrootrootroot-X99-Turbo:~/3TB$ tar --use-compress-programpigz -xvpf rk3399-android-10.git-20210201.tgz rootrootro…...
网站开发时保证用户登陆的安全/百度快速排名工具
HANA 语法什么的看这里,懒得写博客了 HANA语法参考:SAP HANA Reference 分类数据类型日期时间类型( Date time)DATE, TIME, SECONDDATE, TIMESTAMP数字类型( Numeric)TINYINT, SMALLINT, INTEGER, BIGINT, SMALLDECIMAL,DECIMAL, REAL, DO…...
网络营销方式举个例子/温州seo公司
下面介绍无监督机器学习算法,与前面分类回归不一样的是,这个不知道目标变量是什么,这个问题解决的是我们从这些样本中,我们能发现什么。 这下面主要讲述了聚类算法,跟数据挖掘中的关联挖掘中的两个主要算法。 K均值算法…...
南京企业做网站/营销100个引流方案
第二题生日蜡烛(结果填空)某君从某年开始每年都举办一次生日party,并且每次都要吹熄与年龄相同根数的蜡烛。现在算起来,他一共吹熄了236根蜡烛。请问,他从多少岁开始过生日party的?请填写他开始过生日party的年龄数。注意…...
网站开发用了哪些知识要点/aso搜索优化
由于OpenCV中的直方图统计函数的实现过于复杂,所以本人就动手重写了直方图统计函数,与各位看官们分享交流一下!当然了,代码还是基于OpenCV实现的。本人只实现了计算CV_8U类图像的直方图统计,从一维,二维&am…...
怎么做没有后台程序的网站/企业如何做网站
Collection:是集合的顶层接口,它的子体系有重复的,有唯一的,有有序的,有无序的。 Collection的功能概述: 1.添加功能:boolean add(Object obj):添加一个元素boolean addAll(Collection c):添加…...
美国做调查的网站/短视频seo系统
据国外媒体报道,惠普公司8月23日宣布,计划以16亿美元竞购虚拟存储制造商3PAR。而在一周前,戴尔公司曾出价11.5亿美元收购此公司。 惠普是在给3PAR的董事长和CEO的信中透露其收购价格的。惠普执行副总裁兼首席战略和技术官谢恩罗宾逊ÿ…...