当前位置: 首页 > news >正文

计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用,本文我们借助PyTorch,快速构建和训练卷积神经网络(CNN)等模型,以实现街道房屋号码的准确识别。引入并注意力机制,它是一种模仿人类视觉注意机制的方法,在图像处理任务中具有广泛应用。通过引入注意力机制,模型可以自动关注图像中与房屋号码相关的区域,提高识别的准确性和鲁棒性。

一、项目介绍

街道房屋号码识别是计算机视觉中的一个重要任务,通过对街道房屋号码的自动识别,可以对街道图像进行更好的理解和分析。本文将介绍如何使用PyTorch框架和注意力机制,结合SVHN数据集,来实现街道房屋号码的分类识别。

二、SVHN数据集

SVHN(Street View House Numbers)是一个公开的大规模街道数字图像数据集。该数据集包含了从Google Street View中获取的房屋门牌号码图像,可以用于训练和测试机器学习模型,以实现自动识别街道房屋号码的任务。

2.1 数据集下载和加载

首先,我们需要下载并加载SVHN数据集。在PyTorch中,我们可以使用torchvision库中的datasets模块来实现这一步。

数据集的下载与查看:

train_dataset = datasets.SVHN(root='./data', split='train', download=True)images = train_dataset.data[:10]  # shape: (10, 3, 32, 32)
labels = train_dataset.labels[:10]images = np.transpose(images, (0, 2, 3, 1))# Plot the images
fig, axs = plt.subplots(2, 5, figsize=(12, 6))
axs = axs.ravel()for i in range(10):axs[i].imshow(images[i])axs[i].set_title(f"Label: {labels[i]}")axs[i].axis('off')plt.tight_layout()
plt.show()

在这里插入图片描述

数据集的加载,预处理,便于输入模型训练:

import torch
from torchvision import datasets, transforms# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 下载并加载SVHN数据集
trainset = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
testset = datasets.SVHN(root='./data', split='test', download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

三、卷积网络搭建

使用PyTorch搭建卷积神经网络。卷积神经网络(Convolutional Neural Network, CNN)是一种主要用于处理具有类似网格结构的数据的神经网络,如图像(2D网格的像素点)或者文本(1D网格的单词)。

3.1 网络结构定义

下面是一个基础的卷积神经网络模型,包含两个卷积层、两个最大池化层和两个全连接层。

from torch import nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.drop_out = nn.Dropout()self.fc1 = nn.Linear(7 * 7 * 64, 1000)self.fc2 = nn.Linear(1000, 10)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.drop_out(out)out = self.fc1(out)return self.fc2(out)

四、加入注意力机制

注意力机制是一种能够改进模型性能的技术。在我们的模型中,我们将添加一个注意力层来帮助模型更好地专注于输入图像中的重要部分。

4.1 注意力层定义

我将实现基本的注意力层,这个层将会生成一个和输入同样大小的注意力图,然后将输入和这个注意力图对应元素相乘,以此来实现对输入的加权。

注意力机制层的数学原理:
注意力机制的数学原理可以用以下公式表示:

给定输入张量 x ∈ R b × c × h × w x \in \mathbb{R}^{b \times c \times h \times w} xRb×c×h×w,其中 b b b 是批量大小, c c c 是通道数, h h h 是高度, w w w 是宽度。注意力机制分为两个阶段:特征提取和特征加权。
1.特征提取阶段:
首先,通过自适应平均池化层(AdaptiveAvgPool2d)将输入张量 x x x 在高度和宽度上进行平均池化,得到形状为 b × c × 1 × 1 b \times c \times 1 \times 1 b×c×1×1 的张量 y y y。这里使用自适应平均池化是为了使得张量 y y y 在不同尺寸的输入上也能产生相同的输出。
2.特征加权阶段:
接下来,通过全连接层(Linear)和非线性激活函数ReLU对张量 y y y 进行特征变换,减少通道数,并保留重要特征。然后再通过另一个全连接层和Sigmoid激活函数得到权重张量 y ′ ∈ R b × c × 1 × 1 y' \in \mathbb{R}^{b \times c \times 1 \times 1} yRb×c×1×1,表示每个通道的权重值。这里的权重值在0到1之间,用于控制每个通道在后续的计算中所占的比重。将权重张量 y ′ y' y 扩展成与输入张量 x x x 相同的形状,并将其与输入张量相乘,得到经过注意力加权的特征张量。这样就实现了对输入张量的自适应特征加权。

数学表示为:
y = AdaptiveAvgPool2d ( x ) y ′ = Sigmoid ( Linear ( ReLU ( Linear ( y ) ) ) ) output = x ⊙ y ′ y = \text{AdaptiveAvgPool2d}(x) \\ y' = \text{Sigmoid}(\text{Linear}(\text{ReLU}(\text{Linear}(y)))) \\ \text{output} = x \odot y' y=AdaptiveAvgPool2d(x)y=Sigmoid(Linear(ReLU(Linear(y))))output=xy

其中 ⊙ \odot 表示按元素相乘操作。

注意力机制层的搭建代码:

class AttentionLayer(nn.Module):def __init__(self, channel, reduction=16):super(AttentionLayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel// reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)

4.2 在网络中加入注意力层

我们将注意力层加入到ConvNet模型中:

class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),AttentionLayer(32))self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),AttentionLayer(64))self.drop_out = nn.Dropout()self.fc1 = nn.Linear(8 * 8 * 64, 1000)self.fc2 = nn.Linear(1000, 10)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.drop_out(out)out = self.fc1(out)return self.fc2(out)

五、模型训练与测试

接下来,我们将进行模型的训练和测试。

5.1 模型训练

import torch.optim as optimmodel = ConvNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)for epoch in range(10):  # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 20 == 0:    # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')

5.2 模型测试

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

六、结论

这篇文章就像是一张奇妙的地图,引领你进入计算机视觉任务的神奇世界。在这个世界里,你将与PyTorch和注意力机制这两位强大的伙伴结伴前行,共同探索街道房屋号码识别的奥秘。

想象一下,你置身于繁忙的街道上,满目琳琅的房屋号码挑战着你的视力。而你却拥有了一种神奇的眼力,能轻松识别出每一个号码。这种超凡能力正是计算机视觉任务的魔法所在。

我们要携手PyTorch这位强大的工具,它如同一把巧妙的魔法棒,能帮助我们构建强大的神经网络模型。通过PyTorch,我们可以灵活地定义模型的结构,设置各种参数,并进行高效的训练和推理。

我们遇到了注意力机制,就像是一盏明亮的灯塔,照亮了我们前进的方向。注意力机制能够使神经网络集中注意力于图像中的重要区域,从而提高识别的准确性。利用这种机制,我们可以让模型更加聪明地注重街道房屋号码所在的位置和细节,从而更好地进行识别。而SVHN数据集则是我们探险的指南,其中包含了大量真实世界中的街道房屋号码图像。通过导入这些数据,我们可以让模型从中学习并提高自己的识别能力。这些图像将带领我们穿越城市的角落,感受不同场景下的挑战和变化。通过这篇文章,我们不仅可以更深入地理解计算机视觉任务的本质,还能获得启发。就像是一次奇妙的冒险,我们将学会如何使用PyTorch和注意力机制来实现街道房屋号码的识别任务。让我们一起跟随这个引人入胜的旅程,开拓视野,追寻新的可能性吧!

相关文章:

计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用,本文我们借助PyTorch,快速构建和训练卷积神经网络(CNN)等模型,…...

正则表达式:学习使用正则表达式提取网页中的目标数据

使用正则表达式提取网页中的目标数据主要有以下几个步骤: 获取网页内容:首先,你需要使用Python的库(如requests)获取网页的HTML内容。 构建正则表达式:根据你想要提取的目标数据的特征,构建相应…...

最长重复子数组(力扣)动态规划 JAVA

给两个整数数组 nums1 和 nums2 ,返回 两个数组中 公共的 、长度最长的子数组的长度 。 示例 1: 输入:nums1 [1,2,3,2,1], nums2 [3,2,1,4,7] 输出:3 解释:长度最长的公共子数组是 [3,2,1] 。 示例 2: 输…...

JavaWeb_LeadNews_Day6-Kafka

JavaWeb_LeadNews_Day6-Kafka Kafka概述安装配置kafka入门kafka高可用方案kafka详解生产者同步异步发送消息生产者参数配置消费者同步异步提交偏移量 SpringBoot集成kafka 自媒体文章上下架实现思路具体实现 来源Gitee Kafka 概述 对比 选择 介绍 producer: 发布消息的对象称…...

ATTCK覆盖度97.1%!360终端安全管理系统获赛可达认证

近日,国际知名第三方网络安全检测服务机构——赛可达实验室(SKD Labs)发布最新测试报告,360终端安全管理系统以ATT&CK V12框架攻击技术覆盖面377个、覆盖度97.1%,勒索病毒、挖矿病毒检出率100%,误报率0…...

透视俄乌网络战之一:数据擦除软件

数据擦除破坏 1. WhisperGate2. HermeticWiper3. IsaacWiper4. WhisperKill5. CaddyWiper6. DoubleZero7. AcidRain8. RURansom 数据是政府、社会和企业组织运行的关键要素。数据擦除软件可以在不留任何痕迹的情况下擦除数据并阻止操作系统恢复摧,达到摧毁或目标系统…...

微服务中间件--Nacos

Nacos 1. Nacos入门a.服务注册到Nacosb.Nacos服务分级存储模型c.NacosRule负载均衡d.服务实例的权重设置e.环境隔离 - namespacef.Nacos和Eureka的对比 2. Nacos配置管理a.统一配置管理b.配置热更新c.多环境配置共享 1. Nacos入门 Nacos是阿里巴巴的产品,现在是Spr…...

驱动开发点亮led灯

头文件 #ifndef __HEAD_H__ #define __HEAD_H__#define PHY_LED_MODER 0X50006000 #define PHY_LED_ODR 0X50006014 #define PHY_LED_RCC 0X50000A28 #define PHY_LED_FMODER 0X50007000 #define PHY_LED_FODR 0X50007014#endif驱动代码 #include <linux/init.h> #incl…...

回归预测 | MATLAB实现IPSO-SVM改进粒子群优化算法优化支持向量机多输入单输出回归预测(多指标,多图)

回归预测 | MATLAB实现IPSO-SVM改进粒子群优化算法优化支持向量机多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xff09; 目录 回归预测 | MATLAB实现IPSO-SVM改进粒子群优化算法优化支持向量机多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xf…...

数学建模之“TOPSIS数学模型”原理和代码详解

一、简介 TOPSIS&#xff08;Technique for Order Preference by Similarity to Ideal Solution&#xff09;是一种多准则决策分析方法&#xff0c;用于解决多个候选方案之间的排序和选择问题。它基于一种数学模型&#xff0c;通过比较每个候选方案与理想解和负理想解之间的相…...

threejs使用gui改变相机的参数

调节相机远近角度 定义相机的配置&#xff1a; const cameraConfg reactive({ fov: 45 }) gui中加入调节fov的方法 const gui new dat.GUI();const cameraFolder gui.addFolder("相机属性设置");cameraFolder.add(cameraConfg, "fov", 0, 100).name(…...

计算机竞赛 图像识别-人脸识别与疲劳检测 - python opencv

文章目录 0 前言1 课题背景2 Dlib人脸识别2.1 简介2.2 Dlib优点2.3 相关代码2.4 人脸数据库2.5 人脸录入加识别效果 3 疲劳检测算法3.1 眼睛检测算法3.3 点头检测算法 4 PyQt54.1 简介4.2相关界面代码 5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是…...

PHP8的字符串操作3-PHP8知识详解

今天继续分享字符串的操作&#xff0c;前面说到了字符串的去除空格和特殊字符&#xff0c;获取字符串的长度&#xff0c;截取字符串、检索字符串。 今天继续分享字符串的其他操作。如&#xff1a;替换字符串、分割和合成字符串。 5、替换字符串 替换字符串就是对指定字符串中…...

Unity VR:XR Interaction Toolkit 输入系统(Input System):获取手柄的输入

文章目录 &#x1f4d5;教程说明&#x1f4d5;Input System 和 XR Input Subsystem&#xff08;推荐 Input System&#xff09;&#x1f4d5;Input Action Asset⭐Actions Maps⭐Actions⭐Action Properties&#x1f50d;Action Type (Value, Button, Pass through) ⭐Binding …...

智慧工地一体化云平台源码:监管端、工地端、危大工程、智慧大屏、物联网、塔机、吊钩、升降机

智慧工地解决方案依托计算机技术、物联网、云计算、大数据、人工智能、VR&AR等技术相结合&#xff0c;为工程项目管理提供先进技术手段&#xff0c;构建工地现场智能监控和控制体系&#xff0c;弥补传统方法在监管中的缺陷&#xff0c;最终实现项目对人、机、料、法、环的全…...

C# 表达式体方法 C#算阶乘

//表达式体方法private int Add(int a, int b) > a b;[Fact]public void Test(){var result1 Factorial(1);//1var result2 Factorial(2);//2var result3 Factorial(3);//6var result4 Factorial(4);//24var result5 Factorial(5);//120var result6 Add(100, 200);//…...

互联网发展历程:保护与隔离,防火墙的安全壁垒

互联网的快速发展&#xff0c;不仅带来了便利和连接&#xff0c;也引发了越来越多的安全威胁。在数字时代&#xff0c;保护数据和网络安全变得尤为重要。然而&#xff0c;在早期的网络中&#xff0c;安全问题常常让人担忧。 安全问题的困扰&#xff1a;网络威胁日益增加 随着互…...

基于IMX6ULLmini的linux裸机开发系列七:中断处理流程

中断上下文 cpu通过内核寄存器来运行指令并进行数据的读写处理的&#xff0c;它在进入中断前一个时刻的具体值&#xff0c;称为中断上下文 中断上下文是指CPU在进入中断之前保存的寄存器状态和其他相关信息。当CPU接收到中断请求时&#xff0c;它会保存当前正在执行的指令的状…...

Postman软件基本用法:浏览器复制请求信息并导入到软件从而测试、发送请求

本文介绍在浏览器中&#xff0c;获取网页中的某一个请求信息&#xff0c;并将其导入到Postman软件&#xff0c;并进行API请求测试的方法。 Postman是一款流行的API开发和测试工具&#xff0c;它提供了一个用户友好的界面&#xff0c;用于创建、测试、调试和文档化API。本文就介…...

react go实现用户历史登录列表页面

refer: http://ip-api.com/ 1.首先需要创建一个保存用户历史的登录的表&#xff0c;然后连接go 2.在用户登录的时候&#xff0c;获取用户的IP IP位置&#xff0c;在后端直接处理数据即可&#xff08;不需要在前端传递数据&#xff09; &#xff08;1&#xff09;增加路由&am…...

如何做好服务性能测试

一、什么是性能测试 新功能上线或切换底层数据库或扩容调优&#xff0c;根据实际业务场景的需要&#xff0c;做必要的性能压测&#xff0c;收集性能数据&#xff0c;作为上线的基准报告。 性能测试一般分一下几个阶段&#xff1a; 1. 性能测试 并发量小&#xff08;jmeter 并…...

速通蓝桥杯嵌入式省一教程:(五)用按键和屏幕实现嵌入式交互系统

一个完整的嵌入式系统&#xff0c;包括任务执行部分和人机交互部分。在前四节中&#xff0c;我们已经讲解了LED、LCD和按键&#xff0c;用这三者就能够实现一个人机交互系统&#xff0c;也即搭建整个嵌入式系统的框架。在后续&#xff0c;只要将各个功能加入到这个交互系统中&a…...

虚拟拍摄,如何用stable diffusion制作自己的形象照?

最近收到了某活动的嘉宾邀请&#xff0c;我将分享&#xff1a; 主题&#xff1a;生成式人工智能的创新实践 简要描述&#xff1a;从品牌营销、智能体、数字内容创作、下一代社区范式等方面&#xff0c;分享LLM与图像等生成式模型的落地应用与实践经验。 领域/研究方向&#xff…...

开启AI创新之旅!“华为云杯”2023人工智能应用创新大赛等你来挑战

简介 近年来&#xff0c;人工智能技术的发展如日中天&#xff0c;深刻地改变着我们的生活方式和产业格局。 为了培养AI人才&#xff0c;持续赋能AI企业&#xff0c;推进国家新一代人工智能开放创新平台建设&#xff0c;打造更加完善的AI技术创新生态&#xff0c;华为&#xf…...

npm和node版本升级教程

cmd中查看本地安装的node版本 node -v //查询node的位置 where node2.官网下载所需要的node版本&#xff0c;安装在刚查出来的文件夹下&#xff0c;即覆盖掉原来的版本 3.查看node版本是否已经更新 4.查看npm版本是否和node版本相匹配 cnpm install -g npm...

C++入门篇9---list

list是带头双向循环链表 一、list的相关接口及其功能 1. 构造函数 函数声明功能说明list(size_type n,const value_type& valvalue_type())构造的list中包含n个值为val的元素list()构造空的listlist(const list& x)拷贝构造list(InputIterator first, InputIterator…...

STM32基于CubeIDE和HAL库 基础入门学习笔记:物联网项目开发流程和思路

文章目录&#xff1a; 第一部分&#xff1a;项目开始前的计划与准备 1.项目策划和开发规范 1.1 项目要求文档 1.2 技术实现文档 1.3 开发规范 2.创建项目工程与日志 第二部分&#xff1a;调通硬件电路与驱动程序 第三部分&#xff1a;编写最基础的应用程序 第四部分&…...

Hive on Spark (1)

spark中executor和driver分别有什么作用&#xff1f; Spark中Executor 在 Apache Spark 中&#xff0c;Executor 是分布式计算框架中的一个关键组件&#xff0c;用于在集群中执行具体的计算任务。每个 Executor 都在独立的 JVM 进程中运行&#xff0c;可以在集群的多台机器上…...

PostgreSQL基本操作总结

安装按PostgreSQL数据库后&#xff0c;会默认创建用户postgres和数据库postgres&#xff0c;这个用户是超级用户&#xff0c;权限最高&#xff0c;可以创建其他用户和权限&#xff0c;在实际开发过程中&#xff0c;会新创建用户和业务数据库&#xff0c;本文主要介绍用户权限和…...

Jakarta 的 Servlet 下BeanUtils的日期处理 和JSTL 的使用

jsp优于性能等问题已经不被spring boot等支持&#xff0c;如果想使用jsp和jstl标签库需要引入一下依赖。 <!-- 用jakarta.servlet.jsp.jstl&#xff0c;用org.glassfish.web--><dependency><groupId>jakarta.servlet.jsp.jstl</groupId><art…...