当前位置: 首页 > news >正文

LeNet

概念

代码

model

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()  # super()继承父类的构造函数self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x): x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

forward:定义正向传播的过程。

ReLU:激活哈数

观察网络中的参数传递:发现传递的都是channel通道数,最后output在softmax函数里展开的也是展开的通道数。

train

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)for epoch in range(5):  # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499:    # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image)  # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()

predict.py

import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))im = Image.open('1.jpg').convert('RGB')im = transform(im)  # [C, H, W]im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()# predict = torch.softmax(outputs,dim=1)# print(predict)# tensor([[9.9884e-01, 1.9386e-04, 3.8757e-04, 2.0671e-05, 2.5372e-04, 3.6199e-05,# 3.7643e-05, 1.7624e-04, 2.0138e-05, 3.4801e-05]])print(classes[int(predict)])if __name__ == '__main__':main()

知识点:

增加新的维度: 

im = torch.unsqueeze(im, dim=0)  # [N, C, H, W] 

predict = torch.max(outputs, dim=1)[1].numpy():

这一行代码使用torch.max()函数找到outputs张量在第一个维度上的最大值,并返回最大值和对应的索引。dim=1表示在第一个维度上进行最大值的计算,即对每个样本的输出进行比较。[1]表示返回最大值对应的索引。最后,.numpy()将结果转换为NumPy数组。 

更换:

predict = torch.softmax(outputs,dim=1)

print:tensor([[9.9884e-01, 1.9386e-04, 3.8757e-04, 2.0671e-05, 2.5372e-04, 3.6199e-05,
         3.7643e-05, 1.7624e-04, 2.0138e-05, 3.4801e-05]])

Pytorch使用

相关文章:

LeNet

概念 代码 model import torch.nn as nn import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__() # super()继承父类的构造函数self.conv1 nn.Conv2d(3, 16, 5)self.pool1 nn.MaxPool2d(2, 2)self.conv2 nn.Conv2d(16…...

JavaScript 简单理解原型和创建实例时 new 操作符的执行操作

function Person(){// 构造函数// 当函数创建,prototype 属性指向一个原型对象时,在默认情况下,// 这个原型对象将会获得一个 constructor 属性,这个属性是一个指针,指向 prototype 所在的函数对象。 } // 为原型对象添…...

生成对抗网络——研讨会

时隔一年,再跟着李沐大师学习了GAN之后,仍旧没能在离散优化中实现通用的应用,实在惭愧,借着组内研讨会的机会,再队GAN的前世今生做一个简单的综述。 GAN产生的背景 目前与GAN相关的应用 去reddit社区的机器学习板块…...

Ubuntu 20.04 安装 mysql8 LTS

Ubuntu 20.04 安装 mysql8 LTS sudo apt-get update sudo apt-get install mysql-server -y mysql --version mysql Ver 8.0.35-0ubuntu0.20.04.1 for Linux on x86_64 ((Ubuntu)) Ubuntu20.04 是自带了 MySQL8. 几版本的,低于 20.04 则默认安装是 MySQL5.7.33…...

蓝桥杯:货物摆放

小蓝有一个超大的仓库,可以摆放很多货物。 现在,小蓝有 n 箱货物要摆放在仓库,每箱货物都是规则的正方体。小蓝规定了长、宽、高三个互相垂直的方向,每箱货物的边都必须严格平行于长、宽、高。 小蓝希望所有的货物最终摆成一个大…...

ganache部署智能合约报错VM Exception while processing transaction: invalid opcode

这是因为编译的字节码不正确,ganache和remix编译时需要选择相同的evm version 如下图所示: remix: ganache: 确保两者都选择london或者其他evm,只要确保EVM一致就可以正确编译并部署, 不会再出现VM Exception while processing…...

金融银行业更适合申请哪种SSL证书?

在当今数字化时代,金融行业的重要性日益增加。越来越多的金融交易和敏感信息在线进行,金融银行机构必须采取必要的措施来保护客户数据的安全。SSL证书作为一种重要的安全技术工具,可以帮助金融银行机构加密数据传输,验证网站身份&…...

文心一言API(高级版)使用

文心一言API高级版使用 一、百度文心一言API(高级版)二、使用步骤1、接口2、请求参数3、请求参数示例4、接口 返回示例 三、 如何获取appKey和uid1、申请appKey:2、获取appKey和uid 四、重要说明 一、百度文心一言API(高级版) 基于百度文心一言语言大模型的智能文本对话AI机器…...

C# 任务并行类库Parallel调用示例

写在前面 Task Parallel Library 是微软.NET框架基础类库(BCL)中的一个,主要目的是为了简化并行编程,可以实现在不同的处理器上并行处理不同任务,以提升运行效率。Parallel常用的方法有For/ForEach/Invoke三个静态方法…...

2024年江苏省职业院校技能大赛信息安全管理与评估 第二阶段学生组(样卷)

2024年江苏省职业院校技能大赛信息安全管理与评估 第二阶段学生组(样卷) 竞赛项目赛题 本文件为信息安全管理与评估项目竞赛-第二阶段样题,内容包括:网络安全事件响应、数字取证调查、应用程序安全。 本次比赛时间为180分钟。 …...

飞天使-linux操作的一些技巧与知识点3

http工作原理 http1.0 协议 使用的是短连接,建立一次tcp连接,发起一次http的请求,结束,tcp断开 http1.1 协议使用的是长连接,建立一次tcp的连接,发起多次http的请求,结束,tcp断开ngi…...

Appium获取toast方法封装

一、前置说明 toast消失的很快,并且通过uiautomatorviewer也不能获取到它的定位信息,如下图: 二、操作步骤 toast的class name值为android.widget.Toast,虽然toast消失的很快,但是它终究是在Dom结构中出现过&…...

Google Guava简析

Google Guava 是Google开源的一个Java类库,对基本类库做了扩充。感觉最大的价值点在于其 集合类、Cache和String工具类。 github项目地址:GitHub - google/guava: Google core libraries for Java github文档地址:Home google/guava Wiki …...

反序列化漏洞详解(二)

目录 pop链前置知识,魔术方法触发规则 pop构造链解释(开始烧脑了) 字符串逃逸基础 字符减少 字符串逃逸基础 字符增加 实例获取flag 字符串增多逃逸 字符串减少逃逸 延续反序列化漏洞(一)的内容 pop链前置知识,魔术方法触…...

React全站框架Next.js使用入门

Next.js是一个基于React的服务器端渲染框架,它可以帮助我们快速构建React应用程序,并具有以下优势: 1. 支持服务器端渲染,提高页面渲染速度和SEO; 2. 自带webpack开发环境,实现即插即用的特性;…...

【操作系统笔记】-文件系统

引言 之前已经学习过数据在内存中是如何表示,如何存储,但是这些存储在PC断电后数据便消失。因此我们需要一个可以持久化存储并且容量远远大于内存的结构,这一篇我们将学习,文件是如何被组织和操作的,这是一个操作系统…...

第二十一章 网络通信

计算机网络实现了堕胎计算机间的互联,使得它们彼此之间能够进行数据交流。网络应用程序就是再已连接的不同计算机上运行的程序,这些程序借助于网络协议,相互之间可以交换数据,编写网络应用程序前,首先必须明确网络协议…...

【漏洞复现】万户协同办公平台ezoffice wpsservlet接口存在任意文件上传漏洞 附POC

漏洞描述 万户ezOFFICE集团版协同平台以工作流程、知识管理、沟通交流和辅助办公四大核心应用 万户ezOFFICE协同管理平台是一个综合信息基础应用平台。 万户协同办公平台ezoffice wpsservlet接口存在任意文件上传漏洞。 免责声明 技术文章仅供参考,任何个人和组织使用网络应…...

【uniapp】小程序中input输入框的placeholder-class不生效解决办法

问题描述 uniapp微信小程序&#xff0c;使用input组件时&#xff0c;想要改变提示词 placeholder 的样式&#xff0c;但是使用placeholder-class 改变不了 如下&#xff1a; <input type"text" placeholder"搜索" placeholder-class"placeholde…...

SimplePIR——目前最快单服务器匿踪查询方案

一、介绍 这篇论文旨在实现高效的单服务器隐私信息检索&#xff08;PIR&#xff09;方案&#xff0c;以解决在保护用户隐私的同时快速检索数据库的问题。为了实现这一目标&#xff0c;论文提出了两种新的PIR方案&#xff1a;SimplePIR和DoublePIR。这两种方案的实现基于学习与错…...

【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器

一.自适应梯度算法Adagrad概述 Adagrad&#xff08;Adaptive Gradient Algorithm&#xff09;是一种自适应学习率的优化算法&#xff0c;由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率&#xff0c;适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...

边缘计算医疗风险自查APP开发方案

核心目标:在便携设备(智能手表/家用检测仪)部署轻量化疾病预测模型,实现低延迟、隐私安全的实时健康风险评估。 一、技术架构设计 #mermaid-svg-iuNaeeLK2YoFKfao {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg…...

Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务

通过akshare库&#xff0c;获取股票数据&#xff0c;并生成TabPFN这个模型 可以识别、处理的格式&#xff0c;写一个完整的预处理示例&#xff0c;并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务&#xff0c;进行预测并输…...

P3 QT项目----记事本(3.8)

3.8 记事本项目总结 项目源码 1.main.cpp #include "widget.h" #include <QApplication> int main(int argc, char *argv[]) {QApplication a(argc, argv);Widget w;w.show();return a.exec(); } 2.widget.cpp #include "widget.h" #include &q…...

第 86 场周赛:矩阵中的幻方、钥匙和房间、将数组拆分成斐波那契序列、猜猜这个单词

Q1、[中等] 矩阵中的幻方 1、题目描述 3 x 3 的幻方是一个填充有 从 1 到 9 的不同数字的 3 x 3 矩阵&#xff0c;其中每行&#xff0c;每列以及两条对角线上的各数之和都相等。 给定一个由整数组成的row x col 的 grid&#xff0c;其中有多少个 3 3 的 “幻方” 子矩阵&am…...

全志A40i android7.1 调试信息打印串口由uart0改为uart3

一&#xff0c;概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本&#xff1a;2014.07&#xff1b; Kernel版本&#xff1a;Linux-3.10&#xff1b; 二&#xff0c;Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01)&#xff0c;并让boo…...

现有的 Redis 分布式锁库(如 Redisson)提供了哪些便利?

现有的 Redis 分布式锁库&#xff08;如 Redisson&#xff09;相比于开发者自己基于 Redis 命令&#xff08;如 SETNX, EXPIRE, DEL&#xff09;手动实现分布式锁&#xff0c;提供了巨大的便利性和健壮性。主要体现在以下几个方面&#xff1a; 原子性保证 (Atomicity)&#xff…...

android13 app的触摸问题定位分析流程

一、知识点 一般来说,触摸问题都是app层面出问题,我们可以在ViewRootImpl.java添加log的方式定位;如果是touchableRegion的计算问题,就会相对比较麻烦了,需要通过adb shell dumpsys input > input.log指令,且通过打印堆栈的方式,逐步定位问题,并找到修改方案。 问题…...

探索Selenium:自动化测试的神奇钥匙

目录 一、Selenium 是什么1.1 定义与概念1.2 发展历程1.3 功能概述 二、Selenium 工作原理剖析2.1 架构组成2.2 工作流程2.3 通信机制 三、Selenium 的优势3.1 跨浏览器与平台支持3.2 丰富的语言支持3.3 强大的社区支持 四、Selenium 的应用场景4.1 Web 应用自动化测试4.2 数据…...

【Linux】自动化构建-Make/Makefile

前言 上文我们讲到了Linux中的编译器gcc/g 【Linux】编译器gcc/g及其库的详细介绍-CSDN博客 本来我们将一个对于编译来说很重要的工具&#xff1a;make/makfile 1.背景 在一个工程中源文件不计其数&#xff0c;其按类型、功能、模块分别放在若干个目录中&#xff0c;mak…...