当前位置: 首页 > news >正文

简易机器学习笔记(八)关于经典的图像分类问题-常见经典神经网络LeNet

前言

图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉的核心,是物体检测、图像分割、物体跟踪、行为分析、人脸识别等其他高层次视觉任务的基础。图像分类在许多领域都有着广泛的应用,如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。

这里简单讲讲LeNet

我的推荐是可以看看这个视频,可视化的查看卷积神经网络是如何一层一层地抽稀获得特征,最后将所有的图像展开成一个一维的轴,再通过全连接神经网络预测得到一个最后的预测值。

手写数字识别 1.4 LeNet-5-哔哩哔哩

在这里插入图片描述

计算过程

前置知识:

  1. 步长 Stride & 加边 Padding

卷积后尺寸=(输入尺寸-卷积核大小+加边像素数)/步长 + 1

默认Padding = ‘valid’ (丢弃),strides = 1
在这里插入图片描述

正式计算

  1. 卷积层1:

第一层我们给定的图像时32 * 32,使用六个5 x 5的卷积核,步长为1

第一层中没有加边,那么卷积后的尺寸就是(32 - 5 + 0 )/1 + 1 =28,那么输出的图像就是 28*28的边长

在第一层中,由于我们使用了六个卷积核,我们得到的输出为:62828,可以理解为一个六层厚的图像

  1. 池化层1:

我们在池化层内在2x2的图像内选取了一个最大值或者平均值,也就是图片整体缩水到原先的二分之一,所以我们得到池化层的输出为 6 x 14 x 14

  1. 卷积层2:

还是按照公式,卷积后尺寸=(输入-卷积核+加边像素数)/步长 + 1,这个时候输入为6 x 14 x 14,这一次我们给定了16个卷积核,得到输出后的尺寸为(14 - 5 + 0)/1 + 1 = 10,得到输出为161010

关于这个16个卷积核是怎么来的,可以见图:

问了下组里的大佬,大佬说这个卷积核数目和层数很多是经验值,即你寻求更多或者更少的卷积核数目或者层数,实际效果不一定有经验值更好,反正都是离散值,就随便试试就行了。

其中:卷积输出尺寸nout:nin为输入原图尺寸大小;s是步长(一次移动几个像素);p补零圈数,

我们这里输入的值

  1. 池化层2

得到 输出后尺寸为16 * 5 * 5

  1. 全连接层1:

输入为16 * 5 * 5 ,有120个5*5卷积核,步长为1,输出尺寸为(5 - 5 + 0)/1 + 1 =1,这时候输出的就是一条直线的一维输出了

  1. 全连接层2:

输入为120,使用了84个神经元,

  1. 输出层

输入84,输出为10

比如我们如图所示,在代码中是这样的:

# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear## 组网
import paddle.nn.functional as F
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST
#定义LeNet网络结构# 定义 LeNet 网络结构
class LeNet(paddle.nn.Layer):def __init__(self, num_classes=1):super(LeNet,self).__init__()#创建卷积层和池化层#创建第一个卷积层self.conv1 = Conv2D(in_channels=1,out_channels=6,kernel_size=5)self.max_pool1 = MaxPool2D(kernel_size=2,stride=2)#尺寸的逻辑:池化层未改变通道数,当前通道为6#创建第二个卷积层self.conv2 = Conv2D(in_channels=6,out_channels=16,kernel_size=5)self.max_pool2 = MaxPool2D(kernel_size=2,stride=2)#创建第三个卷积层self.conv3 = Conv2D(in_channels=16,out_channels=120,kernel_size=4)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]# 输入size是[28,28],经过三次卷积和两次池化之后,C*H*W等于120self.fc1 = Linear(in_features=120, out_features=64)# 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分类标签的类别数self.fc2 = Linear(in_features=64, out_features=num_classes)# 网络的前向计算过程def forward(self, x):x = self.conv1(x)# 每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化x = F.sigmoid(x)x = self.max_pool1(x)x = F.sigmoid(x)x = self.conv2(x)x = self.max_pool2(x)x = self.conv3(x)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]x = paddle.reshape(x, [x.shape[0], -1])x = self.fc1(x)x = F.sigmoid(x)x = self.fc2(x)return x
# 飞桨会根据实际图像数据的尺寸和卷积核参数自动推断中间层数据的W和H等,只需要用户表达通道数即可。
# 下面的程序使用随机数作为输入,查看经过LeNet-5的每一层作用之后,输出数据的形状。# 输入数据形状是 [N, 1, H, W]
# 这里用np.random创建一个随机数组作为输入数据
x = np.random.randn(*[3,1,28,28])
x = x.astype('float32')# 创建LeNet类的实例,指定模型名称和分类的类别数目
model = LeNet(num_classes=10)# 通过调用LeNet从基类继承的sublayers()函数,
# 查看LeNet中所包含的子层
print(model.sublayers())
x = paddle.to_tensor(x)for item in model.sublayers():#item是LeNet类中的一个子层#查看经过子层之后的输出数据形状try:x = item(x)except:x = paddle.reshape(x, [x.shape[0], -1])x = item(x)if len(item.parameters())==2:# 查看卷积和全连接层的数据和参数的形状,# 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数bprint(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)else:# 池化层没有参数print(item.full_name(), x.shape)# 设置迭代轮数
EPOCH_NUM = 5
#定义训练过程 
def train(model,opt,train_loader,valid_loader):print("start training ... ")model.train()for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):img = data[0]label = data[1] #计算模型输出# 计算模型输出logits = model(img)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)avg_loss = paddle.mean(loss)if batch_id % 2000 == 0:print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))#反向传播avg_loss.backward()opt.step()opt.clear_grad()model.eval()accuracies = []losses = []for batch_id, data in enumerate(valid_loader()):img = data[0]label = data[1]# 计算模型输出logits = model(img)pred = F.softmax(logits)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)acc = paddle.metric.accuracy(pred, label)accuracies.append(acc.numpy())losses.append(loss.numpy())print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))model.train()# 保存模型参数paddle.save(model.state_dict(), 'mnist.pdparams')    # 创建模型
model = LeNet(num_classes=10)
# 设置迭代轮数
EPOCH_NUM = 5
# 设置优化器为Momentum,学习率为0.001
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train', transform=ToTensor()), batch_size=10, shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test', transform=ToTensor()), batch_size=10)
# 启动训练过程
train(model, opt, train_loader, valid_loader)

相关文章:

简易机器学习笔记(八)关于经典的图像分类问题-常见经典神经网络LeNet

前言 图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉的核心,是物体检测、图像分割、物体跟踪、行为分析、人脸识别等其他高层次视觉任务的基础。图像分类在许多领域都有着广泛的应用,如:安防领域的人脸识别和…...

pytest conftest通过fixture实现变量共享

conftest.py scope"module" 只对当前执行的python文件 作用 pytest.fixture(scope"module") def global_variable():my_dict {}yield my_dict test_case7.py import pytestlist1 []def test_case001(global_variable):data1 123global_variable.u…...

系列五、搭建Naco(集群版)

一、搭建Naco(集群版) 1.1、前置说明 (1)64位Red Hat7 Linux 系统; (2)64位JDK1.8;备注:如果没有安装JDK,请参考【系列二、Linux中安装JDK】 (3&…...

JavaScript中alert、prompt 和 confirm区别及使用【通俗易懂】

✨前言✨   本篇文章主要在于,让我们看几个与用户交互的函数:alert,prompt 和confirm的使用及区别 🍒欢迎点赞 👍 收藏 ⭐留言评论 📝私信必回哟😁 🍒博主将持续更新学习记录收获&…...

【GoLang入门教程】Go语言几种标准库介绍(四)

编程语言的未来? 文章目录 编程语言的未来?前言几种库fmt库 (格式化操作)关键函数:示例 Go库标准库第三方库示例 html库(HTML 转义及模板系统)主要功能:示例 总结专栏集锦写在最后 前言 上一篇,我们介绍了debug、enco…...

面试算法:快速排序

题目 快速排序是一种非常高效的算法,从其名字可以看出这种排序算法最大的特点就是快。当表现良好时,快速排序的速度比其他主要对手(如归并排序)快2~3倍。 分析 快速排序的基本思想是分治法,排序过程如下…...

航空业数字化展翅高飞,开源网安专业服务保驾护航

​某知名航空公司是中国首批民营航空公司之一,运营国内外航线200多条,也是国内民航最高客座率的航空公司之一。在数字化发展中,该航空公司以数据驱动决策,通过精细化管理、数字创新和模式优化等方式,实现了精准营销和个…...

SpringBoot学习(三)-员工管理系统开发(重在理解)

注:此为笔者学习狂神说SpringBoot的笔记,其中包含个人的笔记和理解,仅做学习笔记之用,更多详细资讯请出门左拐B站:狂神说!!! 本文是基于狂神老师SpringBoot教程中的员工管理系统从0到1的实践和理解。该系统应用SpringB…...

2 Windows网络编程

1 基础概念 1.1 socket概念 Socket 的原意是“插座”,在计算机通信领域,socket 被翻译为“套接字”,它是计算机之间进行通信的一种约定或一种方式。Socket本质上是一个抽象层,它是一组用于网络通信的API,包括了一系列…...

uniapp选择android非图片文件的方案踩坑记录

这个简单的问题我遇到下面6大坑,原始需求是选择app如android的excel然后读取到页面并上传表格数据json 先看看效果 uniapp 选择app excel文件读取 1.uniapp自带不支持 uniapp选择图片和视频非常方便自带已经支持可以直接上传和读取 但是选择word excel的时候就出现…...

前端发开的性能优化 请求级:请求前(资源预加载和预读取)

预加载 预加载:是优化网页性能的重要技术,其目的就是在页面加载过程中先提前请求和获取相关的资源信息,减少用户的等待时间,提高用户的体验性。预加载的操作可以尝试去解决一些类似于减少首次内容渲染的时间,提升关键资…...

B01、类加载子系统-02

JVM架构图-英文版 中文版见下图: 1、概述类的加载器及类加载过程 1.1、类加载子系统的作用 类加载器子系统负责从文件系统或者网络中加载Class文件,class文件在文件开头有特定的文件标识。ClassLoader只负责class文件的加载,至于它是否可以运行,则由Execution Engi…...

用PHP搭建一个绘画API

【腾讯云AI绘画】用PHP搭建一个绘画API 大家好!今天我要给大家推荐的是如何用PHP搭建一个绘画API,让你的网站或应用瞬间拥有强大的绘画能力!无论你是想要让用户在网页上绘制自己的创意,还是想要实现自动绘画生成特效,这…...

西安人民检察院 | OLED翻页查询一体机

产品:55寸OLED柔性屏 项目时间:2023年12月 项目地点:西安 在2023年12月,西安人民检察院引入了OLED翻页查询一体机,为来访者提供了一种全新的信息查询方式。 这款一体机采用55寸OLED柔性屏,具有高清晰度、…...

superset利用mysql物化视图解决不同数据授权需要写好几次中文别名的问题

背景 在使用superset时,给不同的人授权不同的数据,需要不同的数据源,可视化字段希望是中文,所以导致不同的人需要都需要去改表的字段,因此引入视图,将视图中字段名称设置为中文 原表数据 select * from …...

输入输出流

1.输入输出流 输入/输出流类:iostream---------i input(输入) o output(输出) stream:流 iostream: istream类:输入流类-------------cin:输入流类的对象 ostream类…...

IOS:Safari无法播放MP4(H.264编码)

一、问题描述 MP4使用H.264编码通常具有良好的兼容性,因为H.264是一种广泛支持的视频编码标准。它可以在许多设备和平台上播放,包括电脑、移动设备和流媒体设备。 使用caniuse查询H.264兼容性,看似确实具有良好的兼容性: 然而…...

Pycharm恢复默认设置

window 系统 找到下方目录-->删除. 再重新打开Pycharm C:\Users\Administrator\.PyCharm2023.3 你的不一定和我名称一样 只要是.PyCharm*因为版本不同后缀可能不一样 mac 系统 请根据需要删除下方目录 # Configuration rm -rf ~/Library/Preferences/PyCharm* # Caches …...

简单计算器实现,包括两个数

正在加载中... 简单计算器实现,包括两个数 ❤ 厾罗 简单计算器实现,包括两个数 以下代码用于实现简单计算器实现,包括两个数基本的加减乘除运算: 实例(Python 3.0) # Filename : test.py # author by : www.dida100.com …...

竞赛保研 基于机器视觉的手势检测和识别算法

0 前言 🔥 优质竞赛项目系列,今天要分享的是 基于深度学习的手势检测与识别算法 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng…...

Android App从备案到上架全过程

不知道大家注意没有,最近几年来,新的移动App想要上架是会非常困难的,并且对于个人开发者和小企业几乎是难如登天,各种备案和审核。但是到底有多难,或许只有上架过的才会有所体会。 首先是目前各大应用市场陆续推出新的声明,各种备案截止日期到12月就要到最后期限责令整改…...

用邮件及时获取变更的公网IP--------python爬虫+打包成exe文件

参考获取PC机公网IP并发送至邮箱 零、找一个发送邮件的邮箱 本文用QQ邮箱为发送邮箱,网易等邮箱一般也有这个功能,代码也是通用的。 第一步:在设置中找到账户,找到POP3/IMAP/SMTP/Exchange/CardDAV/CalDAV服务,点击获…...

c++学习:函数模板+实战

目录 函数模板 思考 如果两个参数的类型不一样可以下面这么写 如果有指定返回参数可以下面这么写 实战 找出三个数中最大的一个 函数模板 实际上就是建立一个通用函数,其函数返回值类型和形参类型不具体指定,用一个虚拟的类型来代表template 是一个…...

three.js gltf后处理颜色异常(伽马校正)

效果&#xff1a; 应用了伽马校正&#xff0c;好像效果不明显 代码&#xff1a; <template><div><el-container><el-main><div class"box-card-left"><div id"threejs" style"border: 1px solid red"><…...

面试经典150题(55-58)

leetcode 150道题 计划花两个月时候刷完&#xff0c;今天&#xff08;第二十四天&#xff09;完成了4道(55-58)150&#xff1a; 55.&#xff08;19. 删除链表的倒数第 N 个结点&#xff09;题目描述&#xff1a; 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff…...

如果一个n位正整数等于其各位数字的n次方之和

❤ 厾罗 如果一个n位正整数等于其各位数字的n次方之和 如果一个n位正整数等于其各位数字的n次方之和,则称该数为阿姆斯特朗数。 例如1^3 5^3 3^3 153。 1000以内的阿姆斯特朗数&#xff1a; 1, 2, 3, 4, 5, 6, 7, 8, 9, 153, 370, 371, 407。 以下代码用于检测用户输…...

solidity显示以太坊美元价格

看过以太坊白皮书的都知道&#xff0c;以太坊比较比特币而言所提升的地方中&#xff0c;我认为最重要的一点就是能够访问外部的数据&#xff0c;这一点在赌博、金融领域应用会很广泛&#xff0c;但是区块链是一个确定的系统&#xff0c;包括里面的所有数值包括交易ID等都是确定…...

ChatGPT学习笔记——大模型基础理论体系

1、ChatGPT的背景与意义 近期,ChatGPT表现出了非常惊艳的语言理解、生成、知识推理能力, 它可以极好的理解用户意图,真正做到多轮沟通,并且回答内容完整、重点清晰、有概括、有条理。 ChatGPT 是继数据库和搜索引擎之后的全新一代的 “知识表示和调用方式”如下表所示。 …...

Termius for Mac/Win:一款功能强大的终端模拟器、SSH 和 SFTP 客户端软件

随着远程工作和云技术的普及&#xff0c;对于高效安全的远程访问和管理服务器变得至关重要。Termius&#xff0c;一款强大且易用的终端模拟器、SSH 和 SFTP 客户端软件&#xff0c;正是满足这一需求的理想选择。 Termius 提供了一站式的解决方案&#xff0c;允许用户通过单一平…...

python如何读取被压缩的图像

读取压缩的图像数据&#xff1a; PackBits 压缩介绍&#xff1a; CCITT T.3 压缩介绍&#xff1a; 读取压缩的图像数据&#xff1a; 在做图像处理的时候&#xff0c;平时都是使用 函数io.imread() 或者是 函数cv2.imread( ) 函数来读取图像数据&#xff0c;很少用PIL.Image…...

郑州网站优化的微博_腾讯微博/微信crm系统

你希望你的网站更有说服力吗&#xff1f;说服的能力是演说家、作家和营销人员梦寐以求的技能。在你的网站应用一个或多个这种增强说服力的技术&#xff0c;可以让你游刃有余地控制转化率。 **以下是心理学中最具魅力和说服力的21种说服技巧。**有了这些技巧&#xff0c;就可以…...

信主网站/广告开户

题目&#xff1a;用*号输出字母C的图案。 程序分析&#xff1a;可先用*号在纸上写出字母C&#xff0c;再分行输出。 程序代码&#xff1a; #include <stdio.h> int main() {printf("用 * 号输出字母 C!\n");printf(" ****\n");printf(" *\n&…...

成都百度公司在哪里/广州网站优化方式

这是一款支持多种平台去水印的一款微信小程序源码 支持短视频去水印,还有图集去水印等 内含多平台去水印接口,响应的速度也是非常的快 这是一款非常值得推荐的一款小程序源码 另外还支持多种流量主模式收益,大家只需要替换对应的流量主ID即可 小程序源码下载地址&#xff1…...

个人主页网站模板/google推广平台怎么做

package com.njue.mis;public class GetNumber {static String s "32fdsfd8fds0fdsf9323k32k";public static void main(String[] args){String a s.replaceAll("[^0-9]", "");System.out.print(a);}}...

清远做网站seo/以图搜图百度识图网页版

先准备MySQL的安装包&#xff0c;可以去官网下载&#xff0c;也可以在其它镜像资源网站下载。搜狐的http://mirrors.sohu.com/mysql/MySQL-5.7/。MySQL5.7需要boost_1.59,指明了就这个版本的&#xff0c;其它版本不行。也可以下载包含boost的mysql&#xff1a;mysql-boost-5.7.…...

个人做金融网站能赚钱吗/珠海关键词优化软件

本文适合初学者阅读 5.3 切片切片本身并非动态数组或数组指针. 它内部通过指针引用底屋数组, 设定相关属性将数据读写操作限定在指定区域内.切片本身是只只读对象, 其工作机制类似数据指针的一种包装.可基于数组或数组指针创建切片, 以开始和结束索引位置确定所引用的数组片段…...