昇思25天学习打卡营第2天|MindSpore快速入门
打卡

目录
打卡
快速入门案例:minist图像数据识别任务
案例任务说明
流程
1 加载并处理数据集
2 模型网络构建与定义
3 模型约束定义
4 模型训练
5 模型保存
6 模型推理
相关参考文档入门理解
MindSpore数据处理引擎
模型网络参数初始化
模型优化器
损失函数
代码
安装
从模型训练到预测推理
self_main_train_and_save.py
self_dataprocess.py
self_network.py
self_modeltrain.py
self_modeltest.py
self_predict.py
快速入门案例:minist图像数据识别任务
案例任务说明
MINIST数据集是有标签的图像数据,图像数据是0-9的手写阿拉伯数字。其中,训练集有6W个,测试集1W个。
目的是训练一个可以高效识别手写阿拉伯数字的模型。
流程
1 加载并处理数据集
涉及到的mindspore接口 mindspore.dataset。例如对数据集的map、batch、shuffle等操作,数据列名获取,对数据集进行迭代访问、查看数据和标签的shape和datatype等。
2 模型网络构建与定义
涉及到 mindspore.nn 类。例如用户可继承nn.Cell类来自定义网络结构,其中的construct类函数包含数据(Tensor)的变换过程。。
3 模型约束定义
包括损失函数、优化器等。如 nn.CrossEntropyLoss() 、nn.SGD(model.trainable_params(), 1e-2)
4 模型训练
- 定义训练函数,用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
- 定义测试函数,用来评估模型的性能。
5 模型保存
- 两种保存方式:
1)模型参数保存:mindspore.save_checkpoint(model, "model.ckpt")
2)统一的中间表示(Intermediate Representation,IR)的保存,MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
6 模型推理
- 两种加载方式:
1)模型参数加载:
> model = network()
> param_dict = mindspore.load_checkpoint("model.ckpt");
> param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
2)统一的中间表示(Intermediate Representation,IR)的加载:
> mindspore.set_context(mode=mindspore.GRAPH_MODE) > graph = mindspore.load("model.mindir") > model = nn.GraphCell(graph) ## nn.GraphCell 仅支持图模式。 > outputs = model(inputs)
保存与加载 — MindSpore master 文档
相关参考文档入门理解
MindSpore数据处理引擎
MindSpore 通过对外暴露API层来构建数据图;内部的Data Processing Pipeline 层用来进行数据加载和预处理多步并行流水线。
高性能数据处理引擎 — MindSpore master 文档

MindSpore 通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。
数据集 Dataset — MindSpore master 文档
数据变换 Transforms — MindSpore master 文档
模型网络参数初始化
Initializer是MindSpore内置的参数初始化基类,所有内置参数初始化方法均继承该类。mindspore.nn中提供的神经网络层封装均提供weight_init、bias_init等入参,可以直接使用实例化的Initializer进行参数初始化。
参数初始化 — MindSpore master 文档
模型优化器
优化器 — MindSpore master 文档
损失函数
损失函数 — MindSpore master 文档
代码
安装
pip/conda均可:
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1
从模型训练到预测推理
训练:
python self_main_train_and_save.py
推理:
python self_predict.py
self_main_train_and_save.py
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# 用download库从公开华为云obs桶下载 MINIST 数据集并解压。因为mindspore.dataset 提供的接口仅支持解压后的数据文件
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True) ## 1 加载数据集
train_dataset = MnistDataset('MNIST_Data/train', shuffle=False)
test_dataset = MnistDataset('MNIST_Data/test')
print(train_dataset.get_col_names()) # 打印数据集中包含的数据列名,用于dataset的预处理。输出['image', 'label']## 2 MindSpore的dataset使用数据处理流水线,这里将处理好的数据集打包为大小为64的batch。
from self_dataprocess import datapipe
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64) ## 3 数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。可使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。
for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")break“”“Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32Shape of label: (64,) Int32”“”
for data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break## 4 模型训练
from self_network import Network
from self_modeltrain import train, loss_fn
from self_modelteset import test
model = Network()
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")## 5 保存模型
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")
self_dataprocess.py
from mindspore.dataset import vision, transforms
def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset
self_network.py
# Define model
from mindspore import nnclass Network(nn.Cell): def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsdef check_network():model = Network()print(model)
self_modeltrain.py
# Instantiate loss function and optimizer
from mindspore import nnloss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train() ## 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 BatchNorm),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。默认Truefor batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
self_modeltest.py
from mindspore import nn def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
self_predict.py
## 加载模型
from self_network import Network# Instantiate a random initialized model
model = Network()# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load) ## param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。## 加载后的模型可以直接用于预测推理。
model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break
相关文章:
昇思25天学习打卡营第2天|MindSpore快速入门
打卡 目录 打卡 快速入门案例:minist图像数据识别任务 案例任务说明 流程 1 加载并处理数据集 2 模型网络构建与定义 3 模型约束定义 4 模型训练 5 模型保存 6 模型推理 相关参考文档入门理解 MindSpore数据处理引擎 模型网络参数初始化 模型优化器 …...
django之url路径
方式一:path 语法:<<转换器类型:自定义>> 作用:若转换器类型匹配到对应类型的数据,则将数据按照关键字传参的方式传递给视图函数 类型: str: 匹配除了”/“之外的非空字符串。 /test/zvxint: 匹配0或任何…...
【OnlyOffice】桌面应用编辑器,插件开发大赛,等你来挑战
OnlyOffice,桌面应用编辑器,最近版本已从8.0升级到了8.1 从PDF、Word、Excel、PPT等全面进行了升级。随着AI应用持续的火热,OnlyOffice也在不断推出AI相关插件。 因此,在此给大家推荐一下OnlyOffice本次的插件开发大赛。 详细信息…...
[学习笔记]SQL学习笔记(连载中。。。)
学习视频:【数据库】SQL 3小时快速入门 #数据库教程 #SQL教程 #MySQL教程 #database#Python连接数据库 目录 1.SQL的基础知识1.1.表(table)和键(key)1.2.外键、联合主键 2.MySQL安装(略,请自行参考视频)3.基本的MySQL语法3.1.规…...
Buuctf之SimpleRev做法
首先,查个壳,64bit,那就丢进ida64中进行反编译进来之后,我们进入main函数,发现里面没什么东西,那就shiftf12搜索字符串,找到关键字符串,双击进入然后再选中该字符串,ctrl…...
【云原生监控】Prometheus 普罗米修斯从搭建到使用详解
目录 一、前言 二、服务监控概述 2.1 什么是微服务监控 2.2 微服务监控指标 2.3 微服务监控工具 三、Prometheus概述 3.1 Prometheus是什么 3.2 Prometheus 特点 3.3 Prometheus 架构图 3.3.1 Prometheus核心组件 3.3.2 Prometheus 工作流程 3.4 Prometheus 应用场景…...
【C++】模板进阶--保姆级解析(什么是非类型模板参数?什么是模板的特化?模板的特化如何应用?)
目录 一、前言 二、什么是C模板? 💦泛型编程的思想 💦C模板的分类 三、非类型模板参数 ⚡问题引入⚡ ⚡非类型模板参数的使用⚡ 🔥非类型模板参数的定义 🔥非类型模板参数的两种类型 ὒ…...
Cookie与Session
Cookie Set-Cookie: sessionIdabc123; ExpiresWed, 09 Jun 2024 10:18:14 GMT; Path/; Secure; HttpOnlySession session作用域 首先需要了解servlet容器可能包含多个web应用。 在servlet容器中同一应用的servlet 对 session数据是可见的,不同应用之间session是相互…...
Nuxt3 的生命周期和钩子函数(十一)
title: Nuxt3 的生命周期和钩子函数(十一) date: 2024/7/5 updated: 2024/7/5 author: cmdragon excerpt: 摘要:本文详细介绍了Nuxt3中几个关键的生命周期钩子和它们的使用方法,包括webpack:done用于Webpack编译完成后执行操作…...
Windows ipconfig命令详解,Windows查看IP地址信息
「作者简介」:冬奥会网络安全中国代表队,CSDN Top100,就职奇安信多年,以实战工作为基础著作 《网络安全自学教程》,适合基础薄弱的同学系统化的学习网络安全,用最短的时间掌握最核心的技术。 ipconfig 1、基…...
在C#/Net中使用Mqtt
net中MQTT的应用场景 c#常用来开发上位机程序,或者其他一些跟设备打交道比较多的系统,所以会经常作为拥有数据的终端,可以用来采集上传数据,而MQTT也是物联网常用的协议,所以下面介绍在C#开发中使用MQTT。 安装MQTTn…...
VBA提取word表格内容到excel
这是一段提取word表格中部分内容的vb代码。 Sub 提取word表格() mypath ThisWorkbook.Path & "\"myname Dir(mypath & "*.doc*")n 4 index of rowsRange("A1:F1") Array("课程代码", "课程名称", "专业&…...
html+css+js图片手动轮播
源代码在界面图片后面 轮播演示用的几张图片是Bing上的,直接用的几张图片的URL,谁加载可能需要等一下,现实中替换成自己的图片即可 关注一下点个赞吧😄 谢谢大佬 界面图片 源代码 <!DOCTYPE html> <html lang&quo…...
【十三】图解 Spring 核心数据结构:BeanDefinition 其二
图解 Spring 核心数据结构:BeanDefinition 其二 概述 前面写过一篇相关文章作为开篇介绍了一下BeanDefinition,本篇将深入细节来向读者展示BeanDefinition的设计,让我们一起来揭开日常开发中使用的bean的神秘面纱,深入细节透彻理解…...
数据库作业
命令 登陆数据库 mysql -uroot -p123456 --prompt"\u\h:\d--> " 创建数据库zcr create database zcr; 修改数据库zcr字符集为gbk alter database zcr default character set gbk collate gbk_chinese_ci; 选择数据库zcr use zcr 查看数据库zc…...
12、matlab中for循环,if else判断语句,break和continue用法以及switch case语句使用
1、前言 在MATLAB中,for循环用于迭代一个固定次数的循环。可以使用if else语句在循环中进行条件判断,根据条件的不同执行相应的代码块。break和continue可以用于控制循环的执行流程,break用于提前结束循环,而continue用于跳过当前…...
AcWing 3207:门禁系统 ← 桶排序中“桶”的思想
【题目来源】https://www.acwing.com/problem/content/3210/【题目描述】 涛涛最近要负责图书馆的管理工作,需要记录下每天读者的到访情况。 每位读者有一个唯一编号,每条记录用读者的编号来表示。 给出读者的来访记录,请问每一条记录中的读者…...
开发个人Go-ChatGPT--3 服务拆分
开发个人Go-ChatGPT–3 服务拆分 个人Go-ChatGPT项目可拆分用户服务(user),AI模型服务(AiModel),… 每个服务都可以再分为 api 服务和 rpc 服务。api 服务对外,可提供给 app 调用。rpc 服务是…...
Android --- 新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了
新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了 大概原因就是,初始化默认Google的安卓模拟器占用的RAM内存是2048,如果电脑的性能和内存一般的话就可能卡死,解决方案是手动修改安卓模拟器的config文件&…...
从入门到深入,Docker新手学习教程
编译整理|TesterHome社区 作者|Ishaan Gupta 以下为作者观点: Docker 彻底改变了我们开发、交付和运行应用程序的方式。它使开发人员能够将应用程序打包到容器中 - 标准化的可执行组件,将应用程序源代码与在任何环境中运行该代码…...
Linux应用开发之网络套接字编程(实例篇)
服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...
使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装
以下是基于 vant-ui(适配 Vue2 版本 )实现截图中照片上传预览、删除功能,并封装成可复用组件的完整代码,包含样式和逻辑实现,可直接在 Vue2 项目中使用: 1. 封装的图片上传组件 ImageUploader.vue <te…...
C# 类和继承(抽象类)
抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...
土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等
🔍 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术,可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势,还能有效评价重大生态工程…...
三体问题详解
从物理学角度,三体问题之所以不稳定,是因为三个天体在万有引力作用下相互作用,形成一个非线性耦合系统。我们可以从牛顿经典力学出发,列出具体的运动方程,并说明为何这个系统本质上是混沌的,无法得到一般解…...
NFT模式:数字资产确权与链游经济系统构建
NFT模式:数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新:构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议:基于LayerZero协议实现以太坊、Solana等公链资产互通,通过零知…...
Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)
参考官方文档:https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java(供 Kotlin 使用) 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...
回溯算法学习
一、电话号码的字母组合 import java.util.ArrayList; import java.util.List;import javax.management.loading.PrivateClassLoader;public class letterCombinations {private static final String[] KEYPAD {"", //0"", //1"abc", //2"…...
七、数据库的完整性
七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...
认识CMake并使用CMake构建自己的第一个项目
1.CMake的作用和优势 跨平台支持:CMake支持多种操作系统和编译器,使用同一份构建配置可以在不同的环境中使用 简化配置:通过CMakeLists.txt文件,用户可以定义项目结构、依赖项、编译选项等,无需手动编写复杂的构建脚本…...
