当前位置: 首页 > news >正文

实现pytorch版的mobileNetV1

mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客

这里是根据网络结构,搭建模型,用于图像分类任务。

1. 网络结构和基本组件

2. 搭建组件

(1)普通的卷积组件:CBL = Conv2d + BN + ReLU6;

(2)深度可分离卷积:DwCBL  = Conv dw+ Conv dp;

Conv dw+ Conv dp = {Conv2d(3x3) + BN + ReLU6 }  + {Conv2d(1x1) + BN + ReLU6};

Conv dw是3x3的深度卷积,通过步长控制是否进行下采样;

Conv dp是1x1的逐点卷积,通过控制输出通道数,控制通道维度的变化;

# 普通卷积
class CBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(CBN, self).__init__()self.conv = nn.Conv2d(in_c, out_c, 3, stride, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_c)self.relu = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x
# 深度可分离卷积: 深度卷积(3x3x1) + 逐点卷积(1x1xc卷积)
class DwCBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(DwCBN, self).__init__()# conv3x3x1, 深度卷积,通过步长,只控制是否缩小特征hwself.conv3x3 = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False)self.bn1 = nn.BatchNorm2d(in_c)self.relu1 = nn.ReLU6(inplace=True)# conv1x1xc, 逐点卷积,通过控制输出通道数,控制通道维度的变化self.conv1x1 = nn.Conv2d(in_c, out_c, 1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(out_c)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv3x3(x)x = self.bn1(x)x = self.relu1(x)x = self.conv1x1(x)x = self.bn2(x)x = self.relu2(x)return x

3. 搭建网络

class MobileNetV1(nn.Module):def __init__(self, class_num=1000):super(MobileNetV1, self).__init__()self.stage1 = torch.nn.Sequential(CBN(3, 32, 2),  # 下采样/2DwCBN(32, 64, 1))self.stage2 = torch.nn.Sequential(DwCBN(64, 128, 2),  # 下采样/4DwCBN(128, 128, 1))self.stage3 = torch.nn.Sequential(DwCBN(128, 256, 2),  # 下采样/8DwCBN(256, 256, 1))self.stage4 = torch.nn.Sequential(DwCBN(256, 512, 2),  # 下采样/16DwCBN(512, 512, 1),  # 5个DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),)self.stage5 = torch.nn.Sequential(DwCBN(512, 1024, 2),  # 下采样/32DwCBN(1024, 1024, 1))# classifierself.avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1))self.fc = torch.nn.Linear(1024, class_num, bias=True)# self.classifier = torch.nn.Softmax()  # 原始的softmax值# torch.log_softmax 首先计算 softmax 然后再取对数,因此在数值上更加稳定。# 在分类网络在训练过程中,通常使用交叉熵损失函数(Cross-Entropy Loss)。# torch.nn.CrossEntropyLoss 会在内部进行 softmax 操作,因此在网络的最后一层不需要手动加上 softmax 操作。def forward(self, x):scale1 = self.stage1(x)  # /2scale2 = self.stage2(scale1)scale3 = self.stage3(scale2)scale4 = self.stage4(scale3)scale5 = self.stage5(scale4)  # /32. 7x7x = self.avg_pooling(scale5)  # (b,1024,7,7)->(b,1024,1,1)x = torch.flatten(x, 1)  # (b,1024,1,1)->(b,1024,)x = self.fc(x)  # (b,1024,)  -> (b,1000,)return xif __name__ == '__main__':m1 = MobileNetV1(class_num=1000)input_data = torch.randn(64, 3, 224, 224)output = m1.forward(input_data)print(output.shape)

4. 训练验证

import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optimfrom mobilenetv1 import MobileNetV1def validate(model, val_loader, criterion, device):model.eval()  # Set the model to evaluation modetotal_correct = 0total_samples = 0with torch.no_grad():for val_inputs, val_labels in val_loader:val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)val_outputs = model(val_inputs)_, predicted = torch.max(val_outputs, 1)total_samples += val_labels.size(0)total_correct += (predicted == val_labels).sum().item()accuracy = total_correct / total_samplesmodel.train()  # Set the model back to training modereturn accuracyif __name__ == '__main__':# 下载并准备数据集# Define image transformations (adjust as needed)transform = transforms.Compose([transforms.Resize((224, 224)),  # Resize images to a consistent sizetransforms.ToTensor(),  # converts to PIL Image to a Pytorch Tensor and scales values to the range [0, 1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Adjust normalization values. val = (val - mean) / std.])# Create ImageFolder datasetdata_folder = r"D:\zxq\data\car_or_dog"dataset = torchvision.datasets.ImageFolder(root=data_folder, transform=transform)# Optionally, split the dataset into training and validation sets# Adjust the `split_ratio` as neededsplit_ratio = 0.8train_size = int(split_ratio * len(dataset))val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])# Create DataLoader for training and validationtrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)# 初始化模型、损失函数和优化器net = MobileNetV1(class_num=2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)# 训练模型device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)net.to(device)for epoch in range(20):  # 例如,训练 20 个周期for i, data in enumerate(train_loader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)  # 将数据移动到GPUoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()if i % 100 == 0:print("epoch/step: {}/{}: loss: {}".format(epoch, i, loss.item()))# Validation after each epochval_accuracy = validate(net, val_loader, criterion, device)print("Epoch {} - Validation Accuracy: {:.2%}".format(epoch, val_accuracy))print('Finished Training')

待续。。。

相关文章:

实现pytorch版的mobileNetV1

mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客 这里是根据网络结构,搭建模型,用于图像分类任务。 1. 网络结构和基本组件 2. 搭建组件 (1)普通的卷积组件:CBL …...

vue多tab页面全部关闭后自动退出登录

业务场景:主项目是用vue写的单页面应用,但是有多开页面的需求,现在需要在用户关闭了所有的浏览器标签页面后,自动退出登录。 思路:因为是不同的tab页面,我只能用localStorage来通信,新打开一个…...

记一个集群环境部署不完整导致的BUG

一 背景 产品有三个环境:开发测试环境、验收环境、生产环境。 开发测试环境,保持最新的更新; 验收环境,阶段待发布内容; 生产环境,部署稳定内容。 产品为BS架构,后端采用微服务&#xf…...

Go zero copy,复制文件

这里使用零拷贝技术复制文件,从内核态操作源文件和目标文件。避免了在用户态开辟缓冲区,然后从内核态复制文件到用户态的问题。 由内核态完成文件复制操作。 调用的是syscall.Sendfile系统调用函数。 //go:build linuxpackage zero_copyimport ("f…...

http协议九种请求方法介绍及常见状态码

http1.0定义了三种: GET: 向服务器获取资源,比如常见的查询请求POST: 向服务器提交数据而发送的请求Head: 和get类似,返回的响应中没有具体的内容,用于获取报头 http1.1定义了六种 PUT:一般是用于更新请求,…...

详解flink exactly-once和两阶段提交

以下是我们常见的三种 flink 处理语义: 最多一次(At-most-Once):用户的数据只会被处理一次,不管成功还是失败,不会重试也不会重发。 至少一次(At-least-Once):系统会保…...

Qt/QML编程学习之心得:QDbus实现service接口调用(28)

D-Bus协议用于进程间通讯的。 QString value = retrieveValue();QDBusPendingCall pcall = interface->asyncCall(QLatin1String("Process"), value);QDBusPendingCallWatcher *watcher = new QDBusPendingCallWatcher(pcall, this);QObject::connect(watcher, SI…...

前端nginx配置指南

前端项目发布后,有些接口需要在服务器配置反向代理,资源配置gzip压缩,配置跨域允许访问等 配置文件模块概览 配置示例 反向代理 反向代理是Nginx的核心功能之一,是指客户端发送请求到代理服务器,代理服务器再将请求…...

接口测试到底怎么做,5分钟时间看完这篇文章彻底搞清楚

01、通用的项目架构 02、什么是接口 接口:服务端程序对外提供的一种统一的访问方式,通常采用HTTP协议,通过不同的url,不同的请求类型(GET、POST),不同的参数,来执行不同的业务逻辑。…...

显示管理磁盘分区 fdisk

显示管理磁盘分区 fdisk fdisk是用于检查一个磁盘上分区信息最通用的命令。 fdisk可以显示分区信息及一些细节信息,比如文件系统类型等。 设备的名称通常是/dev/sda、/dev/sdb 等。 对于以前的设备有可能还存在设备名为 /dev/hd* (IDE)的设备,这个设…...

Hyperledger Fabric 管理链码 peer lifecycle chaincode 指令使用

链上代码(Chaincode)简称链码,包括系统链码和用户链码。系统链码(System Chaincode)指的是 Fabric Peer 中负责系统配置、查询、背书、验证等平台功能的代码逻辑,运行在 Peer 进程内,将在第 14 …...

L1-011 A-B(Java)

题目 本题要求你计算A−B。不过麻烦的是,A和B都是字符串 —— 即从字符串A中把字符串B所包含的字符全删掉,剩下的字符组成的就是字符串A−B。 输入格式: 输入在2行中先后给出字符串A和B。两字符串的长度都不超过10的四次方,并且…...

系列七、Ribbon

一、Ribbon 1.1、概述 Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具,是Netflix发布的一款开源项目,其主要功能是提供客户端的软件负载均衡算法和服务调用,Ribbon客户端组件提供一系列完善的配置项,例如&#xff1a…...

山东名岳轩印刷包装携专业包装袋盛装亮相2024济南生物发酵展

山东名岳轩印刷包装有限公司盛装亮相2024第12届国际生物发酵展,3月5-7日山东国际会展中心与您相约! 展位号:1号馆F17 山东名岳轩印刷包装有限公司是一家拥有南北两个生产厂区,设计、制版、印刷,营销策划为一体的专业…...

BGP公认必遵属性——Next-hop(一)

BGP公认必遵属性共有三个,分别是:Next-hop、Origin、As-path,本期介绍Next-hop 点赞关注,持续更新!!! Next-hop 华为BGP路由下一跳特点: 默认情况下传给EBGP邻居的BGP路由的下一跳…...

增强Wi-Fi信号的10种方法,值得去尝试

Wi-Fi信号丢失,无线盲区。在一个对一些人来说,上网和呼吸一样必要的世界里,这些问题中的每一个都令人抓狂。 如果你觉得你的Wi-Fi变得迟钝,有很多工具可以用来测试你的互联网速度。你还可以尝试一些技巧来解决网络问题。然而,如果你能获得良好接收的唯一方法是站在无线路…...

第十五章 ECMAScript6新增的常用语法

文章目录 一、声明关键字二、箭头函数三、解构赋值四、展开运算符五、对字符的补充六、Symbol七、对象的简写语法八、Set和Map九、for-of 一、声明关键字 ES6新增的声明关键字: let,const:声明变量class:声明类import&#xff0c…...

vulhub中的Apache SSI 远程命令执行漏洞

Apache SSI 远程命令执行漏洞 1.cd到ssi-rce cd /opt/vulhub/httpd/ssi-rce/ 2.执行docker-compose up -d docker-compose up -d 3.查看靶场是否开启成功 dooker ps 拉取成功了 4.访问url 这里已经执行成功了,注意这里需要加入/upload.php 5.写入一句话木马 &…...

MSB20M-ASEMI迷你贴片整流桥MSB20M

编辑:ll MSB20M-ASEMI迷你贴片整流桥MSB20M 型号:MSB20M 品牌:ASEMI 封装:UMSB-4 特性:贴片、整流桥 最大平均正向电流:2A 最大重复峰值反向电压:1000V 恢复时间:&#xff1…...

工程管理系统功能设计与实践:实现高效、透明的工程管理

在现代化的工程项目管理中,一套功能全面、操作便捷的系统至关重要。本文将介绍一个基于Spring Cloud和Spring Boot技术的Java版工程项目管理系统,结合Vue和ElementUI实现前后端分离。该系统涵盖了项目管理、合同管理、预警管理、竣工管理、质量管理等多个…...

深蓝词库转换器:3分钟掌握30+输入法词库互转的终极指南

深蓝词库转换器:3分钟掌握30输入法词库互转的终极指南 【免费下载链接】imewlconverter ”深蓝词库转换“ 一款开源免费的输入法词库转换程序 项目地址: https://gitcode.com/gh_mirrors/im/imewlconverter 你是否曾因更换输入法而丢失多年积累的个人词库&am…...

AI工厂令牌生产加速:统一服务与实时AI架构

使用统一服务和实时AI加速AI工厂中的令牌生产 在当今的AI工厂环境中,性能并非理论概念,而是经济、竞争和生存的关键。可用GPU时间下降1%,可能意味着每小时损失数百万令牌。几分钟的拥塞可能演变成数小时的恢复时间。机架级功率过载会导致功率…...

腾讯HY-MT1.5翻译模型应用案例:多语言文档翻译实战

腾讯HY-MT1.5翻译模型应用案例:多语言文档翻译实战 1. 模型概述与核心能力 1.1 模型架构与版本 腾讯开源的HY-MT1.5翻译模型包含两个版本: HY-MT1.5-1.8B:18亿参数版本,专为边缘计算和实时翻译场景优化HY-MT1.5-7B&#xff1a…...

零基础玩转AI春联生成:手把手教你Windows WSL2部署达摩院春联模型

零基础玩转AI春联生成:手把手教你Windows WSL2部署达摩院春联模型 春节将至,家家户户都开始准备贴春联。但每年想一副既传统又有新意的对联可不容易——要么是市场上买的千篇一律,要么自己绞尽脑汁也想不出好句子。今天,我将带你…...

M2FP镜像升级指南:如何从基础服务扩展到视频流实时解析?

M2FP镜像升级指南:如何从基础服务扩展到视频流实时解析? 1. 从静态图像到视频流解析的技术演进 多人人体解析技术正在从静态图片处理向动态视频分析快速演进。传统的M2FP服务虽然能出色完成单张图片的语义分割,但面对视频流实时处理时&…...

个人知识库管家:OpenClaw+Gemma-3-12b-it自动整理Obsidian笔记

个人知识库管家:OpenClawGemma-3-12b-it自动整理Obsidian笔记 1. 为什么需要自动化笔记整理 作为一个长期使用Obsidian管理技术笔记的用户,我发现自己逐渐陷入"收集容易整理难"的困境。每天新增的Markdown文档堆积在Vault文件夹中&#xff0…...

HunyuanVideo-Foley多模态交互案例:结合文本与视觉输入生成场景化音效

HunyuanVideo-Foley多模态交互案例:结合文本与视觉输入生成场景化音效 1. 效果亮点开场 想象一下这样的场景:你上传一张古堡图片,输入"添加一些神秘感",系统就能自动生成风声、吱呀作响的木门、隐约的钟声等复合音效。…...

YOLO12入门必看:位置感知器与FlashAttention推理加速原理图解

YOLO12入门必看:位置感知器与FlashAttention推理加速原理图解 1. YOLO12模型概述 1.1 新一代目标检测架构 YOLO12是2025年发布的最新一代目标检测模型,代表了计算机视觉领域的重要突破。这个模型采用了全新的注意力为中心架构,在保持实时推…...

seo外包公司如何提高网站的用户体验_seo外包公司有哪些常见的优化方法

seo外包公司如何提高网站的用户体验 在当前的数字化时代,网站的用户体验(User Experience, UX)已经成为网站成功的关键因素之一。优秀的用户体验不仅能提升网站的流量,还能增加用户的黏性和转化率。对于那些选择了外包SEO服务的企…...

Qwen2.5-1.5B效果展示:金融术语解释+财报摘要生成准确率实测

Qwen2.5-1.5B效果展示:金融术语解释财报摘要生成准确率实测 1. 测试背景与目的 在金融领域,准确理解专业术语和快速分析财务报告是两项核心需求。传统方式需要专业人士花费大量时间进行解释和分析,而AI模型的出现让自动化处理成为可能。 本…...