当前位置: 首页 > news >正文

PyTorch入门之【AlexNet】

参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。

目录

  • 大纲
  • dataloader
  • model
  • train
  • test

大纲

在这里插入图片描述
各个文件的作用:

  • data就是数据集
  • dataloader.py就是数据集的加载以及实例初始化
  • model.py就是AlexNet模块的定义
  • train.py就是模型的训练
  • test.py就是模型的测试

dataloader

import torch
import torchvision
import torchvision.transforms as transformsimport matplotlib.pyplot as plt
import numpy as np# define the dataloader
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])batch_size = 16trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')if __name__ == '__main__':# get some random training imagesdataiter = iter(train_loader)images, labels = next(dataiter)# print labelsprint(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))# show imagesimg_grid = torchvision.utils.make_grid(images)img_grid = img_grid / 2 + 0.5npimg = img_grid.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()

model

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),nn.BatchNorm2d(96),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_2 = nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.conv_3 = nn.Sequential(nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_4 = nn.Sequential(nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.conv_5 = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.fc_1 = nn.Sequential(nn.Dropout(0.5),nn.Linear(9216, 4096),nn.ReLU())self.fc_2 = nn.Sequential(nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU())self.fc_3= nn.Sequential(nn.Linear(4096, num_classes))def forward(self, x):out = self.conv_1(x)out = self.conv_2(out)out = self.conv_3(out)out = self.conv_4(out)out = self.conv_5(out)out = out.reshape(out.size(0), -1)out = self.fc_1(out)out = self.fc_2(out)out = self.fc_3(out)return outif __name__ == '__main__':model = AlexNet()print(model)x = torch.randn(1, 3, 224, 224)y = model(x)print(y.size())

train

import torch
import torch.nn as nnfrom dataloader import train_loader, test_loader
from model import AlexNet# define the hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
num_epochs = 20
learning_rate = 1e-3# load the model
model = AlexNet(num_classes).to(device)# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # train the model
total_len = len(train_loader)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# move tensors to the configured deviceimages = images.to(device)labels = labels.to(device)# forward passoutputs = model(images)loss = criterion(outputs, labels)# backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_len, loss.item()))# Validationwith torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()model.train()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))# save the model checkpoint
torch.save(model.state_dict(), 'alexnet.pth')

test

import torchfrom dataloader import test_loader, classes
from model import AlexNet# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AlexNet().to(device)
model.load_state_dict(torch.load('alexnet.pth', map_location=device))# test the pretrained model on CIFAR-10 test data
with torch.no_grad():model.eval()correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))

相关文章:

PyTorch入门之【AlexNet】

参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from333.999.0.0&vd_source98d31d5c9db8c0021988f2c2c25a9620 AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。 目录 大纲dataloadermodeltraintest 大纲 各个文件的作用&…...

(六)正点原子STM32MP135移植——内核移植

目录 一、概述 二、编译官方代码 三、移植 四、编译 一、概述 前面已经移植好了TF-A、optee、u-boot,在u-boot能正常跑起来的情况下,现在来移植内核。 二、编译官方代码 进入kernel目录 2.1 解压源码、打补丁 /* 解压源码 */ tar xf linux-6.1.28.…...

自媒体工作内容管理助手

内容助手 访问地址:editor.yunwow.cn 背景介绍 最近在学习流量运营, 流量运营的第一站是内容创作, 我试过不少原创内容,都是跟生活相关的例如:录一段联琴的视频、录一段秋天的风景、写一段生活感悟、发一段小宠物的生…...

Echarts 教程一

Echarts 教程一 可视化大屏幕适配方案可视化大屏幕布局方案Echart 图表通用配置部分解决方案1. titile2. tooltip3. xAxis / yAxis 常用配置4. legend5. grid6. series7.color Echarts API 使用全局echarts对象echarts实例对象 可视化大屏幕适配方案 rem flexible.js 关于flex…...

【Kubernetes】Kubernetes 对象是什么?

什么是 Kubernetes 对象?常见的 Kubernetes 对象参考🔎感谢 💖 什么是 Kubernetes 对象? Kubernetes 对象是持久化的实体,用于描述整个集群的状态和配置。它们是在 etcd 等持久化存储中存储的,因此它们的状…...

【C++设计模式之模板模式】分析及示例

C之模板模式 描述实现原理示例步骤1步骤1 分析步骤2步骤2 分析调用输出结果 结论 描述 模板模式(Template Pattern)是设计模式中的一种行为型模式。 该模式定义一个操作中的算法骨架,而将具体的算法实现延迟到子类中。 模板模式使得子类可以…...

C#捕捉全局异常

1.运行图片 2.源码 using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using System.Windows.Forms;namespace 捕捉全局异常 {internal static class Program{/// <summary>/// 应用程序的主入口点。/// </summary…...

java.text.ParseException: Unparseable date: “2023-09-06T09:08:18“

问题描述&#xff1a; java.text.ParseException: Unparseable date: “2023-09-06T09:08:18” 这是在String类型转Date类型出现的错误,主要是String类型时间中间有一个T在转换的过程出现问题. 解决方法&#xff1a; SimpleDateFormat simpleDateFormat new SimpleDateFormat…...

macOS 下如何优雅的使用 Burp Suite 汉化

转载 https://www.sqlsec.com/2019/11/macbp.html 主要内容是根据上面的来的 下面总结个人出现错误的地方 主要是优雅配置方面 不要直接复制粘贴 看清楚人家的内容 下面的可以直接复制粘贴 --add-opensjava.desktop/javax.swingALL-UNNAMED --add-opensjava.base/java.lang…...

进程同步与进程互斥

1.进程同步 知识点回顾: 进程具有异步性的特征。 异步性是指&#xff0c;各并发执行的进程以各自独立的、不可预知的速度向前推进。 如何解决这种异步问题&#xff0c;就是“进程同步”所讨论的内容。 同步亦称直接制约关系&#xff0c;它是指为完成某种任务而建立的两个或多…...

公司安防工程简要介绍及系统需求分析

多年来 从事安保监控领域的经验&#xff0c;在系统的功能要求、设备选型、施 工控制、 后期维护、人员配备等各方面反复论证&#xff0c;最终形成了本方案。在系统 的硬件选择上&#xff0c;把系统的稳定性、安全性、可靠性放在第一位。根据 招标文件的要求选用当今安防行业具…...

JMETER自适应高分辨率的显示器

系列文章目录 历史文章 每天15分钟JMeter入门篇&#xff08;一&#xff09;&#xff1a;Hello JMeter 每天15分钟JMeter入门篇&#xff08;二&#xff09;&#xff1a;使用JMeter实现并发测试 每天15分钟JMeter入门篇&#xff08;三&#xff09;&#xff1a;认识JMeter的逻辑控…...

Linux工具(三)

继Linux工具&#xff08;一&#xff09;和Linux工具&#xff08;二&#xff09;&#xff0c;下面我们就来讲解Linux最后的两个工具&#xff0c;分别是代码托管的版本控制器git和代码调试器gdb。 目录 1.git-版本控制器 从0到1的实现git代码托管 检测并安装git 新建git仓库…...

基于SSM+Vue的鲜花销售系统设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用Vue技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…...

矢量图形编辑软件illustrator 2023 mac特点介绍

illustrator 2023 mac是一款矢量图形编辑软件&#xff0c;用于创建和编辑排版、图标、标志、插图和其他类型的矢量图形。 illustrator mac软件特点 矢量图形&#xff1a;illustrator创建的图形是矢量图形&#xff0c;可以无限放大而不失真&#xff0c;这与像素图形编辑软件&am…...

【计算机网络面试题(62道)】

文章目录 计算机网络面试题&#xff08;62道&#xff09;基础1.**说下计算机网络体系结构2.说一下每一层对应的网络协议有哪些&#xff1f;3.那么数据在各层之间是怎么传输的呢&#xff1f; 网络综合4.**从浏览器地址栏输入 url 到显示主页的过程&#xff1f;5.说说 DNS 的解析…...

JVM-满老师

JVM 前言程序计数器&#xff0c;栈&#xff0c;虚拟机栈&#xff1a;本地方法栈&#xff1a;堆&#xff0c;方法区&#xff1a;堆内存溢出方法区运行时常量池 垃圾回收垃圾回收算法分代回收 前言 JVM 可以理解的代码就叫做字节码&#xff08;即扩展名为 .class 的文件&#xff…...

加锁常见的问题

锁其是用来控制在某些场景下让代码串行的工具。我们为了充分利用计算机的硬件性能&#xff0c;发明了多线程&#xff0c;多线程有好处&#xff0c;但同时也有它复杂的一面&#xff0c;必须控制好多个线程的执行&#xff0c;才能驯服这个有能力也有脾气的烈马。 一、加锁范围误区…...

【LeetCode力扣】LCR170 使用归并排序的思想解决逆序对问题(详细图解)

目录 1、题目介绍 2、解题思路 2.1、暴力破解法 2.2、归并排序思想 2.2.1、画图详细讲解 2.2.2、归并排序解决逆序对的代码实现 1、题目介绍 首先阅读题目可以得出要点&#xff0c;即当前数大于后数时则当作一个【逆序对】&#xff0c;而题目是要求在一个数组中计算一共存…...

python经典百题之一个素数能被几个9整除

题目:判断一个素数能被几个9整除。 首先&#xff0c;我们需要明确素数的定义&#xff1a;素数是大于1&#xff0c;且只能被1和自身整除的整数。 下面将分别介绍三种实现方法&#xff0c;每种方法附上解题思路、实现代码、以及优缺点。最后&#xff0c;将对这三种方法进行总结…...

Thymeleaf 内联语法使用教程

1 表达式内联 Thymeleaf标准方言允许使用标签属性(th:)来实现很多的功能&#xff0c;但在有些场景之下&#xff0c;需要将表达式直接写入HTML 代码中和CSS代码中及JavaScript代码中【代码和html文件在一起&#xff0c;分能不开&#xff0c;待验证&#xff0c;有验证的朋友可…...

Django学习笔记-实现聊天系统

笔记内容转载自 AcWing 的 Django 框架课讲义&#xff0c;课程链接&#xff1a;AcWing Django 框架课。 CONTENTS 1. 实现聊天系统前端界面2. 实现后端同步函数 1. 实现聊天系统前端界面 聊天系统整体可以分为两部分&#xff1a;输入框与历史记录。 我们需要先修改一下之前代…...

C++转换函数

什么是转换函数? C转换函数是一种特殊的成员函数&#xff0c;用于将一个类的对象转换为另一个类型。它是通过在类中定义特定的函数来实现的。 转换函数的用途&#xff1a; 类型转换&#xff1a;转换函数可以将一个类的对象从一种类型转换为另一种类型。这样可以方便地在不同…...

Spring Boot中的@Controller使用教程

一 Controller使用方法&#xff0c;如下所示&#xff1a; Controller是SpringBoot里最基本的组件&#xff0c;他的作用是把用户提交来的请求通过对URL的匹配&#xff0c;分配个不同的接收器&#xff0c;再进行处理&#xff0c;然后向用户返回结果。下面通过本文给大家介绍Spr…...

【17】c++设计模式——>原型模式

原型模式的定义 c中的原型模式&#xff08;Prototype Pattern&#xff09;是一种创建型设计模式&#xff0c;其目的是通过复制&#xff08;克隆&#xff09;已有对象来创建新的对象&#xff0c;而不需要显示的使用构造函数创建对象&#xff0c;原型模式适用于创建复杂对象时&a…...

金三银四好像消失了,IT行业何时复苏!

文章目录 1. 宏观经济形势2. 技术发展趋势3. 教育与培训4. 远程工作和自由职业5. 行业需求和公司招聘计划结论 &#x1f389;欢迎来到Java面试技巧专栏~金三银四好像消失了&#xff0c;IT行业何时复苏&#xff01; ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&#x1f379;✨博客主页&…...

PDF文件超出上传大小?三分钟学会PDF压缩

PDF作为一种流行的文档格式&#xff0c;被广泛用于各种场合&#xff0c;然而有时候PDF文件的大小超出了上传限制&#xff0c;这时候我们就需要采取一些措施来减小PDF文件的大小&#xff0c;下面就给大家分享几个方法&#xff0c;一起来学习下吧~ 方法一&#xff1a;嗨格式压缩大…...

java入坑之国际化编程

一、字符编码 1.1概述 字符编码 --字符&#xff1a;0&#xff0c;a,我&#xff0c;①&#xff0c;,… --计算机只用0和1,1bit(0或者1) --ASCIL码(American Standard Code for Information Interchange) 美国信息交换标准代码&#xff0c;奠定计算机编码基础用一个字节(1Byte8b…...

Kafka客户端核心参数详解

这一部分主要是从客户端使用的角度来理解 Kakfa 的重要机制。重点依然是要建立自己脑海中的 Kafka 消费模型。Kafka 的 HighLevel API 使用是非常简单的&#xff0c;所以梳理模型时也要尽量简单化&#xff0c;主线清晰&#xff0c;细节慢慢扩展。 一、从基础的客户端说起 Kaf…...

踩大坑ssh免密登录详细讲解

目 录 问题背景 环境说明 免密登录流程说明 1.首先要在对应的用户主机名的情况下生成密钥对&#xff0c;在A服务器执行 2.将A服务器d公钥拷贝到B服务器对应的位置 3.在A服务器访问B服务器 免密登录流程 0.用户说明 1.目前现状演示 2.删除B服务器.ssh 文件夹下面的…...

昆明做网站建设方案/专业网站seo推广

32位 计算机中的位数指的是CPU一次能处理的最大位数。32位计算机的CPU一次最多能处理32位数据&#xff0c;例如它的EAX寄存器就是32位的&#xff0c;当然32位计算机通常也可以处理16位和8位数据。在Intel由16位的286升级到386的时候&#xff0c;为了和16位系统兼容&#xff0c;…...

河南做网站推广哪个好/湖北网站设计

导语 Google I/O大会于上个月发布了新一代的Android N操作系统&#xff0c;虽然目前尚未推送正式版&#xff0c;新系统的版本号尚未正式确定&#xff0c;并且从I/O大会之前发布的preview版到发布会结束后相当一段时间内&#xff0c;全新的系统被大家认定为会命名为Android7.0。…...

做旅游网站多少钱/网络推广公司方案

在Java中计算两个日期间的天数&#xff0c;大致有2种方法&#xff1a;一是使用原生JDK进行计算&#xff0c;在JDK8中提供了更为直接和完善的方法&#xff1b;二是使用第三方库。 1、使用原生的JDK private static long daysBetween(Date one, Date two) {long difference (o…...

wordpress shopify/武汉网站seo公司

Android自定义日期区间选择器&#xff0c;类似于途家等酒店、旅游日期区间选择器&#xff1a;重写PopupWindow制定区间日历添加日历日期选中监听封装插件化github开源CustomDatePicker类似于途家等酒店日期选择器&#xff0c;弹出自定义的PopupWindow,监听日期选中&#xff0c;…...

做网页前端接活网站/让顾客进店的100条方法

使用hMailServer搭建邮件系统&#xff0c;使用webmail实现web收发邮件&#xff0c;但是又个问题是在webmail中用户自己无法修改密码。 可以使用hMailServer自带的PhpWebAdmin来实现让用户自己可以修改密码。 把hMailServer的PhpWebAdmin放在用户可访问的目录下&#xff0c;用户…...

别墅庭院园林景观设计公司/网站优化推广平台

对于以亚马逊为代表的跨境平台卖家&#xff0c;多个店铺早已成了标配。海外用户购物需求旺盛&#xff0c;中国供应链发达且完善&#xff0c;对接的就这几个大平台。然而&#xff0c;平台一般只允许一个公司一个账号。 除了AMZ&#xff0c;还要养FB营销号、开发者账号、FB广告账…...