当前位置: 首页 > news >正文

昇思MindSpore 应用学习-基于MobileNetv2的垃圾分类

基于MobileNetv2的垃圾分类

本文档主要介绍垃圾分类代码开发的方法。通过读取本地图像数据作为输入,对图像中的垃圾物体进行检测,并将检测结果图片保存到文件中。

1、实验目的

  • 了解熟悉垃圾分类应用代码的编写(Python语言);
  • 了解Linux操作系统的基本使用;
  • 掌握atc命令进行模型转换的基本操作。

2、MobileNetv2模型原理介绍

MobileNet网络是由Google团队于2017年提出的,专注于移动端、嵌入式或IoT设备的轻量级CNN网络。相比于传统的卷积神经网络,MobileNet网络使用深度可分离卷积(Depthwise Separable Convolution)的思想,在准确率小幅度降低的前提下,大大减小了模型参数与运算量。同时引入宽度系数 α和分辨率系数 β,使模型满足不同应用场景的需求。
由于MobileNet网络中ReLU激活函数处理低维特征信息时会存在大量的丢失,因此MobileNetV2网络提出使用倒残差结构(Inverted residual block)和Linear Bottlenecks来设计网络,以提高模型的准确率,且优化后的模型更小。

图中Inverted residual block结构是先使用1x1卷积进行升维,然后使用3x3的DepthWise卷积,最后使用1x1的卷积进行降维,与Residual block结构相反。Residual block是先使用1x1的卷积进行降维,然后使用3x3的卷积,最后使用1x1的卷积进行升维。

  • 说明:
    详细内容可参见MobileNetV2论文

3、实验环境

本案例支持win_x86和Linux系统,CPU/GPU/Ascend均可运行。
在动手进行实践之前,确保您已经正确安装了MindSpore。不同平台下的环境准备请参考《MindSpore环境搭建实验手册》。

4、数据处理

4.1 数据准备

MobileNetV2的代码默认使用ImageFolder格式管理数据集,每一类图片整理成单独的一个文件夹, 数据集结构如下:

└─ImageFolder├─train│   class1Folder│   ......└─evalclass1Folder......
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y  # 卸载当前安装的mindspore库
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14  # 安装指定版本的mindspore库
# 查看当前 mindspore 版本
!pip show mindspore  # 显示当前安装的mindspore库的信息
from download import download  # 从download模块导入download函数# 下载data_en数据集
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MindStudio-pc/data_en.zip" 
path = download(url, "./", kind="zip", replace=True)  # 下载指定URL的data_en数据集,并保存为zip文件from download import download  # 重复导入download模块(可优化为一次导入)# 下载预训练权重文件
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/ComputerVision/mobilenetV2-200_1067.zip" 
path = download(url, "./", kind="zip", replace=True)  # 下载指定URL的预训练权重文件,并保存为zip文件

代码解析

  1. %%capture captured_output:
    • 这是一个Jupyter Notebook的魔法命令,用于捕获输出,避免在Notebook中显示执行命令的输出。
  2. !pip uninstall mindspore -y:
    • 使用pip命令卸载名为mindspore的库,-y参数表示自动确认卸载。
  3. !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14:
    • 通过pip安装指定版本的mindspore库(2.2.14),并指定使用的镜像源。
  4. !pip show mindspore:
    • 查询并显示当前安装的mindspore库的信息,包括版本、位置等。
  5. from download import download:
    • download模块导入download函数,用于下载文件。
  6. url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MindStudio-pc/data_en.zip":
    • 定义一个变量url,存储待下载的数据集的URL。
  7. path = download(url, "./", kind="zip", replace=True):
    • 调用download函数下载指定URL的文件,并保存到当前目录,文件类型为zipreplace=True表示如果已存在同名文件则替换。
  8. url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/ComputerVision/mobilenetV2-200_1067.zip":
    • 定义另一个变量url,存储待下载的预训练权重文件的URL。
  9. path = download(url, "./", kind="zip", replace=True):
    • 再次调用download函数,下载预训练权重文件,操作与之前相同。

API 解析

  • pip:Python的包管理工具,用于安装和管理Python包。
  • download(url, path, kind, replace):自定义的下载函数,通常用来从指定的URL下载文件,参数包括:
    • url:待下载文件的链接。
    • path:保存文件的路径。
    • kind:文件类型(如zip)。
    • replace:是否替换已存在的同名文件。

4.2 数据加载

将模块导入,具体如下:
import math  # 导入数学库,用于数学运算
import numpy as np  # 导入NumPy库,用于数组和数学计算
import os  # 导入os库,用于操作系统相关的功能
import random  # 导入random库,用于生成随机数from matplotlib import pyplot as plt  # 导入matplotlib库用于绘图
from easydict import EasyDict  # 导入EasyDict库,用于简化字典操作
from PIL import Image  # 导入PIL库,用于图像处理
import numpy as np  # 再次导入NumPy库(可优化为一次导入)
import mindspore.nn as nn  # 导入MindSpore的神经网络模块
from mindspore import ops as P  # 导入MindSpore的操作模块,简化命名
from mindspore.ops import add  # 导入加法操作
from mindspore import Tensor  # 导入Tensor类,用于创建张量
import mindspore.common.dtype as mstype  # 导入数据类型模块
import mindspore.dataset as de  # 导入MindSpore的数据集模块
import mindspore.dataset.vision as C  # 导入Vision模块,处理图像数据
import mindspore.dataset.transforms as C2  # 导入数据转换模块
import mindspore as ms  # 导入MindSpore库,简化命名
from mindspore import set_context, nn, Tensor, load_checkpoint, save_checkpoint, export  # 导入多种功能
from mindspore.train import Model  # 导入模型训练模块
from mindspore.train import Callback, LossMonitor, ModelCheckpoint, CheckpointConfig  # 导入训练回调和检查点模块# 设置GLOG日志相关的环境变量
os.environ['GLOG_v'] = '3'  # 设置日志等级,3表示ERROR
os.environ['GLOG_logtostderr'] = '0'  # 0表示日志输出到文件,1表示输出到控制台
os.environ['GLOG_log_dir'] = '../../log'  # 设置日志输出目录
os.environ['GLOG_stderrthreshold'] = '2'  # 设置错误输出的阈值,2表示只输出WARNING及以上级别# 设置MindSpore的上下文环境
set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0)  # 设置为图模式执行,目标设备为CPU,设备ID为0
  1. import math:
    • 导入Python标准的数学库,用于数学运算。
  2. import numpy as np:
    • 导入NumPy库,通常用于高效的数组和数学计算。
  3. import os:
    • 导入os模块,用于与操作系统进行交互,例如文件路径操作。
  4. import random:
    • 导入random库,用于生成随机数。
  5. from matplotlib import pyplot as plt:
    • 导入matplotlib库中的pyplot模块,用于绘制图形。
  6. from easydict import EasyDict:
    • 导入EasyDict库,允许使用点号访问字典键,简化代码。
  7. from PIL import Image:
    • 导入PIL库中的Image模块,用于处理图像文件。
  8. import mindspore.nn as nn:
    • 从MindSpore框架中导入神经网络模块,以便构建神经网络。
  9. from mindspore import ops as P:
    • 导入MindSpore的操作模块,简化后续调用的命名。
  10. from mindspore.ops import add:
    • 导入加法操作,用于后续计算。
  11. from mindspore import Tensor:
    • 导入Tensor类,用于创建和操作多维数组(张量)。
  12. import mindspore.common.dtype as mstype:
    • 导入MindSpore的基本数据类型模块。
  13. import mindspore.dataset as de:
    • 导入MindSpore的数据集处理模块,以进行数据加载和处理。
  14. import mindspore.dataset.vision as C:
    • 导入MindSpore的视觉数据集模块,处理与图像相关的数据集。
  15. import mindspore.dataset.transforms as C2:
    • 导入数据转换模块,用于对数据进行预处理和变换。
  16. import mindspore as ms:
    • 导入MindSpore库,简化后续代码中对MindSpore功能的调用。
  17. from mindspore import set_context, nn, Tensor, load_checkpoint, save_checkpoint, export:
    • 导入多个功能,包括设置上下文、神经网络模块、张量处理以及模型的保存和加载。
  18. from mindspore.train import Model:
    • 导入MindSpore的模型训练模块。
  19. from mindspore.train import Callback, LossMonitor, ModelCheckpoint, CheckpointConfig:
    • 导入训练过程中常用的回调函数,包括监控损失、模型保存等功能。
  20. os.environ['GLOG_v'] = '3':
    • 设置GLOG日志输出级别为ERROR。
  21. os.environ['GLOG_logtostderr'] = '0':
    • 将日志输出设置为文件而非控制台。
  22. os.environ['GLOG_log_dir'] = '../../log':
    • 设置日志文件的输出目录。
  23. os.environ['GLOG_stderrthreshold'] = '2':
    • 设置标准错误输出的阈值为WARNING及以上级别。
  24. set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0):
    • 设置MindSpore的上下文环境为图模式,目标设备为CPU,设备ID为0,便于进行模型训练和推理。
  • os:操作系统接口模块,提供与操作系统交互的功能。
  • math:标准数学库,提供数学计算功能。
  • numpy:用于高效的数组运算和数值计算的库。
  • PIL(Python Imaging Library):用于图像处理和打开、操作图像文件。
  • matplotlib.pyplot:用于绘图的库。
  • mindspore:深度学习框架,适用于构建和训练神经网络。
  • set_context:设置MindSpore的执行上下文,包括运行模式和设备类型。
  • Tensor:MindSpore中用于表示多维数组(张量)的数据结构。
配置后续训练、验证、推理用到的参数:
# 垃圾分类数据集标签,以及用于标签映射的字典。
garbage_classes = {'干垃圾': ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服'],  # 定义干垃圾类别及其对应物品'可回收物': ['报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张'],  # 定义可回收物类别及其对应物品'湿垃圾': ['菜叶', '橙皮', '蛋壳', '香蕉皮'],  # 定义湿垃圾类别及其对应物品'有害垃圾': ['电池', '药片胶囊', '荧光灯', '油漆桶']  # 定义有害垃圾类别及其对应物品
}# 定义所有垃圾类别的中文名称
class_cn = ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服','报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张','菜叶', '橙皮', '蛋壳', '香蕉皮','电池', '药片胶囊', '荧光灯', '油漆桶']# 定义所有垃圾类别的英文名称
class_en = ['Seashell', 'Lighter', 'Old Mirror', 'Broom', 'Ceramic Bowl', 'Toothbrush', 'Disposable Chopsticks', 'Dirty Cloth','Newspaper', 'Glassware', 'Basketball', 'Plastic Bottle', 'Cardboard', 'Glass Bottle', 'Metalware', 'Hats', 'Cans', 'Paper','Vegetable Leaf', 'Orange Peel', 'Eggshell', 'Banana Peel','Battery', 'Tablet capsules', 'Fluorescent lamp', 'Paint bucket']# 定义从英文类别名到索引的映射
index_en = {'Seashell': 0, 'Lighter': 1, 'Old Mirror': 2, 'Broom': 3, 'Ceramic Bowl': 4, 'Toothbrush': 5, 'Disposable Chopsticks': 6, 'Dirty Cloth': 7,'Newspaper': 8, 'Glassware': 9, 'Basketball': 10, 'Plastic Bottle': 11, 'Cardboard': 12, 'Glass Bottle': 13, 'Metalware': 14, 'Hats': 15, 'Cans': 16, 'Paper': 17,'Vegetable Leaf': 18, 'Orange Peel': 19, 'Eggshell': 20, 'Banana Peel': 21,'Battery': 22, 'Tablet capsules': 23, 'Fluorescent lamp': 24, 'Paint bucket': 25
}# 训练超参
config = EasyDict({"num_classes": 26,  # 类别数量"image_height": 224,  # 输入图像高度"image_width": 224,  # 输入图像宽度"backbone_out_channels": 1280,  # 主干网络输出通道数"batch_size": 16,  # 训练批次大小"eval_batch_size": 8,  # 验证批次大小"epochs": 10,  # 训练轮数"lr_max": 0.05,  # 最大学习率"momentum": 0.9,  # 动量"weight_decay": 1e-4,  # 权重衰减"save_ckpt_epochs": 1,  # 每隔多少轮保存一次模型"dataset_path": "./data_en",  # 数据集路径"class_index": index_en,  # 类别索引映射"pretrained_ckpt": "./mobilenetV2-200_1067.ckpt"  # 预训练模型的路径
})
  1. 垃圾分类数据集标签字典:
    • garbage_classes:
      • 定义不同类型垃圾的中文标签及其对应的具体物品列表。
      • 包括四类:
        • 干垃圾(如贝壳、打火机等)
        • 可回收物(如报纸、塑料瓶等)
        • 湿垃圾(如菜叶、蛋壳等)
        • 有害垃圾(如电池、药片胶囊等)
  2. 垃圾类别列表:
    • class_cn:
      • 包含所有垃圾类别的中文名称。
    • class_en:
      • 包含所有垃圾类别的英文名称。
  3. 类别索引映射:
    • index_en:
      • 定义英文类别名称与其对应索引的映射,便于在训练过程中进行标签管理。
  4. 训练超参数配置:
    • config:
      • 使用EasyDict来组织训练超参数,方便后续访问和修改。
      • 包括:
        • num_classes: 总类别数(26种垃圾)。
        • image_heightimage_width: 输入图像的尺寸(224x224)。
        • backbone_out_channels: 主干网络的输出通道数(1280)。
        • batch_size: 训练时的批次大小(16)。
        • eval_batch_size: 验证时的批次大小(8)。
        • epochs: 训练的轮数(10)。
        • lr_max: 最大学习率(0.05)。
        • momentum: 动量(0.9)。
        • weight_decay: 权重衰减的系数(1e-4)。
        • save_ckpt_epochs: 模型保存的频率(每1轮保存一次)。
        • dataset_path: 数据集的路径。
        • class_index: 类别索引映射,方便后续训练。
        • pretrained_ckpt: 预训练模型的路径。
  • EasyDict: EasyDict是一个方便的字典类,可以通过点操作符访问属性,使得代码更加简洁易读。
  • 字典和列表: Python内置数据结构,用于存储和管理数据。
  • 配置参数的结构化: 通过结构化配置(如EasyDict)来集中管理模型超参数,方便修改和读取。
数据预处理操作

利用ImageFolderDataset方法读取垃圾分类数据集,并整体对数据集进行处理。
读取数据集时指定训练集和测试集,首先对整个数据集进行归一化,修改图像频道等预处理操作。然后对训练集的数据依次进行RandomCropDecodeResize、RandomHorizontalFlip、RandomColorAdjust、shuffle操作,以增加训练数据的丰富度;对测试集进行Decode、Resize、CenterCrop等预处理操作;最后返回处理后的数据集。

def create_dataset(dataset_path, config, training=True, buffer_size=1000):"""create a train or eval datasetArgs:dataset_path(string): the path of dataset.config(struct): the config of train and eval in different platform.Returns:train_dataset, val_dataset"""# 根据训练或测试模式设置数据集路径data_path = os.path.join(dataset_path, 'train' if training else 'test')# 创建图像文件夹数据集,指定并行工作线程数和类别索引ds = de.ImageFolderDataset(data_path, num_parallel_workers=4, class_indexing=config.class_index)# 设置图像的目标高度和宽度resize_height = config.image_heightresize_width = config.image_width# 定义归一化操作normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])# 定义通道转换操作(HWC到CHW)change_swap_op = C.HWC2CHW()# 定义类型转换操作,将标签转换为int32类型type_cast_op = C2.TypeCast(mstype.int32)if training:# 如果是训练模式,定义数据增强的操作crop_decode_resize = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)color_adjust = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)# 组合训练数据的转换操作train_trans = [crop_decode_resize, horizontal_flip_op, color_adjust, normalize_op, change_swap_op]# 对图像应用数据增强操作train_ds = ds.map(input_columns="image", operations=train_trans, num_parallel_workers=4)# 对标签应用类型转换操作train_ds = train_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)# 打乱训练数据train_ds = train_ds.shuffle(buffer_size=buffer_size)# 按照批次大小创建训练数据集ds = train_ds.batch(config.batch_size, drop_remainder=True)else:# 如果是评估模式,定义评估数据的处理操作decode_op = C.Decode()resize_op = C.Resize((int(resize_width/0.875), int(resize_width/0.875)))center_crop = C.CenterCrop(resize_width)# 组合评估数据的转换操作eval_trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]# 对图像应用评估处理操作eval_ds = ds.map(input_columns="image", operations=eval_trans, num_parallel_workers=4)# 对标签应用类型转换操作eval_ds = eval_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)# 按照批次大小创建评估数据集ds = eval_ds.batch(config.eval_batch_size, drop_remainder=True)return ds  # 返回处理后的数据集
  1. 函数定义:
    • create_dataset(dataset_path, config, training=True, buffer_size=1000):
      • 创建训练或评估的数据集。
  2. 参数说明:
    • dataset_path: 数据集的路径。
    • config: 配置对象,包含训练和评估的参数。
    • training: 布尔值,指示当前是否为训练模式(默认为True)。
    • buffer_size: 用于打乱数据时的缓冲区大小(默认为1000)。
  3. 数据集路径设置:
    • 根据模式(训练或测试)设置数据集的路径。
  4. 创建数据集:
    • 使用de.ImageFolderDataset创建图像文件夹数据集,指定并行工作线程数和类别索引。
  5. 图像预处理:
    • 设定图像的高度和宽度。
    • 定义归一化操作normalize_op
    • 定义通道转换操作change_swap_op
    • 定义类型转换操作type_cast_op
  6. 训练模式处理:
    • 如果training为True,定义数据增强操作,如随机裁剪、水平翻转和颜色调整。
    • 创建训练数据的转换列表train_trans,并将其应用于数据集。
    • 打乱训练数据并按照批次大小进行分组。
  7. 评估模式处理:
    • 如果training为False,定义评估数据的处理操作,包括解码、调整大小和中心裁剪。
    • 创建评估数据的转换列表eval_trans,并将其应用于数据集。
    • 按照评估批次大小进行分组。
  8. 返回数据集:
    • 返回创建和处理后的数据集。
  • os.path.join: 用于拼接文件路径。
  • de.ImageFolderDataset: MindSpore中的数据集类,用于创建图像文件夹数据集。
  • C.Normalize: 图像归一化操作。
  • C.HWC2CHW: 将图像的通道格式从HWC(高度-宽度-通道)转换为CHW格式。
  • C.RandomCropDecodeResize: 随机裁剪并调整大小的操作,用于数据增强。
  • C.RandomHorizontalFlip: 随机水平翻转操作,用于数据增强。
  • C.RandomColorAdjust: 随机颜色调整操作,用于数据增强。
  • C.Decode, C.Resize, C.CenterCrop: 用于图像解码、大小调整和中心裁剪的操作。
  • map: 将操作应用于数据集中的指定列。
展示部分处理后的数据:
# 创建数据集,设置为评估模式
ds = create_dataset(dataset_path=config.dataset_path, config=config, training=False)# 打印数据集的大小
print(ds.get_dataset_size())# 创建字典迭代器以获取数据
data = ds.create_dict_iterator(output_numpy=True)._get_next()# 提取图像和标签
images = data['image']  # 获取图像数据
labels = data['label']  # 获取标签数据# 绘制前四个图像及其标签
for i in range(1, 5):plt.subplot(2, 2, i)  # 创建2行2列的子图plt.imshow(np.transpose(images[i], (1, 2, 0)))  # 将图像从CHW格式转为HWC格式以便显示plt.title('label: %s' % class_en[labels[i]])  # 设置标题为对应的英文标签plt.xticks([])  # 不显示x轴刻度plt.show()  # 显示图像
  1. 创建数据集:
    • ds = create_dataset(dataset_path=config.dataset_path, config=config, training=False):
      • 调用create_dataset函数,创建用于评估的数据集。这里的training=False表示我们不在训练模式下。
  2. 获取数据集大小:
    • print(ds.get_dataset_size()):
      • 打印数据集的大小,返回数据集中样本的数量。
  3. 创建字典迭代器:
    • data = ds.create_dict_iterator(output_numpy=True)._get_next():
      • 创建一个字典迭代器,允许从数据集中读取数据,并将输出格式设置为NumPy数组。
      • _get_next()方法从迭代器中获取下一个数据字典。
  4. 提取图像和标签:
    • images = data['image']:
      • 从数据字典中提取图像数据。
    • labels = data['label']:
      • 从数据字典中提取标签数据。
  5. 绘制图像及其标签:
    • 使用for循环遍历前四个图像(索引从1到4)。
    • plt.subplot(2, 2, i):
      • 创建一个2行2列的子图,i表示当前子图的位置。
    • plt.imshow(np.transpose(images[i], (1, 2, 0))):
      • 使用imshow绘制图像。np.transpose将图像从CHW(通道-高度-宽度)格式转换为HWC(高度-宽度-通道)格式,以便正确显示。
    • plt.title('label: %s' % class_en[labels[i]]):
      • 设置子图的标题,标题为对应的英文标签。
    • plt.xticks([]):
      • 隐藏x轴刻度,增加图像的可读性。
  6. 显示图像:
    • plt.show():
      • 显示所有绘制的图像及其对应标签。
  • create_dataset: 用于创建和处理数据集的函数。
  • get_dataset_size(): 获取数据集大小的方法。
  • create_dict_iterator: 创建一个字典迭代器,用于逐步读取数据。
  • output_numpy: 设置输出格式为NumPy数组。
  • imshow: Matplotlib函数,用于显示图像。
  • title: 设置子图的标题。
  • xticks([]): 隐藏指定的刻度线,提升图像的可读性。
  • show(): 显示所有当前绘制的图形。

5、MobileNetV2模型搭建

使用MindSpore定义MobileNetV2网络的各模块时需要继承mindspore.nn.Cell。Cell是所有神经网络(Conv2d等)的基类。
神经网络的各层需要预先在__init__方法中定义,然后通过定义construct方法来完成神经网络的前向构造。原始模型激活函数为ReLU6,池化模块采用全局平均池化层。

__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']def _make_divisible(v, divisor, min_value=None):# 确保通道数为指定的倍数if min_value is None:min_value = divisornew_v = max(min_value, int(v + divisor / 2) // divisor * divisor)if new_v < 0.9 * v:new_v += divisorreturn new_vclass GlobalAvgPooling(nn.Cell):"""Global avg pooling definition.Args:Returns:Tensor, output tensor.Examples:>>> GlobalAvgPooling()"""def __init__(self):super(GlobalAvgPooling, self).__init__()def construct(self, x):# 对输入x进行全局平均池化x = P.mean(x, (2, 3))  # 在高和宽维度上计算均值return xclass ConvBNReLU(nn.Cell):"""Convolution/Depthwise fused with Batchnorm and ReLU block definition.Args:in_planes (int): 输入通道数。out_planes (int): 输出通道数。kernel_size (int): 卷积核大小。stride (int): 第一个卷积层的步幅,默认为1。groups (int): 通道组数,对于深度可分离卷积,等于输入通道数。默认为1。Returns:Tensor, output tensor.Examples:>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)"""def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):super(ConvBNReLU, self).__init__()padding = (kernel_size - 1) // 2  # 计算填充大小in_channels = in_planesout_channels = out_planes# 根据分组选择卷积类型if groups == 1:conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding)else:out_channels = in_planesconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad',padding=padding, group=in_channels)layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]  # 定义卷积、Batchnorm和ReLU6层self.features = nn.SequentialCell(layers)  # 组合成SequentialCelldef construct(self, x):# 前向传播output = self.features(x)return outputclass InvertedResidual(nn.Cell):"""Mobilenetv2 residual block definition.Args:inp (int): 输入通道数。oup (int): 输出通道数。stride (int): 第一个卷积层的步幅,默认为1。expand_ratio (int): 输入通道的扩展比。Returns:Tensor, output tensor.Examples:>>> ResidualBlock(3, 256, 1, 1)"""def __init__(self, inp, oup, stride, expand_ratio):super(InvertedResidual, self).__init__()assert stride in [1, 2]  # 步幅只能为1或2hidden_dim = int(round(inp * expand_ratio))  # 计算扩展后的通道数self.use_res_connect = stride == 1 and inp == oup  # 判断是否使用残差连接layers = []if expand_ratio != 1:layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))  # 扩展卷积layers.extend([ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),  # 深度卷积nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False),  # 线性卷积nn.BatchNorm2d(oup),  # BatchNorm])self.conv = nn.SequentialCell(layers)  # 组合层self.cast = P.Cast()  # 类型转换操作def construct(self, x):identity = x  # 保存输入x = self.conv(x)  # 前向传播if self.use_res_connect:return P.add(identity, x)  # 如果使用残差连接,则返回相加的结果return x  # 否则返回卷积结果class MobileNetV2Backbone(nn.Cell):"""MobileNetV2 architecture.Args:class_num (int): 类别数量。width_mult (int): 通道数的乘子,默认为1。has_dropout (bool): 是否使用dropout,默认为false。inverted_residual_setting (list): 反向残差设置,默认为None。round_nearest (list): 通道数的近似值,默认为8。Returns:Tensor, output tensor.Examples:>>> MobileNetV2(num_classes=1000)"""def __init__(self, width_mult=1., inverted_residual_setting=None, round_nearest=8,input_channel=32, last_channel=1280):super(MobileNetV2Backbone, self).__init__()block = InvertedResidual  # 定义块类型# 设置反向残差块的配置self.cfgs = inverted_residual_settingif inverted_residual_setting is None:self.cfgs = [[1, 16, 1, 1],[6, 24, 2, 2],[6, 32, 3, 2],[6, 64, 4, 2],[6, 96, 3, 1],[6, 160, 3, 2],[6, 320, 1, 1],]# 构建第一层input_channel = _make_divisible(input_channel * width_mult, round_nearest)self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)features = [ConvBNReLU(3, input_channel, stride=2)]  # 第一个卷积层# 构建反向残差块for t, c, n, s in self.cfgs:output_channel = _make_divisible(c * width_mult, round_nearest)for i in range(n):stride = s if i == 0 else 1features.append(block(input_channel, output_channel, stride, expand_ratio=t))  # 添加反向残差块input_channel = output_channel  # 更新输入通道数features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1))  # 最后一个卷积层self.features = nn.SequentialCell(features)  # 组合成SequentialCellself._initialize_weights()  # 初始化权重def construct(self, x):# 前向传播x = self.features(x)return xdef _initialize_weights(self):"""Initialize weights.Args:Returns:None.Examples:>>> _initialize_weights()"""self.init_parameters_data()for _, m in self.cells_and_names():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),m.weight.data.shape).astype("float32")))  # 权重初始化if m.bias is not None:m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))  # 偏置初始化elif isinstance(m, nn.BatchNorm2d):m.gamma.set_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32")))  # gamma初始化m.beta.set_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32")))  # beta初始化@propertydef get_features(self):return self.featuresclass MobileNetV2Head(nn.Cell):"""MobileNetV2 architecture.Args:class_num (int): 类别数量,默认为1000。has_dropout (bool): 是否使用dropout,默认为false。Returns:Tensor, output tensor.Examples:>>> MobileNetV2(num_classes=1000)"""def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False, activation="None"):super(MobileNetV2Head, self).__init__()# mobilenet headhead = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)])self.head = nn.SequentialCell(head)  # 组合成SequentialCellself.need_activation = Trueif activation == "Sigmoid":self.activation = nn.Sigmoid()  # 设置激活函数为Sigmoidelif activation == "Softmax":self.activation = nn.Softmax()  # 设置激活函数为Softmaxelse:self.need_activation = Falseself._initialize_weights()  # 初始化权重def construct(self, x):# 前向传播x = self.head(x)if self.need_activation:x = self.activation(x)return xdef _initialize_weights(self):"""Initialize weights.Args:Returns:None.Examples:>>> _initialize_weights()"""self.init_parameters_data()for _, m in self.cells_and_names():if isinstance(m, nn.Dense):m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))  # 权重初始化if m.bias is not None:m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))  # 偏置初始化@propertydef get_head(self):return self.headclass MobileNetV2(nn.Cell):"""MobileNetV2 architecture.Args:class_num (int): 类别数量。width_mult (int): 通道数的乘子,默认为1。has_dropout (bool): 是否使用dropout,默认为false。inverted_residual_setting (list): 反向残差设置,默认为None。round_nearest (int): 通道数的近似值,默认为8。Returns:Tensor, output tensor.Examples:>>> MobileNetV2(backbone, head)"""def __init__(self, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None, \round_nearest=8, input_channel=32, last_channel=1280):super(MobileNetV2, self).__init__()self.backbone = MobileNetV2Backbone(width_mult=width_mult, \inverted_residual_setting=inverted_residual_setting, \round_nearest=round_nearest, input_channel=input_channel, last_channel=last_channel).get_featuresself.head = MobileNetV2Head(input_channel=self.backbone.out_channels, num_classes=num_classes, \has_dropout=has_dropout).get_headdef construct(self, x):# 前向传播x = self.backbone(x)x = self.head(x)return xclass MobileNetV2Combine(nn.Cell):"""MobileNetV2Combine architecture.Args:backbone (Cell): 特征提取层。head (Cell): 全连接层。Returns:Tensor, output tensor.Examples:>>> MobileNetV2(num_classes=1000)"""def __init__(self, backbone, head):super(MobileNetV2Combine, self).__init__(auto_prefix=False)self.backbone = backbone  # 特征提取部分self.head = head  # 分类部分def construct(self, x):# 前向传播x = self.backbone(x)x = self.head(x)return xdef mobilenet_v2(backbone, head):# 返回组合后的MobileNetV2模型return MobileNetV2Combine(backbone, head)
  1. 模块和函数定义:
    • __all__: 定义模块的公开接口。
    • _make_divisible: 确保某个值是指定的倍数(如8或16),用于调整通道数。
  2. 全局平均池化类:
    • GlobalAvgPooling: 实现全局平均池化操作,通过P.mean函数在高和宽维度计算均值。
  3. 卷积、BatchNorm和ReLU模块:
    • ConvBNReLU: 组合卷积层、BatchNorm层和ReLU激活函数的模块,支持深度可分离卷积。
  4. 反向残差模块:
    • InvertedResidual: 实现MobileNetV2的关键残差块,支持输入通道数的扩展。
  5. MobileNetV2主干网络:
    • MobileNetV2Backbone: 构建MobileNetV2的特征提取部分,包括设置反向残差块的配置以及初始化权重。
  6. MobileNetV2头部:
    • MobileNetV2Head: 定义模型的头部,处理全局平均池化和最后的全连接层,实现分类功能。
  7. MobileNetV2综合模型:
    • MobileNetV2Combine: 组合特征提取和分类功能的完整模型。
  8. MobileNetV2模型:
    • MobileNetV2: 包含主干和头部的完整模型,执行前向传播。
  9. 构建函数:
    • mobilenet_v2: 返回组合的MobileNetV2模型。
  • nn.Cell: MindSpore框架中的基本模块类,所有网络模块都应继承此类。
  • P.mean: 用于计算输入张量沿指定维度的均值。
  • nn.Conv2d: 2D卷积层。
  • nn.BatchNorm2d: 2D批量归一化层。
  • nn.ReLU6: 6修正线性单元激活函数。
  • nn.Dense: 全连接层。
  • Tensor: MindSpore中的张量数据结构。
  • SequentialCell: 用于将多个层组合成一个顺序层的容器。

6、MobileNetV2模型的训练与测试

训练策略

一般情况下,模型训练时采用静态学习率,如0.01。随着训练步数的增加,模型逐渐趋于收敛,对权重参数的更新幅度应逐渐降低,以减小模型训练后期的抖动。因此,模型训练时可以采用动态下降的学习率,常见的学习率下降策略有:

  • Polynomial decay/square decay
  • Cosine decay
  • Exponential decay
  • Stage decay

这里使用cosine decay下降策略:

def cosine_decay(total_steps, lr_init=0.0, lr_end=0.0, lr_max=0.1, warmup_steps=0):"""Applies cosine decay to generate learning rate array.Args:total_steps(int): 所有训练步骤的总数。lr_init(float): 初始学习率。lr_end(float): 结束学习率。lr_max(float): 最大学习率。warmup_steps(int): 热身阶段的总步骤。Returns:list, 学习率数组。"""# 将学习率参数转换为浮点数lr_init, lr_end, lr_max = float(lr_init), float(lr_end), float(lr_max)# 计算总的衰减步骤decay_steps = total_steps - warmup_stepslr_all_steps = []  # 初始化学习率数组# 计算每步增加的学习率inc_per_step = (lr_max - lr_init) / warmup_steps if warmup_steps else 0# 循环生成每一步的学习率for i in range(total_steps):if i < warmup_steps:# 在热身阶段,逐步增加学习率lr = lr_init + inc_per_step * (i + 1)else:# 计算余弦衰减cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))# 根据余弦衰减计算当前学习率lr = (lr_max - lr_end) * cosine_decay + lr_endlr_all_steps.append(lr)  # 将当前学习率添加到数组中return lr_all_steps  # 返回生成的学习率数组
  1. 函数定义:
    • cosine_decay(total_steps, lr_init=0.0, lr_end=0.0, lr_max=0.1, warmup_steps=0):
      • 应用余弦衰减生成学习率数组。
  2. 参数说明:
    • total_steps: 整个训练过程中的总步骤数。
    • lr_init: 初始学习率。
    • lr_end: 结束时的学习率。
    • lr_max: 训练中间的最大学习率。
    • warmup_steps: 热身阶段的步骤数。
  3. 参数转换:
    • lr_initlr_endlr_max转换为浮点数,以确保计算时的正确性。
  4. 衰减步骤计算:
    • decay_steps = total_steps - warmup_steps:
      • 计算进行余弦衰减的步骤数。
  5. 学习率数组初始化:
    • lr_all_steps = []:
      • 初始化一个空的学习率数组用于存储每一步的学习率。
  6. 热身阶段学习率计算:
    • inc_per_step: 在热身期间,每步增加的学习率。
    • 如果warmup_steps不为0,则计算初始学习率到最大学习率的增加步长。
  7. 循环生成学习率:
    • 对于每一步i,判断是否在热身阶段。
    • 在热身阶段,学习率线性增加。
    • 在衰减阶段,使用余弦函数计算衰减值。
    • 根据余弦衰减公式计算当前学习率,并将其添加到数组中。
  8. 返回学习率数组:
    • 函数结束时返回生成的学习率数组。
  • math.cos: Python内置的数学库,用于计算余弦值。
  • math.pi: π的值,用于计算余弦衰减的角度。
  • list: 用于创建一个空的学习率数组,以便存储每个训练步骤的学习率。

在模型训练过程中,可以添加检查点(Checkpoint)用于保存模型的参数,以便进行推理及中断后再训练使用。使用场景如下:

  • 训练后推理场景
    1. 模型训练完毕后保存模型的参数,用于推理或预测操作。
    2. 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。
  • 再训练场景
    1. 进行长时间训练任务时,保存训练过程中的Checkpoint文件,防止任务异常退出后从初始状态开始训练。
    2. Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。

这里加载ImageNet数据上预训练的MobileNetv2进行Fine-tuning,只训练最后修改的FC层,并在训练过程中保存Checkpoint:

def switch_precision(net, data_type):# 检查当前设备是否为Ascendif ms.get_context('device_target') == "Ascend":# 将整个网络转换为指定的数据类型net.to_float(data_type)# 遍历网络中的所有子模块for _, cell in net.cells_and_names():# 如果子模块是全连接层(Dense)if isinstance(cell, nn.Dense):# 将全连接层的权重转换为float32类型cell.to_float(ms.float32)
  1. 函数定义:
    • switch_precision(net, data_type):
      • 用于切换神经网络的精度,特别适用于Ascend设备。
  2. 设备检查:
    • if ms.get_context('device_target') == "Ascend"::
      • 检查当前上下文是否设置为Ascend设备,以确保仅在该设备上执行精度转换。
  3. 网络精度转换:
    • net.to_float(data_type):
      • 将整个网络的精度转换为指定的数据类型(如ms.float16ms.float32)。
  4. 遍历网络中的子模块:
    • for _, cell in net.cells_and_names()::
      • 遍历网络中的每一个子模块, _ 用于忽略名称,cell 是具体的网络模块。
  5. 全连接层处理:
    • if isinstance(cell, nn.Dense)::
      • 检查当前模块是否为全连接层(Dense层)。
    • cell.to_float(ms.float32):
      • 将全连接层的权重转换为float32类型。这通常是为了提高数值稳定性和计算精度。
  • ms.get_context: 获取当前的上下文设置,通常用于确认设备类型(如Ascend、GPU或CPU)。
  • nn.Dense: 表示全连接层的类,通常用于构建神经网络的线性部分。
  • to_float: 用于将网络或层的参数数据类型转换为指定的浮点数据类型。
模型训练与测试

在进行正式的训练之前,定义训练函数,读取数据并对模型进行实例化,定义优化器和损失函数。
首先简单介绍损失函数及优化器的概念:

  • 损失函数:又叫目标函数,用于衡量预测值与实际值差异的程度。深度学习通过不停地迭代来缩小损失函数的值。定义一个好的损失函数,可以有效提高模型的性能。
  • 优化器:用于最小化损失函数,从而在训练过程中改进模型。

定义了损失函数后,可以得到损失函数关于权重的梯度。梯度用于指示优化器优化权重的方向,以提高模型性能。
在训练MobileNetV2之前,对MobileNetV2Backbone层的参数进行了固定,使其在训练过程中对该模块的权重参数不进行更新;只对MobileNetV2Head模块的参数进行更新。
MindSpore支持的损失函数有SoftmaxCrossEntropyWithLogits、L1Loss、MSELoss等。这里使用SoftmaxCrossEntropyWithLogits损失函数。
训练测试过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个运行的loss值有一定随机性,不一定完全相同。

from mindspore.amp import FixedLossScaleManager
import timeLOSS_SCALE = 1024  # 定义损失缩放因子# 创建训练和评估数据集
train_dataset = create_dataset(dataset_path=config.dataset_path, config=config)
eval_dataset = create_dataset(dataset_path=config.dataset_path, config=config)
step_size = train_dataset.get_dataset_size()  # 获取训练步骤数量# 创建MobileNetV2的主干网络
backbone = MobileNetV2Backbone()  # last_channel=config.backbone_out_channels
# 冻结主干网络的参数(可以根据需要注释这两行)
for param in backbone.get_parameters():param.requires_grad = False# 从预训练模型加载参数
load_checkpoint(config.pretrained_ckpt, backbone)# 创建MobileNetV2的头部
head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)
network = mobilenet_v2(backbone, head)  # 组合主干网络和头部# 定义损失函数、优化器和模型
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')  # 定义损失函数
loss_scale = FixedLossScaleManager(LOSS_SCALE, drop_overflow_update=False)  # 定义固定损失缩放管理器
lrs = cosine_decay(config.epochs * step_size, lr_max=config.lr_max)  # 生成学习率数组
opt = nn.Momentum(network.trainable_params(), lrs, config.momentum, config.weight_decay, loss_scale=LOSS_SCALE)  # 定义动量优化器# 定义用于训练的train_loop函数。
def train_loop(model, dataset, loss_fn, optimizer):# 定义正向计算函数def forward_fn(data, label):logits = model(data)  # 获取模型预测loss = loss_fn(logits, label)  # 计算损失return loss# 使用mindspore.value_and_grad获得微分函数grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)  # 计算损失和梯度# 定义一次训练的步骤def train_step(data, label):loss, grads = grad_fn(data, label)  # 计算损失和梯度optimizer(grads)  # 更新模型参数return loss  # 返回损失值size = dataset.get_dataset_size()  # 获取数据集的大小model.set_train()  # 设置模型为训练模式for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):  # 遍历数据集loss = train_step(data, label)  # 执行训练步骤if batch % 10 == 0:  # 每10个批次输出一次损失loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")# 定义用于测试的test_loop函数。
def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()  # 获取数据集的大小model.set_train(False)  # 设置模型为评估模式total, test_loss, correct = 0, 0, 0  # 初始化计数器for data, label in dataset.create_tuple_iterator():  # 遍历评估数据集pred = model(data)  # 获取模型预测total += len(data)  # 累计样本数量test_loss += loss_fn(pred, label).asnumpy()  # 计算累计损失correct += (pred.argmax(1) == label).asnumpy().sum()  # 计算正确预测的数量test_loss /= num_batches  # 计算平均损失correct /= total  # 计算准确率print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")  # 输出测试结果print("============== Starting Training ==============")
# 由于时间原因,训练过程只进行了2个epoch,可以根据需求调整。
epoch_begin_time = time.time()
epochs = 2  # 设置训练的总轮数
for t in range(epochs):begin_time = time.time()  # 开始时间print(f"Epoch {t+1}\n-------------------------------")train_loop(network, train_dataset, loss, opt)  # 执行训练ms.save_checkpoint(network, "save_mobilenetV2_model.ckpt")  # 保存模型检查点end_time = time.time()  # 结束时间times = end_time - begin_time  # 计算每个epoch的时间print(f"per epoch time: {times}s")test_loop(network, eval_dataset, loss)  # 执行评估
epoch_end_time = time.time()  # 训练结束时间
times = epoch_end_time - epoch_begin_time  # 计算总时间
print(f"total time:  {times}s")
print("============== Training Success ==============")
  1. 导入与常量定义:
    • from mindspore.amp import FixedLossScaleManager: 导入固定损失缩放管理器。
    • LOSS_SCALE = 1024: 定义用于损失缩放的常量。
  2. 数据集创建:
    • 使用create_dataset函数创建训练和评估数据集,并获取训练步骤数。
  3. 模型构建:
    • 创建MobileNetV2的主干网络和头部。
    • 冻结主干网络的参数(可选)并加载预训练模型的参数。
  4. 损失函数与优化器:
    • 定义使用的损失函数(Softmax交叉熵)。
    • 创建损失缩放管理器和学习率数组,通过余弦衰减生成学习率,然后定义动量优化器。
  5. 训练循环train_loop:
    • 定义正向计算函数和一次训练步骤,计算损失和梯度。
    • 遍历数据集并进行训练,每10个批次输出一次当前损失。
  6. 测试循环test_loop:
    • 评估模型性能,计算准确率和平均损失,并输出结果。
  7. 训练过程控制:
    • 开始训练,循环进行指定数量的epoch,每个epoch进行训练和测试,并保存模型检查点。
  • create_dataset: 创建数据集的函数,参数指定数据集路径和配置信息。
  • ms.save_checkpoint: 保存模型检查点的函数。
  • nn.SoftmaxCrossEntropyWithLogits: 计算Softmax交叉熵损失的类。
  • nn.Momentum: 定义动量优化器的类。
  • ms.value_and_grad: 用于计算模型输出的梯度。
  • dataset.create_tuple_iterator(): 创建数据集的迭代器,以批次形式遍历数据。

7、模型推理

加载模型Checkpoint进行推理,使用load_checkpoint接口加载数据时,需要把数据传入给原始网络,而不能传递给带有优化器和损失函数的训练网络。

CKPT = "save_mobilenetV2_model.ckpt"  # 定义模型检查点路径def image_process(image):"""处理单张图像。Args:image: 形状为(H, W, C)的图像。"""# 定义均值和标准差,用于图像归一化mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]std = [0.229 * 255, 0.224 * 255, 0.225 * 255]# 归一化图像image = (np.array(image) - mean) / std# 转置图像,使其形状变为(C, H, W)image = image.transpose((2, 0, 1))# 将图像转换为Tensorimg_tensor = Tensor(np.array([image], np.float32))  # 增加一个维度用于批处理return img_tensordef infer_one(network, image_path):# 打开并调整图像大小image = Image.open(image_path).resize((config.image_height, config.image_width))# 进行图像处理和推理logits = network(image_process(image))# 获取预测结果pred = np.argmax(logits.asnumpy(), axis=1)[0]# 输出预测结果print(image_path, class_en[pred])  # class_en为类名列表def infer():# 初始化MobileNetV2主干网络和头部backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)# 构建完整的网络network = mobilenet_v2(backbone, head)# 从检查点加载模型参数load_checkpoint(CKPT, network)# 对指定范围内的图像进行推理for i in range(91, 100):infer_one(network, f'data_en/test/Cardboard/000{i}.jpg')  # 进行推理
infer()  # 执行推理函数
  1. 检查点定义:
    • CKPT = "save_mobilenetV2_model.ckpt":
      • 定义了存储模型的检查点路径。
  2. 图像处理函数image_process:
    • 该函数负责对单张图像进行预处理。
    • 参数说明:
      • image: 输入的图像,形状为(H, W, C)。
    • 归一化:
      • 使用给定的均值和标准差对图像进行归一化处理,将像素值标准化。
    • 转置操作:
      • 将图像形状从(H, W, C)转置为(C, H, W),以符合模型输入要求。
    • 返回Tensor:
      • 将处理后的图像转换为MindSpore的Tensor格式,增加一个维度以支持批处理。
  3. 推理函数infer_one:
    • 该函数执行单张图像的推理。
    • 读取和调整图像:
      • 使用PIL库打开图像并调整大小。
    • 执行推理:
      • 处理图像并将其输入到网络中,获取输出logits。
    • 预测结果:
      • 使用np.argmax函数获取预测类别,并打印图像路径及其对应的类别。
  4. 推理主函数infer:
    • 初始化MobileNetV2的主干网络和头部。
    • 组合构建完整的网络模型。
    • 从检查点加载模型参数。
    • 对指定范围内的图像进行推理,调用infer_one函数。
  5. 启动推理:
    • 最后调用infer()函数开始推理过程。
  • Image.open: 从文件中加载图像的PIL库方法。
  • Tensor: MindSpore中用于存储数据的类,类似于其他深度学习框架中的Tensor。
  • np.argmax: NumPy函数,用于返回数组中最大值的索引。
  • load_checkpoint: 从指定的检查点加载模型参数的函数。

8、导出AIR/GEIR/ONNX模型文件

导出AIR模型文件,用于后续Atlas 200 DK上的模型转换与推理。当前仅支持MindSpore+Ascend环境。

# 初始化MobileNetV2的主干网络和头部
backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)
head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)
# 组合构建完整的网络模型
network = mobilenet_v2(backbone, head)
# 从检查点加载模型参数
load_checkpoint(CKPT, network)# 生成一个随机输入,形状为[1, 3, 224, 224]
input = np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32)
# 导出模型
# export(network, Tensor(input), file_name='mobilenetv2.air', file_format='AIR')
# export(network, Tensor(input), file_name='mobilenetv2.pb', file_format='GEIR')
export(network, Tensor(input), file_name='mobilenetv2.onnx', file_format='ONNX')  # 导出为ONNX格式
  1. 模型构建:
    • backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels):
      • 创建MobileNetV2的主干网络,其中last_channel参数设置为配置中的输出通道数量。
    • head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes):
      • 创建MobileNetV2的头部,输入通道设置为主干网络的输出通道,类别数设置为配置中的类别数。
    • network = mobilenet_v2(backbone, head):
      • 将主干和头部组合成完整的MobileNetV2网络模型。
  2. 加载模型参数:
    • load_checkpoint(CKPT, network):
      • 从指定检查点加载模型参数,以恢复之前训练的状态。
  3. 输入数据生成:
    • input = np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32):
      • 生成一个随机输入数据,形状为[1, 3, 224, 224],表示一张224x224的RGB图像,数据类型为float32
  4. 模型导出:
    • export(network, Tensor(input), file_name='mobilenetv2.onnx', file_format='ONNX'):
      • 将网络模型导出为ONNX格式文件,file_name指定导出的文件名,file_format指定导出的格式。
    • 注释掉的行可以用于导出为AIR或GEIR格式,具体取决于需要的模型格式。
  • MobileNetV2Backbone: 构建MobileNetV2主干网络的类。
  • MobileNetV2Head: 构建MobileNetV2头部的类。
  • mobilenet_v2: 组合主干和头部生成完整网络的函数。
  • load_checkpoint: 从检查点加载模型参数的函数。
  • export: 导出模型为指定格式的函数。

以上为基于MobileNetV2的垃圾分类模型的完整实现过程,涵盖了从数据准备、模型训练到推理和导出模型文件的各个环节。

整体代码

# 基于MobileNetv2的垃圾分类
# 本文档主要介绍垃圾分类代码开发的方法。通过读取本地图像数据作为输入,
# 对图像中的垃圾物体进行检测,并且将检测结果图片保存到文件中。## 1、实验目的# - 了解熟悉垃圾分类应用代码的编写(Python语言);
# - 了解Linux操作系统的基本使用;
# - 掌握atc命令进行模型转换的基本操作。## 2、MobileNetv2模型原理介绍# MobileNet网络是由Google团队于2017年提出的专注于移动端、嵌入式或IoT设备的轻量级CNN网络,
# 相比于传统的卷积神经网络,MobileNet网络使用深度可分离卷积(Depthwise Separable Convolution)的思想
# 在准确率小幅度降低的前提下,大大减小了模型参数与运算量。
# 并引入宽度系数 α和分辨率系数 β使模型满足不同应用场景的需求。# 由于MobileNet网络中Relu激活函数处理低维特征信息时会存在大量的丢失,
# 所以MobileNetV2网络提出使用倒残差结构(Inverted residual block)和Linear Bottlenecks来设计网络,
# 以提高模型的准确率,且优化后的模型更小。# 图中Inverted residual block结构是先使用1x1卷积进行升维,
# 然后使用3x3的DepthWise卷积,最后使用1x1的卷积进行降维,
# 与Residual block结构相反。Residual block是先使用1x1的卷积进行降维,
# 然后使用3x3的卷积,最后使用1x1的卷积进行升维。# 说明:
# [详细内容可参见MobileNetV2论文](https://arxiv.org/pdf/1801.04381.pdf)## 3、实验环境# 本案例支持win_x86和Linux系统,CPU/GPU/Ascend均可运行。# 在动手进行实践之前,确保您已经正确安装了MindSpore。
# 不同平台下的环境准备请参考《MindSpore环境搭建实验手册》。## 4、数据处理### 4.1数据准备
# MobileNetV2的代码默认使用ImageFolder格式管理数据集,
# 每一类图片整理成单独的一个文件夹, 数据集结构如下:# └─ImageFolder
#     ├─train
#     │   class1Folder
#     │   ......
#     └─eval
#         class1Folder
#         ......# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14# 查看当前 mindspore 版本
!pip show mindspore
from download import download# 下载data_en数据集
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MindStudio-pc/data_en.zip"
path = download(url, "./", kind="zip", replace=True)# 下载预训练权重文件
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/ComputerVision/mobilenetV2-200_1067.zip"
path = download(url, "./", kind="zip", replace=True)### 4.2数据加载###### 将模块导入,具体如下:
import math
import numpy as np
import os
import randomfrom matplotlib import pyplot as plt
from easydict import EasyDict
from PIL import Image
import mindspore.nn as nn
from mindspore import ops as P
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.vision as C
import mindspore.dataset.transforms as C2
import mindspore as ms
from mindspore import set_context, load_checkpoint, save_checkpoint, exportos.environ['GLOG_v'] = '3' # Log level includes 3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG).
os.environ['GLOG_logtostderr'] = '0' # 0:输出到文件,1:输出到屏幕
os.environ['GLOG_log_dir'] = '../../log' # 日志目录
os.environ['GLOG_stderrthreshold'] = '2' # 输出到目录也输出到屏幕:3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG).set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0) # 设置采用图模式执行,设备为Ascend####### 配置后续训练、验证、推理用到的参数:# 垃圾分类数据集标签,以及用于标签映射的字典。
garbage_classes = {'干垃圾': ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服'],'可回收物': ['报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张'],'湿垃圾': ['菜叶', '橙皮', '蛋壳', '香蕉皮'],'有害垃圾': ['电池', '药片胶囊', '荧光灯', '油漆桶']
}class_cn = ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服','报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张','菜叶', '橙皮', '蛋壳', '香蕉皮','电池', '药片胶囊', '荧光灯', '油漆桶']class_en = ['Seashell', 'Lighter','Old Mirror', 'Broom','Ceramic Bowl', 'Toothbrush','Disposable Chopsticks','Dirty Cloth','Newspaper', 'Glassware', 'Basketball', 'Plastic Bottle', 'Cardboard','Glass Bottle', 'Metalware', 'Hats', 'Cans', 'Paper','Vegetable Leaf','Orange Peel', 'Eggshell','Banana Peel','Battery', 'Tablet capsules','Fluorescent lamp', 'Paint bucket']index_en = {'Seashell': 0, 'Lighter': 1, 'Old Mirror': 2, 'Broom': 3, 'Ceramic Bowl': 4, 'Toothbrush': 5, 'Disposable Chopsticks': 6, 'Dirty Cloth': 7,'Newspaper': 8, 'Glassware': 9, 'Basketball': 10, 'Plastic Bottle': 11, 'Cardboard': 12, 'Glass Bottle': 13, 'Metalware': 14, 'Hats': 15, 'Cans': 16, 'Paper': 17,'Vegetable Leaf': 18, 'Orange Peel': 19, 'Eggshell': 20, 'Banana Peel': 21,'Battery': 22, 'Tablet capsules': 23, 'Fluorescent lamp': 24, 'Paint bucket': 25}# 训练超参
config = EasyDict({"num_classes": 26,"image_height": 224,"image_width": 224,"backbone_out_channels": 1280,"batch_size": 16,"eval_batch_size": 8,"epochs": 10,"lr_max": 0.05,"momentum": 0.9,"weight_decay": 1e-4,"save_ckpt_epochs": 1,"dataset_path": "./data_en","class_index": index_en,"pretrained_ckpt": "./mobilenetV2-200_1067.ckpt" # mobilenetV2-200_1067.ckpt 
})###### 数据预处理操作# 利用ImageFolderDataset方法读取垃圾分类数据集,并整体对数据集进行处理。
def create_dataset(dataset_path, config, training=True, buffer_size=1000):"""create a train or eval datasetArgs:dataset_path(string): the path of dataset.config(struct): the config of train and eval in different platform.Returns:train_dataset, val_dataset"""data_path = os.path.join(dataset_path, 'train' if training else 'test')ds = de.ImageFolderDataset(data_path, num_parallel_workers=4, class_indexing=config.class_index)resize_height = config.image_heightresize_width = config.image_widthnormalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])change_swap_op = C.HWC2CHW()type_cast_op = C2.TypeCast(mstype.int32)if training:crop_decode_resize = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)color_adjust = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)train_trans = [crop_decode_resize, horizontal_flip_op, color_adjust, normalize_op, change_swap_op]train_ds = ds.map(input_columns="image", operations=train_trans, num_parallel_workers=4)train_ds = train_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)train_ds = train_ds.shuffle(buffer_size=buffer_size)ds = train_ds.batch(config.batch_size, drop_remainder=True)else:decode_op = C.Decode()resize_op = C.Resize((int(resize_width / 0.875), int(resize_width / 0.875)))center_crop = C.CenterCrop(resize_width)eval_trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]eval_ds = ds.map(input_columns="image", operations=eval_trans, num_parallel_workers=4)eval_ds = eval_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)ds = eval_ds.batch(config.eval_batch_size, drop_remainder=True)return ds###### 展示部分处理后的数据:
ds = create_dataset(dataset_path=config.dataset_path, config=config, training=False)
print(ds.get_dataset_size())
data = ds.create_dict_iterator(output_numpy=True)._get_next()
images = data['image']
labels = data['label']for i in range(1, 5):plt.subplot(2, 2, i)plt.imshow(np.transpose(images[i], (1, 2, 0)))plt.title('label: %s' % class_en[labels[i]])plt.xticks([])
plt.show()## 5、MobileNetV2模型搭建# 使用MindSpore定义MobileNetV2网络的各模块时需要继承mindspore.nn.Cell。
# Cell是所有神经网络(Conv2d等)的基类。
# 神经网络的各层需要预先在__init__方法中定义,然后通过定义construct方法来完成神经网络的前向构造。
# 原始模型激活函数为ReLU6,池化模块采用是全局平均池化层。__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']def _make_divisible(v, divisor, min_value=None):if min_value is None:min_value = divisornew_v = max(min_value, int(v + divisor / 2) // divisor * divisor)if new_v < 0.9 * v:new_v += divisorreturn new_vclass GlobalAvgPooling(nn.Cell):"""Global avg pooling definition."""def __init__(self):super(GlobalAvgPooling, self).__init__()def construct(self, x):x = P.mean(x, (2, 3))return xclass ConvBNReLU(nn.Cell):"""Convolution/Depthwise fused with Batchnorm and ReLU block definition."""def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):super(ConvBNReLU, self).__init__()padding = (kernel_size - 1) // 2in_channels = in_planesout_channels = out_planesif groups == 1:conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding)else:out_channels = in_planesconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad',padding=padding, group=in_channels)layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return outputclass InvertedResidual(nn.Cell):"""Mobilenetv2 residual block definition."""def __init__(self, inp, oup, stride, expand_ratio):super(InvertedResidual, self).__init__()assert stride in [1, 2]hidden_dim = int(round(inp * expand_ratio))self.use_res_connect = stride == 1 and inp == ouplayers = []if expand_ratio != 1:layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))layers.extend([ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False),nn.BatchNorm2d(oup),])self.conv = nn.SequentialCell(layers)self.cast = P.Cast()def construct(self, x):identity = xx = self.conv(x)if self.use_res_connect:return P.add(identity, x)return xclass MobileNetV2Backbone(nn.Cell):"""MobileNetV2 architecture."""def __init__(self, width_mult=1., inverted_residual_setting=None, round_nearest=8,input_channel=32, last_channel=1280):super(MobileNetV2Backbone, self).__init__()block = InvertedResidualself.cfgs = inverted_residual_settingif inverted_residual_setting is None:self.cfgs = [# t, c, n, s[1, 16, 1, 1],[6, 24, 2, 2],[6, 32, 3, 2],[6, 64, 4, 2],[6, 96, 3, 1],[6, 160, 3, 2],[6, 320, 1, 1],]input_channel = _make_divisible(input_channel * width_mult, round_nearest)self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)features = [ConvBNReLU(3, input_channel, stride=2)]for t, c, n, s in self.cfgs:output_channel = _make_divisible(c * width_mult, round_nearest)for i in range(n):stride = s if i == 0 else 1features.append(block(input_channel, output_channel, stride, expand_ratio=t))input_channel = output_channelfeatures.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1))self.features = nn.SequentialCell(features)self._initialize_weights()def construct(self, x):x = self.features(x)return xdef _initialize_weights(self):self.init_parameters_data()for _, m in self.cells_and_names():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),m.weight.data.shape).astype("float32")))if m.bias is not None:m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))elif isinstance(m, nn.BatchNorm2d):m.gamma.set_data(Tensor(np.ones(m.gamma.data.shape, dtype="float32")))m.beta.set_data(Tensor(np.zeros(m.beta.data.shape, dtype="float32")))@propertydef get_features(self):return self.featuresclass MobileNetV2Head(nn.Cell):"""MobileNetV2 architecture."""def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False, activation="None"):super(MobileNetV2Head, self).__init__()head = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)])self.head = nn.SequentialCell(head)self.need_activation = Trueif activation == "Sigmoid":self.activation = nn.Sigmoid()elif activation == "Softmax":self.activation = nn.Softmax()else:self.need_activation = Falseself._initialize_weights()def construct(self, x):x = self.head(x)if self.need_activation:x = self.activation(x)return xdef _initialize_weights(self):self.init_parameters_data()for _, m in self.cells_and_names():if isinstance(m, nn.Dense):m.weight.set_data(Tensor
(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))if m.bias is not None:m.bias.set_data(Tensor(np.zeros(m.bias.data.shape, dtype="float32")))@propertydef get_head(self):return self.headclass MobileNetV2(nn.Cell):"""MobileNetV2 architecture."""def __init__(self, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None,round_nearest=8, input_channel=32, last_channel=1280):super(MobileNetV2, self).__init__()self.backbone = MobileNetV2Backbone(width_mult=width_mult,inverted_residual_setting=inverted_residual_setting,round_nearest=round_nearest, input_channel=input_channel, last_channel=last_channel).get_featuresself.head = MobileNetV2Head(input_channel=self.backbone.out_channels, num_classes=num_classes, has_dropout=has_dropout).get_headdef construct(self, x):x = self.backbone(x)x = self.head(x)return xclass MobileNetV2Combine(nn.Cell):"""MobileNetV2Combine architecture."""def __init__(self, backbone, head):super(MobileNetV2Combine, self).__init__(auto_prefix=False)self.backbone = backboneself.head = headdef construct(self, x):x = self.backbone(x)x = self.head(x)return xdef mobilenet_v2(backbone, head):return MobileNetV2Combine(backbone, head)## 6、MobileNetV2模型的训练与测试###### 训练策略# 一般情况下,模型训练时采用静态学习率,如0.01。
# 随着训练步数的增加,模型逐渐趋于收敛,对权重参数的更新幅度应该逐渐降低,以减小模型训练后期的抖动。
# 所以,模型训练时可以采用动态下降的学习率,常见的学习率下降策略有:
# - polynomial decay/square decay;
# - cosine decay;
# - exponential decay;
# - stage decay.# 这里使用cosine decay下降策略:
def cosine_decay(total_steps, lr_init=0.0, lr_end=0.0, lr_max=0.1, warmup_steps=0):"""Applies cosine decay to generate learning rate array.Args:total_steps(int): all steps in training.lr_init(float): init learning rate.lr_end(float): end learning ratelr_max(float): max learning rate.warmup_steps(int): all steps in warmup epochs.Returns:list, learning rate array."""lr_init, lr_end, lr_max = float(lr_init), float(lr_end), float(lr_max)decay_steps = total_steps - warmup_stepslr_all_steps = []inc_per_step = (lr_max - lr_init) / warmup_steps if warmup_steps else 0for i in range(total_steps):if i < warmup_steps:lr = lr_init + inc_per_step * (i + 1)else:cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))lr = (lr_max - lr_end) * cosine_decay + lr_endlr_all_steps.append(lr)return lr_all_steps# 在模型训练过程中,可以添加检查点(Checkpoint)用于保存模型的参数,以便进行推理及中断后再训练使用。
# 使用场景如下:
# - 训练后推理场景
# 1) 模型训练完毕后保存模型的参数,用于推理或预测操作。
# 2) 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。
# - 再训练场景
# 1) 进行长时间训练任务时,保存训练过程中的Checkpoint文件,防止任务异常退出后从初始状态开始训练。
# 2) Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。# 这里加载ImageNet数据上预训练的MobileNetv2进行Fine-tuning,只训练最后修改的FC层,并在训练过程中保存Checkpoint。
def switch_precision(net, data_type):if ms.get_context('device_target') == "Ascend":net.to_float(data_type)for _, cell in net.cells_and_names():if isinstance(cell, nn.Dense):cell.to_float(ms.float32)###### 模型训练与测试# 在进行正式的训练之前,定义训练函数,读取数据并对模型进行实例化,定义优化器和损失函数。
# 首先简单介绍损失函数及优化器的概念:
# - 损失函数:又叫目标函数,用于衡量预测值与实际值差异的程度。
# 深度学习通过不停地迭代来缩小损失函数的值。
# 定义一个好的损失函数,可以有效提高模型的性能。
# - 优化器:用于最小化损失函数,从而在训练过程中改进模型。# 定义了损失函数后,可以得到损失函数关于权重的梯度。
# 梯度用于指示优化器优化权重的方向,以提高模型性能。# 在训练MobileNetV2之前对MobileNetV2Backbone层的参数进行了固定,
# 使其在训练过程中对该模块的权重参数不进行更新;只对MobileNetV2Head模块的参数进行更新。# MindSpore支持的损失函数有SoftmaxCrossEntropyWithLogits、L1Loss、MSELoss等。
# 这里使用SoftmaxCrossEntropyWithLogits损失函数。from mindspore.amp import FixedLossScaleManager
import timeLOSS_SCALE = 1024train_dataset = create_dataset(dataset_path=config.dataset_path, config=config)
eval_dataset = create_dataset(dataset_path=config.dataset_path, config=config)
step_size = train_dataset.get_dataset_size()backbone = MobileNetV2Backbone()  # last_channel=config.backbone_out_channels
# Freeze parameters of backbone. You can comment these two lines.
for param in backbone.get_parameters():param.requires_grad = False
# load parameters from pretrained model
load_checkpoint(config.pretrained_ckpt, backbone)head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)
network = mobilenet_v2(backbone, head)# define loss, optimizer, and model
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(LOSS_SCALE, drop_overflow_update=False)
lrs = cosine_decay(config.epochs * step_size, lr_max=config.lr_max)
opt = nn.Momentum(network.trainable_params(), lrs, config.momentum, config.weight_decay, loss_scale=LOSS_SCALE)# 定义用于训练的train_loop函数。
def train_loop(model, dataset, loss_fn, optimizer):# 定义正向计算函数def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss# 定义微分函数,使用mindspore.value_and_grad获得微分函数grad_fn,输出loss和梯度。# 由于是对模型参数求导, grad_position 配置为None,传入可训练参数。grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)# 定义 one-step training函数def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return losssize = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 10 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")# 定义用于测试的test_loop函数。
def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")print("============== Starting Training ==============")
# 由于时间问题,训练过程只进行了2个epoch ,可以根据需求调整。
epoch_begin_time = time.time()
epochs = 2
for t in range(epochs):begin_time = time.time()print(f"Epoch {t + 1}\n-------------------------------")train_loop(network, train_dataset, loss, opt)ms.save_checkpoint(network, "save_mobilenetV2_model.ckpt")end_time = time.time()times = end_time - begin_timeprint(f"per epoch time: {times}s")test_loop(network, eval_dataset, loss)
epoch_end_time = time.time()
times = epoch_end_time - epoch_begin_time
print(f"total time:  {times}s")
print("============== Training Success ==============")## 7、模型推理# 加载模型Checkpoint进行推理
CKPT = "save_mobilenetV2_model.ckpt"def image_process(image):"""处理单张图像。Args:image: 形状为(H, W, C)的图像。"""mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]std = [0.229 * 255, 0.224 * 255, 0.225 * 255]# 进行归一化处理image = (np.array(image) - mean) / std# 转置图像,使其形状变为(C, H, W)image = image.transpose((2, 0, 1))img_tensor = Tensor(np.array([image], np.float32))  # 增加一个维度用于批处理return img_tensordef infer_one(network, image_path):"""对单张图像进行推理。Args:network: 待推理的网络。image_path: 图像路径。"""image = Image.open(image_path).resize((config.image_height, config.image_width))logits = network(image_process(image))  # 进行推理pred = np.argmax(logits.asnumpy(), axis=1)[0]  # 获取预测结果print(image_path, class_en[pred])  # 输出预测结果def infer():"""执行推理过程。"""backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)network = mobilenet_v2(backbone, head)load_checkpoint(CKPT, network)  # 加载模型参数# 对指定范围的图像进行推理for i in range(91, 100):infer_one(network, f'data_en/test/Cardboard/000{i}.jpg')infer()  # 执行推理## 8、导出AIR/GEIR/ONNX模型文件# 导出AIR模型文件,用于后续Atlas 200 DK上的模型转换与推理。
# 当前仅支持MindSpore+Ascend环境。
backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)
head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)
network = mobilenet_v2(backbone, head)
load_checkpoint(CKPT, network)# 生成一个随机输入,形状为[1, 3, 224, 224]
input = np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32)
# export(network, Tensor(input), file_name='mobilenetv2.air', file_format='AIR')
# export(network, Tensor(input), file_name='mobilenetv2.pb', file_format='GEIR')
export(network, Tensor(input), file_name='mobilenetv2.onnx', file_format='ONNX')  # 导出为ONNX格式

代码解析

  1. 模型创建与训练:
    • 使用MobileNetV2BackboneMobileNetV2Head构建模型的主干和头部,通过mobilenet_v2组合成完整的网络。
    • 训练过程中使用cosine_decay策略动态调整学习率。
  2. 推理与模型导出:
    • 定义image_process用于处理输入图像,并在infer_one中执行推理。
    • 将训练完成的模型保存为ONNX格式以便于后续使用。

API 解析

  • de.ImageFolderDataset: 用于加载图像数据集并按类别组织。
  • ms.value_and_grad: 用于计算损失和梯度的函数。
  • export: 将训练好的模型导出为指定格式的文件,如ONNX。

通过上述代码和解析,可以理解如何使用MobileNetV2进行垃圾分类的开发和模型推理。

相关文章:

昇思MindSpore 应用学习-基于MobileNetv2的垃圾分类

基于MobileNetv2的垃圾分类 本文档主要介绍垃圾分类代码开发的方法。通过读取本地图像数据作为输入&#xff0c;对图像中的垃圾物体进行检测&#xff0c;并将检测结果图片保存到文件中。 1、实验目的 了解熟悉垃圾分类应用代码的编写&#xff08;Python语言&#xff09;&…...

matlab 常用数据类型的转换

目录 一、数据类型1、整型2、浮点型3、逻辑型4、元胞数组5、结构体 二、数据类型转换三、图像数据类型转换四、参考链接 一、数据类型 1、整型 int和unit都是整型&#xff0c;只是前一个有符号&#xff0c;后一个没有符号&#xff0c;比如在16位系统中&#xff0c;int范围是-3…...

Cocos Creator2D游戏开发(6)-飞机大战(4)-敌机产生

敌机产生&玩家发射子弹 敌机产生: 创建一个空节点 创建一个敌机预制体 把敌机图片拖入预制体内 使用代码生成敌机 让敌机动起来 创建一个预制体enemy_prefab双击预制体enemy_prefab,然后拖入一个敌机图片,设置好方向和尺寸,一定要记得保存然后关闭(场景编辑器里面的保存)…...

Hugo部署到Vercel踩大坑——全是XML文件?

问题描述 部署到Vercel全都是XML文件 Vercel是著名PAAS服务&#xff0c;相比于 Github Pages&#xff0c;其中国大陆可直接访问&#xff0c;因此尝试把Hugo站点发布到vercel中&#xff0c;部署后遇到问题&#xff0c;所有页面都为xml文件&#xff0c;如下所示&#xff1a; Ve…...

2024 暑假友谊赛-热身1

[ABC102D] Equal Cut - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 思路:找在区间[2,n-1]中找到i,j,k三个点,把序列分割成4个区间:[1,i],[i1,j],[j1,k],[k1,n] 暴力的做法是枚举i,j,k加上前缀和是o(n^3)的 key:"考虑枚举处于中间的j&#xff0c;然后用i平衡左两个区间,…...

Nginx系列-11 HTTP消息处理流程

背景 了解Nginx处理HTTP请求的11个阶段&#xff0c;有助于理解和配置nginx、自定义模块、基于lua模块自定义功能。按如下配置&#xff0c;执行"curl http://localhost:8001/query/test.html"&#xff0c;如果读者对结果不是很确定&#xff0c;建议阅读本文。 serve…...

前端知识--前端访问后端技术Ajax及框架Axios

一、异步数据请求技术----Ajax Ajax是前端访问后端的技术&#xff0c;为异步请求&#xff08;不刷新页面&#xff0c;请求数据&#xff0c;只更新局部数据&#xff09;。 例如&#xff1a;在京东网站中搜索电脑&#xff0c;就会出现一些联想搜索&#xff0c;但此时页面并没有…...

【前端/js】使用js读取本地文件(xml、二进制)内容

目录 说在前面FileReaderDOMParser文本文件二进制文件 说在前面 浏览器版本&#xff1a;Microsoft Edge 126.0.2 (正式版本) (64 位) FileReader MDNFileReader 接口允许 Web 应用程序异步读取存储在用户计算机上的文件&#xff08;或原始数据缓冲区&#xff09;的内容&#x…...

初步入门C ++之类的概念

文章目录 0 Hello World!1 编译过程2 类2.1 类的概念2.2 构造函数与析构函数 0 Hello World! #include <iostream> //相当于#include <stdio.h>int main(int argc, char argv[]) {char c;std::cout << "Hello World!\n" <<…...

什么是技术作家风格指南?

技术写作风格指南旨在提供必要的格式风格&#xff0c;以帮助技术作家为读者创建引人入胜且一致的内容。然而&#xff0c;技术写作与普通的自由写作有很大不同。目的是将复杂的技术主题分解为易于理解的内容&#xff0c;以帮助读者了解如何使用产品或服务。 在本文中&#xff0…...

WebGIS学习——Cesium|Javascript

1.Cesium学习什么&#xff1a;Cesium实战项目说明-CSDN博客 2.Cesium绘制图形(箭头等):Cesium 态势标绘 _cesium态势标绘-CSDN博客 3.CesiumThree集成 4.Cesium深度图相关&#xff1a;Cesium离屏渲染深度图实验_cesium 离屏渲染-CSDN博客 5.洪涝&#xff1a;cesium淹没分析…...

Qt,获取其他.exe文件的标准输出流的信息(printf/print的输出信息)

比如&#xff0c;通过Python编写爬虫软件功能是运行程序获取豆瓣电影排行榜信息&#xff0c;并通过print打印出来。将其打包成.exe,通过Qt来调用&#xff0c;并获取到.exe程序运行的结果 简单示例代码&#xff1a; // 创建 QProcess 对象QProcess process;// 连接信号槽以获取…...

LeetCode 热题 HOT 100 (010/100)【宇宙最简单版】

【链表】No. 0206 反转链表 【简单】&#x1f449;力扣对应题目指路 希望对你有帮助呀&#xff01;&#xff01;&#x1f49c;&#x1f49c; 如有更好理解的思路&#xff0c;欢迎大家留言补充 ~ 一起加油叭 &#x1f4a6; 欢迎关注、订阅专栏 【力扣详解】谢谢你的支持&#xf…...

Ubuntu24.04安装mysql-server小计,解决mysql_secure_installation时不能重置密码的问题

Ubuntu24.04安装mysql-server小计&#xff0c;解决mysql_secure_installation时不能重置密码的问题 为什么要写这往篇文章&#xff1f; 一般情况下&#xff0c;我安装mysql都用源码编译&#xff0c;以此方便安装更多自定义插件&#xff0c;但这次只需要安装一台开发机&#x…...

unity3d:TabView,UGUI多标签页组件,TreeView树状展开菜单

概述 1.最外层DataForm为空壳编辑数据用。可以有多个DataForm&#xff0c;例如福利DataForm&#xff0c;抽奖DataForm 2.Menu层为左边栏层&#xff0c;每个DataForm可以使用不同样式的MenuForm预制体 3.DataForm中使用ReorderList&#xff0c;可排列配置 4.有定位功能&#xf…...

go语言map底层及扩容机制原理详解(下)

前言 上文对Go map的底层数据结构有所了解&#xff0c;并对其扩容机制的步骤进行简略的描述。本文将会详细地去解释Go map扩容机制的详细原理。 1. 触发扩容操作 在go语言中&#xff0c;当我们插入一个元素到hmap时&#xff0c;会有以下两种情况&#xff1a; 若元素存在&…...

网络协议二 : 使用Cisco Packet Traceer工具模拟网络环境,集线器,网桥,交换机,路由器,IP,同一网段

1. 安装 Cisco Packet Tracer baidu 网盘地址&#xff0c;感谢大神分享 安装&#xff0c;破解&#xff0c;中文化&#xff0c;都有说明&#xff0c;建议使用7.x的那个版本&#xff0c;感觉比8.x的翻译要完整一点 https://pan.baidu.com/s/18iWBOfhJJRhqgQqdNQcfMQ?pwddcch#…...

Aria2 任意文件写入漏洞

目录 Aria2介绍漏洞描述漏洞复现 Aria2介绍 Aria2是一个在命令行下运行&#xff0c;多协议&#xff0c;多来源下载工具&#xff08;HTTP / HTTPS&#xff0c;FTP&#xff0c;BitTorrent&#xff0c;Metalink&#xff09;&#xff0c;内建XML-RPC用户界面。Aria提供RPC服务器&a…...

成为git砖家(4): git status 命令简介

1. untracked 和 tracked 状态 Remember that each file in your working directory can be in one of two states: tracked or untracked. Tracked files are files that were in the last snapshot, as well as any newly staged files; they can be unmodified, modified, o…...

2-48 基于matlab的EM算法聚类可视化程序

基于matlab的EM算法聚类可视化程序&#xff0c;通过期望最大化算法&#xff08;EM&#xff09;优化类别间距&#xff0c;使得类别间距最大、类内间距最小。输出聚类前后结果及收敛曲线。程序已调通&#xff0c;可直接运行。 2-48 期望最大化算法&#xff08;EM&#xff09; 聚类…...

k8s 使用技巧

文章目录 kubectlkubectl 自动补全kubectl 上下文和配置打印当前使用 API 调用过程生成yaml模板强制删除 Pod&#xff08;即使处于Terminating&#xff09; kubectl kubectl 自动补全 source < (kubectl completion bash) # setup autocomplete in bash, bash-completion …...

学习笔记-系统框图传递函数公式推导

目录 *待了解 现代控制理论和自动控制理论区别 自动控制系统的组成 信号流图 1、系统框图 1.1、信号线、分支点、相加点 1.2、系统各环节间的连接 1.3、 相加点和分支点的等效移动&#xff08;比较点、引出点&#xff09; 2、反馈连接公式推导 2.1、前向通路传递函数…...

C++ - 基于多设计模式下的同步异步⽇志系统

1.项目介绍 项⽬介绍 本项⽬主要实现⼀个⽇志系统&#xff0c; 其主要⽀持以下功能: • ⽀持多级别⽇志消息 • ⽀持同步⽇志和异步⽇志 • ⽀持可靠写⼊⽇志到控制台、⽂件以及滚动⽂件中 • ⽀持多线程程序并发写⽇志 • ⽀持扩展不同的⽇志落地⽬标地 2.开发环境 • Cent…...

git 相关内容

...

ElasticSearch(es)倒排索引

目录 一、ElasticSearch 二、倒排索引 1. 正向索引 2. 倒排索引 具体细节 1. 文档分析 2. 索引构建 3. 索引存储 4. 词条编码 5. 索引优化 6. 查询处理 示例 总结 3. 正向和倒排 三、总结 倒排索引的基本概念 为什么倒排索引快 一、ElasticSearch Elasticsear…...

【自然语言处理】概论(一):自然语言处理概要

1.1 概论&#xff1a;&#xff08;一&#xff09;自然语言处理概要 知识点 自然语言的定义&#xff1a;人类交流使用的&#xff0c;包括口语和书面语的信息交流方式。AI的终极目标&#xff1a;使计算机具备理解&#xff08;听、读&#xff09;和生成&#xff08;说、写&#…...

flask 开始

# 导入flask类 from flask import Flask,request,render_template # 使用flask类来创建一个app对象 # __name__ 代表当前app.py 这个模块 app Flask(__name__) # 创建一个路由和视图函数的映射 url http://127.0.0.1:5000/ app.route("/") def hello_word():return …...

仕考网:公务员可以报考军队文职吗?

公务员可以报考军队文职考试&#xff0c;但是需要满足前提条件。 对于已经与国家、地方的用人单位建立劳动关系的社会人才&#xff0c;在获得当前用人单位的许可后才可以申请报考。 在面试过程中&#xff0c;考生必须出示一份由其用人单位出具的且加盖公章的同意报考证明。一…...

Java整理22

1、动态sql 多条件查询 .xml配置文件中sql语句书写<select id"getEmpByCondition",resultType"Emp">select * from t_emp where <if test"empName ! null and empName! ">empName#{empName}</if><if test"age ! nul…...

leetcode 408周赛 3234. 统计 1 显著的字符串的数量

3234. 统计 1 显著的字符串的数量 题目描述 给你一个二进制字符串 s。 请你统计并返回其中 1 显著 的子字符串的数量。 如果字符串中 1 的数量 大于或等于 0 的数量的 平方&#xff0c;则认为该字符串是一个 1 显著 的字符串 。 思路 一个很显然的思路是&#xff0c;我们…...

容器对比虚拟机有哪些不足?

引言 在当今的云计算和微服务架构中&#xff0c;容器技术已成为不可或缺的一部分。它以其轻量级、高效和快速部署的特性&#xff0c;赢得了广大开发者和运维人员的青睐。然而&#xff0c;正如任何技术都有其两面性&#xff0c;容器技术也不例外。本文将对容器技术在安全性、隔离…...

C# 归并排序

栏目总目录 概念 归并排序是一种分而治之的排序算法。它将一个大数组分成两个小数组&#xff0c;递归地对这两个小数组进行排序&#xff0c;然后将排序好的小数组合并成一个有序的大数组。这个过程一直递归进行&#xff0c;直到数组被拆分成只有一个元素的数组&#xff08;自然…...

【请求代理】springboot单机服务基于过滤器Filter实现第三方服务器接口请求代理功能

springboot单机服务基于过滤器Filter实现第三方服务器接口请求代理功能 一、前言二、解决思路三、基于gateway实现四、基于过滤器Filter实现五、问题总结 **注&#xff1a;本文源码获取或者更多资料&#xff0c;关注公众号&#xff1a;技术闲人**一、前言 在项目开发时会遇到w…...

.NET Core异步编程与多线程解析:提升性能与响应能力的关键技术

在.NET Core中&#xff0c;异步编程和多线程是构建高性能应用程序的核心技能。理解这两个概念不仅可以提升应用程序的响应能力&#xff0c;还能优化资源使用。本文将深入剖析异步编程和多线程的关键知识点&#xff0c;提供代码示例&#xff0c;并附上步骤以帮助理解。 1. 异步…...

Photoshop(PS) 抠图简单教程

目录 快速选择 魔棒 钢笔 橡皮擦 蒙版 通道 小结 可以发现&#xff0c;ps逐渐成为必备基础的办公软件。本文让ps新手轻松学会抠图。 快速选择 在抠图之前&#xff0c;先了解下选区的概念。ps中大多数的抠图操作都是基于选区的&#xff0c;先选区再Ctrl J提取选区。而快…...

项目管理中的常用工件(二):可视化工件

项目管理中的常用工件&#xff08;二&#xff09;&#xff1a;可视化工件 亲和图&#xff08;affinity diagram&#xff09;因果图&#xff08;cause-and-effect diagram&#xff09;直方图&#xff08;histogram&#xff09;流程图&#xff08;flowchart&#xff09;散点图&am…...

Git入门与实战:版本控制的艺术

&#x1f341; 作者&#xff1a;知识浅谈&#xff0c;CSDN签约讲师&#xff0c;CSDN博客专家&#xff0c;华为云云享专家&#xff0c;阿里云专家博主 &#x1f4cc; 擅长领域&#xff1a;全栈工程师、爬虫、ACM算法 &#x1f525; 微信&#xff1a;zsqtcyw 联系我领取学习资料 …...

[Mysql-DML数据操作语句]

目录 数据增加&#xff1a;INSERT 全字段插入&#xff1a; 部分字段插入&#xff1a; 一次性添加多条&#xff1a; 数据修改&#xff1a;UPDATE 数据删除&#xff1a;DELECT delete truncate drop 区别 数据增加&#xff1a;INSERT 总体格式&#xff1a;insert into 表…...

Tableau入门|数据可视化与仪表盘搭建

原视频链接&#xff08;up:戴戴戴师兄&#xff09;&#xff0c;文章为笔者的自学笔记&#xff0c;用于复习回顾&#xff0c;原视频下方有原up整理的笔记&#xff0c;更加直观便捷。因为视频中间涉及的细节较多&#xff0c;建议一边操作&#xff0c;一边学习。 整体介绍 可视化…...

API 技术开发分享:连接电商平台数据获取的桥梁

在当今数字化的时代&#xff0c;API&#xff08;Application Programming Interface&#xff0c;应用程序编程接口&#xff09;技术成为了实现不同系统之间通信和数据交换的关键。它就像是一座无形的桥梁&#xff0c;使得各种应用能够相互协作&#xff0c;共享资源&#xff0c;…...

区块链如何助力数字版权保护和内容创作者的权益?

区块链技术可以助力数字版权保护和内容创作者的权益&#xff0c;主要有以下几个方面&#xff1a; 去中心化的版权登记和溯源&#xff1a;区块链可作为一个可信的去中心化数据库&#xff0c;记录并验证数字内容的版权信息。内容创作者可以将自己的作品信息存储在区块链上&#x…...

记一次老旧项目的整体技术升级

最近给公司采购的老旧的 node8 vue2.6 webpack3 npm 项目做构建优化 背景&#xff1a;整个项目 build 一次 20 min &#xff0c;本地冷启动和热更新也忒慢&#xff0c;依赖 npm i 一下也得装个 20 min 众所周知&#xff0c;Node 版本&#xff0c;依赖包管理工具 和 构建工…...

2024年最受欢迎的五大上网审计设备和软件

在2024年的市场上&#xff0c;上网行为审计设备和软件种类繁多&#xff0c;它们帮助企业监控和管理员工的网络活动&#xff0c;确保网络安全并提高工作效率。下面是一些受欢迎的上网行为审计设备和软件。 2024年最受欢迎的上网行为审计设备和软件如下 1.安企神软件&#xff1a…...

sed利用脚本处理文件

一、sed是什么 sed 命令是利用脚本来处理文本文件。它可以依照脚本的指令来处理、编辑文本文件。主要用来自动编 辑一个或多个文件、简化对文件的反复操作、编写转换程序等。 二、sed的原理 读入新的一行内容到缓存空间&#xff1b; 从指定的操作指令中取出第一条指令&…...

泰山派RK3566开发板800x1280MIPI屏设备树补丁

泰山派RK3566开发板800x1280MIPI屏设备树补丁 泰山派下800 X 1280分辨率MIPI屏调试&#xff0c;设备树补丁如下&#xff1a; https://download.csdn.net/download/qq_45143522/89584066 用kernel.patch文件&#xff0c;在泰山派内核源码下打补丁即可完成更新&#xff0c;或者…...

informer中的indexer机制的实现分析与源码解读

1. 背景 client-go工具下的tools/cache.indexer为informer提供缓存与索引的能力。可以实现快速通过索引找到对应的对象(pod, deployment,secret,configmap等)。 indexer再informer机制中的使用图示&#xff1a; indexer包括2部分: 一部分是store用于实际数据的存储&#xff0c;…...

英特尔宣布针对对Llama 3.1进行优化 以提升所有产品的性能

日前Meta正式发布了Llama 3.1开源大模型&#xff0c;以其庞大的参数量和卓越性能&#xff0c;首次在多项基准测试中击败了GPT-4o等业界领先的闭源模型。允许开发者自由地进行微调、蒸馏&#xff0c;甚至在任何地方部署&#xff0c;这种开放性为AI技术的普及和创新提供了无限可能…...

Python3网络爬虫开发实战(1)爬虫基础

一、URL 基础 URL也就是网络资源地址&#xff0c;其满足如下格式规范 scheme://[username:password]hostname[:port][/path][;parameters][?query][#fragment] scheme&#xff1a;协议&#xff0c;常用的协议有 Http&#xff0c;https&#xff0c;ftp等等&#xff1b;usern…...

Redis的五种数据类型与命令

目录 引言 一 Redis的特性 二 Redis的安装 三 Redis的优点 四 Redis的五种数据类型与命令 五 Redis的配置文件 引言 Redis是什么&#xff1f; Remote Dictionary Service(远程字典服务器) Redis 是一个开源的(BSD许可)的&#xff0c;C语言编写的&#xff0c;高性能的数…...

RocketMQ的详细讲解(四种mq的对比(activeMq、rabbitmq、rocketmq、kafka))

20240729 RocketMQ1 mq的三大作用 异步、削峰限流、解耦合2. 四种mq的对比&#xff08;activeMq、rabbitmq、rocketmq、kafka&#xff09;3 rocketmq特点1. 平台无关2. 能提供什么样的功能 4 rocketMq4.1 broker中的标题&#xff0c;来约束读和写4.2 rocketmq的结构4.3 读和写的…...