基于 ResNet18 架构使用 deformable convolution的车道线检测
下面是一个基于关键点的车道线检测网络的 PyTorch 代码示例,其中使用了 deformable convolution。该代码示例基于 ResNet18 架构,可以根据实际情况进行修改。
首先,需要导入必要的库和模块:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameterfrom torchvision.models.resnet import resnet18
然后,定义一个基于 ResNet18 架构的车道线检测网络模型:
class LaneDetectionNet(nn.Module):def __init__(self, num_classes=1, deformable_groups=2):super(LaneDetectionNet, self).__init__()# load ResNet18self.resnet = resnet18(pretrained=True)# replace the first conv layerself.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)# add deformable convolutionsself.resnet.layer1[0].conv1 = DeformConv2d(64, 64, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer1[0].conv2 = DeformConv2d(64, 64, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer2[0].conv1 = DeformConv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer2[0].conv2 = DeformConv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer3[0].conv1 = DeformConv2d(256, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer3[0].conv2 = DeformConv2d(256, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer4[0].conv1 = DeformConv2d(512, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer4[0].conv2 = DeformConv2d(512, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)# add the output layersself.fc1 = nn.Linear(512, 512)self.fc2 = nn.Linear(512, num_classes)def forward(self, x):x = self.resnet(x)x = F.relu(self.fc1(x))x = self.fc2(x)return x
其中,DeformConv2d 是一个 deformable convolution 的实现类。其代码如下:
class DeformConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, deformable_groups=1):super(DeformConv2d, self).__init__()self.offset_conv = nn.Conv2d(in_channels, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)self.weight = Parameter(torch.Tensor(out_channels, in_channels, kernel_size[0], kernel_size[1]))if bias:self.bias = Parameter(torch.Tensor(out_channels))else:self.register_parameter('bias', None)self.reset_parameters()self.stride = _pair(stride)self.padding = _pair(padding)self.dilation = _pair(dilation)self.groups = groupsself.deformable_groups = deformable_groupsdef reset_parameters(self):nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))if self.bias is not None:fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)bound = 1 / math.sqrt(fan_in)nn.init.uniform_(self.bias, -bound, bound)def forward(self, x):offset = self.offset_conv(x)output = deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups)return output
最后,定义一个 deformable convolution 的实现函数 deform_conv2d,代码如下:
def deform_conv2d(input, offset, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1):# get shapes and parametersbatch_size, in_channels, in_h, in_w = input.size()out_channels, _, kernel_h, kernel_w = weight.size()stride_h, stride_w = _pair(stride)pad_h, pad_w = _pair(padding)dilation_h, dilation_w = _pair(dilation)input_padded = F.pad(input, (pad_w, pad_w, pad_h, pad_h))# calculate output shapeout_h = (in_h + 2*pad_h - dilation_h*(kernel_h-1) - 1) // stride_h + 1out_w = (in_w + 2*pad_w - dilation_w*(kernel_w-1) - 1) // stride_w + 1# unfold input and offsetoffset = offset.view(batch_size, deformable_groups, 2 * kernel_h * kernel_w, out_h, out_w)input_unfolded = F.unfold(input_padded, (kernel_h, kernel_w), dilation=dilation, stride=stride)# calculate outputoutput = torch.zeros(batch_size, out_channels, out_h, out_w).to(input.device)weight = weight.view(1, out_channels, in_channels // groups, kernel_h, kernel_w).repeat(batch_size, 1, 1, 1, 1)for h in range(out_h):for w in range(out_w):input_region = input_unfolded[:, :, h, w].view(batch_size, -1, 1, 1)offset_region = offset[:, :, :, h, w]weight_region = weightoutput_region = F.conv2d(input_region, weight_region, bias=None, stride=1, padding=0, dilation=1, groups=deformable_groups)output_region = deformable_conv2d_compute(output_region, offset_region)output[:, :, h, w] = output_region.squeeze()if bias is not None:output += bias.view(1, -1, 1, 1)return output
其中,deformable_conv2d_compute 函数是 deformable convolution 的计算函数。它的代码如下:
def deformable_conv2d_compute(input, offset):# get shapes and parametersbatch_size, out_channels, out_h, out_w = input.size()in_channels = offset.size(1) // 2# sample input according to offsetgrid_h = torch.linspace(-1, 1, out_h).view(1, 1, out_h, 1).to(input.device)grid_w = torch.linspace(-1, 1, out_w).view(1, 1, 1, out_w).to(input.device)offset_h = offset[:, :in_channels, :, :]offset_w = offset[:, in_channels:, :, :]sample_h = torch.add(grid_h, offset_h)sample_w = torch.add(grid_w, offset_w)sample_h = sample_h.clamp(-1, 1)sample_w = sample_w.clamp(-1, 1)sample_h = ((sample_h + 1) / 2) * (out_h - 1)sample_w = ((sample_w + 1) / 2) * (out_w - 1)sample_h_floor = sample_h.floor().long()sample_w_floor = sample_w.floor().long()sample_h_ceil = sample_h.ceil().long()sample_w_ceil = sample_w.ceil().long()sample_h_floor = sample_h_floor.clamp(0, out_h - 1)sample_w_floor = sample_w_floor.clamp(0, out_w - 1)sample_h_ceil = sample_h_ceil.clamp(0, out_h - 1)sample_w_ceil = sample_w_ceil.clamp(0, out_w - 1)# gather input values according to sampled indicesinput_flat = input.view(batch_size, in_channels, out_h * out_w)index_base = torch.arange(0, batch_size, device=input.device).view(batch_size, 1, 1) * out_h * out_windex_base = index_base.expand(batch_size, in_channels, out_h * out_w)index_offset = torch.arange(0, out_h * out_w, device=input.device).view(1, 1, -1)index_offset = index_offset.expand(batch_size, in_channels, out_h * out_w)indices_a = (sample_h_floor + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)indices_b = (sample_w_floor + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)indices_c = (sample_h_ceil + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)indices_d = (sample_w_ceil + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)value_a = input_flat.gather(2, indices_a.unsqueeze(1).repeat(1, out_channels, 1))value_b = input_flat.gather(2, indices_b.unsqueeze(1).repeat(1, out_channels, 1))value_c = input_flat.gather(2, indices_c.unsqueeze(1).repeat(1, out_channels, 1))value_d = input_flat.gather(2, indices_d.unsqueeze(1).repeat(1, out_channels, 1))# calculate interpolation weights and outputw_a = ((sample_w_ceil - sample_w) * (sample_h_ceil - sample_h)).view(batch_size, 1, out_h, out_w)w_b = ((sample_w - sample_w_floor) * (sample_h_ceil - sample_h)).view(batch_size, 1, out_h, out_w)w_c = ((sample_w_ceil - sample_w) * (sample_h - sample_h_floor)).view(batch_size, 1, out_h, out_w)w_d = ((sample_w - sample_w_floor) * (sample_h - sample_h_floor)).view(batch_size, 1, out_h, out_w)output = w_a * value_a + w_b * value_b + w_c * value_c + w_d * value_dreturn output
最后,可以使用以下代码进行网络的测试:
net = LaneDetectionNet(num_classes=1, deformable_groups=2) # create the network
input = torch.randn(1, 3, 100, 100) # create a random input tensor
output = net(input) # feed it through the network
print(output.shape) # print the output shape
输出的结果应该为 (1, 1, 1, 1)。这说明网络已经成功地将 100*100 的像素图压缩成了一个标量。可以根据实际情况进行调整和优化,来达到更好的性能。
相关文章:
基于 ResNet18 架构使用 deformable convolution的车道线检测
下面是一个基于关键点的车道线检测网络的 PyTorch 代码示例,其中使用了 deformable convolution。该代码示例基于 ResNet18 架构,可以根据实际情况进行修改。 首先,需要导入必要的库和模块: import torch import torch.nn as nn…...
C++in/out输入输出流[IO流]
文章目录 1. C语言的输入与输出2.C的IO流2.1流的概念2.2CIO流2.3刷题常见while(cin >> str)重载强制类型转换运算符模拟while(cin >> str) 2.4C标准IO流2.5C文件IO流1.ifstream 1. C语言的输入与输出 C语言用到最频繁的输入输出方式就是scanf ()与printf()。 scanf…...
MongoDB的安装
MongoDB的安装 1、Windows下MongoDB的安装及配置 1.1 下载Mongodb安装包 下载地址: https://www.mongodb.com/try/download http://www.mongodb.org/dl/win32 MongoDB Windows系统64位下载地址:http://www.mongodb.org/dl/win32/x86_64 MongoDB W…...
SQL查询优化---如何查询截取分析
慢查询日志 1、慢查询日志是什么 MySQL的慢查询日志是MySQL提供的一种日志记录,它用来记录在MySQL中响应时间超过阀值的语句,具体指运行时间超过long_query_time值的SQL,则会被记录到慢查询日志中。 具体指运行时间超过long_query_time值的…...
vue3基础流程
目录 1. 安装和创建项目 2. 项目结构 3. 主要文件解析 3.1 main.js 3.2 App.vue 4. 组件和Props 5. 事件处理 6. 生命周期钩子 7. Vue 3的Composition API 8. 总结和结论 响应式系统: 组件化: 易于学习: 灵活性: 社…...
Vue 数据绑定 和 数据渲染
目录 一、Vue快速入门 1.简介 : 2.MVVM : 3.准备工作 : 二、数据绑定 1.实例 : 2.验证 : 三、数据渲染 1.单向渲染 : 2.双向渲染 : 一、Vue快速入门 1.简介 : (1) Vue[/vju/],是Vue.js的简称,是一个前端框架,常用于构建前端用户…...
【原创】解决Kotlin无法使用@Slf4j注解的问题
前言 主要还是辟谣之前的网上的用法,当然也会给出最终的使用方法。这可是Kotlin,关Slf4j何事!? 辟谣内容:创建注解来解决这个问题 例如: Target(AnnotationTarget.CLASS) Retention(AnnotationRetentio…...
CDN是如何实现全球节点同步的
当谈到内容交付网络(Content Delivery Network,CDN)加速时,我们必须了解CDN是如何实现全球节点同步的。CDN是一种网络架构,通过将内容分发到全球各地的服务器节点,以降低用户访问网站或应用程序时的延迟和提…...
Centos7 Linux系统下生成https的crt和key证书
linux下生成https的crt和key证书 步骤如下: x509证书一般会用到三类文,key,csr,crt Key 是私用密钥openssl格,通常是rsa算法。 Csr 是证书请求文件,用于申请证书。在制作csr文件的时,必须使…...
性能测试工具——Jmeter的安装【超详细】
目录 1、性能测试工具:JMeter和LoadRunner对比 2、为什么学习JMeter? 3、JMeter环境搭建 3.1、安装JDK 3.2、下载安装JMeter 3.3、配置环境变量 2.4、启动验证JMeter是否安装成功 4、认识JMeter的目录结构 1)bin目录:存放…...
系列三十、Spring AOP vs AspectJ AOP
一、关系 (1)当在Spring中要使用Aspect、Before、After等注解时,需要添加AspectJ的相关依赖,如下 <dependency><groupId>cglib</groupId><artifactId>cglib</artifactId><version>3.1</…...
面向对象设计模式——策略模式
策略设计模式(Strategy Pattern)是一种行为型设计模式,它允许在运行时选择算法的行为。该模式定义了一系列算法,将每个算法封装到一个独立的类中,使它们可以相互替换。策略模式使算法独立于客户端而变化,客…...
Kubernetes - Ingress HTTP 负载搭建部署解决方案(新版本v1.21+)
在看这一篇之前,如果不了解 Ingress 在 K8s 当中的职责,建议看之前的一篇针对旧版本 Ingress 的部署搭建,在开头会提到它的一些简介Kubernetes - Ingress HTTP 负载搭建部署解决方案_放羊的牧码的博客-CSDN博客 开始表演 1、kubeasz 一键安装…...
刚刚:腾讯云3年轻量2核2G4M服务器优惠价格366元三年
腾讯云3年轻量2核2G4M服务器,2023双十一优惠价格366元三年,自带4M公网带宽,下载速度可达512KB/秒,300GB月流量,50GB SSD盘系统盘,腾讯云百科txybk.com分享腾讯云轻量2核2G4M服务器性能、优惠活动、购买条件…...
`include指令【FPGA】
案例: 在Verilog中,include指令可以将一个文件的内容插入到当前文件中。 这个指令通常用于将一些常用的代码片段或者模块定义放在单独的文件中, 然后在需要使用的地方通过include指令将其插入到当前文件中。 这样可以提高代码的复用性和可维…...
iphone备份后怎么转到新手机,iphone备份在哪里查看
iphone备份会备份哪些东西?iphone可根据需要备份设备数据、应用数据、苹果系统等。根据不同的备份数据,可备份的数据类型不同,有些工具可整机备份,有些工具可单项数据备份。本文会详细讲解苹果手机备份可以备份哪些东西。 一、ip…...
JAVA毕业设计106—基于Java+Springboot的外卖系统(源码+数据库)
基于JavaSpringboot的外卖系统(源码数据库)106 一、系统介绍 本系统分为用户端和管理端角色 前台用户功能: 登录、菜品浏览,口味选择,加入购物车,地址管理,提交订单。 管理后台: 登录,员工管…...
SpringCore完整学习教程4,入门级别
本章从第4章开始 4. Logging Spring Boot使用Commons Logging进行所有内部日志记录,但保留底层日志实现开放。为Java Util Logging、Log4J2和Logback提供了默认配置。在每种情况下,记录器都预先配置为使用控制台输出和可选的文件输出。 默认情况下&…...
如何能在项目具体编码实现之前能尽可能早的发现问题并解决问题
在项目的具体编码实现之前尽可能早地发现并解决问题,可以大大节省时间和资源,提高项目的成功率。以下是一些策略和方法: 1. 明确需求和预期: 确保所有的项目需求都是清晰和明确的。需求模糊不清是项目失败的常见原因之一。与利益…...
Windows server服务器允许多用户远程的设置
在Windows Server上允许多用户同时进行远程桌面连接,您需要配置远程桌面服务以支持多用户并确保许可证和授权允许多用户连接。以下是在Windows Server上允许多用户远程桌面连接的步骤: 注意:这些步骤适用于 Windows Server 2012、Windows Ser…...
(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)
题目:3442. 奇偶频次间的最大差值 I 思路 :哈希,时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况,哈希表这里用数组即可实现。 C版本: class Solution { public:int maxDifference(string s) {int a[26]…...
基于大模型的 UI 自动化系统
基于大模型的 UI 自动化系统 下面是一个完整的 Python 系统,利用大模型实现智能 UI 自动化,结合计算机视觉和自然语言处理技术,实现"看屏操作"的能力。 系统架构设计 #mermaid-svg-2gn2GRvh5WCP2ktF {font-family:"trebuchet ms",verdana,arial,sans-…...
k8s从入门到放弃之Ingress七层负载
k8s从入门到放弃之Ingress七层负载 在Kubernetes(简称K8s)中,Ingress是一个API对象,它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress,你可…...
visual studio 2022更改主题为深色
visual studio 2022更改主题为深色 点击visual studio 上方的 工具-> 选项 在选项窗口中,选择 环境 -> 常规 ,将其中的颜色主题改成深色 点击确定,更改完成...
C/C++ 中附加包含目录、附加库目录与附加依赖项详解
在 C/C 编程的编译和链接过程中,附加包含目录、附加库目录和附加依赖项是三个至关重要的设置,它们相互配合,确保程序能够正确引用外部资源并顺利构建。虽然在学习过程中,这些概念容易让人混淆,但深入理解它们的作用和联…...
GitFlow 工作模式(详解)
今天再学项目的过程中遇到使用gitflow模式管理代码,因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存,无论是github还是gittee,都是一种基于git去保存代码的形式,这样保存代码…...
快刀集(1): 一刀斩断视频片头广告
一刀流:用一个简单脚本,秒杀视频片头广告,还你清爽观影体验。 1. 引子 作为一个爱生活、爱学习、爱收藏高清资源的老码农,平时写代码之余看看电影、补补片,是再正常不过的事。 电影嘛,要沉浸,…...
Linux nano命令的基本使用
参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时,显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...
区块链技术概述
区块链技术是一种去中心化、分布式账本技术,通过密码学、共识机制和智能合约等核心组件,实现数据不可篡改、透明可追溯的系统。 一、核心技术 1. 去中心化 特点:数据存储在网络中的多个节点(计算机),而非…...
Axure 下拉框联动
实现选省、选完省之后选对应省份下的市区...
