完整网络模型训练(一)
文章目录
- 一、网络模型的搭建
- 二、网络模型正确性检验
- 三、创建网络函数
一、网络模型的搭建
以CIFAR10数据集作为训练例子
准备数据集:
#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
查看数据集的长度:
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")
运行结果:
利用DataLoader来加载数据集:
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)
搭建CIFAR10数据集神经网络:
卷积层【1】代码解释:
#第一个数字3表示inputs(可以看到图中为3),第二个数字32表示outputs(图中为32)
#第三个数字5为卷积核(图中为5),第四个数字1表示步长(stride)
#第五个数字表示padding,需要计算,计算公式:
nn.Conv2d(3, 32, 5, 1, 2)
最大池化代码解释:
#数字2表示kernel卷积核
nn.MaxPool2d(2)
读图
卷积层【1】的Inputs 和 Outputs是下图这两个:
最大池化【1】的Inputs 和 Outputs是下图这两个:
卷积层【2】的Inputs 和 Outputs是下图这两个:
以此类推
展平:
Flatten后它会变成64*4 *4的一个结果
线性输出:
线性输入是64*4 *4,线性输出是64,故如下代码
nn.LInear(64 *4 *4,64)
继续线性输出
nn.LInear(64,10)
搭建网络完整代码:
class Sen(nn.Module):def __init__(self):super(Sen, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1 ,2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x
二、网络模型正确性检验
if __name__ == '__main__':sen = Sen()input = torch.ones((64, 3, 32, 32))output = sen(input)print(output.shape)
注释:
input = torch.ones((64, 3, 32, 32))
这一行代码的含义是:创建一个大小为 (64, 3, 32, 32) 的全 1 张量,数据类型为 torch.float32。
64:这是批次大小,代表输入有 64 张图片。
3:这是图片的通道数,通常为 RGB 图像的三个通道 (红、绿、蓝)。
32, 32:这是图片的高和宽,表示每张图片的尺寸为 32x32 像素。
torch.ones 函数用于生成一个全 1 的张量,这里的张量形状适合用于输入图像分类或卷积神经网络(CNN)中常见的 CIFAR-10 或类似的 32x32 像素图像数据。
运行结果:
可以得到成功变成了【64, 10】的结果。
三、创建网络函数
创建网络模型:
sen = Sen()
搭建损失函数:
loss_fn = nn.CrossEntropyLoss()
优化器:
learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)
优化器注释:
使用随机梯度下降(SGD)优化器
learning_rate = 1e-2 这里的1e-2代表的是:1 x (10)^(-2) = 1/100 = 0.01
记录训练的次数:
total_train_step = 0
记录测试的次数:
total_test_step = 0
训练的轮数:
epoch= 10
进行循环训练:
for i in range(epoch):print(f"第{i+1}轮训练开始")for data in train_dataloader:imgs, targets = dataoutputs = sen(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1print(f"训练次数:{total_train_step},Loss:{loss.item()}")
注释:
imgs, targets = data
是解包数据,imgs 是输入图像,targets 是目标标签(真实值)
outputs = sen(imgs)
将输入图像传入模型 ‘sen’,得到模型的预测输出 outputs
loss = loss_fn(outputs, targets)
计算损失值(Loss),loss_fn 是损失函数,它比较outputs的值与targets 是目标标签(真实值)的误差
optimizer.zero_grad()
清除优化器中上一次计算的梯度,以免梯度累积
loss.backward()
反向传播,计算损失相对于模型参数的梯度
optimizer.step()
使用优化器更新模型的参数,以最小化损失
loss.item() 将张量转换为 Python 的数值
loss.item演示:
import torch
a = torch.tensor(5)
print(a)
print(a.item())
运行结果:
因此可以得到:item的作用是将tensor变成真实数字5
本章节完整代码展示:
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoaderclass Sen(nn.Module):def __init__(self):super(Sen, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1 ,2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x
#准备数据集
#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)#length长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)sen = Sen()#损失函数
loss_fn = nn.CrossEntropyLoss()#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch= 10for i in range(epoch):print(f"第{i+1}轮训练开始")for data in train_dataloader:imgs, targets = dataoutputs = sen(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1print(f"训练次数:{total_train_step},Loss:{loss.item()}")
运行结果:
可以看到训练的损失函数在一直进行修正。
相关文章:
完整网络模型训练(一)
文章目录 一、网络模型的搭建二、网络模型正确性检验三、创建网络函数 一、网络模型的搭建 以CIFAR10数据集作为训练例子 准备数据集: #因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集 train_data torchvision.datasets.CIFAR10(root&quo…...
高效便捷,体验不一样的韩语翻译神器
嘿,大家好啊!今天想跟大家聊聊我用过的几款翻译神器,特别是它们在翻译韩语时的那些小感受。作为一个偶尔需要啃啃韩语资料或者跟韩国朋友聊天的普通人,我真心觉得这些翻译工具简直就是我的救星! 一、福昕在线翻译 网址…...
Markdown笔记管理工具Haptic
什么是 Haptic ? Haptic 是一个新的本地优先、注重隐私的开源 Markdown 笔记管理工具。它简约、轻量、高效,旨在提供您所需的一切,而不包含多余的功能。 目前官方提供了 docker 和 Mac 客户端。 Haptic 仍在积极开发中。以下是未来计划的一些…...
网络原理-传输层UDP
上集回顾: 上一篇博客中讲述了应用层如何自定义协议:确定传输信息,确定数据格式 应用层也有一些现成的协议:HTTP协议 这一篇博客中来讲述传输层协议 传输层 socket api都是传输层协议提供的(操作系统内核实现的了…...
C++中,如何使你设计的迭代器被标准算法库所支持。
iterator(读写迭代器) const_iterator(只读迭代器) reverse_iterator(反向读写迭代器) const_reverse_iterator(反向只读迭代器) 以经常介绍的_DList类为例,它的迭代…...
Java NIO 全面详解:掌握 `Path` 和 `Files` 的一切
在 Java 7 中引入的 NIO (New I/O) 为文件系统和流的操作带来了强大的能力,其中 Path 和 Files 是核心部分。Path 作为对文件路径的抽象,提供了灵活的方式处理文件系统中的路径;Files 则通过一系列静态方法,使得文件的读写、复制、…...
bluez免提协议hands-free介绍,全到无法想象,bluez hfp ag介绍
零. 前言 由于Bluez的介绍文档有限,以及对Linux 系统/驱动概念、D-Bus 通信和蓝牙协议都有要求,加上网络上其实没有一个完整的介绍Bluez系列的文档,所以不管是蓝牙初学者还是蓝牙从业人员,都有不小的难度,学习曲线也相对较陡,所以我有了这个想法,专门对Bluez做一个系统…...
关于区块链的安全和隐私
背景 区块链技术在近年来发展迅速,被认为是安全计算的突破,但其安全和隐私问题在不同应用中的部署仍处于争论焦点。 目的 对区块链的安全和隐私进行全面综述,帮助读者深入了解区块链的相关概念、属性、技术和系统。 结构 首先介绍区块链…...
特征工程——一门提高机器学习性能的艺术
当前围绕人工智能(AI)和机器学习(ML)展开的许多讨论以模型为中心,聚焦于 ML和深度学习(DL)的最新进展。这种模型优先的方法往往对用于训练这些模型的数据关注不足,甚至完全忽视。类似MLOps的领域正迅速发展,通过系统性地训练和利用ML模型&…...
Paper解读:工作场所人机协作的团队形成:促进组织变革的目标编程模型
人工智能(AI)具有降低运营成本、提高效率和改善客户体验的潜力。 因此,在组织中组建项目团队至关重要,这样他们就会在决策过程中欢迎人工智能。 当前的技术革命要求公司快速变革,并增加了对团队在促进创新采用方面的作…...
图文深入理解Oracle Network配置管理(一)
List item 本篇图文深入介绍Oracle Network配置管理。 Oracle Network概述 Oracle Net 服务 Oracle Net 监听程序 <oracle_home>/network/admin/listener.ora <oracle_home>/network/admin/sqlnet.ora建立网络连接 要建立客户机或中间层连接,Oracle…...
leetcode-链表篇3
leetcode-61 给你一个链表的头节点 head ,旋转链表,将链表每个节点向右移动 k 个位置。 示例 1: 输入:head [1,2,3,4,5], k 2 输出:[4,5,1,2,3]示例 2: 输入:head [0,1,2], k 4 输出&#x…...
RAG(Retrieval Augmented Generation)及衍生框架:CRAG、Self-RAG与HyDe的深入探讨
近年来,随着大型语言模型(LLMs)的迅猛发展,我们在寻求更精确、更可靠的语言生成能力上取得了显著进展。其中,检索增强生成(Retrieval-Augmented Generation)作为一种创新方法,极大地…...
C语言介绍
什么是C语言 C programing language 能干什么 Hello world? 如何学C语言 no reading no learning...
损失函数篇 | YOLOv10 更换损失函数之 MPDIoU | 《2023 一种用于高效准确的边界框回归的损失函数》
论文地址:https://arxiv.org/pdf/2307.07662v1.pdf 边界框回归(Bounding Box Regression,BBR)在目标检测和实例分割中得到了广泛应用,是目标定位的重要步骤。然而,对于边界框回归的大多数现有损失函数来说,当预测的边界框与真值边界框具有相同的长宽比,但宽度和高度的…...
WMware安装WMware Tools(Linux~Ubuntu)
1、这里终端里面输入sudo apt upgrade用于更新最新的包 sudo apt upgrade 2、安装 open-vm-tools-desktop 包, Ps:这里是以为我已经安装好了。 udo apt install open-vm-tools-desktop -y3、最后重启就大功告成了 reboot 4、测试是否成功:…...
SLAM ORB-SLAM2(30)关键帧跟踪
SLAM ORB-SLAM2(30)关键帧跟踪 1. 关键帧跟踪2. TrackReferenceKeyFrame2.1. 将当前普通帧的描述子转化为BoW向量2.2. 通过词袋BoW加速当前帧与参考帧之间的特征点匹配2.3. 将上一帧的位姿态作为当前帧位姿的初始值2.4. 通过优化3D-2D的重投影误差来获得位姿2.5. 剔除优化后的…...
k8s 部署 prometheus
创建namespace prometheus-namespace.yaml apiVersion: v1 kind: Namespace metadata:name: ns-prometheus拉取镜像 docker pull swr.cn-north-4.myhuaweicloud.com/ddn-k8s/quay.io/prometheus/prometheus:v2.54.0prometheus配置文件configmap prometheus-configmap.yaml …...
使用VBA快速生成Excel工作表非连续列图片快照
Excel中示例数据如下图所示。 现在需要拷贝A2:A15,D2:D15,J2:J15,L2:L15,R2:R15为图片,然后粘贴到A18单元格,如下图所示。 大家都知道VBA中Range对象有CopyPicture方法可以拷贝为图片,但是如果Range对象为非连续区域,那么将产生10…...
解决GitHub下载速度慢
解决GitHub下载速度慢 方法一:使用git clone 地址 --depth 1来下载 depth 1 表示只科隆最新的一次提交,也就是默认主分支,而不是完整地克隆整个代码仓库,这样可以减少下载地数据,加快克隆操作 可以用git clone 地址 …...
【机器学习(五)】分类和回归任务-AdaBoost算法
文章目录 一、算法概念一、算法原理(一)分类算法基本思路1、训练集和权重初始化2、弱分类器的加权误差3、弱分类器的权重4、Adaboost 分类损失函数5、样本权重更新6、AdaBoost 的强分类器 (二)回归算法基本思路1、最大误差的计算2…...
【设计模式-模板】
定义 模板方法模式是一种行为设计模式,它在一个方法中定义了一个算法的骨架,并将一些步骤延迟到子类中实现。通过这种方式,模板方法允许子类在不改变算法结构的情况下重新定义算法中的某些特定步骤。 UML图 组成角色 AbstractClass&#x…...
小程序原生-列表渲染
1. 列表渲染的基础用法 <!--渲染数组列表--> <view wx:for"{{numList}}" wx:key"*this" > 序号:{{index}} - 元素:{{item}}</view> <!--渲染对象属性--> <view wx:for"{{userInfo}}" wx:key&q…...
JAVA认识异常
目录 1. 异常的概念与体系结构 1.1 异常的概念 1. 算术异常 2. 数组越界异常 3. 空指针异常 1.2 异常的分类 1. 编译时异常 2. 运行时异常 2.1 异常的处理 防御式编程 2.2 异常的捕获 2.3.1 异常声明throws 2.3.2 try-catch捕获并处理 2.3.3 finally 总结 1. 异常…...
2024年10月计划(工作为主,Ue5独立游戏为辅,)
我发现一点,就是工作很忙,比如中秋也在远程加班,周末有时也远程加班,国庆节甚至也差点去甲方工作。甚至有可能驻场。可见,小公司确实不能去。 好在,9月份时,通过渲染 除了上班时间外࿰…...
并发、并行和异步设计
译者个人领悟,一家之言: 并发和并行确实可以明确区分出来,因为cpu的速度非常快,在执行一个任务时经常要等其他组件,比如网络,磁盘等,如果一直串行等待这样就会造成很大的浪费. (就类似于烧水的同时,可以切菜,不用等烧水完成了才去切菜,我可以烧一会水,火生起来了水壶放上了,随…...
求职Leetcode题目(12)
1.只出现一次的数字 异或运算满足交换律 a⊕bb⊕a ,即以上运算结果与 nums 的元素顺序无关。代码如下: class Solution {public int singleNumber(int[] nums) {int ans 0;for(int num:nums){ans^num;}return ans;} } 2.只出现一次的数字II 这是今天滴…...
【YashanDB知识库】如何配置jdbc驱动使getDatabaseProductName()返回Oracle
本文转自YashanDB官网,具体内容请见https://www.yashandb.com/newsinfo/7352676.html?templateId1718516 问题现象 某些三方件,例如 工作流引擎activiti,暂未适配yashandb,使用中会出现如下异常: 问题的风险及影响 …...
Hadoop三大组件之MapReduce(一)
Hadoop之MapReduce 1. MapReduce是什么 MapReduce是一个分布式运算程序的编程框架,旨在帮助用户开发基于Hadoop的数据分析应用。它的核心功能是将用户编写的业务逻辑代码与自带的默认组件整合,形成一个完整的分布式运算程序,并并发运行在一…...
SQL Server 分页查询的学习文章
SQL Server 分页查询的学习文章 一、SQL Server 分页查询1. 什么是分页查询?2. SQL Server 的分页查询方法2.1 使用 OFFSET 和 FETCH NEXT语法:示例: 2.2 使用 ROW_NUMBER() 方法语法:示例: 2.3 性能考虑3. 总结 一、S…...
湘潭做网站 搜搜磐石网络/怎样做品牌推广
什么牌子的蓝牙耳机无延迟?无延迟游戏蓝牙耳机分享随着各种自媒体、手游、短视频越来越流行,大家对使用便捷的蓝牙耳机的需求也呈几何倍数开始上升。价格相对亲民的游戏蓝牙耳机尤其受到消费者的喜爱。但是如何挑选适合自己的游戏蓝牙耳机呢?…...
如何保证网站安全/企业营销策划书模板
题库来源:安全生产模拟考试一点通公众号小程序 化工自动化控制仪表考试平台是安全生产模拟考试一点通总题库中随机出的一套化工自动化控制仪表复审模拟考试,在公众号安全生产模拟考试一点通上点击化工自动化控制仪表作业手机同步练习。2021年化工自动化…...
www 上海网站建设/南昌搜索引擎优化
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处作者:流云编辑:流云链接:https://news.mydrivers.com/1/687/687718_all.htm#2来源:快科技 2020-05-07 21:00:39一、前言:时代变了 入门…...
网站dns解析失败/代推广app下载
文章目录 本课题的研究内容:探地雷达原理探地雷达图像预处理图像倾斜矫正均值法去背景原理与实现图像分割技术阈值分割技术的实现腐蚀与膨胀技术探地雷达杂波抑制研究与实现探地雷达合成孔径成像探地雷达目标识别总结本文为论文解读,为2008年发布的基于传统图像处理与识别论文…...
厦门注册公司流程/苏州seo关键词排名
马上就要到中秋节了,提前祝大家中秋节快乐,最近比较忙,考虑到粉丝一直要求我更新文章,我今天就加班更新一下文章。实时数仓如何做数据分层我不喜欢搞什么花里胡哨的词汇,让粉丝听着挠头,我就想用大白话分享…...
广州网站建设培训/游戏优化软件
atitit 商业项目常用模块技术知识点 v3 qc29 条码二维码barcodebarcode 条码二维码qrcodeqrcode 条码二维码dm码生成与识别 条码二维码pdf147码 条码二维码zxing 条码二维码azetec 条码二维码maxicode 自动完成 翻页page 公告管理 小元宵活动刮刮卡 小元宵活动闸金蛋 小元宵活动…...