在国产GPU寒武纪MLU上快速上手Pytorch使用指南
本文旨在帮助Pytorch使用者快速上手使用寒武纪MLU。以代码块为主,文字尽可能简洁,许多部分对标NVIDIA CUDA。不正确的地方请留言更正。本文不定期更新。
文章目录
- 前言
- Cambricon PyTorch的Python包torch_mlu导入
- 将模型加载到MLU上model.to('mlu')
- 定义损失函数,然后将其拷贝至MLU
- 将数据从CPU拷贝到MLU设备
- 以mnist.py为例的训练代码demo
- 参考引用
前言
大背景:信创改造、信创国产化、GPU国产化。
为使PyTorch支持寒武纪MLU,寒武纪对机器学习框架PyTorch进行了部分定制。若要在寒武纪MLU上运行PyTorch,需要安装并使用寒武纪定制的 Cambricon PyTorch。
Cambricon PyTorch的Python包torch_mlu导入
Cambricon CATCH是寒武纪发布的一款Python包(包名torch_mlu),提供了在MLU设备上进行张量计算的能力。安装好Cambricon CATCH后,便可使用torch_mlu模块:
import torch # 需安装Cambricon PyTorch
import torch_mlu # 动态扩展MLU后端
附 Cambricon PyTorch源码编译安装
导入 torch 和 torch_mlu 后可以测试在MLU上完成加法运算:
t0 = torch.randn(2, 2, device='mlu') # 在MLU设备上生成Tensor
t1 = torch.randn(2, 2, device='mlu')
result = t0 + t1 # 在MLU设备上完成加法运算
将模型加载到MLU上model.to(‘mlu’)
以ResNet18为例,将模型加载到MLU上用 model.to('mlu'),对标cuda的 model.to(device) :
# 定义模型
model = models.__dict__["resnet50"]()
# 将模型加载到MLU上。
mlu_model = model.to('mlu')
定义损失函数,然后将其拷贝至MLU
# 构造损失函数
criterion = nn.CrossEntropyLoss()
# 将损失函数拷贝到MLU上
criterion.to('mlu')
将数据从CPU拷贝到MLU设备
x = torch.randn(1000000, dtype=torch.float)
x_mlu = x.to(torch.device('mlu'), non_blocking=True)
以mnist.py为例的训练代码demo
import torch # 导入原生 PyTorch
import torch_mlu # 导入 Cambricon PyTorch
from torch.utils.data import DataLoader
from torchvision.datasets import mnist
from torch import nn
from torch import optim
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout2d(0.25)self.dropout2 = nn.Dropout2d(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)# 定义前向计算def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)output = F.log_softmax(x, dim=1)return output# 模型训练
def train(model, train_data, optimizer, epoch):model = model.train()for batch_idx, (img, label) in enumerate(train_data):img = img.mlu()label = label.mlu()optimizer.zero_grad()out = model(img)loss = F.nll_loss(out, label)# 反向计算loss.backward()# 梯度更新optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(img), len(train_data.dataset),100. * batch_idx / len(train_data), loss.item()))# 模型推理
def validate(val_loader, model):test_loss = 0correct = 0model.eval()with torch.no_grad():for images, target in val_loader:images = images.mlu()target = target.mlu()output = model(images)test_loss += F.nll_loss(output, target, reduction='sum').item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(val_loader.dataset)# 打印精度结果print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(val_loader.dataset),100. * correct / len(val_loader.dataset)))# 主函数
def main():# 定义预处理函数data_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.1307],[0.3081])])# 获取 MNIST 数据集train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True)test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)train_data = DataLoader(train_set, batch_size=64, shuffle=True)test_data = DataLoader(test_set, batch_size=1000, shuffle=False)net_orig = Net()# 模型拷贝到MLU设备net = net_orig.mlu()optimizer = optim.Adadelta(net.parameters(), 1)# 训练10个epochnums_epoch = 10# 训练完成后保存模型save_model = True# 学习率调整策略scheduler = StepLR(optimizer, step_size=1, gamma=0.7)for epoch in range(nums_epoch):train(net, train_data, optimizer, epoch)validate(test_data, net)scheduler.step()if save_model: # 将训练好的模型保存为model.pthif epoch == nums_epoch-1:checkpoint = {"state_dict":net.state_dict(), "optimizer":optimizer.state_dict(), "epoch": epoch}torch.save(checkpoint, 'model.pth')if __name__ == '__main__':main()
参考引用
寒武纪PyTorch v1.13.1用户手册
相关文章:
在国产GPU寒武纪MLU上快速上手Pytorch使用指南
本文旨在帮助Pytorch使用者快速上手使用寒武纪MLU。以代码块为主,文字尽可能简洁,许多部分对标NVIDIA CUDA。不正确的地方请留言更正。本文不定期更新。 文章目录 前言Cambricon PyTorch的Python包torch_mlu导入将模型加载到MLU上model.to(mlu)定义损失函…...
重生奇迹MU觉醒战士攻略
剑士连招技巧:生命之光:PK前起手式,增加血上限。 雷霆裂闪:眩晕住对手,剑士PK战士第一技能,雷霆裂闪是否使用好关系到胜负。 霹雳回旋斩:雷霆裂闪后可以选择用霹雳回旋斩跑出一定范围(因为对手…...
美颜技术详解:深入了解视频美颜SDK的工作机制
本文将深入探讨视频美颜SDK的工作机制,揭示其背后的科技奥秘和算法原理。 1.引言 视频美颜SDK作为一种集成到应用程序中的技术工具,通过先进的算法和图像处理技术,为用户提供令人印象深刻的实时美颜效果。 2.视频美颜SDK的基本工作原理 首…...
3D模型格式转换工具如何实现高性能数据转换?请看CAE系统开发实例!
客户背景 DP Technology是全球知名的CAM的供应商,在全球8个国家设有18个办事处。DP Technology提供的CAMESPRIT系统是一个用于数控编程,优化和仿真全方面的CAM系统。CAMESPRIT的客户来自多个行业,因此支持多种CAD工具和文件格式显得格外重…...
多级缓存:亿级流量的缓存方案
文章目录 一.多级缓存的引入二.JVM进程缓存三.Lua语法入门四.多级缓存1.OpenResty2.查询Tomcat3.Redis缓存预热4.查询Redis缓存5.Nginx本地缓存6.缓存同步 一.多级缓存的引入 传统缓存的问题 传统的缓存策略一般是请求到达Tomcat后,先查询Redis,如果未…...
C语言——高精度乘法
一、引子 高精度乘法相较于高精度加法和减法有更多的不同,加法和减法是一位对应一位进行操作的,而乘法是一个数的每一位对另一个数的每一位进行操作,需要的计算步骤更多。 二、核心算法 void Calculate(int num1[], int num2[], int numres…...
为什么C语言没有被C++所取代呢?
今日话题,为什么C语言没有被C所取代呢?虽然C是一个功能更强大的语言,但C语言在嵌入式领域仍然广泛使用,因为它更轻量级、更具可移植性,并且更适合在资源受限的环境中工作。这就是为什么C语言没有被C所取代的原因。如果…...
基于Spring的枚举类+策略模式设计(以实现多种第三方支付功能为例)
摘要 最近阅读《贯彻设计模式》这本书,里面使用一个更真实的项目来介绍设计模式的使用,相较于其它那些只会以披萨、厨师为例的设计模式书籍是有些进步。但这书有时候为了使用设计模式而强行朝着对应的 UML 图来设计类结构,并且对设计理念缺少…...
基于Linphone android sdk开发Android软话机
1.Linphone简介 1.1 简介 LinPhone是一个遵循GPL协议的开源网络电话或者IP语音电话(VOIP)系统,其主要如下。使用linphone,开发者可以在互联网上随意的通信,包括语音、视频、即时文本消息。linphone使用SIP协议&#…...
[论文分享]TimeDRL:多元时间序列的解纠缠表示学习
论文题目:TimeDRL: Disentangled Representation Learning for Multivariate Time-Series 论文地址:https://arxiv.org/abs/2312.04142 代码地址:暂无 关键要点:多元时间序列,自监督表征学习,分类和预测 摘…...
分享一个好看的vs主题
最近发现了一个很好看的vs主题(个人认为挺好看的),想要分享给大家。 主题的名字叫NightOwl,和vscode的主题颜色挺像的。操作方法也十分简单,首先我们先在最上面哪一行找到扩展。 然后点击管理扩展,再搜索栏…...
什么是云呼叫中心?
云呼叫中心作为一种高效的企业呼叫管理方案,越来越受到企业的青睐,常被用于管理客服和销售业务。那么,云呼叫中心到底是什么? 什么是云呼叫中心? 云呼叫中心是一种基于互联网的呼叫管理系统,与传统的呼叫…...
还在用nvm?来试试更快的node版本管理工具——fnm
前言 📫 大家好,我是南木元元,热衷分享有趣实用的文章,希望大家多多支持,一起进步! 🍅 个人主页:南木元元 目录 什么是node版本管理 常见的node版本管理工具 fnm是什么 安装fnm …...
【Hadoop精讲】HDFS详解
目录 理论知识点 角色功能 元数据持久化 安全模式 SecondaryNameNode(SNN) 副本放置策略 HDFS写流程 HDFS读流程 HA高可用 CPA原则 Paxos算法 HA解决方案 HDFS-Fedration解决方案(联邦机制) 理论知识点 角色功能 元数据持久化 另一台机器就…...
企业需要哪些数字化管理系统?
企业需要哪些数字化管理系统? ✅企业引进管理系统肯定是为了帮助整合和管理大量的数据,从而优化业务流程,提高工作效率和生产力。 ❌但是,如果各个系统之间不互通、无法互相关联数据的话,反而会增加工作量和时间成本…...
【vue】开发常见问题及解决方案
有一些问题不限于 Vue,还适应于其他类型的 SPA 项目。 1. 页面权限控制和登陆验证页面权限控制 页面权限控制是什么意思呢? 就是一个网站有不同的角色,比如管理员和普通用户,要求不同的角色能访问的页面是不一样的。如果一个页…...
飞天使-k8s知识点3-卸载yum 安装的k8s
要彻底卸载使用yum安装的 Kubernetes 集群,您可以按照以下步骤进行操作: 停止 Kubernetes 服务: sudo systemctl stop kubelet sudo systemctl stop docker 卸载 Kubernetes 组件: sudo yum remove -y kubelet kubeadm kubectl…...
ZooKeeper 集群搭建
文章目录 ZooKeeper 概述选举机制搭建前准备分布式配置分布式安装解压缩并重命名配置环境配置服务器编号配置文件 操作集群编写脚本运行脚本搭建过程中常见错误 ZooKeeper 概述 Zookeeper 是一个开源的分布式服务协调框架,由Apache软件基金会开发和维护。以下是对Z…...
Meson:现代的构建系统
Meson是一款现代化、高性能的开源构建系统,旨在提供简单、快速和可读性强的构建脚本。Meson被设计为跨平台的,支持多种编程语言,包括C、C、Fortran、Python等。其目标是替代传统的构建工具,如Autotools和CMake,提供更简…...
【大模型AIGC系列课程 5-2】视觉-语言大模型原理
重磅推荐专栏: 《大模型AIGC》;《课程大纲》 本专栏致力于探索和讨论当今最前沿的技术趋势和应用领域,包括但不限于ChatGPT和Stable Diffusion等。我们将深入研究大型模型的开发和应用,以及与之相关的人工智能生成内容(AIGC)技术。通过深入的技术解析和实践经验分享,旨在…...
【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器
一.自适应梯度算法Adagrad概述 Adagrad(Adaptive Gradient Algorithm)是一种自适应学习率的优化算法,由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率,适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院挂号小程序
一、开发准备 环境搭建: 安装DevEco Studio 3.0或更高版本配置HarmonyOS SDK申请开发者账号 项目创建: File > New > Create Project > Application (选择"Empty Ability") 二、核心功能实现 1. 医院科室展示 /…...
第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明
AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...
学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”
2025年#高考 将在近日拉开帷幕,#AI 监考一度冲上热搜。当AI深度融入高考,#时间同步 不再是辅助功能,而是决定AI监考系统成败的“生命线”。 AI亮相2025高考,40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕,江西、…...
SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)
上一章用到了V2 的概念,其实 Fiori当中还有 V4,咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务),代理中间件(ui5-middleware-simpleproxy)-CSDN博客…...
rnn判断string中第一次出现a的下标
# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...
第7篇:中间件全链路监控与 SQL 性能分析实践
7.1 章节导读 在构建数据库中间件的过程中,可观测性 和 性能分析 是保障系统稳定性与可维护性的核心能力。 特别是在复杂分布式场景中,必须做到: 🔍 追踪每一条 SQL 的生命周期(从入口到数据库执行)&#…...
HubSpot推出与ChatGPT的深度集成引发兴奋与担忧
上周三,HubSpot宣布已构建与ChatGPT的深度集成,这一消息在HubSpot用户和营销技术观察者中引发了极大的兴奋,但同时也存在一些关于数据安全的担忧。 许多网络声音声称,这对SaaS应用程序和人工智能而言是一场范式转变。 但向任何技…...
es6+和css3新增的特性有哪些
一:ECMAScript 新特性(ES6) ES6 (2015) - 革命性更新 1,记住的方法,从一个方法里面用到了哪些技术 1,let /const块级作用域声明2,**默认参数**:函数参数可以设置默认值。3&#x…...
鸿蒙HarmonyOS 5军旗小游戏实现指南
1. 项目概述 本军旗小游戏基于鸿蒙HarmonyOS 5开发,采用DevEco Studio实现,包含完整的游戏逻辑和UI界面。 2. 项目结构 /src/main/java/com/example/militarychess/├── MainAbilitySlice.java // 主界面├── GameView.java // 游戏核…...
