pytorch04:网络模型创建
目录
- 一、模型创建过程
- 1.1 以LeNet网络为例
- 1.2 LeNet结构
- 1.3 nn.Module
- 二、网络层容器(Containers)
- 2.1 nn.Sequential
- 2.1.1 常规方法实现
- 2.1.2 OrderedDict方法实现
- 2.2 nn.ModuleList
- 2.3 nn.ModuleDict
- 2.4 三种容器构建总结
- 三、AlexNet网络构建
一、模型创建过程
1.1 以LeNet网络为例
网络代码如下:
class LeNet(nn.Module):def __init__(self, classes):super(LeNet, self).__init__() # 调用父类方法,作用是调用nn.Module类的构造函数,# 确保LeNet类被正确地初始化,并继承了nn.Module 的所有属性和方法self.conv1 = nn.Conv2d(3, 6, 5) # 卷积层self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, classes)def forward(self, x):out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out
1.2 LeNet结构
LeNet:conv1–>pool1–>conv2–>pool2–>fc1–>fc2–>fc3
1.3 nn.Module
Module是nn模块中的功能,nn模块还有Parameter、functional等模块。
nn.Module主要有以下参数:
• parameters : 存储管理nn.Parameter类
• modules : 存储管理nn.Module类
• buffers:存储管理缓冲属性,如BN层中的running_mean
二、网络层容器(Containers)
2.1 nn.Sequential
nn.Sequential 是 nn.module的容器,也是最常用的容器,用于按顺序包装一组网络层
• 顺序性:各网络层之间严格按照顺序构建
• 自带forward():自带的forward里,通过for循环依次执行前向传播运算
2.1.1 常规方法实现
LeNet网络由两部分构成,中间的卷积池化特征提取部分(features),以及最后的分类部分(classifier)。
具体代码如下:
class LeNetSequential(nn.Module):def __init__(self, classes):super(LeNetSequential, self).__init__()self.features = nn.Sequential( #特征提取部分nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential( #分类部分nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes),)def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x
打印网络层:
2.1.2 OrderedDict方法实现
使用有序字典的方法构建Sequential
代码如下:
class LeNetSequentialOrderDict(nn.Module):def __init__(self, classes):super(LeNetSequentialOrderDict, self).__init__()self.features = nn.Sequential(OrderedDict({'conv1': nn.Conv2d(3, 6, 5),'relu1': nn.ReLU(inplace=True),'pool1': nn.MaxPool2d(kernel_size=2, stride=2),'conv2': nn.Conv2d(6, 16, 5),'relu2': nn.ReLU(inplace=True),'pool2': nn.MaxPool2d(kernel_size=2, stride=2),}))self.classifier = nn.Sequential(OrderedDict({'fc1': nn.Linear(16 * 5 * 5, 120),'relu3': nn.ReLU(),'fc2': nn.Linear(120, 84),'relu4': nn.ReLU(inplace=True),'fc3': nn.Linear(84, classes),}))def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x
先看一下Sequential函数中init初始化的两种方法,当我们使用OrderedDict方法时,会进行判断,使用self.add_module(key, module)方法将字典中的key和value取出来添加到Sequential中。
class Sequential(Module):def __init__(self, *args):super().__init__()if len(args) == 1 and isinstance(args[0], OrderedDict):for key, module in args[0].items():self.add_module(key, module)else:for idx, module in enumerate(args):self.add_module(str(idx), module)
通过这种方法构建可以给每一网络层添加一个名称,网络输出结果如下:
2.2 nn.ModuleList
nn.ModuleList是 nn.module的容器,用于包装一组网络层,以迭代方式调用网络层
主要方法:
• append():在ModuleList后面添加网络层
• extend():拼接两个ModuleList
• insert():指定在ModuleList中位置插入网络层
使用列表生成式,通过一行代码就能构建20个网络层。
代码演示:
class ModuleList(nn.Module):def __init__(self):super(ModuleList, self).__init__()# 使用列表生成式构建20个全连接层,每个全连接层10个神经元的网络self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])def forward(self, x):for i, linear in enumerate(self.linears):x = linear(x)return xnet = ModuleList()
2.3 nn.ModuleDict
nn.ModuleDict是 nn.module的容器,用于包装一组网络层,以索引方式调用网络层,可以用过参数的形式选取想要调用的网络层。
主要方法:
• clear():清空ModuleDict
• items():返回可迭代的键值对(key-value pairs)
• keys():返回字典的键(key)
• values():返回字典的值(value)
• pop():返回一对键值,并从字典中删除
代码展示,只选取conv和relu两个网络层:
class ModuleDict(nn.Module):def __init__(self):super(ModuleDict, self).__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})# 激活函数self.activations = nn.ModuleDict({'relu': nn.ReLU(),'prelu': nn.PReLU()})def forward(self, x, choice, act): # 传入两个参数 用来选择网络层x = self.choices[choice](x)x = self.activations[act](x)return x
net = ModuleDict()
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu') #只选取conv和relu两个网络层。
print(output)
2.4 三种容器构建总结
• nn.Sequential:顺序性,各网络层之间严格按顺序执行,常用于block构建
• nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建
• nn.ModuleDict:索引性,常用于可选择的网络层
三、AlexNet网络构建
AlexNet:2012年以高出第二名10多个百分点的准确率获得ImageNet分类任务冠军,开创了卷积神经网络的新时代
AlexNet特点如下:
- 采用ReLU:替换饱和激活函数,减轻梯度消失
- 采用LRN(Local Response Normalization):对数据归一化,减轻梯度消失
- Dropout:提高全连接层的鲁棒性,增加网络的泛化能力
- Data Augmentation:TenCrop,色彩修改
网络结构图如下:
构建代码:
import torch.nn as nn
import torch
from torchsummary import summary
# 定义一个名为AlexNet的神经网络模型,继承自nn.Module基类
class AlexNet(nn.Module):# 构造函数,初始化网络的参数def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:# 调用父类的构造函数super().__init__()# 定义神经网络的特征提取部分,包含多个卷积层和池化层self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), # 输入通道3,输出通道64,卷积核大小11x11,步长4,填充2nn.ReLU(inplace=True), # 使用ReLU激活函数,inplace=True表示原地操作,节省内存nn.MaxPool2d(kernel_size=3, stride=2), # 最大池化层,核大小3x3,步长2nn.Conv2d(64, 192, kernel_size=5, padding=2), # 输入通道64,输出通道192,卷积核大小5x5,填充2nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)# 定义自适应平均池化层,将输入的任意大小的特征图池化为固定大小6x6self.avgpool = nn.AdaptiveAvgPool2d((6, 6))# 定义分类器部分,包含全连接层和Dropout层self.classifier = nn.Sequential(nn.Dropout(p=dropout), # 使用Dropout进行正则化,随机丢弃一部分神经元以防止过拟合nn.Linear(256 * 6 * 6, 4096), # 输入大小为256*6*6,输出大小为4096nn.ReLU(inplace=True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes), # 最后的全连接层输出类别数)# 前向传播函数,定义数据在网络中的传播过程def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.features(x) # 特征提取x = self.avgpool(x) # 平均池化x = torch.flatten(x, 1) # 将特征图展平成一维向量x = self.classifier(x) # 分类器return xif __name__ == '__main__':net = AlexNet().cuda()summary(net, (3, 256, 256))
打印出的网络结构图如下:
相关文章:
pytorch04:网络模型创建
目录 一、模型创建过程1.1 以LeNet网络为例1.2 LeNet结构1.3 nn.Module 二、网络层容器(Containers)2.1 nn.Sequential2.1.1 常规方法实现2.1.2 OrderedDict方法实现 2.2 nn.ModuleList2.3 nn.ModuleDict2.4 三种容器构建总结 三、AlexNet网络构建 一、模型创建过程 1.1 以LeNe…...
用js让用户输入一个数累加和
需求:用户输入一个数, 计算 1 到这个数的和。 比如 用户输入的是 5, 则计算 1~5 之间的累加和 并且输出到控制台 <body><script>let numprompt(请输入一个数)let sum0for(let i1;i<num;i){sumi}console.log(sum)</script…...
踩坑记录-安装nuxt3报错:Error: Failed to download template from registry: fetch failed;
报错复现 安装nuxt3报错:Error: Failed to download template from registry: fetch failednpx nuxi init nuxt-demo 初始化nuxt 项目 报错 Error: Failed to download template from registry: fetch faile 解决方法 配置hosts Mac电脑:/etc/hostswin电…...
大数据学习(31)-Spark非常用及重要特性
&&大数据学习&& 🔥系列专栏: 👑哲学语录: 承认自己的无知,乃是开启智慧的大门 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言📝支持一下博主哦ᾑ…...
【教学类-43-14】 20240103 (4宫格数独:正确版:576套) 不重复的基础模板数量:576套
作品展示::——4宫格 576套不重复模板(48页*12套题) 背景需求: 生成4宫格基础模板768套,观看64页内容时,明显看到有错误 【教学类-43-13】 20240103 (4宫格数独:错误版…...
AIGC开发:调用openai的API接口实现简单机器人
简介 开始进行最简单的使用:通过API调用openai的模型能力 OpenAI的能力如下图: 文本生成模型 OpenAI 的文本生成模型(通常称为生成式预训练 Transformer 或大型语言模型)经过训练可以理解自然语言、代码和图像。这些模型提供文…...
c基础(二)
指针: 含义:是一个值,一个值代表着一个内存地址,类似于存放路径 * 运算符 : 1 字符*表示指针 作用:通常跟在类型关键字的后面,表示指针指向的是什么类型的值 int * foo, * bar;声明指针后会…...
人工智能趋势报告解读:ai野蛮式生长的背后是机遇还是危机?
近期,Enterprise WordPress发布了生成式人工智能在营销中的应用程度的报告,这是一个人工智能迅猛发展的时代,目前人工智能已经广泛运用到内容创作等领域,可以预见的是人工智能及其扩展应用还将延伸到我们工作与生活中的方方面面。…...
三、C语言中的分支与循环—goto语句 (10) (完)
在C语言中,goto语句允许程序无条件地跳转到同一函数内的标记位置。这个标记位置通过一个标签和冒号(:)来标示。goto语句可以用于从深层嵌套的循环或条件语句中直接跳出,或者跳过某些代码执行。尽管goto语句在某些情况下可以使程序逻辑变得清晰࿰…...
RabbitMQ 常见问题
1. 如何保证消息顺序消费 在RabbitMQ中,消息最终会保存在队列中,在同一个队列中,消息是顺序的,保持先进先出的原则,这个由Rabbitmq保证。而不同队列中的消息,RabbitMQ 是无法保证其顺序性。顺序消费主要是…...
阶段二-Day10-日期类
日期类结构: 1.java.util.Date是日期类 2.DateFormat是日期格式类、SimpleDateFormat是日期格式类的子类 Timezone代表时区 3.Calendar是日历类,GregorianCalendar是日历的子类 一. 常用类-Date 1.1 Date构造方法 Date(long date) 使用给定的毫秒时间价值构建…...
多任务并行处理相关面试题
我自己面试时被问过两次多任务并行相关的问题: 假设现在有10个任务,要求同时处理,并且必须所有任务全部完成才返回结果 这个面试题的难点是: 既然要同时处理,那么肯定要用多线程。怎么设计多线程同时处理任务呢&…...
Shell脚本学习笔记
1. 写在前面 工作中,需要用到写一些shell脚本去完成一些简单的重复性工作, 于是就想系统的学习下shell脚本的相关知识, 本篇文章是学习shell脚本整理的学习笔记,内容参考主要来自C语言中文网, 学习过程中,…...
ROS-安装xacro
安装 运行下列命令进行安装,xxxxxx处更改为自己的版本 sudo apt-get install ros-xxxxxx-xacro运行 输入下列命令 roscd xacro如果没有报错,并且进入了xacro软件包的目录,则表示安装成功。 参考: [1]https://wenku.csdn.net/ans…...
为什么说 $mash 是 Solana 上最正统的铭文通证?
早在 2023 年的 11 月,包括 Solana、Avalanche、Polygon、Arbitrum、zkSync 等生态正在承接比特币铭文生态外溢的价值。当然,因铭文赛道过于火爆,当 Avalanche、BNB Chain 以及 Polygon 等链上 Gas 飙升至极值,Arbitrum、zkSync 等…...
安装elasticsearch、kibana、IK分词器、扩展IK词典
安装elasticsearch、kibana、IK分词器、扩展IK词典 后面还会安装kibana,这个会提供可视化界面方面学习。 需要注意的是elasticsearch和kibana版本一定要一样!!! 否则就像这样 elasticsearch 1、创建网络 因为我们还需要部署k…...
Spring中常见的BeanFactory后处理器
常见的BeanFacatory后处理器 先给出没有添加任何BeanFactory后处理器的测试代码 public class TestBeanFactoryPostProcessor {public static void main(String[] args) {GenericApplicationContext context new GenericApplicationContext();context.registerBean("co…...
FPGA LCD1602驱动代码 (已验证)
一.需求解读 1.需求 在液晶屏第一行显示“HELLO FPGA 1234!” 2. 知识背景 1602 液晶也叫 1602 字符型液晶,它是一种专门用来显示字母、数字、符号等的点阵 型液晶模块。它由若干个 5X7 或者 5X11 等点阵字符位组成,每个点阵字符位都可以显示一 个字符,每位之间有一个点距的…...
c++编程要养成的好习惯
1、缩进 你说有缩进看的清楚还是没缩进看的清楚 2、i和i i运行起来和i更快 3、 n%20和n&1 不要再用n%20来判断n是不是偶数了,又慢又土,用n&10,如果n&10就说明n是偶数 同理,n&11说明n是奇数 4、*2和<<…...
后台管理项目的多数据源方案
引言 在互联网开发公司中,往往伴随着业务的快速迭代,程序员可能没有过多的时间去思考技术扩展的相关问题,长久下来导致技术过于单一。为此最近在学习互联网思维,从相对简单的功能开始做总结,比如非常常见的基础数据的…...
视频美颜SDK趋势畅想:未来发展方向与应用场景
当下,视频美颜SDK正不断演进,本文将深入探讨视频美颜SDK的发展趋势,探讨未来可能的方向和广泛的应用场景。 1.深度学习与视频美颜的融合 未来,我们可以期待看到更多基于深度学习算法的视频美颜SDK,为用户提供更高质量…...
C++ const 限定符的全面介绍
C const 限定符的全面介绍 1. const 修饰基本数据类型 定义 const 修饰的基本数据类型变量,值不可改变。 语法 const type variable value;特点 不可变性,增加代码可读性。 作用 定义不可修改的常量。 使用场景 全局常量、配置项。 注意事项…...
Vue 中的 ref 与 reactive:让你的应用更具响应性(上)
🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…...
华为云CCE-集群内访问-根据ip访问同个pod
华为云CCE-集群内访问-根据ip访问同个pod 问题描述:架构如下:解决方法: 问题描述: 使用service集群内访问时,由于启用了两个pod,导致请求轮询在两个pod之间,无法返回正确的结果。 架构如下&am…...
Kasada p.js (x-kpsdk-cd、x-kpsdk-ct、integrity)
提供x-kpsdk-cd的API服务 详细请私信~ 可试用~ V:zhzhsgg 一、简述 integrity是通过身份验证Kasada检测机器人流量后获得的一个检测结果(数据完整性) x-kpsdk-cd 是经过编码计算等等获得。当你得到正确的解决验证码值之后,解码会看到如下图…...
Thinkphp 5框架学习
TP框架主要是通过URL实现操作 http://servername/index.php/模块/控制器/操作/参数/值.. index.php 为入口文件,在 public 目录内的 index.php 文件; 模块在 application 目录下默认有一个 index 目录,这就是一个模块; 而在 index 目录下有一个 contro…...
麒麟云增加计算节点
一、安装基座系统并配置好各项设置 追加的计算节点服务器,安装好系统,把主机名、网络网线(网线要和其他网线插的位置一样)、hosts这些配置好,在所有节点的/etc/hosts里面添加信息 在控制节点添加/kylincloud/multinod…...
使用Redis进行搜索
文章目录 构建反向索引 构建反向索引 在Begin-End区域编写 tokenize(content) 函数,实现文本标记化的功能,具体参数与要求如下: 方法参数 content 为待标记化的文本; 文本标记的实现:使用正则表达式提取全小写化后的…...
Oracle修改用户密码
文章目录 Oracle修改用户密码Oracle用户锁定常见的两种状态Oracle用户锁定和解锁 Oracle修改用户密码 使用sys或system使用sysdba权限登录,然后执行以下命令修改密码: alter user 用户名 identified by 密码;密码过期导致的锁定,也通过修改…...
LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
目录链接: 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目: https://github.com/September26/java-algorithms 原题链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 描述: 圣诞活动预…...
网站推广好做吗/广州网站制作实力乐云seo
public int findTheWinner(int n, int k) {int ans0;for(int i2;i<n;i){//最后的人一定在0位置处,每次去掉一人,他的位置左移k,要想知道他的初始位置则要kans(ansk)%i;}return ans1;}...
贝壳找房网站做销售/服务外包平台
关于模块报错 查看报错详细信息,右击打开报错如下图: window7 iis7 发布asp.net网站时报“HTTP 错误 500.21 - Internal Server Error 处理程序“PageHandlerFactory-Integrated”在其模块列表中有一个错误模块“ManagedPipelineHandler””的错…...
企业推广软件/seo软件工具
分段式多级离心泵的工作原理:分段式多级离心泵中部的每个叶轮都配有一个导向轮,该导向轮是一个固定盘。 其功能是通过减速将从叶轮喷出的液体的部分动能转换为静能,并收集沿径向返回的液体,将其引导至下一个叶轮组。 叶轮搜索的前…...
武汉网站推广公司招聘/公司网站设计的内容有哪些
前言 上节讲到qt for android开发百度地图,已经可以打开地图了,里面的一些功能,如自定义搜索栏,添加控件等等,这些百度地图官方开发文档都提供了例子,可以自定义开发。但是问题也来了,我们开发百…...
静态企业网站下载/app优化建议
应盆友需求,自己在闲暇时间整个了Unifier 用户的接口的DEMO,版本 v19.12 当然,通过导入用户EXCEL系统原生也是可以实现的,这里更多考虑的是业务逻辑和操作效率(体验),毕竟以集成接口的方式&…...