ccc-pytorch-卷积神经网络实战(6)
文章目录
- 一、CIFAR10 与 lenet5
- 二、CIFAR10 与 ResNet
一、CIFAR10 与 lenet5

第一步:准备数据集
lenet5.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transformsdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)if __name__ =='__main__':main()

第二步:确认Lenet5网络流程结构
main.py
import torch
from torch import nn
from torch.nn import functional as Fclass Lenet5(nn.Module):def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 6, ]nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),nn.AvgPool2d(kernel_size=2, stride=2, padding=0),#nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)self.fc_unit = nn.Sequential(nn.Linear(2,120), # 由输出结果反推(拉直打平)nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))#[b,3,32,32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)#[2,16,5,5] 由输出结果得到print('conv out:', out.shape)def main():net = Lenet5()if __name__ == '__main__':main()

第三步:完善lenet5 结构并使用GPU加速
lenet5.py
import torch
from torch import nn
from torch.nn import functional as Fclass Lenet5(nn.Module):def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 6, ]nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),nn.AvgPool2d(kernel_size=2, stride=2, padding=0),#nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)self.fc_unit = nn.Sequential(nn.Linear(16*5*5,120),nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))#[b,3,32,32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)#[b,16,5,5]print('conv out:', out.shape)def forward(self,x):batchsz = x.size(0)# [b, 3, 32, 32] => [b, 16, 5, 5]x = self.conv_unit(x)#[b, 16, 5, 5] => [b,16*5*5]x = x.view(batchsz,16*5*5)# [b, 16*5*5] => [b, 10]logits = self.fc_unit(x)pred = F.softmax(logits,dim=1)return logitsdef main():net = Lenet5()tmp = torch.randn(2, 3, 32, 32)out = net(tmp)print('lenet out:', out.shape)if __name__ == '__main__':main()
main.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = Lenet5().to(device)print(model)if __name__ =='__main__':main()

第四步:计算交叉熵和准确率,完成迭代
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = Lenet5().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(),lr=1e-3)print(model)for epoch in range(1000):for batchidx, (x,label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)# logits: [b, 10]# label: [b]loss = criteon(logits,label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch,'loss:',loss.item())model.eval()with torch.no_grad(): #之后代码不需backproptotal_correct = 0total_num = 0for x ,label in cifar_test:# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum()total_num += x.size(0)acc = total_correct / total_numprint(epoch,acc)if __name__ =='__main__':main()

注意事项:
- 之所以在 测试时 添加 model.eval()是因为eval()时,BN会使用之前计算好的值,并且停止使用DropOut。保证用全部训练的均值和方差
二、CIFAR10 与 ResNet

第一步:构建ResNet18的网络结构
ResNet.py
import torch
from torch import nn
from torch.nn import functional as Fclass ResBlk(nn.Module):def __init__(self,ch_in,ch_out,stride=1):super(ResBlk,self).__init__()self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_out != ch_in:self.extra = nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))#[b, ch_in, h, w] = > [b, ch_out, h, w]out = self.extra(x) + outout = F.relu((out))return outclass ResNet18(nn.Module):def __init__(self):super(ResNet18, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),nn.BatchNorm2d(64))# followed 4 blocks# [b, 64, h, w] => [b, 128, h ,w]self.blk1 = ResBlk(64,128)# [b, 128, h, w] => [b, 256, h ,w]self.blk2 = ResBlk(128,256)# [b, 256, h, w] => [b, 512, h ,w]self.blk3 = ResBlk(256,512)# [b, 512, h, w] => [b, 1024, h ,w]self.blk4 = ResBlk(512,512)self.outlayer = nn.Linear(512*1*1,10)def forward(self,x):x = F.relu(self.conv1(x))x = self.blk1(x)x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)print('after conv:', x.shape)# [b, 512, h, w] => [b, 512, 1, 1]x = F.adaptive_avg_pool2d(x, [1, 1])print('after pool:', x.shape)x = x.view(x.size(0), -1)x = self.outlayer(x)return xdef main():blk = ResBlk(64,128,stride=2)tmp = torch.randn(2,64,32,32)out = blk(tmp)print('block:',out.shape)x = torch.randn(2,3,32,32)model = ResNet18()out = model(x)print('resnet:',out.shape)if __name__ == '__main__':main()
第二步:代入第一个项目的main函数中即可
main.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from resnet import ResNet18
from torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = ResNet18().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(),lr=1e-3)print(model)for epoch in range(1000):for batchidx, (x,label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)# logits: [b, 10]# label: [b]loss = criteon(logits,label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch,'loss:',loss.item())model.eval()with torch.no_grad(): #之后代码不需backproptotal_correct = 0total_num = 0for x ,label in cifar_test:# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum()total_num += x.size(0)acc = total_correct / total_numprint(epoch,acc)if __name__ =='__main__':main()
网络结构如下:

迭代准确率和交叉熵计算如下:

其他需要注意的地方:
- 并不是ResNet的paper中流程完全相同,但是十分类似
- 可以对数据进行数据增强和归一化等操作进一步提升效果
相关文章:
ccc-pytorch-卷积神经网络实战(6)
文章目录一、CIFAR10 与 lenet5二、CIFAR10 与 ResNet一、CIFAR10 与 lenet5 第一步:准备数据集 lenet5.py import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvision import transformsdef main():batchsz 128C…...
置信椭圆(误差椭圆)详解
文章目录Part.I 预备知识Chap.I 一些概念Chap.II 主成分分析Chap.III Matlab 函数 randnChap.IV Matlab 函数 pcaPart.II 置信椭圆的含义Chap.I 一个 Matlab 实例Sec.I 两个不相关变量的特征Sec.II 两个相关变量的特征Chap.II 变换阵 (解相关矩阵) 的求解ReferencePart.I 预备知…...
FreeSWITCH 智能呼叫流程设计
文章目录1. 智能呼叫流程2. 细节处理1. 呼叫字符串指定拨号计划2. 外呼的拨号计划3. 语音打断的支持1. 智能呼叫流程 用户与机器人对话通常都是以文本的形式进行,但是借助 ASR 和 TTS 技术,以语音电话为载体的智能呼叫系统成为可能。智能呼叫系统涉及到…...
什么是Restful风格
什么是RestFul风格? Restful就是一个资源定位及资源操作的风格。不是标准也不是协议,只是一种风格。基于这个风格设计的软件可以更简洁,更有层次,更易于实现缓存等机制。 REST即Representational State Transfer的缩写࿰…...
sumifs的交叉 表的例子
比如这样,那么冰箱绿山店的栏位中,SUMIFS($D$3:$D$10,$B$3:$B$10,$F3,$C$3:$C$10,G$2)就是把求和范围,条件1设置为固定列的复合引用,条件2设置为固定行的复合引用即可。...
React :一、简单概念
目录 1.什么是React? 2.谁开发的 3.为什么要学React? 4.React的特点? 5.React依赖包 6.第一个React程序 7.虚拟DOM的两种创建方法 8.虚拟DOM和真实DOM 1.什么是React? 用于构建用户界面的JavaScript库,是一个将…...
Actipro WinForms Studio Crack
Actipro WinForms Studio Crack 已验证Microsoft.NET 7兼容性。 添加了MetroDark配色方案。 添加了支持MetroLight和MetroDark颜色方案的MetroScrollBarRenderer。 添加了IWindowsColorScheme接口,该接口将替换对WindowsColorScheme的大多数引用。 添加了IWindowsCo…...
英伦四地到底是什么关系?
英格兰、苏格兰、威尔士和北爱尔兰四地到底是什么关系,为何苏格兰非要独立?故事还要从中世纪说起。大不列颠岛位于欧洲西部,和欧洲大陆隔海相望。在古代,大不列颠岛和爱尔兰属于凯尔特人的领地。凯尔特人是欧洲西部一个庞大的族群…...
Google三大论文之GFS
Google三大论文之GFS Google GFS(Google File System) 文件系统,一个面向大规模数据密集型应用的、可伸缩的分布式文件系统。GFS 虽然运行在廉价的普遍硬件设备上,但是它依然了提供灾难冗余的能力,为大量客户机提供了…...
嵌入式安防监控项目——exynos4412主框架搭建
目录 一、模块化编程思维 二、安防监控项目主框架搭建 一、模块化编程思维 其实我们以前学习32使用keil的时候就是再用模块化的思维。每个硬件都单独有一个实现功能的C文件和声明函数,进行宏定义以及引用需要使用头文件的h文件。 比如简单的加减乘除取余操作我们…...
YOLOv5s网络模型讲解(一看就会)
文章目录前言1、YOLOv5s-6.0组成2、YOLOv5s网络介绍2.1、参数解析2.2、YOLOv5s.yaml2.3、YOLOv5s网络结构图3、附件3.1、yolov5s.yaml 解析表3.2、 yolov5l.yaml 解析表总结前言 最近在重构YOLOv5代码,本章主要介绍YOLOv5s的网络结构 1、YOLOv5s-6.0组成 我们熟知YO…...
kkfileView linux 离线安装
文章目录前言一、安装 LiberOffice二、安装kkfileView1.下载安装包2.启动总结前言 一、安装 LiberOffice 下载https://kkfileview.keking.cn/LibreOffice_7.1.4_Linux_x86-64_rpm.tar.gz 安装 tar -zxvf LibreOffice_7.1.4_Linux_x86-64_rpm.tar.gz cd LibreOffice_7.1.4.2_L…...
如何编写BI项目之ETL文档
XXXXBI项目之ETL文档 xxx项目组 ------------------------------------------------1---------------------------------------------------------------------- 目录 一 、ETL之概述 1、ETL是数据仓库建构/应用中的核心…...
【LeetCode】剑指 Offer 24. 反转链表 p142 -- Java Version
题目链接:https://leetcode.cn/problems/fan-zhuan-lian-biao-lcof/submissions/ 1. 题目介绍(24. 反转链表) 定义一个函数,输入一个链表的头节点,反转该链表并输出反转后链表的头节点。 【测试用例】: 示…...
LAY-EXCEL导出excel并实现单元格合并
通过lay-excel插件实现Excel导出,并实现单元格合并,样式设置等功能。更详细描述,请去lay-excel插件文档查看,地址:http://excel.wj2015.com/_book/docs/%E5%BF%AB%E9%80%9F%E4%B8%8A%E6%89%8B.html一、安装这里使用Vue…...
配置VM虚拟机Centos7网络
配置VM虚拟机Centos7网络 第一步,进入虚拟机设置选中【网络适配器】选择【NAT模式】 第二步,进入windows【控制面板\网络和 Internet\网络连接】设置网络状态。 我们选择【VMnet8】 点击【属性】查看它的网络配置 2 .我们找到【Internet 协议版本 4(TCP…...
Kafka 位移主题
Kafka 位移主题位移格式创建位移提交位移删除位移Kafka 的内部主题 (Internal Topic) : __consumer_offsets (位移主题,Offsets Topic) 老 Consumer 会将位移消息提交到 ZK 中保存 当 Consumer 重启后,能自动从 ZK 中读取位移数据,继续消费…...
详细讲解零拷贝机制的进化过程
一、传统拷贝方式(一)操作系统经过4次拷贝CPU 负责将数据从磁盘搬运到内核空间的 Page Cache 中;CPU 负责将数据从内核空间的 Page Cache 搬运到用户空间的缓冲区;CPU 负责将数据从用户空间的缓冲区搬运到内核空间的 Socket 缓冲区…...
2023年场外个股期权研究报告
第一章 概况 场外个股期权(Over-the-Counter Equity Option),是指由交易双方根据自己的需求和意愿,通过协商确定行权价格、行权日期等条款的股票期权。与交易所交易的标准化期权不同,场外个股期权的合同内容可以根据交…...
k8s pod,ns,pvc 强制删除
一、强制删除pod$ kubectl delete pod <your-pod-name> -n <name-space> --force --grace-period0解决方法:加参数 --force --grace-period0,grace-period表示过渡存活期,默认30s,在删除POD之前允许POD慢慢终止其上的…...
HTML 语义化
目录 HTML 语义化HTML5 新特性HTML 语义化的好处语义化标签的使用场景最佳实践 HTML 语义化 HTML5 新特性 标准答案: 语义化标签: <header>:页头<nav>:导航<main>:主要内容<article>&#x…...
【JVM】- 内存结构
引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...
从零实现STL哈希容器:unordered_map/unordered_set封装详解
本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说,直接开始吧! 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...
【Java学习笔记】BigInteger 和 BigDecimal 类
BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点:传参类型必须是类对象 一、BigInteger 1. 作用:适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...
智能AI电话机器人系统的识别能力现状与发展水平
一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...
快刀集(1): 一刀斩断视频片头广告
一刀流:用一个简单脚本,秒杀视频片头广告,还你清爽观影体验。 1. 引子 作为一个爱生活、爱学习、爱收藏高清资源的老码农,平时写代码之余看看电影、补补片,是再正常不过的事。 电影嘛,要沉浸,…...
MySQL 8.0 事务全面讲解
以下是一个结合两次回答的 MySQL 8.0 事务全面讲解,涵盖了事务的核心概念、操作示例、失败回滚、隔离级别、事务性 DDL 和 XA 事务等内容,并修正了查看隔离级别的命令。 MySQL 8.0 事务全面讲解 一、事务的核心概念(ACID) 事务是…...
FFmpeg avformat_open_input函数分析
函数内部的总体流程如下: avformat_open_input 精简后的代码如下: int avformat_open_input(AVFormatContext **ps, const char *filename,ff_const59 AVInputFormat *fmt, AVDictionary **options) {AVFormatContext *s *ps;int i, ret 0;AVDictio…...
ArcPy扩展模块的使用(3)
管理工程项目 arcpy.mp模块允许用户管理布局、地图、报表、文件夹连接、视图等工程项目。例如,可以更新、修复或替换图层数据源,修改图层的符号系统,甚至自动在线执行共享要托管在组织中的工程项。 以下代码展示了如何更新图层的数据源&…...
【工具教程】多个条形码识别用条码内容对图片重命名,批量PDF条形码识别后用条码内容批量改名,使用教程及注意事项
一、条形码识别改名使用教程 打开软件并选择处理模式:打开软件后,根据要处理的文件类型,选择 “图片识别模式” 或 “PDF 识别模式”。如果是处理包含条形码的 PDF 文件,就选择 “PDF 识别模式”;若是处理图片文件&…...
