当前位置: 首页 > news >正文

[Machine Learning]pytorch手搓一个神经网络模型

因为之前虽然写过一点点关于pytorch的东西,但是用的还是他太少了。

这次从头开始,尝试着搓出一个神经网络模型

(因为没有什么训练数据,所以最后的训练部分使用可能不太好跑起来的代码作为演示,如果有需要自己连上数据集合进行修改捏)

1.先阐述一下什么是神经网络块(block)

一般来说,我们之前遇到的一些神经网络,网络中是这样子的结构

net----> layer ----> neuron

而块的存在,就是给这样一个神经网络的整体做了一个封装操作,让神经网络能复合实现一些功能。

结构就变成了这个样子(图片来自D2l)

这样子,神经网络结构就变成了四层

  block ----》 net ---》layer ---》neuron

这样子自然是可以使用诸如一些奇怪的方法,通过三层索引去进行调用什么的,不过这个我们到后面再说。先看一下如何构建一个块。

我们这里构建了一个类,这个类的计算方法实际就是实现了几个层的输入和输出,相当与封装了一个神经网络。

class MLP(nn.Module):# 用模型参数声明层。这里,我们声明两个全连接的层def __init__(self):# 调用MLP的父类Module的构造函数来执行必要的初始化。# 这样,在类实例化时也可以指定其他函数参数,例如模型参数params(稍后将介绍)super().__init__()self.hidden = nn.Linear(20, 256)  # 隐藏层self.out = nn.Linear(256, 10)  # 输出层# 定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self, X):# 注意,这里我们使用ReLU的函数版本,其在nn.functional模块中定义。return self.out(F.relu(self.hidden(X)))# 这个东西就相当与先隐藏层,然后relu,然后最后进行一次输出#创建这个神经网路块,然后开始输出
net = MLP()
net(X)

这段代码没有使用squential容器进行封装,但是可以很清楚地看到我们定义了两层(隐藏层256个神经元,输出层10个神经元,不知道为什么没用softmax函数),并且在返回函数计算的时候,中间还经过了一步‘relu’激活函数的操作

(注意和tf不同,pytorch框架下面是不能把激活函数存入层中的,需要单独作为一个‘层’来进行一个输入和输出的控制)

注意下(在后面自定义层的时候也是这样子)由于继承了nn.Module这个类 , 所以我们必须要实现两个函数,首先是_init_,这个在python中是最终要的构造函数。其次就是forward,我对py不是很了解,不过这应该是通过面向对象实现的集成。forward这个方法就是向前传播,也就是接受参数,内部计算,然后返回值传递下去。

我们直接给net对象传递我们随机生成的两条数据的时候,底层时就调用了这个函数。

其他的一些比如sequential的实现方法,在这里我们就不加以赘述了。

为了更好的解释forward这个函数的作用,在这里我们自己创建一个单层,通过类创建,仍然是获取一个集成nn.Module的类,然后内部设置好初始化(为了创建对象),设置好向前传播(为了用来调用)

# 自定义一个不需要参数的层
class CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):    #该层向前传播的方法return X - X.mean()
# 这一层最终也是返回一个张量
# sequential是一个简单的线性封装容器,所以只要是符合输入张量,输出张量
# 并且在内部会调用他们的forward方法layer = CenteredLayer()
print('自定义层,每个元素都 - 平均值2',layer(torch.FloatTensor([1, 2, 3, 4, 5])))

这个单层的效果就是对每个元素,都减去平均值。

并且如果想的话,我们也可以创建一些拥有自己属性的层


#现在创建一个带有权重和偏好
class MyLinear(nn.Module):def __init__(self, in_units, units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, units)) #这个需要手动输入一下输入特征数目还有神经元数目self.bias = nn.Parameter(torch.randn(units,))            def forward(self, X):linear = torch.matmul(X, self.weight.data) + self.bias.data  #向前传播其实就是接受输入return F.relu(linear)
#创建一个层,这个层可以直接用在sequential之中
linear = MyLinear(5, 3) #五个输入三个神经元

这里可以看到只要重写了forward方法,那么这个类就能变成一个能用来计算的类,甚至是一个层可以单独计算。并且这样子写好以后是可以放在sequential容器中,作为一个统一训练的。

因此,如果我们有多个块的话,也是可以自己去写一个容器,进行组合。

#创建一个新类
class MySequential(nn.Module): #()就是py中的继承语法def __init__(self, *args):super().__init__()for idx, module in enumerate(args):# 这里,module是Module子类的一个实例。我们把它保存在'Module'类的成员# 变量_modules中。_module的类型是OrderedDictself._modules[str(idx)] = moduledef forward(self, X):# OrderedDict保证了按照成员添加的顺序遍历它们for block in self._modules.values():X = block(X)return X#这个类实现的效果就类似原声的sequential
net = MySequential(net1,net2,net3)
print(net(X))

这个就大概是在拼接块,层的时候,内部所做的底层原理。

当然直接用sequential容器是更加省力气的方法,对吧

2.关于参数如何进行检查

假设现在有一个单独的神经网络

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))

众所周知,这个神经网络是两层(中间的一层是激活函数我们不做讨论)

我们可以通过索引来调用和获取某个层的属性

#返回结果是这个全链接层的weight和bias,正好对应八个神经元
print(net[0].state_dict())
#检查参数
print(net[2].bias) #还会返回一些具体的属性
print(net[2].bias.data) #单纯的数据

对于block组成的神经网路社区中(我也不知道很多块组在一起应该叫什么了),仍然是一个嵌套的结构,我们可以创建这样一个社区

#这段代码其实也能看出来,sequential也是一个能容纳block的东西
#     容器 --》 block --》 layer --》 神经元   这三层架构(或者说四层)
def block1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),nn.Linear(8, 4), nn.ReLU())
def block2():net = nn.Sequential()for i in range(4):# 在这里嵌套net.add_module(f'block {i}', block1())return netrgnet = nn.Sequential(block2(), nn.Linear(4, 1))

然后我们对这个rgnet进行打印,可以直接看到工作状态

#这样子打印会展示整个网络的状态
print('检查工作状态',rgnet)

可以很清晰地看到,这样一个嵌套结构

所以比如说我们想要访问第一个社区中,第2个块,中的第一个层中的参数,我们可以直接这样子读取

rgnet[0][1][0].bias.data

另外如果想要对已经形成的模型做初始化,这里还有一个例子

def init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)   #平均值0,标准差为0.01nn.init.zeros_(m.bias)                        #偏移直接设置为0
net.apply(init_normal)
print('手动初始化的效果为',net[0].weight.data[0],'手动初始化bias:', net[0].bias.data[0])

函数实现的功能是先检测传进来的是不是正常的线性层,然后分别初始化。

补充一下,apply函数和js里的用法差不多,对内部的每个单元进行遍历,然后做一些操作。

(当然这不是唯一一种方法,自然还有别的)。

3.关于张量的保存和获取

在pytorch中,张量的保存主要有两种形式,第一种是保存数据,用于其他模型的训练

#===========读写张量===========#
x = torch.arange(4)       #[0,1,2,3],创建了一个张量
torch.save(x, 'x-file')   #这是保存在x-file这个文件下面的
loaded_x = torch.load('x-file')  #反过来加载
print(loaded_x)                  #输出
#这样子读取列表和读出,也可以使用字典{x:x,Y:y}或者列表[x,y],反正是变成文件形式了

另一种是保存模型的参数,可以直接套在其他模型上

#=====读写参数并且保存在内存=====#class MLP(nn.Module):   #手动创建多层感知机def __init__(self):super().__init__()self.hidden = nn.Linear(20, 256)self.output = nn.Linear(256, 10)def forward(self, x):return self.output(F.relu(self.hidden(x)))net = MLP()           #构建对象
print('MLP的参数',net.state_dict()) #这里输出一下参数#保存这个模型的参数
torch.save(net.state_dict(), 'mlp.params')#然后对一个新模型使用这个参数
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))  #内置函数加载参数clone.eval()#设置为评估模式,禁止训练什么的,这应该是module中附带的功能print('clone的参数',clone.state_dict())#可以看到参数被完全复制了

但是注意一个问题,如果使用另一个模型初始化自身的时候,要保证两个模型的结构一致

相关文章:

[Machine Learning]pytorch手搓一个神经网络模型

因为之前虽然写过一点点关于pytorch的东西,但是用的还是他太少了。 这次从头开始,尝试着搓出一个神经网络模型 (因为没有什么训练数据,所以最后的训练部分使用可能不太好跑起来的代码作为演示,如果有需要自己连上数据…...

KdMapper扩展实现之Dell(pcdsrvc_x64.pkms)

1.背景 KdMapper是一个利用intel的驱动漏洞可以无痕的加载未经签名的驱动,本文是利用其它漏洞(参考《【转载】利用签名驱动漏洞加载未签名驱动》)做相应的修改以实现类似功能。需要大家对KdMapper的代码有一定了解。 2.驱动信息 驱动名称pcds…...

python和go相互调用的两种方法

前言 Python 和 Go 语言是两种不同的编程语言,它们分别有自己的优势和适用场景。在一些项目中,由于团队内已有的技术栈或者某一部分业务的需求,可能需要 Python 和 Go 相互调用,以此来提升效率和性能。 性能优势 Go 通常比 Python 更高效&…...

c# 分部视图笔记

Html.Partial("**", 1) public ActionResult **(int page) { ViewBag.page page; return PartialView("**"); }...

Vue3最佳实践 第七章 TypeScript 中

Vue组件中TypeScript 在Vue组件中,我们可以使用TypeScript进行各种类型的设置,包括props、Reactive和ref等。下面,让我们详细地探讨一下这些设置。 设置描述设置props在Vue中,props本身就具有类型设定的功能。但如果你希望使用Ty…...

(三)行为模式:8、状态模式(State Pattern)(C++示例)

目录 1、状态模式(State Pattern)含义 2、状态模式的UML图学习 3、状态模式的应用场景 4、状态模式的优缺点 (1)优点 (2)缺点 5、C实现状态模式的实例 1、状态模式(State Pattern&#x…...

nginx的配置文件概述及简单demo(二)

默认配置文件 当安装完nginx后,它的目录下通常有默认的配置文件 #user nobody; worker_processes 1;#error_log logs/error.log; #error_log logs/error.log notice; #error_log logs/error.log info;#pid logs/nginx.pid;events {worker_connection…...

Apollo Planning2.0决策规划算法代码详细解析 (2): vscode gdb单步调试环境搭建

前言: apollo planning2.0 在新版本中在降低学习和二次开发成本上进行了一些重要的优化,重要的优化有接口优化、task插件化、配置参数改造等。 GNU symbolic debugger,简称「GDB 调试器」,是 Linux 平台下最常用的一款程序调试器。GDB 编译器通常以 gdb 命令的形式在终端…...

flex 布局:元素/文字靠右

前言 略 使用flex的justify-content属性控制元素的摆放位置 靠右 <view class"more">展开更多<text class"iconfont20231007 icon-zhankai"></text></view>.more {display: flex;flex-direction: row;color: #636363;justify-co…...

java基础-第1章-走进java世界

一、计算机基础知识 常用的DOS命令 二、计算机语言介绍 三、Java语言概述 四、Java环境的搭建 JDK安装图解 环境变量的配置 配置环境变量意义 配置环境变量步骤 五、第一个Java程序 编写Java源程序 编译Java源文件 运行Java程序 六、Java语言运行机制 核心机制—Java虚拟机 核…...

jvm 堆内存 栈内存 大小设置

4种方式配置不同作用域的jvm的堆栈内存。 1、Eclise 中设置jvm内存: 改动eclipse的配置文件,对全部project都起作用 改动eclipse根文件夹下的eclipse.ini文件 -vmargs //虚拟机设置 -Xms40m //初始内存 -Xmx256m //最大内存 -Xmn16m //最小内存 -XX:PermSize=128M //非堆内…...

免杀对抗-反沙盒+反调试

反VT-沙盒检测-Go&Python 介绍&#xff1a; 近年来&#xff0c;各类恶意软件层出不穷&#xff0c;反病毒软件也更新了各种检测方案以提高检率。 其中比较有效的方案是动态沙箱检测技术&#xff0c;即通过在沙箱中运行程序并观察程序行为来判断程序是否为恶意程序。简单来说…...

QTimer类的使用方法

本文介绍QTimer类的使用方法。 1.单次触发 在某些情况下&#xff0c;定时器只运行一次&#xff0c;可使用单次触发方式。 QTimer *timer new QTimer(this); connect(timer, &QTimer::timeout, this, &MainWindow::timeout); timer->setSingleShot(true); timer-…...

(三)行为模式:9、空对象模式(Null Object Pattern)(C++示例)

目录 1、空对象模式&#xff08;Null Object Pattern&#xff09;含义 2、空对象模式的主要涉及以下几个角色 3、空对象模式的应用场景 4、空对象模式的优缺点 &#xff08;1&#xff09;优点 &#xff08;2&#xff09;缺点 5、C实现空对象模式的实例 1、空对象模式&am…...

Django实战项目-学习任务系统-用户登录

第一步&#xff1a;先创建一个Django应用程序框架代码 1&#xff0c;先创建一个Django项目 django-admin startproject mysite将创建一个目录&#xff0c;其布局如下&#xff1a;mysite/manage.pymysite/__init__.pysettings.pyurls.pyasgi.pywsgi.py 2&#xff0c;再创建一个…...

【动手学深度学习-Pytorch版】Transformer代码总结

本文是纯纯的撸代码讲解&#xff0c;没有任何Transformer的基础内容~ 是从0榨干Transformer代码系列&#xff0c;借用的是李沐老师上课时讲解的代码。 本文是根据每个模块的实现过程来进行讲解的。如果您想获取关于Transformer具体的实现细节&#xff08;不含代码&#xff09;可…...

做外贸独立站选Shopify还是WordPress?

现在确实会有很多新人想做独立站&#xff0c;毕竟跨境电商平台内卷严重&#xff0c;平台规则限制不断升级&#xff0c;脱离平台“绑架”布局独立站&#xff0c;才能获得更多流量、订单、塑造品牌价值。然而&#xff0c;在选择建立外贸独立站的过程中&#xff0c;选择适合的建站…...

echarts的bug,在series里写tooltip,不起作用,要在全局先写tooltip:{}才起作用,如果在series里写的不起作用就写到全局里

echarts的bug&#xff0c;在series里写tooltip&#xff0c;不起作用&#xff0c;要在全局先写tooltip&#xff1a;{show:true}才起作用&#xff0c;如果在series里写的不起作用就写到全局里 series里写tooltip不起作用&#xff0c;鼠标悬浮在echarts图表上时不显示提示 你需要…...

jmeter分布式压测

一、什么是压力测试&#xff1f; 压力测试&#xff08;Stress Test&#xff09;&#xff0c;也称为强度测试、负载测试&#xff0c;属于性能测试的范畴。 压力测试是模拟实际应用的软硬件环境及用户使用过程的系统负荷&#xff0c;长时间或超大负荷地运行被测软件系统&#xff…...

consulmanage部署

一、部署consul 使用yum方式部署consul yum install -y yum-utils yum-config-manager --add-repo https://rpm.releases.hashicorp.com/RHEL/hashicorp.repo yum -y install consul 执行以下命令获取uuid密钥并记录下来 uuidgen 编辑consul配置文件 vi /etc/consul.d/consul.h…...

大数据软件项目的验收流程

大数据软件项目的验收流程是确保项目交付符合预期需求和质量标准的关键步骤。以下是一般的大数据软件项目验收流程&#xff0c;希望对大家有所帮助。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff0c;欢迎交流合作。 1.项目验收计划制定&#xff1a; 在…...

《第一行代码Andorid》阅读笔记-第一章

这篇文章是我自己的《第一行代码Andorid》的阅读笔记&#xff0c;虽然大量参考了别人已经写好的一些笔记和代码但是也有自己的提炼和新的问题在里面&#xff0c;我也会放上参考文章链接。 学习重点 Android系统的四大组件&#xff1a; &#xff08;1&#xff09;活动&#xff…...

Educational Codeforces Round 146 (Rated for Div. 2)(VP)

写个题解 A. Coins void solve(){ll n, k; cin >> n >> k;bl ok true;if (n &1 && k %2 0) ok false;print(ok ? yes : no); } B. Long Legs void solve(){db x, y; cin >> x >> y;if (x < y) swap(x, y);int t1 ceil(sqrt(x))…...

9.30国庆

消息队列完成进程间通信 #include <myhead.h>#define size sizeof(msg_ds)-sizeof(long) //正文大小//消息结构体 typedef struct {long msgtype; //消息类型char data[1024]; //消息正文 }msg_ds;//创建子线程构造体 void *task1(void *arg) {//创造第二个key值ke…...

java基础-第4章-面向对象(二)

一、static关键字 静态&#xff08;static&#xff09;可以修饰属性和方法。 称为静态属性&#xff08;类属性&#xff09;、静态方法&#xff08;类方法&#xff09;。 静态成员是全类所有对象共享的成员。 在全类中只有一份&#xff0c;不因创建多个对象而产生多份。 不必创…...

flex加 grid 布局笔记

<style> .flex-container { display: flex; height: 100%; /* 设置容器的高度 */ } .wide { display: flex; padding: 10px; border: 1px solid lightgray; text-align: center; justify-content: …...

最高评级!华为云CodeArts Board获信通院软件研发效能度量平台先进级认证

9月26日&#xff0c;华为云CodeArts Board获得了中国信通院《云上软件研发效能度量分级模型》的先进级最高级评估&#xff0c;达到了软件研发效能度量平台评估的通用效能度量能力、组织效能模型、项目效能模型、资源效能模型、个人效能模型、研发效能评价模型、项目管理域、开发…...

图像上传功能实现

一、后端 文件存放在images.path路径下 package com.like.common;import jakarta.servlet.ServletOutputStream; import jakarta.servlet.http.HttpServletResponse; import org.springframework.beans.factory.annotation.Value; import org.springframework.web.bind.annot…...

03_Node.js模块化开发

1 Node.js的基本使用 1.1 NPM nodejs安装完成后&#xff0c;会跟随着自动安装另外一个工具npm。 NPM的全称是Node Package Manager&#xff0c;是一个NodeJS包管理和分发工具&#xff0c;已经成为了非官方的发布Node模块&#xff08;包&#xff09;的标准。 2020年3月17日&…...

Nginx支持SNI证书,已经ssl_server_name的使用

整理了一些网上的资料&#xff0c;这里记录一下&#xff0c;供大家参考 什么是SNI&#xff1f; 传统的应用场景中&#xff0c;一台服务器对应一个IP地址&#xff0c;一个域名&#xff0c;使用一张包含了域名信息的证书。随着云计算技术的普及&#xff0c;在云中的虚拟机有了一…...

扁平化wordpress主题/成都搜狗seo

Spring Mvc在所有内部日志中使用Commons Logging 默认情况下&#xff0c;Spring Boot会用Logback来记录日志&#xff0c; 假如maven依赖中添加了spring-boot-starter-logging&#xff1a; 那么&#xff0c;我们的Spring Boot应用将自动使用logback作为应用日志框架&#xff…...

滁州网站开发/互联网运营推广

知识点1----ALTER 下列代码意义&#xff1a;向已存在的表my_foods中新增自动排列的列 作为主键 ALTER TABLE my_contacts  --表名称ADD COLUMN id INT NOT NULL AUTO_INCREMENT FIRST,   --新的 列 id&#xff0c;自动排列&#xff0c;该列于第一位 ADD PRIMARY KEY (id);…...

建立网站的技术路径/免费网站外链推广

for循环&#xff1a;语法 for循环的好基友是数组 for&#xff08;初始值&#xff08;var a 1&#xff09;&#xff1b;循环条件&#xff08;a<10&#xff09;;改变条件&#xff08;a&#xff09;&#xff09; { 写在内面表示一起循环 }外面只一次 初始值定义变量可以定义多个…...

武汉网站模板/百分百营销软件官网

二次联通门 : BZOJ 1858: [Scoi2010]序列操作 /*BZOJ 1858: [Scoi2010]序列操作已经...没有什么好怕的的了...16K的代码...调个MMP啊...*/ #include <cstdio>void read (int &now) {now 0;register char word getchar ();while (word < 0 || word > 9)word …...

丽水网站建设报价/百度广告客服电话

目录 HMR是什么 使用场景 配置使用HMR 配置webpack解析webpack打包后的文件内容配置HMRHMR原理 debug服务端源码 服务端简易实现服务端调试阶段 debug客户端源码 客户端简易实现客户端调试阶段问题总结 HMR是什么 HMR即Hot Module Replacement是指当你对代码修改并保存后&…...

商务网站建设实训总结/如何做游戏推广

外部中断&#xff0c;通过按键来控制LED灯...