PyTorch入门之【AlexNet】
参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。
目录
- 大纲
- dataloader
- model
- train
- test
大纲

各个文件的作用:
- data就是数据集
- dataloader.py就是数据集的加载以及实例初始化
- model.py就是AlexNet模块的定义
- train.py就是模型的训练
- test.py就是模型的测试
dataloader
import torch
import torchvision
import torchvision.transforms as transformsimport matplotlib.pyplot as plt
import numpy as np# define the dataloader
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])batch_size = 16trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')if __name__ == '__main__':# get some random training imagesdataiter = iter(train_loader)images, labels = next(dataiter)# print labelsprint(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))# show imagesimg_grid = torchvision.utils.make_grid(images)img_grid = img_grid / 2 + 0.5npimg = img_grid.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()
model
import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),nn.BatchNorm2d(96),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_2 = nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_3 = nn.Sequential(nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_4 = nn.Sequential(nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_5 = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.fc_1 = nn.Sequential(nn.Dropout(0.5),nn.Linear(9216, 4096),nn.ReLU())self.fc_2 = nn.Sequential(nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU())self.fc_3= nn.Sequential(nn.Linear(4096, num_classes))def forward(self, x):out = self.conv_1(x)out = self.conv_2(out)out = self.conv_3(out)out = self.conv_4(out)out = self.conv_5(out)out = out.reshape(out.size(0), -1)out = self.fc_1(out)out = self.fc_2(out)out = self.fc_3(out)return outif __name__ == '__main__':model = AlexNet()print(model)x = torch.randn(1, 3, 224, 224)y = model(x)print(y.size())
train
import torch
import torch.nn as nnfrom dataloader import train_loader, test_loader
from model import AlexNet# define the hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
num_epochs = 20
learning_rate = 1e-3# load the model
model = AlexNet(num_classes).to(device)# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # train the model
total_len = len(train_loader)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# move tensors to the configured deviceimages = images.to(device)labels = labels.to(device)# forward passoutputs = model(images)loss = criterion(outputs, labels)# backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_len, loss.item()))# Validationwith torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()model.train()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))# save the model checkpoint
torch.save(model.state_dict(), 'alexnet.pth')
test
import torchfrom dataloader import test_loader, classes
from model import AlexNet# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AlexNet().to(device)
model.load_state_dict(torch.load('alexnet.pth', map_location=device))# test the pretrained model on CIFAR-10 test data
with torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))
相关文章:
PyTorch入门之【AlexNet】
参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from333.999.0.0&vd_source98d31d5c9db8c0021988f2c2c25a9620 AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。 目录 大纲dataloadermodeltraintest 大纲 各个文件的作用&…...
(六)正点原子STM32MP135移植——内核移植
目录 一、概述 二、编译官方代码 三、移植 四、编译 一、概述 前面已经移植好了TF-A、optee、u-boot,在u-boot能正常跑起来的情况下,现在来移植内核。 二、编译官方代码 进入kernel目录 2.1 解压源码、打补丁 /* 解压源码 */ tar xf linux-6.1.28.…...
自媒体工作内容管理助手
内容助手 访问地址:editor.yunwow.cn 背景介绍 最近在学习流量运营, 流量运营的第一站是内容创作, 我试过不少原创内容,都是跟生活相关的例如:录一段联琴的视频、录一段秋天的风景、写一段生活感悟、发一段小宠物的生…...
Echarts 教程一
Echarts 教程一 可视化大屏幕适配方案可视化大屏幕布局方案Echart 图表通用配置部分解决方案1. titile2. tooltip3. xAxis / yAxis 常用配置4. legend5. grid6. series7.color Echarts API 使用全局echarts对象echarts实例对象 可视化大屏幕适配方案 rem flexible.js 关于flex…...
【Kubernetes】Kubernetes 对象是什么?
什么是 Kubernetes 对象?常见的 Kubernetes 对象参考🔎感谢 💖 什么是 Kubernetes 对象? Kubernetes 对象是持久化的实体,用于描述整个集群的状态和配置。它们是在 etcd 等持久化存储中存储的,因此它们的状…...
【C++设计模式之模板模式】分析及示例
C之模板模式 描述实现原理示例步骤1步骤1 分析步骤2步骤2 分析调用输出结果 结论 描述 模板模式(Template Pattern)是设计模式中的一种行为型模式。 该模式定义一个操作中的算法骨架,而将具体的算法实现延迟到子类中。 模板模式使得子类可以…...
C#捕捉全局异常
1.运行图片 2.源码 using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using System.Windows.Forms;namespace 捕捉全局异常 {internal static class Program{/// <summary>/// 应用程序的主入口点。/// </summary…...
java.text.ParseException: Unparseable date: “2023-09-06T09:08:18“
问题描述: java.text.ParseException: Unparseable date: “2023-09-06T09:08:18” 这是在String类型转Date类型出现的错误,主要是String类型时间中间有一个T在转换的过程出现问题. 解决方法: SimpleDateFormat simpleDateFormat new SimpleDateFormat…...
macOS 下如何优雅的使用 Burp Suite 汉化
转载 https://www.sqlsec.com/2019/11/macbp.html 主要内容是根据上面的来的 下面总结个人出现错误的地方 主要是优雅配置方面 不要直接复制粘贴 看清楚人家的内容 下面的可以直接复制粘贴 --add-opensjava.desktop/javax.swingALL-UNNAMED --add-opensjava.base/java.lang…...
进程同步与进程互斥
1.进程同步 知识点回顾: 进程具有异步性的特征。 异步性是指,各并发执行的进程以各自独立的、不可预知的速度向前推进。 如何解决这种异步问题,就是“进程同步”所讨论的内容。 同步亦称直接制约关系,它是指为完成某种任务而建立的两个或多…...
公司安防工程简要介绍及系统需求分析
多年来 从事安保监控领域的经验,在系统的功能要求、设备选型、施 工控制、 后期维护、人员配备等各方面反复论证,最终形成了本方案。在系统 的硬件选择上,把系统的稳定性、安全性、可靠性放在第一位。根据 招标文件的要求选用当今安防行业具…...
JMETER自适应高分辨率的显示器
系列文章目录 历史文章 每天15分钟JMeter入门篇(一):Hello JMeter 每天15分钟JMeter入门篇(二):使用JMeter实现并发测试 每天15分钟JMeter入门篇(三):认识JMeter的逻辑控…...
Linux工具(三)
继Linux工具(一)和Linux工具(二),下面我们就来讲解Linux最后的两个工具,分别是代码托管的版本控制器git和代码调试器gdb。 目录 1.git-版本控制器 从0到1的实现git代码托管 检测并安装git 新建git仓库…...
基于SSM+Vue的鲜花销售系统设计与实现
末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用Vue技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…...
矢量图形编辑软件illustrator 2023 mac特点介绍
illustrator 2023 mac是一款矢量图形编辑软件,用于创建和编辑排版、图标、标志、插图和其他类型的矢量图形。 illustrator mac软件特点 矢量图形:illustrator创建的图形是矢量图形,可以无限放大而不失真,这与像素图形编辑软件&am…...
【计算机网络面试题(62道)】
文章目录 计算机网络面试题(62道)基础1.**说下计算机网络体系结构2.说一下每一层对应的网络协议有哪些?3.那么数据在各层之间是怎么传输的呢? 网络综合4.**从浏览器地址栏输入 url 到显示主页的过程?5.说说 DNS 的解析…...
JVM-满老师
JVM 前言程序计数器,栈,虚拟机栈:本地方法栈:堆,方法区:堆内存溢出方法区运行时常量池 垃圾回收垃圾回收算法分代回收 前言 JVM 可以理解的代码就叫做字节码(即扩展名为 .class 的文件ÿ…...
加锁常见的问题
锁其是用来控制在某些场景下让代码串行的工具。我们为了充分利用计算机的硬件性能,发明了多线程,多线程有好处,但同时也有它复杂的一面,必须控制好多个线程的执行,才能驯服这个有能力也有脾气的烈马。 一、加锁范围误区…...
【LeetCode力扣】LCR170 使用归并排序的思想解决逆序对问题(详细图解)
目录 1、题目介绍 2、解题思路 2.1、暴力破解法 2.2、归并排序思想 2.2.1、画图详细讲解 2.2.2、归并排序解决逆序对的代码实现 1、题目介绍 首先阅读题目可以得出要点,即当前数大于后数时则当作一个【逆序对】,而题目是要求在一个数组中计算一共存…...
python经典百题之一个素数能被几个9整除
题目:判断一个素数能被几个9整除。 首先,我们需要明确素数的定义:素数是大于1,且只能被1和自身整除的整数。 下面将分别介绍三种实现方法,每种方法附上解题思路、实现代码、以及优缺点。最后,将对这三种方法进行总结…...
RestClient
什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端,它允许HTTP与Elasticsearch 集群通信,而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级ÿ…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...
解锁数据库简洁之道:FastAPI与SQLModel实战指南
在构建现代Web应用程序时,与数据库的交互无疑是核心环节。虽然传统的数据库操作方式(如直接编写SQL语句与psycopg2交互)赋予了我们精细的控制权,但在面对日益复杂的业务逻辑和快速迭代的需求时,这种方式的开发效率和可…...
《通信之道——从微积分到 5G》读书总结
第1章 绪 论 1.1 这是一本什么样的书 通信技术,说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号(调制) 把信息从信号中抽取出来&am…...
【python异步多线程】异步多线程爬虫代码示例
claude生成的python多线程、异步代码示例,模拟20个网页的爬取,每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程:允许程序同时执行多个任务,提高IO密集型任务(如网络请求)的效率…...
2023赣州旅游投资集团
单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...
springboot整合VUE之在线教育管理系统简介
可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生,小白用户,想学习知识的 有点基础,想要通过项…...
深度学习水论文:mamba+图像增强
🧀当前视觉领域对高效长序列建模需求激增,对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模,以及动态计算优势,在图像质量提升和细节恢复方面有难以替代的作用。 🧀因此短时间内,就有不…...
vulnyx Blogger writeup
信息收集 arp-scan nmap 获取userFlag 上web看看 一个默认的页面,gobuster扫一下目录 可以看到扫出的目录中得到了一个有价值的目录/wordpress,说明目标所使用的cms是wordpress,访问http://192.168.43.213/wordpress/然后查看源码能看到 这…...
LabVIEW双光子成像系统技术
双光子成像技术的核心特性 双光子成像通过双低能量光子协同激发机制,展现出显著的技术优势: 深层组织穿透能力:适用于活体组织深度成像 高分辨率观测性能:满足微观结构的精细研究需求 低光毒性特点:减少对样本的损伤…...
