昇思25天学习打卡营第2天|MindSpore快速入门
打卡

目录
打卡
快速入门案例:minist图像数据识别任务
案例任务说明
流程
1 加载并处理数据集
2 模型网络构建与定义
3 模型约束定义
4 模型训练
5 模型保存
6 模型推理
相关参考文档入门理解
MindSpore数据处理引擎
模型网络参数初始化
模型优化器
损失函数
代码
安装
从模型训练到预测推理
self_main_train_and_save.py
self_dataprocess.py
self_network.py
self_modeltrain.py
self_modeltest.py
self_predict.py
快速入门案例:minist图像数据识别任务
案例任务说明
MINIST数据集是有标签的图像数据,图像数据是0-9的手写阿拉伯数字。其中,训练集有6W个,测试集1W个。
目的是训练一个可以高效识别手写阿拉伯数字的模型。
流程
1 加载并处理数据集
涉及到的mindspore接口 mindspore.dataset。例如对数据集的map、batch、shuffle等操作,数据列名获取,对数据集进行迭代访问、查看数据和标签的shape和datatype等。
2 模型网络构建与定义
涉及到 mindspore.nn 类。例如用户可继承nn.Cell类来自定义网络结构,其中的construct类函数包含数据(Tensor)的变换过程。。
3 模型约束定义
包括损失函数、优化器等。如 nn.CrossEntropyLoss() 、nn.SGD(model.trainable_params(), 1e-2)
4 模型训练
- 定义训练函数,用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
- 定义测试函数,用来评估模型的性能。
5 模型保存
- 两种保存方式:
1)模型参数保存:mindspore.save_checkpoint(model, "model.ckpt")
2)统一的中间表示(Intermediate Representation,IR)的保存,MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
6 模型推理
- 两种加载方式:
1)模型参数加载:
> model = network()
> param_dict = mindspore.load_checkpoint("model.ckpt");
> param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
2)统一的中间表示(Intermediate Representation,IR)的加载:
> mindspore.set_context(mode=mindspore.GRAPH_MODE) > graph = mindspore.load("model.mindir") > model = nn.GraphCell(graph) ## nn.GraphCell 仅支持图模式。 > outputs = model(inputs)
保存与加载 — MindSpore master 文档
相关参考文档入门理解
MindSpore数据处理引擎
MindSpore 通过对外暴露API层来构建数据图;内部的Data Processing Pipeline 层用来进行数据加载和预处理多步并行流水线。
高性能数据处理引擎 — MindSpore master 文档

MindSpore 通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。
数据集 Dataset — MindSpore master 文档
数据变换 Transforms — MindSpore master 文档
模型网络参数初始化
Initializer是MindSpore内置的参数初始化基类,所有内置参数初始化方法均继承该类。mindspore.nn中提供的神经网络层封装均提供weight_init、bias_init等入参,可以直接使用实例化的Initializer进行参数初始化。
参数初始化 — MindSpore master 文档
模型优化器
优化器 — MindSpore master 文档
损失函数
损失函数 — MindSpore master 文档
代码
安装
pip/conda均可:
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1
从模型训练到预测推理
训练:
python self_main_train_and_save.py
推理:
python self_predict.py
self_main_train_and_save.py
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# 用download库从公开华为云obs桶下载 MINIST 数据集并解压。因为mindspore.dataset 提供的接口仅支持解压后的数据文件
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True) ## 1 加载数据集
train_dataset = MnistDataset('MNIST_Data/train', shuffle=False)
test_dataset = MnistDataset('MNIST_Data/test')
print(train_dataset.get_col_names()) # 打印数据集中包含的数据列名,用于dataset的预处理。输出['image', 'label']## 2 MindSpore的dataset使用数据处理流水线,这里将处理好的数据集打包为大小为64的batch。
from self_dataprocess import datapipe
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64) ## 3 数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。可使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。
for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")break“”“Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32Shape of label: (64,) Int32”“”
for data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break## 4 模型训练
from self_network import Network
from self_modeltrain import train, loss_fn
from self_modelteset import test
model = Network()
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")## 5 保存模型
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")
self_dataprocess.py
from mindspore.dataset import vision, transforms
def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset
self_network.py
# Define model
from mindspore import nnclass Network(nn.Cell): def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsdef check_network():model = Network()print(model)
self_modeltrain.py
# Instantiate loss function and optimizer
from mindspore import nnloss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train() ## 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 BatchNorm),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。默认Truefor batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
self_modeltest.py
from mindspore import nn def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
self_predict.py
## 加载模型
from self_network import Network# Instantiate a random initialized model
model = Network()# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load) ## param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。## 加载后的模型可以直接用于预测推理。
model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break
相关文章:
昇思25天学习打卡营第2天|MindSpore快速入门
打卡 目录 打卡 快速入门案例:minist图像数据识别任务 案例任务说明 流程 1 加载并处理数据集 2 模型网络构建与定义 3 模型约束定义 4 模型训练 5 模型保存 6 模型推理 相关参考文档入门理解 MindSpore数据处理引擎 模型网络参数初始化 模型优化器 …...
django之url路径
方式一:path 语法:<<转换器类型:自定义>> 作用:若转换器类型匹配到对应类型的数据,则将数据按照关键字传参的方式传递给视图函数 类型: str: 匹配除了”/“之外的非空字符串。 /test/zvxint: 匹配0或任何…...
【OnlyOffice】桌面应用编辑器,插件开发大赛,等你来挑战
OnlyOffice,桌面应用编辑器,最近版本已从8.0升级到了8.1 从PDF、Word、Excel、PPT等全面进行了升级。随着AI应用持续的火热,OnlyOffice也在不断推出AI相关插件。 因此,在此给大家推荐一下OnlyOffice本次的插件开发大赛。 详细信息…...
[学习笔记]SQL学习笔记(连载中。。。)
学习视频:【数据库】SQL 3小时快速入门 #数据库教程 #SQL教程 #MySQL教程 #database#Python连接数据库 目录 1.SQL的基础知识1.1.表(table)和键(key)1.2.外键、联合主键 2.MySQL安装(略,请自行参考视频)3.基本的MySQL语法3.1.规…...
Buuctf之SimpleRev做法
首先,查个壳,64bit,那就丢进ida64中进行反编译进来之后,我们进入main函数,发现里面没什么东西,那就shiftf12搜索字符串,找到关键字符串,双击进入然后再选中该字符串,ctrl…...
【云原生监控】Prometheus 普罗米修斯从搭建到使用详解
目录 一、前言 二、服务监控概述 2.1 什么是微服务监控 2.2 微服务监控指标 2.3 微服务监控工具 三、Prometheus概述 3.1 Prometheus是什么 3.2 Prometheus 特点 3.3 Prometheus 架构图 3.3.1 Prometheus核心组件 3.3.2 Prometheus 工作流程 3.4 Prometheus 应用场景…...
【C++】模板进阶--保姆级解析(什么是非类型模板参数?什么是模板的特化?模板的特化如何应用?)
目录 一、前言 二、什么是C模板? 💦泛型编程的思想 💦C模板的分类 三、非类型模板参数 ⚡问题引入⚡ ⚡非类型模板参数的使用⚡ 🔥非类型模板参数的定义 🔥非类型模板参数的两种类型 ὒ…...
Cookie与Session
Cookie Set-Cookie: sessionIdabc123; ExpiresWed, 09 Jun 2024 10:18:14 GMT; Path/; Secure; HttpOnlySession session作用域 首先需要了解servlet容器可能包含多个web应用。 在servlet容器中同一应用的servlet 对 session数据是可见的,不同应用之间session是相互…...
Nuxt3 的生命周期和钩子函数(十一)
title: Nuxt3 的生命周期和钩子函数(十一) date: 2024/7/5 updated: 2024/7/5 author: cmdragon excerpt: 摘要:本文详细介绍了Nuxt3中几个关键的生命周期钩子和它们的使用方法,包括webpack:done用于Webpack编译完成后执行操作…...
Windows ipconfig命令详解,Windows查看IP地址信息
「作者简介」:冬奥会网络安全中国代表队,CSDN Top100,就职奇安信多年,以实战工作为基础著作 《网络安全自学教程》,适合基础薄弱的同学系统化的学习网络安全,用最短的时间掌握最核心的技术。 ipconfig 1、基…...
在C#/Net中使用Mqtt
net中MQTT的应用场景 c#常用来开发上位机程序,或者其他一些跟设备打交道比较多的系统,所以会经常作为拥有数据的终端,可以用来采集上传数据,而MQTT也是物联网常用的协议,所以下面介绍在C#开发中使用MQTT。 安装MQTTn…...
VBA提取word表格内容到excel
这是一段提取word表格中部分内容的vb代码。 Sub 提取word表格() mypath ThisWorkbook.Path & "\"myname Dir(mypath & "*.doc*")n 4 index of rowsRange("A1:F1") Array("课程代码", "课程名称", "专业&…...
html+css+js图片手动轮播
源代码在界面图片后面 轮播演示用的几张图片是Bing上的,直接用的几张图片的URL,谁加载可能需要等一下,现实中替换成自己的图片即可 关注一下点个赞吧😄 谢谢大佬 界面图片 源代码 <!DOCTYPE html> <html lang&quo…...
【十三】图解 Spring 核心数据结构:BeanDefinition 其二
图解 Spring 核心数据结构:BeanDefinition 其二 概述 前面写过一篇相关文章作为开篇介绍了一下BeanDefinition,本篇将深入细节来向读者展示BeanDefinition的设计,让我们一起来揭开日常开发中使用的bean的神秘面纱,深入细节透彻理解…...
数据库作业
命令 登陆数据库 mysql -uroot -p123456 --prompt"\u\h:\d--> " 创建数据库zcr create database zcr; 修改数据库zcr字符集为gbk alter database zcr default character set gbk collate gbk_chinese_ci; 选择数据库zcr use zcr 查看数据库zc…...
12、matlab中for循环,if else判断语句,break和continue用法以及switch case语句使用
1、前言 在MATLAB中,for循环用于迭代一个固定次数的循环。可以使用if else语句在循环中进行条件判断,根据条件的不同执行相应的代码块。break和continue可以用于控制循环的执行流程,break用于提前结束循环,而continue用于跳过当前…...
AcWing 3207:门禁系统 ← 桶排序中“桶”的思想
【题目来源】https://www.acwing.com/problem/content/3210/【题目描述】 涛涛最近要负责图书馆的管理工作,需要记录下每天读者的到访情况。 每位读者有一个唯一编号,每条记录用读者的编号来表示。 给出读者的来访记录,请问每一条记录中的读者…...
开发个人Go-ChatGPT--3 服务拆分
开发个人Go-ChatGPT–3 服务拆分 个人Go-ChatGPT项目可拆分用户服务(user),AI模型服务(AiModel),… 每个服务都可以再分为 api 服务和 rpc 服务。api 服务对外,可提供给 app 调用。rpc 服务是…...
Android --- 新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了
新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了 大概原因就是,初始化默认Google的安卓模拟器占用的RAM内存是2048,如果电脑的性能和内存一般的话就可能卡死,解决方案是手动修改安卓模拟器的config文件&…...
从入门到深入,Docker新手学习教程
编译整理|TesterHome社区 作者|Ishaan Gupta 以下为作者观点: Docker 彻底改变了我们开发、交付和运行应用程序的方式。它使开发人员能够将应用程序打包到容器中 - 标准化的可执行组件,将应用程序源代码与在任何环境中运行该代码…...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
SkyWalking 10.2.0 SWCK 配置过程
SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外,K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案,全安装在K8S群集中。 具体可参…...
【WiFi帧结构】
文章目录 帧结构MAC头部管理帧 帧结构 Wi-Fi的帧分为三部分组成:MAC头部frame bodyFCS,其中MAC是固定格式的,frame body是可变长度。 MAC头部有frame control,duration,address1,address2,addre…...
FFmpeg 低延迟同屏方案
引言 在实时互动需求激增的当下,无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作,还是游戏直播的画面实时传输,低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架,凭借其灵活的编解码、数据…...
linux arm系统烧录
1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 (忘了有没有这步了 估计有) 刷机程序 和 镜像 就不提供了。要刷的时…...
【单片机期末】单片机系统设计
主要内容:系统状态机,系统时基,系统需求分析,系统构建,系统状态流图 一、题目要求 二、绘制系统状态流图 题目:根据上述描述绘制系统状态流图,注明状态转移条件及方向。 三、利用定时器产生时…...
JDK 17 新特性
#JDK 17 新特性 /**************** 文本块 *****************/ python/scala中早就支持,不稀奇 String json “”" { “name”: “Java”, “version”: 17 } “”"; /**************** Switch 语句 -> 表达式 *****************/ 挺好的ÿ…...
ArcGIS Pro制作水平横向图例+多级标注
今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作:ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等(ArcGIS出图图例8大技巧),那这次我们看看ArcGIS Pro如何更加快捷的操作。…...
Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)
参考官方文档:https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java(供 Kotlin 使用) 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
