自己做的网站服务器在哪里/百度广告联系方式
目录
一、构建神经网络
二、数据准备
三、损失函数和优化器
四、训练模型
五、保存模型
一、构建神经网络
首先,需要构建一个神经网络模型。在示例代码中,构建了一个名为Tudui的卷积神经网络(CNN)模型。这个模型包括卷积层、池化层和全连接层,用于处理图像分类任务。
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init()self.mode1 = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.mode1(x)return x
二、数据准备
训练深度学习模型需要数据集。在示例中,使用CIFAR-10数据集作为示例数据。数据集的准备包括下载、预处理和分割成训练集和测试集。
import torch
import torchvision
from torch.utils.data import DataLoader# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)train_data_size = len(train_data)
test_data_size = len(test_data)
三、损失函数和优化器
在训练中,需要定义损失函数和优化器。损失函数用于度量模型的输出与真实标签之间的差距,而优化器用于更新模型的参数以减小损失。
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)
四、训练模型
模型训练分为多轮迭代,每轮包括训练和测试步骤。在训练步骤中,通过反向传播算法更新模型参数,以最小化损失函数。在测试步骤中,用测试集验证模型性能。
for epoch in range(10): # 训练的轮数tudui.train()for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()tudui.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss += loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))
五、保存模型
最后,可以保存训练好的模型,以备后续使用。示例代码展示了两种保存模型的方式,包括保存整个模型和仅保存模型参数。
# 保存方式一
torch.save(tudui, "tudui_{}.pth".format(epoch))
# 保存方式二(官方推荐)
# torch.save(tudui.state_dict(), 'tudui_{}.pth'.format(epoch))
完整代码如下:
import torch
from torch import nn# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui,self).__init__()self.mode1 = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32,64,5,1,2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4,64),nn.Linear(64,10))def forward(self, x):x = self.mode1(x)return xif __name__ == '__main__':tudui = Tudui()input = torch.ones((64,3,32,32))output = tudui(input)print(output.shape)
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from P27_model import *
import time# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)# length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)# 如果train_data_size=10,训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用DataLoader 来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
tudui = Tudui()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0# 训练的轮数
epoch = 10# 添加tensorboard
writer = SummaryWriter("logs_train")
# 添加开始时间
strat_time = time.time()for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始tudui.train() # 这两个层,只对一部分层起作用,比如 dropout层;如果有这些特殊的层,才需要调用这个语句for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad() # 优化器,梯度清零loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:end_time = time.time() # 结束时间print(end_time - strat_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss.item())) # 这里用到的 item()方法,有说法的,其实加不加都行,就是输出的形式不一样而已writer.add_scalar("train_loss", loss.item(),total_train_step)# 每训练完一轮,进行测试,在测试集上测试,以测试集的损失或者正确率,来评估有没有训练好,测试时,就不要调优了,就是以当前的模型,进行测试,所以不用再使用梯度(with no_grad 那句)# 测试步骤开始tudui.eval() # 这两个层,只对一部分层起作用,比如 dropout层;如果有这些特殊的层,才需要调用这个语句total_test_loss = 0total_accuracy = 0with torch.no_grad(): # 这样后面就没有梯度了, 测试的过程中,不需要更新参数,所以不需要梯度?for data in test_dataloader: # 在测试集中,选取数据imgs, targets = dataoutputs = tudui(imgs) # 分类的问题,是可以这样的,用一个output进行绘制loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item() # 为了查看总体数据上的 loss,创建的 total_test_loss,初始值是0accuracy = (outputs.argmax(1) == targets).sum() # 正确率,这是分类问题中,特有的一种,评价指标,语义分割之类的,不一定非要有这个东西,这里是存疑的,再看。total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size)) # 即便是输出了上一行的 loss,也不能很好的表现出效果。# 在分类问题上比较特有,通常使用正确率来表示优劣。因为其他问题,可以可视化地显示在tensorboard中。writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)total_test_step = total_test_step + 1# print(total_test_step)# 保存方式一,其实后缀都可以自己取,习惯用 .pth。torch.save(tudui, "tudui_{}.pth".format(i))# 保存方式2(官方推荐)# torch.save(model.state_dict(), pth_dir + '/model_{}.pth'.format(i)print("模型已保存")writer.close()
参考资料:
视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
相关文章:

PyTorch入门学习(十七):完整的模型训练套路
目录 一、构建神经网络 二、数据准备 三、损失函数和优化器 四、训练模型 五、保存模型 一、构建神经网络 首先,需要构建一个神经网络模型。在示例代码中,构建了一个名为Tudui的卷积神经网络(CNN)模型。这个模型包括卷积层、…...

《 Hello 算法 》 - 免费开源的数据结构与算法入门教程电子书,包含大量动画、图解,通俗易懂
这本学习算法的电子书应该是我看过这方面最好的书了,代码例子有多种编程语言,JavaScript 也支持。 《 Hello 算法 》,英文名称是 Hello algo,是一本关于编程中数据解构和算法入门的电子书,作者是毕业于上海交通大学的…...

数据库之事务
数据库之事务 事务的特点: ACID 原子性 一致性:数据库的完整性约束,不能被破坏 隔离性 持久性:数据一旦提交,事务的效果将会被永久的保留在数据库中。而且不会被回滚 主从复制 高可用 备份 权限控制 脏读&am…...

NOIP2023模拟12联测33 B. 游戏
NOIP2023模拟12联测33 B. 游戏 文章目录 NOIP2023模拟12联测33 B. 游戏题目大意思路code 题目大意 期望题 思路 二分答案 m i d mid mid ,我们只关注学生是否能够使得被抓的人数 ≤ m i d \le mid ≤mid 那我们就只关心 a > m i d a > mid a>mid 的房…...

代码随想录打卡第六十三天|84.柱状图中最大的矩形
84.柱状图中最大的矩形 题目:给定 n 个非负整数,用来表示柱状图中各个柱子的高度。每个柱子彼此相邻,且宽度为 1 。求在该柱状图中,能够勾勒出来的矩形的最大面积。 提示: 1 < heights.length <105 0 < h…...

python tempfile 模块使用
在Python中,tempfile 模块用于创建临时文件和目录,它们可以用于存储中间处理数据,不需要长期保存。该模块提供了几种不同的类和函数来创建临时文件和目录。 下面是几个常用的 tempfile 使用方法: 临时文件 使用 NamedTemporary…...

【软件测试】接口测试实战详解
最近找到了几个问题,都还比较有代表性。 作为一个初级测试,想学接口测试,但是一点头绪都没有。求教大神指点,有没有好的书或者工具推荐?如何做接口测试呢?接口测试有哪些工具做接口测试的流程一般是怎么样…...

轻量封装WebGPU渲染系统示例<20>- 美化一下元胞自动机之生命游戏(源码)
当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/GameOfLifePretty.ts 系统特性: 1. 用户态与系统态隔离。 2. 高频调用与低频调用隔离。 3. 面向用户的易用性封装。 4. 渲染数据(内外部相关资源)和渲染机制分离…...

Nodejs的安装以及配置(node-v12.16.1-x64.msi)
Nodejs的安装以及配置 1、安装 node-v12.16.1-x64.msi点击安装,注意以下步骤 本文设置nodejs的安装的路径:D:\soft\nodejs 继续点击next,选中Add to PATH ,旁边的英文告诉我们会把 环境变量 给我们配置好 当然也可以只选择 Nod…...

03【保姆级】-GO语言变量和数据类型和相互转换
03【保姆级】-GO语言变量和数据类型 一、变量1.1 变量的定义:1.2 变量的声明、初始化、赋值1.3 变量使用的注意事项 插播-关于fmt.Printf格式打印%的作用二、 变量的数据类型2.1整数的基本类型2.1.1 有符号类型 int8/16/32/642.1.2 无符号类型 int8/16/32/642.1.3 整…...

mermaid学习第一天/更改主题颜色和边框颜色/《需求解释流程图》
mermaid 在线官网: https://mermaid-js.github.io/ 在线学习文件: https://mermaid.js.org/syntax/quadrantChart.html 1、今天主要是想做需求解释的流程图,又不想自己画,就用了,框框不能直接进行全局配置࿰…...

SAP MASS增加PR字段-删除标识
MASS->BUS2105->发现没有找到PR删除标识字段 SAP MASS增加PR字段-删除标识 1.tcode:MASSOBJ 选中BUS2105 点“应用程序表” 点“字段列表” 2.选中一行进行参考 3.修改字段为删除标识 LOEKZ,保存即可。 4.然后MASS操作,批量设置删除标识&…...

【手把手教你】训练YOLOv8分割模型
1.下载文件 在github上下载YOLOV8模型的文件,搜索yolov8,star最多这个就是 2. 准备环境 环境要求python>3.8,PyTorch>1.8,自行安装ptyorch环境即可 2. 制作数据集 制作数据集,需要使用labelme这个包&#…...

物料主数据增强屏幕绘制器DUMP
问题描述 在做完物料主数据增强后,配置和代码传Q,在Q进入增强屏幕绘制器报错。 错误 CALLBACK_REJECTED_BY_WHITELIST RFC callback call rejected by positive list An RFC callback has been prevented due to no corresponding positive list en…...

vue 实现在线预览Excel-LuckyExcel/LuckySheet实现方案
一、准备工作 1. npm安装 luckyexcel npm i -D luckyexcel 2.引入luckysheet 注意:引入luckysheet,只能通过CDN或者直接引入静态资源的形式,不能npm install。 个人建议直接下载资源引入。我给你们提供一个下载资源的地址: …...

AIGPT重大升级,界面重新设计,功能更加饱满,用户体验升级
AIGPT AIGPT是一款功能强大的人工智能技术处理软件,不但拥有其他模型处理文本认知的能力还有AI绘画模型、拥有自身的插件库。 我们都知道使用ChatGPT是需要账号以及使用魔法的,实现其中的某一项对我们一般的初学者来说都是一次巨大的挑战,但…...

Web逆向-mtgsig1.2简单分析
{"a1": "1.2", # 加密版本"a2": new Date().valueOf() - serverTimeDiff, # 加密过程中用到的时间戳. 这次服主变坏了, 时间戳需要减去一个 serverTimeDiff(见a3) ! "a3": "这是把xxx信息加密后提交给服务器, 服主…...

【蓝桥杯省赛真题41】Scratch电脑开关机 蓝桥杯少儿编程scratch图形化编程 蓝桥杯省赛真题讲解
目录 scratch电脑开关机 一、题目要求 编程实现 二、案例分析 1、角色分析...

第10章 Java常用类
目录 内容说明 章节内容 一、Object类 二、String类和StringBuffer类 三、Math类和Random类...

Android 11 getPackageManager().getPackageInfo 返回null
Android11 之后, 在查找用户手机是否有安装app,进行查询包名是否存在时,getPackageManager().getPackageInfo()这个函数一直返回null ,Android 11增加了权限要求。 1、只是查询指定的App 包 只需要在Andro…...

4、数据结构
数据结构01 数值处理 取整 日常用的四种 / 整数除法,截取整数部分math.Ceil 向上取整 “理解为天花板,向上取值”math.Floor 向下取整 “理解为地板,向下取值”math.Round 四舍五入 / 整数除法,截取整数部分 func main() { f…...

qt5.15.2+vs2019源码调试开发环境搭建
说明 一些qt文件不进行源码调试无法知道其中的原理。提高软件质量,从概念原理及应用角度看待必须知道qt类运行原理。 1.安装 在网上找到qt安装包qt-unified-windows-x64-4.5.1-online.exe,安装qt5.15.2,有选择Qt Debug Information Files …...

【数据结构】单链表之--无头单向非循环链表
前言:前面我们学习了动态顺序表并且模拟了它的实现,今天我们来进一步学习,来学习单链表!一起加油各位,后面的路只会越来越难走需要我们一步一个脚印! 💖 博主CSDN主页:卫卫卫的个人主页 &#x…...

网络中使用最多的图片格式有哪些
互联网中的图片格式五花八门吧,我常常分不清各种格式的使用场景和区别,有些常见的格式和很少见的,在此总结。 常见格式 常见的图片格式,有 JPEG、PNG、GIF、BMP、WebP、SVG、TIFF、ICO等, 少见的比如:HD…...

个人常用Linux命令
来自 linux命令学习-2023-8-1 153913.md等 1、切换目录 cd //切换目录 cd change directory cd 目录名 cd .. 返回上一级目录 pwd显示当前所处目录cd 绝对路径 cd ~ 表示一个用户的home目录 cd - 表示上一次访问的目录 cd / 表示进入根目录下//新建目录/data,并且进入/data…...

数据结构——常见简答题汇总
目录 1、绪论 2、线性表 3、栈、队列和数组 4、串 5、树与二叉树 6、图 7、查找 8、排序 1、绪论 什么是数据结构? 数据结构是相互之间存在一种或多种特定关系的数据元素的集合。数据结构包括三个方面:逻辑结构、存储结构、数据的运算。 逻辑结…...

josef约瑟低电压继电器 DY-110 10-109V 辅助电源·DC110V 嵌入式面板安装
DY-110/110V电压继电器 系列型号 DY-110电压继电器;GY-110电压继电器; GDY-110电压继电器;DY-110/AC电压继电器; GY-110/AC电压继电器;GDY-110/AC电压继电器; DL-110电压继电器;GL-110电压…...

Visual Studio Code将中文写入变量时,中文老是乱码问题
对于这个问题,我也是弄了很久才知道,编码格式的问题 在此之前我们要先下载个插件 照这以上步骤,最后按F6运行即可,按F6是利用我们刚刚下载的插件进行编译,唯一有一点不好就是,用这种插件运行的话ÿ…...

各省市30米分辨率DEM数据,推荐下载!
今天给大家推荐一个新数据 —— 各省市30米分辨率DEM数据! 各省市30米分辨率DEM数据广泛应用于国土资源调查、水利水电工程、地质灾害预警、城市规划等领域,对于了解区域内的地形地貌、地形分析、土地利用等具有非常重要的意义。 网站搜索“citybox城市…...

操作系统引论(一)
操作系统的地位和目标 计算机系统的组成 系统软件是和硬件相关的,这是它本质的特征。 操作系统在计算机系统中的地位 操作系统的设计目标 可扩充性是面向未来的。 操作系统的作用 1)用户与计算机硬件系统之间的接口 2)计算机系统资源的管…...