PyTorch入门之【CNN】
参考:https://www.bilibili.com/video/BV1114y1d79e/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
书接上回的MLP故本章就不详细解释了
目录
- train
- test
train
import torch
from torchvision.transforms import ToTensor
from torchvision import datasets
import torch.nn as nn# load MNIST dataset
training_data = datasets.MNIST(root='../02_dataset/data',train=True,download=True,transform=ToTensor()
)train_data_loader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)# define a CNN model
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1),nn.BatchNorm2d(32),nn.ReLU())self.conv_2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1),nn.BatchNorm2d(64),nn.ReLU(),)self.maxpool = nn.MaxPool2d(2)self.flatten = nn.Flatten()self.fc_1 = nn.Sequential(nn.Linear(9216, 128),nn.BatchNorm1d(128),nn.ReLU())self.fc_2 = nn.Linear(128, 10)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc_1(x)logits = self.fc_2(x)return logits# create a CNN model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN().to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()# train the model
num_epochs = 20for epoch in range(num_epochs):print(f'Epoch {epoch+1}\n-------------------------------')for idx, (img, label) in enumerate(train_data_loader):size = len(train_data_loader.dataset)img, label = img.to(device), label.to(device)# compute prediction errorpred = cnn(img)loss = loss_fn(pred, label)# backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if idx % 400 == 0:loss, current = loss.item(), idx*len(img)print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')# save the model
torch.save(cnn.state_dict(), 'cnn.pth')
print('Saved PyTorch Model State to cnn.pth')
test
import torch
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
import torch.nn as nn# load test data
test_data = datasets.MNIST(root='../02_dataset/data',train=False,download=True,transform=ToTensor()
)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)transform = transforms.Compose([transforms.Grayscale(),transforms.RandomRotation(10),transforms.ToTensor()
])
my_mnist = ImageFolder(root='../02_dataset/my-mnist', transform=transform)
my_mnist_loader = torch.utils.data.DataLoader(my_mnist, batch_size=64, shuffle=True)# define a CNN model
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1),nn.BatchNorm2d(32),nn.ReLU())self.conv_2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1),nn.BatchNorm2d(64),nn.ReLU(),)self.maxpool = nn.MaxPool2d(2)self.flatten = nn.Flatten()self.fc_1 = nn.Sequential(nn.Linear(9216, 128),nn.BatchNorm1d(128),nn.ReLU())self.fc_2 = nn.Linear(128, 10)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc_1(x)logits = self.fc_2(x)return logits# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN()
cnn.load_state_dict(torch.load('cnn.pth', map_location=device))
cnn.eval().to(device)# test the pretrained model on MNIST test data
size = len(test_data_loader.dataset)
correct = 0with torch.no_grad():for img, label in test_data_loader:img, label = img.to(device), label.to(device)pred = cnn(img)correct += (pred.argmax(1) == label).type(torch.float).sum().item()correct /= size
print(f'Accuracy on MNIST: {(100*correct):>0.1f}%')# test the pretrained model on my MNIST test data
size = len(my_mnist_loader.dataset)
correct = 0with torch.no_grad():for img, label in my_mnist_loader:img, label = img.to(device), label.to(device)pred = cnn(img)correct += (pred.argmax(1) == label).type(torch.float).sum().item()correct /= size
print(f'Accuracy on my MNIST: {(100*correct):>0.1f}%')
相关文章:
PyTorch入门之【CNN】
参考:https://www.bilibili.com/video/BV1114y1d79e/?spm_id_from333.999.0.0&vd_source98d31d5c9db8c0021988f2c2c25a9620 书接上回的MLP故本章就不详细解释了 目录 traintest train import torch from torchvision.transforms import ToTensor from torchvi…...
马斯洛需求层次模型之安全需求之云安全浅谈
在互联网云服务领域,安全需求是用户首要考虑的因素之一。用户希望在将数据和信息托付给云服务提供商时,这些数据和信息能够得到充分的保护,避免遭受未经授权的访问、泄露或破坏。这种安全需求的满足,对于用户来说是至关重要的&…...
Pikachu靶场——远程命令执行漏洞(RCE)
文章目录 1. RCE1.1 exec "ping"1.1.1 源代码分析1.1.2 漏洞防御 1.2 exec "eval"1.2.1 源代码分析1.2.2 漏洞防御 1.3 RCE 漏洞防御 1. RCE RCE(remote command/code execute)概述: RCE漏洞,可以让攻击者直接向后台服务器远程注入…...
【WSN】无线传感器网络 X-Y 坐标到图形视图和位字符串前缀嵌入方法研究(Matlab代码实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...
Linux定时任务
文章目录 前言设置定时任务流程定时规则例子 终止定时任务列出当前的定时任务重启任务调度 前言 在Linux系统中有时侯需要周期性的自动执行一些命令,这时候Linux定时任务就派上用场了 设置定时任务流程 进入定时任务的编辑模式 crontab -e编辑定时任务ÿ…...
【Overload游戏引擎分析】画场景网格的Shader
Overload引擎地址: GitHub - adriengivry/Overload: 3D Game engine with editor 一、栅格绘制基本原理 Overload Editor启动之后,场景视图中有栅格线,这个在很多软件中都有。刚开始我猜测它应该是通过绘制线实现的。阅读代码发现࿰…...
【JavaEE】多线程进阶(一)饿汉模式和懒汉模式
多线程进阶(一) 文章目录 多线程进阶(一)单例模式饿汉模式懒汉模式 本篇主要引入多线程进阶的单例模式,为后面的大冰山做铺垫 代码案例介绍 单例模式 非常经典的设计模式 啥是设计模式 设计模式好比象棋中的 “棋谱”…...
C++树详解
树 树的定义 树(Tree)是n(n≥0)个结点的有限集。n0时称为空树。在任意一颗非空树中:①有且仅有一个特定的称为根(Root)的结点;②当n>1时,其余结点可分为m(…...
支付环境安全漏洞介绍
1、平台支付逻辑全流程分析 2、平台支付漏洞如何利用?买东西还送钱? 3、BURP抓包分析修改支付金额,伪造交易状态? 4、修改购物车参数实现底价购买商品 5、SRC、CTF、HW项目月入10W副业之路 6、如何构建最适合自己的网安学习路线 1…...
抄写Linux源码(Day16:内存管理)
回忆我们需要做的事情: 为了支持 shell 程序的执行,我们需要提供: 1.缺页中断(不理解为什么要这个东西,只是闪客说需要,后边再说) 2.硬盘驱动、文件系统 (shell程序一开始是存放在磁盘里的,所以需要这两个东…...
Cookie和Session详解以及结合生成登录效果
目录 引言 1.Cookie中的数据从哪来数据长啥样? 2.Cookie有什么作用? 3.cookie与session的工作关联? 4.Cookie到哪去? 5.Cookie如何存? 6.Session 7.Cookie与Session的关联与区别 8.通过代码理解 8.1 相关代码 8.2…...
Spring基础以及核心概念(IoC和DIQ)
1.Spring是什么 Spring是包含了众多工具方法的IoC容器 2.loC(Inversion of Control )是什么 IoC:控制反转,Spring是一个控制反转容器(控制反转对象的生命周期) Spring是一个loC容器,我们之前学过的List/Map就是数据存储的容器,to…...
《C和指针》笔记32:多维数组初始化
文章目录 使用括号进行初始化初始化省略维度 使用括号进行初始化 我们可以给数组赋值一个长长的列表: int matrix[2][3] { 100, 101, 102, 110, 111, 112 };它等价于 matrix[0][0]100; matrix[0][1]101; matrix[0][2]102; matrix[1][0]110; matrix[1][1]111; ma…...
零食食品经营小程序商城的作用是什么
零食几乎可以涵盖每个年龄阶段,同时又是市场中常见的零售批发商品,在多个场景中都有销售/购买属性,对消费者来说,购买零食的渠道多种多样,无论线下还是线上,都可随心而购。 庞大市场升级促进下,…...
Java泛型--什么是泛型?
https://www.bilibili.com/video/BV1xJ411n77R?p5&vd_sourcebb1fced25254581cf052adea5e87a1ff 1.泛型类、接口 1.1.泛型类 泛型类的定义 class 类名称 <泛型标识, 泛型标识, ...> {private 泛型标识 变量名;...... }常用的泛型标识:T、E、K、V jav…...
LabVIEW工业虚拟仪器的标准化实施
LabVIEW工业虚拟仪器的标准化实施 创建计算机化的测试和测量系统,从计算机桌面控制外部测量硬件设备,以及在计算机屏幕上显示的类似仪器的面板上查看来自外部设备的测试或测量数据,所有这些都需要虚拟仪器系统软件。该软件允许用户执行所有这…...
JavaScript系列从入门到精通系列第十七篇:JavaScript中的全局作用域
文章目录 前言 1:什么叫作用域 一:全局作用域 1:全局变量的声明 2:变量声明和使用的顺序 3:方法声明和使用的顺序 前言 1:什么叫作用域 可以起作用的范围 function fun(){var a 1; } fun();consol…...
汇编指令集合
...
TinyWebServer整体流程
从main主函数开始: 一、定义MySQL数据库的账号、密码和用到的数据库名称。 二、调用Config获得服务器初始化属性 在这一步确定触发模式端口等信息。 三、创建服务器实例对象 设置根目录、开辟存放http连接对象的空间,开辟定时器空间。 四、利用Confi…...
【Java项目推荐之黑马头条】自媒体文章实现异步上下架(使用Kafka中间件实现)
自媒体文章上下架功能完成 需求分析 流程说明 接口定义 说明接口路径/api/v1/news/down_or_up请求方式POST参数DTO响应结果ResponseResult DTO Data public class WmNewsDto {private Integer id;/*** 是否上架 0 下架 1 上架*/private Short enable;}ResponseResult 自媒…...
自学(黑客)技术方法————网络安全
如果你想自学网络安全,首先你必须了解什么是网络安全!,什么是黑客!! 1.无论网络、Web、移动、桌面、云等哪个领域,都有攻与防两面性,例如 Web 安全技术,既有 Web 渗透2.也有 Web 防…...
python+playwright 学习-84 Response 接口返回对象
Response 是获取接口响应对象,根据Response 对象可以获取响应的状态码,响应头部,响应正文等内容。 Response 相关操作方法 all_headers 所有响应HTTP标头, 返回Dict 类型 response.all_headers()body 获取 bytes 类型body内容 response.body()json 返回响应主体的 JS…...
GCN详解
a ⃗ \vec{a} a 向量 a ‾ \overline{a} a 平均值 a ‾ \underline{a} a下横线 a ^ \widehat{a} a (线性回归,直线方程) y尖 a ~ \widetilde{a} a a ˙ \dot{a} a˙ 一阶导数 a \ddot{a} a 二阶导数 H(l)表示l层的节点的特征 W(l)表示l层的参数 D ~ \widet…...
总结二:linux面经
文章目录 1、 Linux中查看进程运行状态的指令、查看内存使用情况的指令、tar解压文件的参数。2、文件权限怎么修改?3、说说常用的Linux命令?4、说说如何以root权限运行某个程序?5、 说说软链接和硬链接的区别?6、说说静态库和动态…...
12、【Qlib】【主要组件】Qlib Recorder:实验管理
11、【Qlib】【主要组件】Qlib Recorder:实验管理 简介Qlib RecorderExperiment ManagerExperimentRecorderRecord Template简介 Qlib包含一个名为QlibRecorder的实验管理系统,旨在帮助用户以高效的方式处理实验并分析结果。 该系统有三个组件: 实验管理器(ExperimentMan…...
三一充填泵:煤矿矸石无害化充填,煤炭绿色高效开采的破局利器
富煤贫油少气是我国的能源禀赋特征,决定了我国以煤炭为主的能源结构,煤炭为国民经济发展提供了重要的基础。煤炭开采过程会对土地、地下水、空气等环境造成较大的污染,但大宗固废煤矸石无害化充填的技术手段可以有效改善这样的情况࿰…...
医疗器械标准目录汇编2022版共178页(文中附下载链接!)
为便于更好地应用医疗器械标准,国家药监局医疗器械标准管理中心组织对现行1851项医疗器械国家和行业标准按技术领域,编排形成《医疗器械标准目录汇编(2022版)》 该目录汇编分为通用技术领域和专业技术领域两大类,通用…...
C#和Excel文件的读写交互
C#和Excel文件的读写交互是一项重要的技术,在许多应用程序开发中起着关键作用。C#作为一种现代的面向编程语言,提供了丰富的库和功能,使开发人员能够轻松地处理Excel文件,并进行数据的读取和写入。 首先,让我们了解一下…...
Pytorch目标分类深度学习自定义数据集训练
目录 一,Pytorch简介; 二,环境配置; 三,自定义数据集; 四,模型训练; 五,模型验证; 一,Pytorch简介; PyTorch是一个开源的Python机…...
2023 年 Web 安全最详细学习路线指南,从入门到入职(含书籍、工具包)【建议收藏】
第一个方向:安全研发 你可以把网络安全理解成电商行业、教育行业等其他行业一样,每个行业都有自己的软件研发,网络安全作为一个行业也不例外,不同的是这个行业的研发就是开发与网络安全业务相关的软件。 既然如此,那其…...
网站平台建设要多久/源码网站
HTML5第一章总结 一 Html和CSS的关系 学习web前端开发基础技术需要掌握:HTML、CSS、JavaScript语言。下面我们就来了解下这三门技术都是用来实现什么的: 1. HTML是网页内容的载体。内容就是网页制作者放在页面上想要让用户浏览的信息,可以包…...
网站免费高清素材软件有哪些/国外b站视频推广网站
目录介绍 1.0.0.1 说下Activity的生命周期?屏幕旋转时生命周期?异常条件会调用什么方法?1.0.0.2 后台的Activity被系统回收怎么办?说一下onSaveInstanceState()和onRestoreInstanceState()方法特点?1.0.0.3 如何避免配…...
珠海建设网站公司简介/最新实时大数据
#把datetime转成字符串 def datetime_toString(dt):return dt.strftime("%Y-%m-%d-%H")#把字符串转成datetime def string_toDatetime(string):return datetime.strptime(string, "%Y-%m-%d-%H")#把字符串转成时间戳形式 def string_toTimestamp(strTime):…...
高唐企业做网站推广/附近电脑培训学校
背景:目前自己在本地写的脚本都是基本Python3.x版本的,想要在linux里边运行,必须安装3.x的环境(centos7.4自带的Python版本是2.7 )安装步骤:1.本地下载python 安装包 ,通过ftp上传到服务器2.解压tar包tar -zxvf Pytho…...
网站建设相关博客/网上销售平台
C#自定义事件需要以下步骤: 1、声明关于事件的委托;2、声明事件;3、编写触发事件的函数;4、创建事件处理程序;5、注册事件处理程序;6、在适当的条件下触发事件。 现在我们来编写一个自定义事件的程序。情…...
厦门市建设委员会网站/seo优化视频教程
这里我的环境是在linux服务器上面配置服务器端,然后在把windows当作client来连linux服务器一,下载所需的软件:1,安装所需的编译工具:#apt-get install gcc g make pkg-config libpam0g-dev sasl2-bin 2,下载lzo库[http…...