当前位置: 首页 > news >正文

神经网络分类和回归任务实战

学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程

直接上案例,先来跑,遇到什么解决什么

数据集Minist 数据集

做简单的任务 Minist 分类任务

总体代码(可以跑通)

from pathlib import Path
import requests
import pickle
import gzip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim
import numpy as np
bs=64
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
from matplotlib import pyplot
import numpy as np
print(x_train.shape)# pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#获取训练数据和测试数据
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
#设置模型结构
class Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
# print(net)
# for name, parameter in net.named_parameters():
#     print(name, parameter,parameter.size())
#设置数据集和数据集加载器
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64 * 2)
def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)loss_func = F.cross_entropy
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
#训练参数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(20, model, loss_func, opt, train_dl, valid_dl)
corret=0
total=0
for xb,yb in valid_dl:outputs=model(xb)_,predicted=torch.max(outputs.data,1)total+=yb.size(0)corret+=(predicted==yb).sum().item()
print('准确率是:%d %%'%(100*corret/total))

1.首先我们从最终实现的fit 函数开始看,

 在fit h函数之前有一个get_model 函数 得到model和优化器

model, opt = get_model()

得到模型的优化器以后

需要把训练轮数 模型 损失函数 训练数据 测试数据传入fit 训练函数

fit(20, model, loss_func, opt, train_dl, valid_dl)

fit 函数

def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

xb是从dataloader 中取64个训练数据图片 也就是64*784维度,784代表28*28的手写数字图片的展平成一维向量

yb是64个图片对应的数字值

loss_batch(model, loss_func, xb, yb, opt)

我们看一下loss_batch 函数

loss 反向传播——更新优化器——优化器梯度归0

loss 是一个带有梯度的tensor    .item()返回的是loss 的值 len(xb )是为了求精度的时候算

输入是模型损失函数xb ,yb 和优化器

loss_func = F.cross_entropy 损失函数是交叉熵损失函数

将xb 经过model 得到输出后 和xb 求损失函数  model 是定义的一个简单的模型有两个隐藏层一个输出层 784——128——256——10

再回到fit 函数的验证部分model.evl()

先来一个 

with torch.no_grad()

不去计算梯度

zip(*的意思是解压缩 分别得到losses 和nums)

loss 和num 鲜橙 求每64个batch 的总loss 再将datalosder的所有batch 相加除以总数得到训练损失

相关文章:

神经网络分类和回归任务实战

学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程 直接上案例,先来跑,遇到什么解决什么 数据集Minist 数据集 做简单的任务 Minist 分类任务 总体代码(可以跑通) from pathlib import …...

【数据结构】考研真题攻克与重点知识点剖析 - 第 4 篇:串

前言 本文基础知识部分来自于b站:分享笔记的好人儿的思维导图与王道考研课程,感谢大佬的开源精神,习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析,本人技术…...

深入浅出 -- 系统架构之分布式多形态的存储型集群

一、多形态的存储型集群 在上阶段,我们简单聊了下集群的基本知识,以及快速过了一下逻辑处理型集群的内容,下面重点来看看存储型集群,毕竟这块才是重头戏,集群的形态在其中有着多种多样的变化。 逻辑处理型的应用&…...

STL —— list

博主首页: 有趣的中国人 专栏首页: C专栏 本篇文章主要讲解 list模拟实现的相关内容 1. list简介 列表(list)是C标准模板库(STL)中的一个容器,它是一个双向链表数据结构&#xff0c…...

申请SSL证书

有很多方法可以确保您的网站安全。添加SSL证书可针对恶意攻击提供额外且关键的保护层。 即使网站不接受交易,您仍然需要保护用户的登录详细信息、地址和其他个人信息。 没有SSL证书的网站使用HTTP(一种基于文本的协议),这意味着…...

深入浅出 -- 系统架构之负载均衡Nginx环境搭建

引入负载均衡技术可带来的收益: 系统的高可用:当某个节点宕机后可以迅速将流量转移至其他节点。系统的高性能:多台服务器共同对外提供服务,为整个系统提供了更高规模的吞吐。系统的拓展性:当业务再次出现增长或萎靡时…...

notepad++绿色版添加右键菜单

解压路径 D:\Green\notepad_v8.0_x64_绿色版 添加右键菜单.reg 新建nodepad添加右键菜单.reg文件 Windows Registry Editor Version 5.00[HKEY_CLASSES_ROOT\*\shell\NotePad] "Edit with &Notepad" "Icon""D:\\Green\\notepad_v8.0_x64_绿色版…...

7 个 iMessage 恢复应用程序/软件可轻松恢复文本

由于误操作、iOS 升级中断、越狱失败、设备损坏等原因,您可能会丢失 iPhone/iPad 上的 iMessages。意外删除很大程度上增加了这种可能性。更糟糕的是,这种情况经常发生在 iDevice 缺乏备份的情况下。 (iPhone消息消失还占用空间?&…...

DockerFile启动jar程序

1.创建Dockerfile 在项目的根目录下创建一个名为Dockerfile的文件,并使用文本编辑器打开它。Dockerfile的内容如下: # 基础镜像 FROM openjdk:8-jre # 创建目录 RUN mkdir -p /usr/app/ # 设置工作目录 WORKDIR /usr/app # 将JAR文件复制到容器中,注:…...

基于R、Python的Copula变量相关性分析及AI大模型应用

在工程、水文和金融等各学科的研究中,总是会遇到很多变量,研究这些相互纠缠的变量间的相关关系是各学科的研究的重点。虽然皮尔逊相关、秩相关等相关系数提供了变量间相关关系的粗略结果,但这些系数都存在着无法克服的困难。例如,…...

鸿蒙组件学习_Tabs组件

说明 该组件从API Version 7 开始支持。 子组件 仅可包含子组件TabContent 参数 barPosition 设置Tabs的页签位置,默认值: BarPosition.StartStart vertical属性方法设置为true时,页签位于容器左侧;vertical属性方法设置为false时,页签位于容器顶部。End vertic…...

【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略

【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略 AutoGPTBaby AGIHuggingGPTLangChain 目前是将基于 CAMEL 框架的代理定义为 Simulation Agents(模拟代理)。这种代理在模拟环境中进行角色扮演,试图模拟特定场景或行为,而不是在真实世界中完成具体…...

thinkphp6入门(21)-- 如何删除图片、文件

假设文件的位置在 /*** 删除文件* $file_name avatar/20240208/d71d108bc1086b498df5191f9f925db3.jpg*/ function deleteFile($file_name) {// 要删除的文件路径$file app()->getRootPath() . public/uploads/ . $file_name; $result [];if (is_file($file)) {if (unlin…...

虚拟内存知识详解

虚拟内存 单片机的 CPU 是直接操作内存的「物理地址」 在这种情况下,要想在内存中同时运行两个程序是不可能的 操作系统是如何解决这个问题呢? 关键的问题是这两个程序都引用了绝对物理地址,而这正是我们最需要避免的。 可以把进程所使用的…...

数据结构初阶:顺序表和链表

线性表 线性表 ( linear list ) 是 n 个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使 用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串 ... 线性表在逻辑上是线性结构,也就说是连续的一条直线。但是在物理结构上并不一定是连续的, 线性…...

在flutter中添加video_player【视频播放插件】

添加插件依赖 dependencies:video_player: ^2.8.3插件的用途 在Flutter框架中,video_player 插件是一个专门用于播放视频的插件。它允许开发者在Flutter应用中嵌入视频播放器,并提供了一系列功能来控制和定制视频播放体验。这个插件对于需要在应用中展…...

golang微服务框架特性分析及选型

目录 一、微服务框架特性(10个)包括:Istio、go-zero、go-kit、go-kratos、go-micro、rpcx、kitex、goa、jupiter、dubbo-go、tarsgo 1、特性及使用场景2、比较 二、web框架特性(7个)包括:gin、fiber、beego…...

苹果cmsV10 MXProV4.5自适应PC手机影视站主题模板苹果cms模板mxone pro

演示站:http://a.88531.cn:8016 MXPro 模板主题(又名:mxonepro)是一款基于苹果 cms程序的一款全新的简洁好看 UI 的影视站模板类似于西瓜视频,不过同对比 MxoneV10 魔改模板来说功能没有那么多,也没有那么大气,但是比较且可视化功…...

GPU的了解

3D动画揭秘显卡的GPU是如何工作的_哔哩哔哩_bilibili 位于显卡中。 与CPU区别: 100名小学生和1位数学博士 做100道非常简单的算术题,小朋友一个人一道题,比博士快。 做1道非常复杂的数学问题,只有博士可以做出来。 CPU主要用于快…...

鸿蒙实战开发-如何使用Stage模型卡片

介绍 本示例展示了Stage模型卡片提供方的创建与使用。 用到了卡片扩展模块接口,ohos.app.form.FormExtensionAbility 。 卡片信息和状态等相关类型和枚举接口,ohos.app.form.formInfo 。 卡片提供方相关接口的能力接口,ohos.app.form.for…...

蓝桥杯刷题 前缀和与差分-[2128]重新排序(C++)

问题描述 给定一个数组 A 和一些查询 L**i, R**i,求数组中第 L**i 至第 R**i 个元素之和。 小蓝觉得这个问题很无聊,于是他想重新排列一下数组,使得最终每个查询结果的和尽可能地大。小蓝想知道相比原数组,所有查询结果的总和最多…...

STM32重要参考资料

stm32f103c8t6 一、引脚定义图 二、时钟树 三、系统结构图 四、启动配置 (有时候不小心短接VCC和GND,芯片会锁住,可以BOOT0拉高试试(用跳线帽接)) 五、最小系统原理图 可用于PCB设计 六、常见折腾人bug…...

[StartingPoint][Tier0]Preignition

Task 1 Directory Brute-forcing is a technique used to check a lot of paths on a web server to find hidden pages. Which is another name for this? (i) Local File Inclusion, (ii) dir busting, (iii) hash cracking. (目录暴力破解是一种用于检查 Web 服务器上的大…...

数据库系统概论(超详解!!!)第三节 关系数据库标准语言SQL(Ⅴ)

1.数据更新 1.插入数据 1.插入元组 语句格式 INSERT INTO <表名> [(<属性列1>[,<属性列2 >…)] VALUES (<常量1> [,<常量2>]… ); 功能&#xff1a;将新元组插入指定表中 INTO子句 &#xff1a; 指定要插入数据的表名及…...

SpringBoot | Spring Boot“整合Redis“

目录: 1. Redis 介绍2. Redis 下载安装3. Redis “服务开启”和“连接配置”4. Spring Boot整合Redis的“前期准备” :① 编写实体类② 编写Repository 接口③ 在“全局配置文件”中添加 “Redis数据库” 的 “相关配置信息” 5. Spring Boot整合“Redis” (案例展示) 作者简介…...

SV学习笔记(四)

OCP Open Closed Principle 开闭原则 文章目录 随机约束和分布为什么需要随机&#xff1f;为什么需要约束&#xff1f;我们需要随机什么&#xff1f;声明随机变量的类什么是约束权重分布集合成员和inside条件约束双向约束 约束块控制打开或关闭约束内嵌约束 随机函数pre_random…...

Midjourney艺术家分享|By Moebius

Moebius&#xff0c;本名让吉拉德&#xff08;Jean Giraud&#xff09;&#xff0c;是一位极具影响力的法国漫画家和插画师&#xff0c;以其独特的科幻和幻想风格而闻名于世。他的艺术作品不仅在漫画领域内受到高度评价&#xff0c;也为电影、时尚和广告等多个领域提供了灵感。…...

Vue - 1( 13000 字 Vue 入门级教程)

一&#xff1a;Vue 导语 1.1 什么是 Vue Vue.js&#xff08;通常称为Vue&#xff09;是一款流行的开源JavaScript框架&#xff0c;用于构建用户界面。Vue由尤雨溪在2014年开发&#xff0c;是一个轻量级、灵活的框架&#xff0c;被广泛应用于构建单页面应用&#xff08;SPA&am…...

Vue关键知识点

watch侦听器 Vue.js 中的侦听器&#xff08;Watcher&#xff09;是 Vue提供的一种响应式系统的核心机制之一。 监听数据的变化&#xff0c;并在数据发生变化时执行相应的回调函数。 目的:数据变化能够自动更新到视图中 原理&#xff1a; Vue 的侦听器通过观察对象的属性&#…...

Prometheus+grafana环境搭建redis(docker+二进制两种方式安装)(四)

由于所有组件写一篇幅过长&#xff0c;所以每个组件分一篇方便查看&#xff0c;前三篇 Prometheusgrafana环境搭建方法及流程两种方式(docker和源码包)(一)-CSDN博客 Prometheusgrafana环境搭建rabbitmq(docker二进制两种方式安装)(二)-CSDN博客 Prometheusgrafana环境搭建m…...

骏域建网站/官方网站百度一下

2019独角兽企业重金招聘Python工程师标准>>> 说道La(n)mp开发&#xff0c;其实我个人觉得应该把这个单词倒过来&#xff0c;可能才更能表现出这套技术栈的主次和先后顺序&#xff1a;pmalphpmysqlapache\nginxlinux。 PHP 手册要仔细读 &#xff08;高&#xff09; …...

wordpress字体路径设置/网站制作河南

说到夜晚发版这个事&#xff0c;有些时候事真的想不明白&#xff0c;为什么发版要夜晚发版。 有的人说&#xff0c;夜晚发版&#xff0c;影响到的用户数是少数的。why&#xff1f;发版为什么不灰度发&#xff1f;大多数项目都是集群的吧&#xff0c;6台机器&#xff0c;先发3台…...

网站注册界面设计/和生活app下载安装最新版

面向对象 面向对象程序语言的特点&#xff1a; 1.抽象&#xff0c; 2.封装&#xff0c; 3.继承&#xff0c; 4.多态 类&#xff0c;对象和成员的使用方法及区别&#xff1a; 声明一个类&#xff1a; class 人类 {public:void 获得身高&#xff08;&#xff09;&#…...

做网站有什么意义/白百度一下你就知道

2405: 牛顿迭代法求根 Description 用牛顿迭代法求根。方程为ax3bx2cxd0。系数a,b,c,d的值一次为1,2,3,4&#xff0c;由主函数输入。求x在1附近的一个实根。求出根后由主函数输出。结果保留两位小数。 Input 系数a,b,c,d的值 Output x在1附近的一个实根 Sample Input 1…...

一般做个网站要多少钱/网站大全

strcat -字符串追加 char* strcat(char* destination,const char* source) 1.源字符串必须以\0结束 2.目标空间必须足够大 能容纳下源字符串的内容 int main(){char arr1[] "hello";//目标空间不够大char arr2[] "world";strcat(arr1, arr2);printf(&q…...

邢台哪儿专业做网站/北京seo公司司

手工的功能测试用例也可以用3A原则来编写。 Arrange: 准备被测功能相关的测试数据&#xff0c;比如往系统里录入一批工单以便测试工单的分页功能Act : 调用被测的功能&#xff0c;实际上这就是我们一直讲的测试步骤Assert: 断言举个例子 # arrange and act 打开chrome浏览器并跳…...