ccc-pytorch-卷积神经网络实战(6)
文章目录
- 一、CIFAR10 与 lenet5
- 二、CIFAR10 与 ResNet
一、CIFAR10 与 lenet5

第一步:准备数据集
lenet5.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transformsdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)if __name__ =='__main__':main()

第二步:确认Lenet5网络流程结构
main.py
import torch
from torch import nn
from torch.nn import functional as Fclass Lenet5(nn.Module):def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 6, ]nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),nn.AvgPool2d(kernel_size=2, stride=2, padding=0),#nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)self.fc_unit = nn.Sequential(nn.Linear(2,120), # 由输出结果反推(拉直打平)nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))#[b,3,32,32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)#[2,16,5,5] 由输出结果得到print('conv out:', out.shape)def main():net = Lenet5()if __name__ == '__main__':main()

第三步:完善lenet5 结构并使用GPU加速
lenet5.py
import torch
from torch import nn
from torch.nn import functional as Fclass Lenet5(nn.Module):def __init__(self):super(Lenet5, self).__init__()self.conv_unit = nn.Sequential(# x: [b, 3, 32, 32] => [b, 6, ]nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),nn.AvgPool2d(kernel_size=2, stride=2, padding=0),#nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),)self.fc_unit = nn.Sequential(nn.Linear(16*5*5,120),nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))#[b,3,32,32]tmp = torch.randn(2, 3, 32, 32)out = self.conv_unit(tmp)#[b,16,5,5]print('conv out:', out.shape)def forward(self,x):batchsz = x.size(0)# [b, 3, 32, 32] => [b, 16, 5, 5]x = self.conv_unit(x)#[b, 16, 5, 5] => [b,16*5*5]x = x.view(batchsz,16*5*5)# [b, 16*5*5] => [b, 10]logits = self.fc_unit(x)pred = F.softmax(logits,dim=1)return logitsdef main():net = Lenet5()tmp = torch.randn(2, 3, 32, 32)out = net(tmp)print('lenet out:', out.shape)if __name__ == '__main__':main()
main.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = Lenet5().to(device)print(model)if __name__ =='__main__':main()

第四步:计算交叉熵和准确率,完成迭代
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = Lenet5().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(),lr=1e-3)print(model)for epoch in range(1000):for batchidx, (x,label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)# logits: [b, 10]# label: [b]loss = criteon(logits,label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch,'loss:',loss.item())model.eval()with torch.no_grad(): #之后代码不需backproptotal_correct = 0total_num = 0for x ,label in cifar_test:# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum()total_num += x.size(0)acc = total_correct / total_numprint(epoch,acc)if __name__ =='__main__':main()

注意事项:
- 之所以在 测试时 添加 model.eval()是因为eval()时,BN会使用之前计算好的值,并且停止使用DropOut。保证用全部训练的均值和方差
二、CIFAR10 与 ResNet

第一步:构建ResNet18的网络结构
ResNet.py
import torch
from torch import nn
from torch.nn import functional as Fclass ResBlk(nn.Module):def __init__(self,ch_in,ch_out,stride=1):super(ResBlk,self).__init__()self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_out != ch_in:self.extra = nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))#[b, ch_in, h, w] = > [b, ch_out, h, w]out = self.extra(x) + outout = F.relu((out))return outclass ResNet18(nn.Module):def __init__(self):super(ResNet18, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),nn.BatchNorm2d(64))# followed 4 blocks# [b, 64, h, w] => [b, 128, h ,w]self.blk1 = ResBlk(64,128)# [b, 128, h, w] => [b, 256, h ,w]self.blk2 = ResBlk(128,256)# [b, 256, h, w] => [b, 512, h ,w]self.blk3 = ResBlk(256,512)# [b, 512, h, w] => [b, 1024, h ,w]self.blk4 = ResBlk(512,512)self.outlayer = nn.Linear(512*1*1,10)def forward(self,x):x = F.relu(self.conv1(x))x = self.blk1(x)x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)print('after conv:', x.shape)# [b, 512, h, w] => [b, 512, 1, 1]x = F.adaptive_avg_pool2d(x, [1, 1])print('after pool:', x.shape)x = x.view(x.size(0), -1)x = self.outlayer(x)return xdef main():blk = ResBlk(64,128,stride=2)tmp = torch.randn(2,64,32,32)out = blk(tmp)print('block:',out.shape)x = torch.randn(2,3,32,32)model = ResNet18()out = model(x)print('resnet:',out.shape)if __name__ == '__main__':main()
第二步:代入第一个项目的main函数中即可
main.py
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from resnet import ResNet18
from torch import nn, optimdef main():batchsz = 128CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)x,label = iter(cifar_train).next()print('x',x.shape,'label:',label.shape)device = torch.device('cuda')model = ResNet18().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(),lr=1e-3)print(model)for epoch in range(1000):for batchidx, (x,label) in enumerate(cifar_train):# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)# logits: [b, 10]# label: [b]loss = criteon(logits,label)# backpropoptimizer.zero_grad()loss.backward()optimizer.step()print(epoch,'loss:',loss.item())model.eval()with torch.no_grad(): #之后代码不需backproptotal_correct = 0total_num = 0for x ,label in cifar_test:# [b, 3, 32, 32]# [b]x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum()total_num += x.size(0)acc = total_correct / total_numprint(epoch,acc)if __name__ =='__main__':main()
网络结构如下:

迭代准确率和交叉熵计算如下:

其他需要注意的地方:
- 并不是ResNet的paper中流程完全相同,但是十分类似
- 可以对数据进行数据增强和归一化等操作进一步提升效果
相关文章:
ccc-pytorch-卷积神经网络实战(6)
文章目录一、CIFAR10 与 lenet5二、CIFAR10 与 ResNet一、CIFAR10 与 lenet5 第一步:准备数据集 lenet5.py import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvision import transformsdef main():batchsz 128C…...
置信椭圆(误差椭圆)详解
文章目录Part.I 预备知识Chap.I 一些概念Chap.II 主成分分析Chap.III Matlab 函数 randnChap.IV Matlab 函数 pcaPart.II 置信椭圆的含义Chap.I 一个 Matlab 实例Sec.I 两个不相关变量的特征Sec.II 两个相关变量的特征Chap.II 变换阵 (解相关矩阵) 的求解ReferencePart.I 预备知…...
FreeSWITCH 智能呼叫流程设计
文章目录1. 智能呼叫流程2. 细节处理1. 呼叫字符串指定拨号计划2. 外呼的拨号计划3. 语音打断的支持1. 智能呼叫流程 用户与机器人对话通常都是以文本的形式进行,但是借助 ASR 和 TTS 技术,以语音电话为载体的智能呼叫系统成为可能。智能呼叫系统涉及到…...
什么是Restful风格
什么是RestFul风格? Restful就是一个资源定位及资源操作的风格。不是标准也不是协议,只是一种风格。基于这个风格设计的软件可以更简洁,更有层次,更易于实现缓存等机制。 REST即Representational State Transfer的缩写࿰…...
sumifs的交叉 表的例子
比如这样,那么冰箱绿山店的栏位中,SUMIFS($D$3:$D$10,$B$3:$B$10,$F3,$C$3:$C$10,G$2)就是把求和范围,条件1设置为固定列的复合引用,条件2设置为固定行的复合引用即可。...
React :一、简单概念
目录 1.什么是React? 2.谁开发的 3.为什么要学React? 4.React的特点? 5.React依赖包 6.第一个React程序 7.虚拟DOM的两种创建方法 8.虚拟DOM和真实DOM 1.什么是React? 用于构建用户界面的JavaScript库,是一个将…...
Actipro WinForms Studio Crack
Actipro WinForms Studio Crack 已验证Microsoft.NET 7兼容性。 添加了MetroDark配色方案。 添加了支持MetroLight和MetroDark颜色方案的MetroScrollBarRenderer。 添加了IWindowsColorScheme接口,该接口将替换对WindowsColorScheme的大多数引用。 添加了IWindowsCo…...
英伦四地到底是什么关系?
英格兰、苏格兰、威尔士和北爱尔兰四地到底是什么关系,为何苏格兰非要独立?故事还要从中世纪说起。大不列颠岛位于欧洲西部,和欧洲大陆隔海相望。在古代,大不列颠岛和爱尔兰属于凯尔特人的领地。凯尔特人是欧洲西部一个庞大的族群…...
Google三大论文之GFS
Google三大论文之GFS Google GFS(Google File System) 文件系统,一个面向大规模数据密集型应用的、可伸缩的分布式文件系统。GFS 虽然运行在廉价的普遍硬件设备上,但是它依然了提供灾难冗余的能力,为大量客户机提供了…...
嵌入式安防监控项目——exynos4412主框架搭建
目录 一、模块化编程思维 二、安防监控项目主框架搭建 一、模块化编程思维 其实我们以前学习32使用keil的时候就是再用模块化的思维。每个硬件都单独有一个实现功能的C文件和声明函数,进行宏定义以及引用需要使用头文件的h文件。 比如简单的加减乘除取余操作我们…...
YOLOv5s网络模型讲解(一看就会)
文章目录前言1、YOLOv5s-6.0组成2、YOLOv5s网络介绍2.1、参数解析2.2、YOLOv5s.yaml2.3、YOLOv5s网络结构图3、附件3.1、yolov5s.yaml 解析表3.2、 yolov5l.yaml 解析表总结前言 最近在重构YOLOv5代码,本章主要介绍YOLOv5s的网络结构 1、YOLOv5s-6.0组成 我们熟知YO…...
kkfileView linux 离线安装
文章目录前言一、安装 LiberOffice二、安装kkfileView1.下载安装包2.启动总结前言 一、安装 LiberOffice 下载https://kkfileview.keking.cn/LibreOffice_7.1.4_Linux_x86-64_rpm.tar.gz 安装 tar -zxvf LibreOffice_7.1.4_Linux_x86-64_rpm.tar.gz cd LibreOffice_7.1.4.2_L…...
如何编写BI项目之ETL文档
XXXXBI项目之ETL文档 xxx项目组 ------------------------------------------------1---------------------------------------------------------------------- 目录 一 、ETL之概述 1、ETL是数据仓库建构/应用中的核心…...
【LeetCode】剑指 Offer 24. 反转链表 p142 -- Java Version
题目链接:https://leetcode.cn/problems/fan-zhuan-lian-biao-lcof/submissions/ 1. 题目介绍(24. 反转链表) 定义一个函数,输入一个链表的头节点,反转该链表并输出反转后链表的头节点。 【测试用例】: 示…...
LAY-EXCEL导出excel并实现单元格合并
通过lay-excel插件实现Excel导出,并实现单元格合并,样式设置等功能。更详细描述,请去lay-excel插件文档查看,地址:http://excel.wj2015.com/_book/docs/%E5%BF%AB%E9%80%9F%E4%B8%8A%E6%89%8B.html一、安装这里使用Vue…...
配置VM虚拟机Centos7网络
配置VM虚拟机Centos7网络 第一步,进入虚拟机设置选中【网络适配器】选择【NAT模式】 第二步,进入windows【控制面板\网络和 Internet\网络连接】设置网络状态。 我们选择【VMnet8】 点击【属性】查看它的网络配置 2 .我们找到【Internet 协议版本 4(TCP…...
Kafka 位移主题
Kafka 位移主题位移格式创建位移提交位移删除位移Kafka 的内部主题 (Internal Topic) : __consumer_offsets (位移主题,Offsets Topic) 老 Consumer 会将位移消息提交到 ZK 中保存 当 Consumer 重启后,能自动从 ZK 中读取位移数据,继续消费…...
详细讲解零拷贝机制的进化过程
一、传统拷贝方式(一)操作系统经过4次拷贝CPU 负责将数据从磁盘搬运到内核空间的 Page Cache 中;CPU 负责将数据从内核空间的 Page Cache 搬运到用户空间的缓冲区;CPU 负责将数据从用户空间的缓冲区搬运到内核空间的 Socket 缓冲区…...
2023年场外个股期权研究报告
第一章 概况 场外个股期权(Over-the-Counter Equity Option),是指由交易双方根据自己的需求和意愿,通过协商确定行权价格、行权日期等条款的股票期权。与交易所交易的标准化期权不同,场外个股期权的合同内容可以根据交…...
k8s pod,ns,pvc 强制删除
一、强制删除pod$ kubectl delete pod <your-pod-name> -n <name-space> --force --grace-period0解决方法:加参数 --force --grace-period0,grace-period表示过渡存活期,默认30s,在删除POD之前允许POD慢慢终止其上的…...
使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式
一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明:假设每台服务器已…...
HTML 语义化
目录 HTML 语义化HTML5 新特性HTML 语义化的好处语义化标签的使用场景最佳实践 HTML 语义化 HTML5 新特性 标准答案: 语义化标签: <header>:页头<nav>:导航<main>:主要内容<article>&#x…...
模型参数、模型存储精度、参数与显存
模型参数量衡量单位 M:百万(Million) B:十亿(Billion) 1 B 1000 M 1B 1000M 1B1000M 参数存储精度 模型参数是固定的,但是一个参数所表示多少字节不一定,需要看这个参数以什么…...
macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用
文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...
剑指offer20_链表中环的入口节点
链表中环的入口节点 给定一个链表,若其中包含环,则输出环的入口节点。 若其中不包含环,则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...
Redis数据倾斜问题解决
Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中,部分节点存储的数据量或访问量远高于其他节点,导致这些节点负载过高,影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...
Java编程之桥接模式
定义 桥接模式(Bridge Pattern)属于结构型设计模式,它的核心意图是将抽象部分与实现部分分离,使它们可以独立地变化。这种模式通过组合关系来替代继承关系,从而降低了抽象和实现这两个可变维度之间的耦合度。 用例子…...
CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝
目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为:一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...
GitHub 趋势日报 (2025年06月06日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 590 cognee 551 onlook 399 project-based-learning 348 build-your-own-x 320 ne…...
Ubuntu系统复制(U盘-电脑硬盘)
所需环境 电脑自带硬盘:1块 (1T) U盘1:Ubuntu系统引导盘(用于“U盘2”复制到“电脑自带硬盘”) U盘2:Ubuntu系统盘(1T,用于被复制) !!!建议“电脑…...
