当前位置: 首页 > news >正文

【深度学习实验七】 自动梯度计算

目录

一、利用预定义算子重新实现前馈神经网络

(1)使用pytorch的预定义算子来重新实现二分类任务 

 (2)完善Runner类

(3) 模型训练 

(4)性能评价

二、增加一个3个神经元的隐藏层,再次实现二分类,并与1做对比

三、自定义隐藏层层数和每个隐藏层中的神经元个数,尝试找到最优超参数完成二分类。可以适当修改数据集,便于探索超参数。


一、利用预定义算子重新实现前馈神经网络

点击查看已经实现的前馈神经网络

(1)使用pytorch的预定义算子来重新实现二分类任务 

导入必要的库和模块: 

from data import make_moons
from nndl import accuracy
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
from Runner2_1 import RunnerV2_2
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

定义的网络结构 Model_MLP_L2_V2

class Model_MLP_L2_V2(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Model_MLP_L2_V2, self).__init__()# 定义第一层线性层self.fc1 = nn.Linear(input_size, hidden_size)# 使用正态分布初始化权重和偏置self.fc1.weight.data = torch.normal(mean=0.0, std=1.0, size=self.fc1.weight.data.size())self.fc1.bias.data.fill_(0.0)  # 常量初始化偏置为0# 定义第二层线性层self.fc2 = nn.Linear(hidden_size, output_size)self.fc2.weight.data = torch.normal(mean=0.0, std=1.0, size=self.fc2.weight.data.size())self.fc2.bias.data.fill_(0.0)  # 常量初始化偏置为0# 定义Logistic激活函数self.act_fn = torch.sigmoidself.layers = [self.fc1, self.act_fn, self.fc2,self.act_fn]# 前向计算def forward(self, inputs):z1 = self.fc1(inputs)a1 = self.act_fn(z1)z2 = self.fc2(a1)a2 = self.act_fn(z2)return a2

数据集构建和划分:

# 数据集构建
n_samples = 1000
X, y = make_moons(n_samples=n_samples, shuffle=True, noise=0.2)
# 划分数据集
num_train = 640  # 训练集样本数量
num_dev = 160    # 验证集样本数量
num_test = 200   # 测试集样本数量
# 根据指定数量划分数据集
X_train, y_train = X[:num_train], y[:num_train]  # 训练集
X_dev, y_dev = X[num_train:num_train + num_dev], y[num_train:num_train + num_dev]  # 验证集
X_test, y_test = X[num_train + num_dev:], y[num_train + num_dev:]  # 测试集
# 调整标签的形状,将其转换为[N, 1]的格式
y_train = y_train.reshape([-1, 1])
y_dev = y_dev.reshape([-1, 1])
y_test = y_test.reshape([-1, 1])可视化生成的数据集
plt.figure(figsize=(5, 5))  # 设置图形大小
plt.scatter(x=X[:, 0], y=X[:, 1], marker='*', c=y, cmap='viridis')  # 绘制散点图
plt.xlim(-3, 4)  # 设置x轴范围
plt.ylim(-3, 4)  # 设置y轴范围
plt.grid(True, linestyle='--', alpha=0.3)  # 添加网格
plt.show()  # 显示图形

 (2)完善Runner类

        基于上一节实现的 RunnerV2_1 类,本节的 RunnerV2_2 类在训练过程中使用自动梯度计算;模型保存时,使用state_dict方法获取模型参数;模型加载时,使用set_state_dict方法加载模型参数.

import os
import torch
class RunnerV2_2(object):def __init__(self, model, optimizer, metric, loss_fn, **kwargs):self.model = modelself.optimizer = optimizerself.loss_fn = loss_fnself.metric = metric# 记录训练过程中的评估指标变化情况self.train_scores = []self.dev_scores = []# 记录训练过程中的评价指标变化情况self.train_loss = []self.dev_loss = []def train(self, train_set, dev_set, **kwargs):# 将模型切换为训练模式self.model.train()# 传入训练轮数,如果没有传入值则默认为0num_epochs = kwargs.get("num_epochs", 0)# 传入log打印频率,如果没有传入值则默认为100log_epochs = kwargs.get("log_epochs", 100)# 传入模型保存路径,如果没有传入值则默认为"best_model.pdparams"save_path = kwargs.get("save_path", "best_model.pdparams")# log打印函数,如果没有传入则默认为"None"custom_print_log = kwargs.get("custom_print_log", None)# 记录全局最优指标best_score = 0# 进行num_epochs轮训练for epoch in range(num_epochs):X, y = train_set# 获取模型预测logits = self.model(X)# 计算交叉熵损失trn_loss = self.loss_fn(logits, y)self.train_loss.append(trn_loss.item())# 计算评估指标trn_score = self.metric(logits, y)self.train_scores.append(trn_score)# 自动计算参数梯度trn_loss.backward()if custom_print_log is not None:# 打印每一层的梯度custom_print_log(self)# 参数更新self.optimizer.step()# 清空梯度self.optimizer.zero_grad()dev_score, dev_loss = self.evaluate(dev_set)# 如果当前指标为最优指标,保存该模型if dev_score > best_score:self.save_model(save_path)print(f"[Evaluate] best accuracy performence has been updated: {best_score:.5f} --> {dev_score:.5f}")best_score = dev_scoreif log_epochs and epoch % log_epochs == 0:print(f"[Train] epoch: {epoch}/{num_epochs}, loss: {trn_loss.item()}")# 模型评估阶段,使用'paddle.no_grad()'控制不计算和存储梯度@torch.no_grad()def evaluate(self, data_set):# 将模型切换为评估模式self.model.eval()X, y = data_set# 计算模型输出logits = self.model(X)# 计算损失函数loss = self.loss_fn(logits, y).item()self.dev_loss.append(loss)# 计算评估指标score = self.metric(logits, y)self.dev_scores.append(score)return score, loss# 模型测试阶段,使用'paddle.no_grad()'控制不计算和存储梯度@torch.no_grad()def predict(self, X):# 将模型切换为评估模式self.model.eval()return self.model(X)# 使用'model.state_dict()'获取模型参数,并进行保存def save_model(self, saved_path):torch.save(self.model.state_dict(), saved_path)# 使用'model.set_state_dict'加载模型参数def load_model(self, model_path):state_dict = torch.load(model_path ,weights_only=True)self.model.load_state_dict(state_dict)

(3) 模型训练 

实例化RunnerV2类,并传入训练配置,代码实现如下:

# 定义训练参数
epoch_num = 1000  # 训练轮数
model_saved_dir = "best_model.pdparams"  # 模型保存目录
# 网络参数
input_size = 2  # 输入层维度为2
hidden_size = 5  # 隐藏层维度为5
output_size = 1  # 输出层维度为1
# 定义多层感知机模型
model = Model_MLP_L2_V2(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# 定义损失函数
loss_fn =F.binary_cross_entropy
# 定义优化器,设置学习率
learning_rate = 0.2
optimizer = torch.optim.SGD(params=model.parameters(),lr=learning_rate)
# 定义评价方法
metric = accuracy
# 实例化RunnerV2_1类,并传入训练配置
runner = RunnerV2_2(model, optimizer, metric, loss_fn)
# 训练模型
runner.train([X_train, y_train], [X_dev, y_dev], num_epochs=epoch_num, log_epochs=50, save_dir=model_saved_dir)

输出结果:

[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.48125
[Train] epoch: 0/1000, loss: 0.7482572793960571
[Evaluate] best accuracy performence has been updated: 0.48125 --> 0.50000
[Evaluate] best accuracy performence has been updated: 0.50000 --> 0.53750
[Evaluate] best accuracy performence has been updated: 0.53750 --> 0.60625
[Evaluate] best accuracy performence has been updated: 0.60625 --> 0.71250
[Evaluate] best accuracy performence has been updated: 0.71250 --> 0.73750
[Evaluate] best accuracy performence has been updated: 0.73750 --> 0.77500
[Evaluate] best accuracy performence has been updated: 0.77500 --> 0.78750
[Evaluate] best accuracy performence has been updated: 0.78750 --> 0.79375
[Evaluate] best accuracy performence has been updated: 0.79375 --> 0.80000
[Evaluate] best accuracy performence has been updated: 0.80000 --> 0.81250
[Train] epoch: 50/1000, loss: 0.4034937918186188
[Train] epoch: 100/1000, loss: 0.36812323331832886
[Train] epoch: 150/1000, loss: 0.3453332781791687
[Evaluate] best accuracy performence has been updated: 0.81250 --> 0.81875
[Evaluate] best accuracy performence has been updated: 0.81875 --> 0.82500
[Evaluate] best accuracy performence has been updated: 0.82500 --> 0.83125
[Evaluate] best accuracy performence h

相关文章:

【深度学习实验七】 自动梯度计算

目录 一、利用预定义算子重新实现前馈神经网络 (1)使用pytorch的预定义算子来重新实现二分类任务 (2)完善Runner类 (3) 模型训练 (4)性能评价 二、增加一个3个神经元的隐藏层,再次实现二分类,并与1做对比 三、自定义隐藏层层数和每个隐藏层中的神经元个数,尝…...

JAVA毕业设计192—基于Java+Springboot+vue的个人博客管理系统(源代码+数据库+万字论文+开题+任务书)

毕设所有选题: https://blog.csdn.net/2303_76227485/article/details/131104075 基于JavaSpringbootvue的个人博客管理系统(源代码数据库万字论文开题任务书)192 一、系统介绍 本项目前后端分离,分为用户、管理员两种角色,角色菜单可自行…...

must be ‘pom‘ but is ‘jar‘解决思路

这个错误信息表明在 Maven 的 pom.xml 文件中,定义的父 POM 的 packaging 类型设置不正确。具体来说,它应该是 pom 类型,但当前设置为 jar。这个问题通常会导致构建失败。以下是解决这个问题的步骤。 解决步骤 检查父 POM 的 packaging 类型…...

STM32启动文件浅析

目录 STM32启动文件简介启动文件中的一些指令 启动文件代码详解栈空间的开辟堆空间的开辟中断向量表定义(简称:向量表)复位程序对于weak的理解对于_main函数的分析 中断服务程序用户堆栈初始化 系统启动流程 STM32启动文件简介 STM32启动文件…...

h5页面与小程序页面互相跳转

小程序跳转h5页面 一个home页 /pages/home/home 一个含有点击事件的元素&#xff1a;<button type"primary" bind:tap"toWebView">点击跳转h5页面</button>toWebView(){ wx.navigateTo({ url: /pages/webview/webview }) } 一个webView页 /pa…...

探索 JavaScript 事件机制(四):React 合成事件系统

前言 在前端开发中&#xff0c;事件处理是不可或缺的一部分。在众多的前端框架中&#xff0c;React 凭借其高效和灵活性受到众多开发者的喜爱。React 的事件处理系统&#xff0c;即“合成事件系统”&#xff0c;是其性能优化的一大亮点。 本文将带你深入浅出地探索 React 的合…...

openlayers 封装加载本地geojson数据 - vue3

Geojson数据是矢量数据&#xff0c;主要是点、线、面数据集合 Geojson数据获取&#xff1a;DataV.GeoAtlas地理小工具系列 实现代码如下&#xff1a; import {ref,toRaw} from vue; import { Vector as VectorLayer } from ol/layer.js; import { Vector as VectorSource } fr…...

手机号码携号转网查询接口-在线手机号码携号转网查询-手机号码携号转网查询API

接口简介&#xff1a;通过手机号精准查询该号码转网前及转网后所归属运营商 可查询号码是否为虚拟手机号 可查询到号码归属地信息 高准确率&#xff0c;实时查询运营商数据库 多用于营销场景&#xff0c;如运营商业务办理、客户信息查询、携号转网、电话营销等 接口地址&#x…...

yolo目标检测和姿态识别和目标追踪

要检测摄像头画面中有多少人&#xff0c;人一排排坐着&#xff0c;像教室那样。由于摄像头高度和角度的原因&#xff0c;有的人会被遮挡。 yolo v5 首先需要下载yolo v5官方代码&#xff0c;可以克隆或下载主分支的代码&#xff0c;或者下载release中发布的。 简单说一下环境…...

Docker搭建开源Web云桌面操作系统Puter和DaedalOS

文章目录 Puter 操作系统说明基于 Docker 启动 Puter 操作系统拉取镜像运行容器基于 Docker-Compose 启动 Puter操作系统创建目录编写docker-compose.yml运行在本地直接运行puter操作系统puter界面截图puter个人使用总结构建自己的Puter镜像daedalos基于web的操作系统说明技术特…...

FAQ-为什么交换机发给服务器的日志显示的时间少8小时

问题描述 配置交换机向日志服务器发送日志&#xff0c;在交换机上面查看日志显示的时间比日志服务器显示的时间快8个小时 解决方案 根据公司全球化整改的要求&#xff0c;syslog默认发送的是UTC时间。 当前设备上配置了时区UTC8&#xff0c;因此&#xff0c;设备上显示的本地…...

[表达式]真假计算

题目描述 有一棵树&#xff0c;不一定是二叉树。 所有叶子节点都是 True 或者 False。 对于从上往下奇数层的非叶子节点是 and&#xff0c;偶数层非叶子节点为 or。 树上每个节点的值是所有孩子节点的值进行该节点的运算操作。 判断一棵树能否砍掉&#xff0c;最快的方法就是从…...

记录一次线上环境svchost.exe antimalware service executable 进程占用CPU过高问题

博主介绍&#xff1a; 大家好&#xff0c;我是想成为Super的Yuperman&#xff0c;互联网宇宙厂经验&#xff0c;17年医疗健康行业的码拉松奔跑者&#xff0c;曾担任技术专家、架构师、研发总监负责和主导多个应用架构。 技术范围&#xff1a; 目前专注java体系&#xff0c;有多…...

Docker 部署 EMQX 一分钟极速部署

部署 EMQX ( Docker ) [Step 1] : 拉取 EMQX 镜像 docker pull emqx/emqx:latest[Step 2] : 创建目录 ➡️ 创建容器 ➡️ 拷贝文件 ➡️ 授权文件 ➡️ 删除容器 # 创建目录 mkdir -p /data/emqx/{etc,data,log}# 创建容器 docker run -d --name emqx -p 1883:1883 -p 1808…...

STL-常用容器-list

1list基本概念 **功能&#xff1a;**将数据进行链式存储 链表&#xff08;list&#xff09;是一种物理存储单元上非连续的存储结构&#xff0c;数据元素的逻辑顺序是通过链表中的指针链接实现的 链表的组成&#xff1a;链表由一系列结点组成 结点的组成&#xff1a;一个是存储…...

Lambda 架构

Lambda架构是一种用于构建可扩展、容错和实时数据处理系统的架构模式。 它由三个主要部分组成&#xff1a;批处理层&#xff08;Batch Layer&#xff09;、实时层&#xff08;Speed Layer&#xff09;和服务层&#xff08;Serving Layer&#xff09;。 Lambda架构旨在结合批处…...

Windows电脑设置网络唤醒(Wake-on-LAN)

1. 启用 Windows 电脑的 Wake-on-LAN 功能 首先&#xff0c;你需要确保你的 Windows 电脑支持并启用了 Wake-on-LAN&#xff1a; BIOS/UEFI 设置(具体看自己电脑主板如何设置): 启动 Windows 电脑&#xff0c;进入 BIOS/UEFI 设置。找到网络适配器相关的设置&#xff0c;启用 …...

前端项目构建流程

1. 需求分析 目标&#xff1a;明确项目目标、核心功能和用户需求。 产品需求讨论&#xff1a; 与产品经理、客户、业务部门讨论项目的需求和目标&#xff0c;理解产品的功能、业务流程以及用户需求。定义用户角色&#xff08;Persona&#xff09;&#xff0c;明确不同用户的功…...

支持国密算法的数字证书-国密SSL证书详解

在互联网中&#xff0c;数字证书作为标志通讯各方身份信息的数字认证而存在&#xff0c;常见的数字证书大都采用国际算法&#xff0c;比如RSA算法、ECC算法、SHA2算法等。随着我国加强网络安全技术自主可控的大趋势&#xff0c;也出现了支持国密算法的数字证书-国密SSL证书。那…...

【EndNote使用教程】创建文献库、导入文献、文献分类

1、创建文献库 打开“EndNote”&#xff0c;点击“文件”&#xff0c;点击“新建”&#xff0c;选择保存文件路径。 2、导入文献 &#xff08;1&#xff09;可以选择导入电脑上的PDF文件&#xff0c;如下图所示。 &#xff08;2&#xff09; 也可以选择直接在浏览器网页上面直…...

双十一电容笔选哪个好?!西圣、益博思、吉玛仕电容笔实测对比!

当数码测评博主几年年&#xff0c;我也实测过不下10款电容笔了&#xff0c;对电容笔这个品类也算是半个内行人了。提到电容笔&#xff0c;在平替品牌的追逐中&#xff0c;西圣、益博思、吉玛仕这三款作为国货黑马一直备受瞩目&#xff0c;综合各大电商平台的销量榜、好评口碑榜…...

房地产网络安全:主要风险及缓解建议

房地产行业已开始数字化转型&#xff0c;因此极易受到网络犯罪的攻击。潜在风险的清单很长&#xff1a;从客户敏感信息的数据泄露到勒索软件攻击&#xff0c;网络犯罪分子将房地产公司视为其所携带的所有类型敏感信息的高价值目标。 在本文中&#xff0c;我们将探讨房地产领域…...

玩转大模型的第一步——提示词(Prompt)工程【抛砖篇】

前言 AI大模型提示词工程&#xff0c;又名 LLM prompts Project&#xff0c;指的是在使用大型语言模型&#xff08;如OpenAI的GPT系列&#xff09;时&#xff0c;用于引导模型生成特定响应的输入&#xff0c;是在使用AI大模型过程中非常重要的一个环节&#xff0c;是模型生成文…...

火山引擎数据飞轮线上研讨会即将开启,助力消费品牌双十一造爆款

随着双十一的临近&#xff0c;各大品牌方的备战工作已进入紧张而有序的倒计时阶段。这场持续十多年的电商大促&#xff0c;对消费者来说是购物狂欢节&#xff0c;对各大品牌方来说&#xff0c;则是更是品牌实力与策略的比拼。面对日益激烈的市场竞争&#xff0c;如何更好地撬动…...

【python实战】利用代理ip爬取Alibaba海外版数据

引言 在跨境电商的业务场景中&#xff0c;数据采集是分析市场、了解竞争对手以及优化经营策略的重要环节。然而&#xff0c;随着越来越多企业依赖数据驱动决策&#xff0c;许多跨境电商平台为了保护自身数据&#xff0c;采取了更严格的防护措施。这些平台通过屏蔽大陆IP地址或部…...

FFMPEG录屏(20)--- 枚举macOS下的窗口和屏幕列表,并获取名称缩略图等信息

在 macOS 下获取可屏幕共享的窗口和屏幕 在 macOS 下&#xff0c;我们可以通过使用 Core Graphics 和 Cocoa 框架来获取当前系统中可屏幕共享的窗口和屏幕信息。本文将详细介绍如何获取窗口和屏幕的 ID、标题、坐标、进程图标和缩略图等信息。 前提条件 在开始之前&#xff…...

Redis 命令集 (超级详细)

目录 Redis 常用命令集 string类型 hash类型 list类型 set类型 zset类型 bitmap 类型 geo 类型 GEOADD (添加地理位置的坐标) GEOPOS (获取地理位置的坐标) GEODIST (计算两个位置之间的距离) GEOHASH (返回一个或多个位置对象的 geohash 值) GEORADIUS (根据用户…...

Spring Cloud --- GateWay和Sentinel集成实现服务限流

pom添加依赖 <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-gateway</artifactId> </dependency> <dependency><groupId>com.alibaba.csp</groupId><artifactId>s…...

python excel如何转成json,并且如何解决excel转成json时中文汉字乱码的问题

1.解决excel转成json时中文汉字乱码的问题 真的好久没有打开这个博客也好久没有想起来记录一下问题了&#xff0c;今天将表格测试集转成json格式的时候遇到了汉字都变成了乱码的问题&#xff0c;虽然这不是个大问题&#xff0c;但是编码问题挺烦人的&#xff0c;乱码之后像下图…...

【MySQL】实战篇—数据库设计与实现:根据需求设计数据库架构

在设计数据库架构时&#xff0c;开发者需要遵循一系列步骤&#xff0c;以确保数据库能够高效、可靠地满足系统需求。以下是设计数据库架构的理论知识和步骤说明。 1. 需求分析 需求分析是数据库设计的第一步&#xff0c;旨在理解系统的功能需求和数据需求。通过与利益相关者&…...

中英文企业网站/放单平台大全app

前言 每天一篇博客&#xff0c;高产似那啥&#xff0c;如有错误&#xff0c;欢迎指出 什么是tmpfs mounts bind mount与volume是持久化数据到磁盘中&#xff0c;在linux上&#xff0c;也可以使用tmpfs mounts&#xff0c;tmpfs mounts是将数据咱存在内存中&#xff0c;当容器…...

wordpress 登陆后跳转/seo数据是什么意思

第一步&#xff1a;找到下图两个组件&#xff0c;卸载。第二步&#xff1a;NuGet下载下图组件。第三步&#xff1a;在连接数据库OnConfiguring方法处&#xff0c;做如下修改&#xff1a;protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder){var co…...

wordpress插件小蜜蜂/网页制作软件免费版

建立拓扑&#xff0c;在实验4所实现的二层交换机基础之上 将Router0的gi0/0/0接口与switch0的fa0/22相连&#xff0c;且由于fa0/22是单臂路由的接口&#xff0c;需要传递两个vlan之间的信息&#xff08;和实验4中连接两台交换机的fa0/21一样&#xff09;&#xff0c;因此fa0/22…...

当当网站建设优点/全媒体运营师报考官网在哪里

引文&#xff1a; 个人名言&#xff1a;“同一条河里淹死两次的人&#xff0c;是傻子&#xff0c;淹死三次及三次以上的人是超人”。经历过上次悲催的面试&#xff0c;决定沉下心来&#xff0c;好好的补充一下基础知识点。本文是这一系列第一篇&#xff1a;进程间通讯之mmap。 …...

登录器显的窗口网站怎么做/网络营销推广方案

数据结构实验之图论三&#xff1a;判断可达性 Description 在古老的魔兽传说中&#xff0c;有两个军团&#xff0c;一个叫天灾&#xff0c;一个叫近卫。在他们所在的地域&#xff0c;有n个隘口&#xff0c;编号为1…n&#xff0c;某些隘口之间是有通道连接的。其中近卫军团在1…...

网站中文域名到期有没有影响/郑州seo外包阿亮

前提 Cordova Android 7.0.0开始改变了项目安卓平台的架构。新建一个空项目分别添加Android 6.4.0 和 Android 7.0.0平台: cordova platform add android6.4.0 cordova platform add android7.0.0生成的安卓平台结构分别为&#xff1a; 可以看到Cordova从7.0.0项目结构开始和原…...