当前位置: 首页 > news >正文

实现pytorch版的mobileNetV1

mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客

这里是根据网络结构,搭建模型,用于图像分类任务。

1. 网络结构和基本组件

2. 搭建组件

(1)普通的卷积组件:CBL = Conv2d + BN + ReLU6;

(2)深度可分离卷积:DwCBL  = Conv dw+ Conv dp;

Conv dw+ Conv dp = {Conv2d(3x3) + BN + ReLU6 }  + {Conv2d(1x1) + BN + ReLU6};

Conv dw是3x3的深度卷积,通过步长控制是否进行下采样;

Conv dp是1x1的逐点卷积,通过控制输出通道数,控制通道维度的变化;

# 普通卷积
class CBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(CBN, self).__init__()self.conv = nn.Conv2d(in_c, out_c, 3, stride, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_c)self.relu = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x
# 深度可分离卷积: 深度卷积(3x3x1) + 逐点卷积(1x1xc卷积)
class DwCBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(DwCBN, self).__init__()# conv3x3x1, 深度卷积,通过步长,只控制是否缩小特征hwself.conv3x3 = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False)self.bn1 = nn.BatchNorm2d(in_c)self.relu1 = nn.ReLU6(inplace=True)# conv1x1xc, 逐点卷积,通过控制输出通道数,控制通道维度的变化self.conv1x1 = nn.Conv2d(in_c, out_c, 1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(out_c)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv3x3(x)x = self.bn1(x)x = self.relu1(x)x = self.conv1x1(x)x = self.bn2(x)x = self.relu2(x)return x

3. 搭建网络

class MobileNetV1(nn.Module):def __init__(self, class_num=1000):super(MobileNetV1, self).__init__()self.stage1 = torch.nn.Sequential(CBN(3, 32, 2),  # 下采样/2DwCBN(32, 64, 1))self.stage2 = torch.nn.Sequential(DwCBN(64, 128, 2),  # 下采样/4DwCBN(128, 128, 1))self.stage3 = torch.nn.Sequential(DwCBN(128, 256, 2),  # 下采样/8DwCBN(256, 256, 1))self.stage4 = torch.nn.Sequential(DwCBN(256, 512, 2),  # 下采样/16DwCBN(512, 512, 1),  # 5个DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),)self.stage5 = torch.nn.Sequential(DwCBN(512, 1024, 2),  # 下采样/32DwCBN(1024, 1024, 1))# classifierself.avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1))self.fc = torch.nn.Linear(1024, class_num, bias=True)# self.classifier = torch.nn.Softmax()  # 原始的softmax值# torch.log_softmax 首先计算 softmax 然后再取对数,因此在数值上更加稳定。# 在分类网络在训练过程中,通常使用交叉熵损失函数(Cross-Entropy Loss)。# torch.nn.CrossEntropyLoss 会在内部进行 softmax 操作,因此在网络的最后一层不需要手动加上 softmax 操作。def forward(self, x):scale1 = self.stage1(x)  # /2scale2 = self.stage2(scale1)scale3 = self.stage3(scale2)scale4 = self.stage4(scale3)scale5 = self.stage5(scale4)  # /32. 7x7x = self.avg_pooling(scale5)  # (b,1024,7,7)->(b,1024,1,1)x = torch.flatten(x, 1)  # (b,1024,1,1)->(b,1024,)x = self.fc(x)  # (b,1024,)  -> (b,1000,)return xif __name__ == '__main__':m1 = MobileNetV1(class_num=1000)input_data = torch.randn(64, 3, 224, 224)output = m1.forward(input_data)print(output.shape)

4. 训练验证

import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optimfrom mobilenetv1 import MobileNetV1def validate(model, val_loader, criterion, device):model.eval()  # Set the model to evaluation modetotal_correct = 0total_samples = 0with torch.no_grad():for val_inputs, val_labels in val_loader:val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)val_outputs = model(val_inputs)_, predicted = torch.max(val_outputs, 1)total_samples += val_labels.size(0)total_correct += (predicted == val_labels).sum().item()accuracy = total_correct / total_samplesmodel.train()  # Set the model back to training modereturn accuracyif __name__ == '__main__':# 下载并准备数据集# Define image transformations (adjust as needed)transform = transforms.Compose([transforms.Resize((224, 224)),  # Resize images to a consistent sizetransforms.ToTensor(),  # converts to PIL Image to a Pytorch Tensor and scales values to the range [0, 1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Adjust normalization values. val = (val - mean) / std.])# Create ImageFolder datasetdata_folder = r"D:\zxq\data\car_or_dog"dataset = torchvision.datasets.ImageFolder(root=data_folder, transform=transform)# Optionally, split the dataset into training and validation sets# Adjust the `split_ratio` as neededsplit_ratio = 0.8train_size = int(split_ratio * len(dataset))val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])# Create DataLoader for training and validationtrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)# 初始化模型、损失函数和优化器net = MobileNetV1(class_num=2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)# 训练模型device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)net.to(device)for epoch in range(20):  # 例如,训练 20 个周期for i, data in enumerate(train_loader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)  # 将数据移动到GPUoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()if i % 100 == 0:print("epoch/step: {}/{}: loss: {}".format(epoch, i, loss.item()))# Validation after each epochval_accuracy = validate(net, val_loader, criterion, device)print("Epoch {} - Validation Accuracy: {:.2%}".format(epoch, val_accuracy))print('Finished Training')

待续。。。

相关文章:

实现pytorch版的mobileNetV1

mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客 这里是根据网络结构,搭建模型,用于图像分类任务。 1. 网络结构和基本组件 2. 搭建组件 (1)普通的卷积组件:CBL …...

vue多tab页面全部关闭后自动退出登录

业务场景:主项目是用vue写的单页面应用,但是有多开页面的需求,现在需要在用户关闭了所有的浏览器标签页面后,自动退出登录。 思路:因为是不同的tab页面,我只能用localStorage来通信,新打开一个…...

记一个集群环境部署不完整导致的BUG

一 背景 产品有三个环境:开发测试环境、验收环境、生产环境。 开发测试环境,保持最新的更新; 验收环境,阶段待发布内容; 生产环境,部署稳定内容。 产品为BS架构,后端采用微服务&#xf…...

Go zero copy,复制文件

这里使用零拷贝技术复制文件,从内核态操作源文件和目标文件。避免了在用户态开辟缓冲区,然后从内核态复制文件到用户态的问题。 由内核态完成文件复制操作。 调用的是syscall.Sendfile系统调用函数。 //go:build linuxpackage zero_copyimport ("f…...

http协议九种请求方法介绍及常见状态码

http1.0定义了三种: GET: 向服务器获取资源,比如常见的查询请求POST: 向服务器提交数据而发送的请求Head: 和get类似,返回的响应中没有具体的内容,用于获取报头 http1.1定义了六种 PUT:一般是用于更新请求,…...

详解flink exactly-once和两阶段提交

以下是我们常见的三种 flink 处理语义: 最多一次(At-most-Once):用户的数据只会被处理一次,不管成功还是失败,不会重试也不会重发。 至少一次(At-least-Once):系统会保…...

Qt/QML编程学习之心得:QDbus实现service接口调用(28)

D-Bus协议用于进程间通讯的。 QString value = retrieveValue();QDBusPendingCall pcall = interface->asyncCall(QLatin1String("Process"), value);QDBusPendingCallWatcher *watcher = new QDBusPendingCallWatcher(pcall, this);QObject::connect(watcher, SI…...

前端nginx配置指南

前端项目发布后,有些接口需要在服务器配置反向代理,资源配置gzip压缩,配置跨域允许访问等 配置文件模块概览 配置示例 反向代理 反向代理是Nginx的核心功能之一,是指客户端发送请求到代理服务器,代理服务器再将请求…...

接口测试到底怎么做,5分钟时间看完这篇文章彻底搞清楚

01、通用的项目架构 02、什么是接口 接口:服务端程序对外提供的一种统一的访问方式,通常采用HTTP协议,通过不同的url,不同的请求类型(GET、POST),不同的参数,来执行不同的业务逻辑。…...

显示管理磁盘分区 fdisk

显示管理磁盘分区 fdisk fdisk是用于检查一个磁盘上分区信息最通用的命令。 fdisk可以显示分区信息及一些细节信息,比如文件系统类型等。 设备的名称通常是/dev/sda、/dev/sdb 等。 对于以前的设备有可能还存在设备名为 /dev/hd* (IDE)的设备,这个设…...

Hyperledger Fabric 管理链码 peer lifecycle chaincode 指令使用

链上代码(Chaincode)简称链码,包括系统链码和用户链码。系统链码(System Chaincode)指的是 Fabric Peer 中负责系统配置、查询、背书、验证等平台功能的代码逻辑,运行在 Peer 进程内,将在第 14 …...

L1-011 A-B(Java)

题目 本题要求你计算A−B。不过麻烦的是,A和B都是字符串 —— 即从字符串A中把字符串B所包含的字符全删掉,剩下的字符组成的就是字符串A−B。 输入格式: 输入在2行中先后给出字符串A和B。两字符串的长度都不超过10的四次方,并且…...

系列七、Ribbon

一、Ribbon 1.1、概述 Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具,是Netflix发布的一款开源项目,其主要功能是提供客户端的软件负载均衡算法和服务调用,Ribbon客户端组件提供一系列完善的配置项,例如&#xff1a…...

山东名岳轩印刷包装携专业包装袋盛装亮相2024济南生物发酵展

山东名岳轩印刷包装有限公司盛装亮相2024第12届国际生物发酵展,3月5-7日山东国际会展中心与您相约! 展位号:1号馆F17 山东名岳轩印刷包装有限公司是一家拥有南北两个生产厂区,设计、制版、印刷,营销策划为一体的专业…...

BGP公认必遵属性——Next-hop(一)

BGP公认必遵属性共有三个,分别是:Next-hop、Origin、As-path,本期介绍Next-hop 点赞关注,持续更新!!! Next-hop 华为BGP路由下一跳特点: 默认情况下传给EBGP邻居的BGP路由的下一跳…...

增强Wi-Fi信号的10种方法,值得去尝试

Wi-Fi信号丢失,无线盲区。在一个对一些人来说,上网和呼吸一样必要的世界里,这些问题中的每一个都令人抓狂。 如果你觉得你的Wi-Fi变得迟钝,有很多工具可以用来测试你的互联网速度。你还可以尝试一些技巧来解决网络问题。然而,如果你能获得良好接收的唯一方法是站在无线路…...

第十五章 ECMAScript6新增的常用语法

文章目录 一、声明关键字二、箭头函数三、解构赋值四、展开运算符五、对字符的补充六、Symbol七、对象的简写语法八、Set和Map九、for-of 一、声明关键字 ES6新增的声明关键字: let,const:声明变量class:声明类import&#xff0c…...

vulhub中的Apache SSI 远程命令执行漏洞

Apache SSI 远程命令执行漏洞 1.cd到ssi-rce cd /opt/vulhub/httpd/ssi-rce/ 2.执行docker-compose up -d docker-compose up -d 3.查看靶场是否开启成功 dooker ps 拉取成功了 4.访问url 这里已经执行成功了,注意这里需要加入/upload.php 5.写入一句话木马 &…...

MSB20M-ASEMI迷你贴片整流桥MSB20M

编辑:ll MSB20M-ASEMI迷你贴片整流桥MSB20M 型号:MSB20M 品牌:ASEMI 封装:UMSB-4 特性:贴片、整流桥 最大平均正向电流:2A 最大重复峰值反向电压:1000V 恢复时间:&#xff1…...

工程管理系统功能设计与实践:实现高效、透明的工程管理

在现代化的工程项目管理中,一套功能全面、操作便捷的系统至关重要。本文将介绍一个基于Spring Cloud和Spring Boot技术的Java版工程项目管理系统,结合Vue和ElementUI实现前后端分离。该系统涵盖了项目管理、合同管理、预警管理、竣工管理、质量管理等多个…...

【C#】网址不进行UrlEncode编码会存在一些问题

欢迎来到《小5讲堂》,大家好,我是全栈小5。 这是2024年第3篇文章,此篇文章是C#知识点实践序列文章,博主能力有限,理解水平有限,若有不对之处望指正! 目录 前言数据丢失效果请求端代码接口端代码…...

深入Pandas(二):高级数据处理技巧

文章目录 系列文章目录引言时间序列分析可视化示例 高级数据分析技术分组与聚合操作时间序列分析 高级数据操作数据合并与重塑示例:数据合并merge示例:数据合并concat示例:数据重塑 - 透视表 高级索引技巧 结论 系列文章目录 Python数据分析…...

实验8 分析HTTP协议和DNS

实验8 分析HTTP协议和DNS 一、 实验目的及任务 熟悉并掌握wireshark的基本操作,了解网络协议实体间的交互以及报文交换。分析HTTP协议分析DNS协议 二、 实验设备 与因特网连接的计算机网络系统;主机操作系统为Windows;wireshark等软件。 …...

Talk | EMNLP 2023 最佳长论文:以标签为锚-从信息流动的视角分析上下文学习

本期为TechBeat人工智能社区第561期线上Talk。 北京时间1月4日(周四)20:00,北京大学博士生—王乐安的Talk已准时在TechBeat人工智能社区开播! 他与大家分享的主题是: “以标签为锚-从信息流动的视角分析上下文学习”,介绍了他的团队在上下文学…...

2024年中国电子学会青少年编程等级考试安排的通知

各有关单位、全体考生: 中国电子学会青少年等级考试(以下简称等级考试)是中国电子学会为落实《全民科学素质行动规划纲要》,提升青少年电子信息科学素质水平而开展的社会化评价项目。等级考试自2011年启动以来,作为中国电子学会科…...

[足式机器人]Part3 机构运动学与动力学分析与建模 Ch00-2(2) 质量刚体的在坐标系下运动

本文仅供学习使用,总结很多本现有讲述运动学或动力学书籍后的总结,从矢量的角度进行分析,方法比较传统,但更易理解,并且现有的看似抽象方法,两者本质上并无不同。 2024年底本人学位论文发表后方可摘抄 若有…...

【亚马逊云科技】自家的AI助手 - Amazon Q

写在前面:博主是一只经过实战开发历练后投身培训事业的“小山猪”,昵称取自动画片《狮子王》中的“彭彭”,总是以乐观、积极的心态对待周边的事物。本人的技术路线从Java全栈工程师一路奔向大数据开发、数据挖掘领域,如今终有小成…...

网络安全—SSL安全访问应用

文章目录 网络拓扑部署CA服务器颁发证书开启Web服务安装IIS服务修改Web默认网页 申请Web证书前提准备申请文件生成申请web证书开始安装web证书 客户机访问web默认网站使用HTTP使用HTTPS 为客户机安装浏览器证书 环境:Windows Server 2003 网络拓扑 这里使用NAT还是…...

Qt5.14.2实现将html文件转换为pdf文件

文章目录 简介源码widget.cppwidget.uihtml文件演示效果简介 QPdfWriter是Qt框架中用于创建和写入PDF文件的类。它允许您在您的Qt应用程序中动态生成并输出PDF文档,以便进行打印、保存或导出。 QPdfWriter类提供了以下一些常用的函数和方法,可以让您创建和定制PDF文件: 构…...

Minecraft教程:使用MCSM面板搭建我的世界私服并实现远程联机

文章目录 前言1. 安装JAVA2. MCSManager安装3.局域网访问MCSM4.创建我的世界服务器5.局域网联机测试6.安装cpolar内网穿透7. 配置公网访问地址8.远程联机测试9. 配置固定远程联机端口地址9.1 保留一个固定tcp地址9.2 配置固定公网TCP地址9.3 使用固定公网地址远程联机 前言 Li…...

wordpress functions.php/百度客服转人工

等级:文件 25MB格式 rar漂亮幼儿园教室3D模型下载 包括:原始3d模型、材质、贴图。 漂亮幼儿园教室立即下载等级:文件 3MB格式 su文件格式:su版本号:sketchup7内容简介 教室SketchUp模型下载 教室 教室立即下载等级…...

未来网站发展方向/友链互换平台推荐

转载请注明出处:http://blog.csdn.net/guoyjoe/article/details/9887591正解答案是:B引用akagea同学的解释http://www.itpub.net/thread-1808740-1-1.html“根据题意要求,需要存储的是一个时间间隔的数据,且方便加减,所…...

乡镇医院网站建设/厉害的seo顾问

第一印象 SolrCloud是Solr4.0引入的,主要应对与商业场景。它很像master-slave,却能自动化的完成以前需要手动完成的操作。利用ZooKeeper这个工具去监控整个Solr集群,以了解集群间各个机器的工作状态。 配置的区别 从配置来看,Solr…...

网站挣钱怎么做/百度百科官网首页

前言 持续集成(Continuous Intergration)做为软件开发重要环节,已经被提出、实施了很多年,在国内IT行业也已普及,它是为敏捷开发而创建。通过引入CI,可以减少重复工作、尽早暴露系统问题。 需求 1, 能够自动从源码管理…...

vs和sql怎么做网站/百度刷排名seo

我曾在《》一文中介绍过HttpCient的使用,这里就不在累述了,感兴趣的朋友能够去看一下。在这里主要介绍怎样通过HttpClient实现文件上传。1.预备知识:在HttpCient4.3之前上传文件主要使用MultipartEntity这个类,但如今这个类已经不…...

委托做网站违反广告法/营销策划公司收费明细

一、RabbitMQ基础知识 一、背景 RabbitMQ是一个由erlang开发的AMQP(Advanced Message Queue )的开源实现。AMQP 的出现其实也是应了广大人民群众的需求,虽然在同步消息通讯的世界里有很多公开标准(如 COBAR的 IIOP ,或…...