pytorch04:网络模型创建
目录
- 一、模型创建过程
- 1.1 以LeNet网络为例
- 1.2 LeNet结构
- 1.3 nn.Module
- 二、网络层容器(Containers)
- 2.1 nn.Sequential
- 2.1.1 常规方法实现
- 2.1.2 OrderedDict方法实现
- 2.2 nn.ModuleList
- 2.3 nn.ModuleDict
- 2.4 三种容器构建总结
- 三、AlexNet网络构建
一、模型创建过程
1.1 以LeNet网络为例
网络代码如下:
class LeNet(nn.Module):def __init__(self, classes):super(LeNet, self).__init__() # 调用父类方法,作用是调用nn.Module类的构造函数,# 确保LeNet类被正确地初始化,并继承了nn.Module 的所有属性和方法self.conv1 = nn.Conv2d(3, 6, 5) # 卷积层self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, classes)def forward(self, x):out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out
1.2 LeNet结构
LeNet:conv1–>pool1–>conv2–>pool2–>fc1–>fc2–>fc3
1.3 nn.Module
Module是nn模块中的功能,nn模块还有Parameter、functional等模块。
nn.Module主要有以下参数:
• parameters : 存储管理nn.Parameter类
• modules : 存储管理nn.Module类
• buffers:存储管理缓冲属性,如BN层中的running_mean
二、网络层容器(Containers)
2.1 nn.Sequential
nn.Sequential 是 nn.module的容器,也是最常用的容器,用于按顺序包装一组网络层
• 顺序性:各网络层之间严格按照顺序构建
• 自带forward():自带的forward里,通过for循环依次执行前向传播运算
2.1.1 常规方法实现
LeNet网络由两部分构成,中间的卷积池化特征提取部分(features),以及最后的分类部分(classifier)。
具体代码如下:
class LeNetSequential(nn.Module):def __init__(self, classes):super(LeNetSequential, self).__init__()self.features = nn.Sequential( #特征提取部分nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential( #分类部分nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes),)def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x
打印网络层:
2.1.2 OrderedDict方法实现
使用有序字典的方法构建Sequential
代码如下:
class LeNetSequentialOrderDict(nn.Module):def __init__(self, classes):super(LeNetSequentialOrderDict, self).__init__()self.features = nn.Sequential(OrderedDict({'conv1': nn.Conv2d(3, 6, 5),'relu1': nn.ReLU(inplace=True),'pool1': nn.MaxPool2d(kernel_size=2, stride=2),'conv2': nn.Conv2d(6, 16, 5),'relu2': nn.ReLU(inplace=True),'pool2': nn.MaxPool2d(kernel_size=2, stride=2),}))self.classifier = nn.Sequential(OrderedDict({'fc1': nn.Linear(16 * 5 * 5, 120),'relu3': nn.ReLU(),'fc2': nn.Linear(120, 84),'relu4': nn.ReLU(inplace=True),'fc3': nn.Linear(84, classes),}))def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x
先看一下Sequential函数中init初始化的两种方法,当我们使用OrderedDict方法时,会进行判断,使用self.add_module(key, module)方法将字典中的key和value取出来添加到Sequential中。
class Sequential(Module):def __init__(self, *args):super().__init__()if len(args) == 1 and isinstance(args[0], OrderedDict):for key, module in args[0].items():self.add_module(key, module)else:for idx, module in enumerate(args):self.add_module(str(idx), module)
通过这种方法构建可以给每一网络层添加一个名称,网络输出结果如下:
2.2 nn.ModuleList
nn.ModuleList是 nn.module的容器,用于包装一组网络层,以迭代方式调用网络层
主要方法:
• append():在ModuleList后面添加网络层
• extend():拼接两个ModuleList
• insert():指定在ModuleList中位置插入网络层
使用列表生成式,通过一行代码就能构建20个网络层。
代码演示:
class ModuleList(nn.Module):def __init__(self):super(ModuleList, self).__init__()# 使用列表生成式构建20个全连接层,每个全连接层10个神经元的网络self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])def forward(self, x):for i, linear in enumerate(self.linears):x = linear(x)return xnet = ModuleList()
2.3 nn.ModuleDict
nn.ModuleDict是 nn.module的容器,用于包装一组网络层,以索引方式调用网络层,可以用过参数的形式选取想要调用的网络层。
主要方法:
• clear():清空ModuleDict
• items():返回可迭代的键值对(key-value pairs)
• keys():返回字典的键(key)
• values():返回字典的值(value)
• pop():返回一对键值,并从字典中删除
代码展示,只选取conv和relu两个网络层:
class ModuleDict(nn.Module):def __init__(self):super(ModuleDict, self).__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})# 激活函数self.activations = nn.ModuleDict({'relu': nn.ReLU(),'prelu': nn.PReLU()})def forward(self, x, choice, act): # 传入两个参数 用来选择网络层x = self.choices[choice](x)x = self.activations[act](x)return x
net = ModuleDict()
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu') #只选取conv和relu两个网络层。
print(output)
2.4 三种容器构建总结
• nn.Sequential:顺序性,各网络层之间严格按顺序执行,常用于block构建
• nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建
• nn.ModuleDict:索引性,常用于可选择的网络层
三、AlexNet网络构建
AlexNet:2012年以高出第二名10多个百分点的准确率获得ImageNet分类任务冠军,开创了卷积神经网络的新时代
AlexNet特点如下:
- 采用ReLU:替换饱和激活函数,减轻梯度消失
- 采用LRN(Local Response Normalization):对数据归一化,减轻梯度消失
- Dropout:提高全连接层的鲁棒性,增加网络的泛化能力
- Data Augmentation:TenCrop,色彩修改
网络结构图如下:
构建代码:
import torch.nn as nn
import torch
from torchsummary import summary
# 定义一个名为AlexNet的神经网络模型,继承自nn.Module基类
class AlexNet(nn.Module):# 构造函数,初始化网络的参数def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:# 调用父类的构造函数super().__init__()# 定义神经网络的特征提取部分,包含多个卷积层和池化层self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), # 输入通道3,输出通道64,卷积核大小11x11,步长4,填充2nn.ReLU(inplace=True), # 使用ReLU激活函数,inplace=True表示原地操作,节省内存nn.MaxPool2d(kernel_size=3, stride=2), # 最大池化层,核大小3x3,步长2nn.Conv2d(64, 192, kernel_size=5, padding=2), # 输入通道64,输出通道192,卷积核大小5x5,填充2nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)# 定义自适应平均池化层,将输入的任意大小的特征图池化为固定大小6x6self.avgpool = nn.AdaptiveAvgPool2d((6, 6))# 定义分类器部分,包含全连接层和Dropout层self.classifier = nn.Sequential(nn.Dropout(p=dropout), # 使用Dropout进行正则化,随机丢弃一部分神经元以防止过拟合nn.Linear(256 * 6 * 6, 4096), # 输入大小为256*6*6,输出大小为4096nn.ReLU(inplace=True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes), # 最后的全连接层输出类别数)# 前向传播函数,定义数据在网络中的传播过程def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.features(x) # 特征提取x = self.avgpool(x) # 平均池化x = torch.flatten(x, 1) # 将特征图展平成一维向量x = self.classifier(x) # 分类器return xif __name__ == '__main__':net = AlexNet().cuda()summary(net, (3, 256, 256))
打印出的网络结构图如下:
相关文章:

pytorch04:网络模型创建
目录 一、模型创建过程1.1 以LeNet网络为例1.2 LeNet结构1.3 nn.Module 二、网络层容器(Containers)2.1 nn.Sequential2.1.1 常规方法实现2.1.2 OrderedDict方法实现 2.2 nn.ModuleList2.3 nn.ModuleDict2.4 三种容器构建总结 三、AlexNet网络构建 一、模型创建过程 1.1 以LeNe…...

用js让用户输入一个数累加和
需求:用户输入一个数, 计算 1 到这个数的和。 比如 用户输入的是 5, 则计算 1~5 之间的累加和 并且输出到控制台 <body><script>let numprompt(请输入一个数)let sum0for(let i1;i<num;i){sumi}console.log(sum)</script…...

踩坑记录-安装nuxt3报错:Error: Failed to download template from registry: fetch failed;
报错复现 安装nuxt3报错:Error: Failed to download template from registry: fetch failednpx nuxi init nuxt-demo 初始化nuxt 项目 报错 Error: Failed to download template from registry: fetch faile 解决方法 配置hosts Mac电脑:/etc/hostswin电…...
大数据学习(31)-Spark非常用及重要特性
&&大数据学习&& 🔥系列专栏: 👑哲学语录: 承认自己的无知,乃是开启智慧的大门 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言📝支持一下博主哦ᾑ…...

【教学类-43-14】 20240103 (4宫格数独:正确版:576套) 不重复的基础模板数量:576套
作品展示::——4宫格 576套不重复模板(48页*12套题) 背景需求: 生成4宫格基础模板768套,观看64页内容时,明显看到有错误 【教学类-43-13】 20240103 (4宫格数独:错误版…...

AIGC开发:调用openai的API接口实现简单机器人
简介 开始进行最简单的使用:通过API调用openai的模型能力 OpenAI的能力如下图: 文本生成模型 OpenAI 的文本生成模型(通常称为生成式预训练 Transformer 或大型语言模型)经过训练可以理解自然语言、代码和图像。这些模型提供文…...
c基础(二)
指针: 含义:是一个值,一个值代表着一个内存地址,类似于存放路径 * 运算符 : 1 字符*表示指针 作用:通常跟在类型关键字的后面,表示指针指向的是什么类型的值 int * foo, * bar;声明指针后会…...

人工智能趋势报告解读:ai野蛮式生长的背后是机遇还是危机?
近期,Enterprise WordPress发布了生成式人工智能在营销中的应用程度的报告,这是一个人工智能迅猛发展的时代,目前人工智能已经广泛运用到内容创作等领域,可以预见的是人工智能及其扩展应用还将延伸到我们工作与生活中的方方面面。…...
三、C语言中的分支与循环—goto语句 (10) (完)
在C语言中,goto语句允许程序无条件地跳转到同一函数内的标记位置。这个标记位置通过一个标签和冒号(:)来标示。goto语句可以用于从深层嵌套的循环或条件语句中直接跳出,或者跳过某些代码执行。尽管goto语句在某些情况下可以使程序逻辑变得清晰࿰…...
RabbitMQ 常见问题
1. 如何保证消息顺序消费 在RabbitMQ中,消息最终会保存在队列中,在同一个队列中,消息是顺序的,保持先进先出的原则,这个由Rabbitmq保证。而不同队列中的消息,RabbitMQ 是无法保证其顺序性。顺序消费主要是…...

阶段二-Day10-日期类
日期类结构: 1.java.util.Date是日期类 2.DateFormat是日期格式类、SimpleDateFormat是日期格式类的子类 Timezone代表时区 3.Calendar是日历类,GregorianCalendar是日历的子类 一. 常用类-Date 1.1 Date构造方法 Date(long date) 使用给定的毫秒时间价值构建…...

多任务并行处理相关面试题
我自己面试时被问过两次多任务并行相关的问题: 假设现在有10个任务,要求同时处理,并且必须所有任务全部完成才返回结果 这个面试题的难点是: 既然要同时处理,那么肯定要用多线程。怎么设计多线程同时处理任务呢&…...

Shell脚本学习笔记
1. 写在前面 工作中,需要用到写一些shell脚本去完成一些简单的重复性工作, 于是就想系统的学习下shell脚本的相关知识, 本篇文章是学习shell脚本整理的学习笔记,内容参考主要来自C语言中文网, 学习过程中,…...

ROS-安装xacro
安装 运行下列命令进行安装,xxxxxx处更改为自己的版本 sudo apt-get install ros-xxxxxx-xacro运行 输入下列命令 roscd xacro如果没有报错,并且进入了xacro软件包的目录,则表示安装成功。 参考: [1]https://wenku.csdn.net/ans…...

为什么说 $mash 是 Solana 上最正统的铭文通证?
早在 2023 年的 11 月,包括 Solana、Avalanche、Polygon、Arbitrum、zkSync 等生态正在承接比特币铭文生态外溢的价值。当然,因铭文赛道过于火爆,当 Avalanche、BNB Chain 以及 Polygon 等链上 Gas 飙升至极值,Arbitrum、zkSync 等…...

安装elasticsearch、kibana、IK分词器、扩展IK词典
安装elasticsearch、kibana、IK分词器、扩展IK词典 后面还会安装kibana,这个会提供可视化界面方面学习。 需要注意的是elasticsearch和kibana版本一定要一样!!! 否则就像这样 elasticsearch 1、创建网络 因为我们还需要部署k…...

Spring中常见的BeanFactory后处理器
常见的BeanFacatory后处理器 先给出没有添加任何BeanFactory后处理器的测试代码 public class TestBeanFactoryPostProcessor {public static void main(String[] args) {GenericApplicationContext context new GenericApplicationContext();context.registerBean("co…...
FPGA LCD1602驱动代码 (已验证)
一.需求解读 1.需求 在液晶屏第一行显示“HELLO FPGA 1234!” 2. 知识背景 1602 液晶也叫 1602 字符型液晶,它是一种专门用来显示字母、数字、符号等的点阵 型液晶模块。它由若干个 5X7 或者 5X11 等点阵字符位组成,每个点阵字符位都可以显示一 个字符,每位之间有一个点距的…...

c++编程要养成的好习惯
1、缩进 你说有缩进看的清楚还是没缩进看的清楚 2、i和i i运行起来和i更快 3、 n%20和n&1 不要再用n%20来判断n是不是偶数了,又慢又土,用n&10,如果n&10就说明n是偶数 同理,n&11说明n是奇数 4、*2和<<…...
后台管理项目的多数据源方案
引言 在互联网开发公司中,往往伴随着业务的快速迭代,程序员可能没有过多的时间去思考技术扩展的相关问题,长久下来导致技术过于单一。为此最近在学习互联网思维,从相对简单的功能开始做总结,比如非常常见的基础数据的…...

工业安全零事故的智能守护者:一体化AI智能安防平台
前言: 通过AI视觉技术,为船厂提供全面的安全监控解决方案,涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面,能够实现对应负责人反馈机制,并最终实现数据的统计报表。提升船厂…...

通过Wrangler CLI在worker中创建数据库和表
官方使用文档:Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后,会在本地和远程创建数据库: npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库: 现在,您的Cloudfla…...
线程与协程
1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指:像函数调用/返回一样轻量地完成任务切换。 举例说明: 当你在程序中写一个函数调用: funcA() 然后 funcA 执行完后返回&…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析
Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...

使用Spring AI和MCP协议构建图片搜索服务
目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式(本地调用) SSE模式(远程调用) 4. 注册工具提…...
现有的 Redis 分布式锁库(如 Redisson)提供了哪些便利?
现有的 Redis 分布式锁库(如 Redisson)相比于开发者自己基于 Redis 命令(如 SETNX, EXPIRE, DEL)手动实现分布式锁,提供了巨大的便利性和健壮性。主要体现在以下几个方面: 原子性保证 (Atomicity)ÿ…...
Java数值运算常见陷阱与规避方法
整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...

从 GreenPlum 到镜舟数据库:杭银消费金融湖仓一体转型实践
作者:吴岐诗,杭银消费金融大数据应用开发工程师 本文整理自杭银消费金融大数据应用开发工程师在StarRocks Summit Asia 2024的分享 引言:融合数据湖与数仓的创新之路 在数字金融时代,数据已成为金融机构的核心竞争力。杭银消费金…...

Qemu arm操作系统开发环境
使用qemu虚拟arm硬件比较合适。 步骤如下: 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载,下载地址:https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...
嵌入式常见 CPU 架构
架构类型架构厂商芯片厂商典型芯片特点与应用场景PICRISC (8/16 位)MicrochipMicrochipPIC16F877A、PIC18F4550简化指令集,单周期执行;低功耗、CIP 独立外设;用于家电、小电机控制、安防面板等嵌入式场景8051CISC (8 位)Intel(原始…...