当前位置: 首页 > news >正文

基于ROPNet项目训练modelnet40数据集进行3d点云的配置

项目地址: https://github.com/zhulf0804/ROPNet 在 MVP Registration Challenge (ICCV Workshop 2021)(ICCV Workshop 2021)中获得了第二名。项目可以在win10环境下运行。
论文地址: https://arxiv.org/abs/2107.02583

网络简介: 一种新的深度学习模型,该模型利用具有区别特征的代表性重叠点进行配准,将部分到部分配准转换为部分完全配准。基于pointnet输出的特征设计了一个上下文引导模块,使用一个编码器来提取全局特征来预测点重叠得分。为了更好地找到有代表性的重叠点,使用提取的全局特征进行粗对齐。然后,引入一种变压器来丰富点特征,并基于点重叠得分和特征匹配去除非代表性点。在部分到完全的模式下建立相似度矩阵,最后采用加权支持向量差来估计变换矩阵。
在这里插入图片描述
实施效果: 从数据上看ROPNet与RPMNet与保持了断崖式的领先地位
在这里插入图片描述

1、运行环境安装

1.1 项目下载

打开https://github.com/zhulf0804/ROPNet,点Download ZIP然后将代码解压到指定目录下即可。
在这里插入图片描述

1.2 依赖项安装

在装有pytorch的环境终端,进入ROPNet-master/src目录,执行以下安装命令。如果已经安装了torch 环境和open3d包,则不用再进行安装了

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118pip install open3d

1.3 模型与数据下载

modelnet40数据集 here [435M]
数据集下载后存储为以下路径即可。
在这里插入图片描述

官网预训练模型,无。
第三方预训练模型:使用ROPNet项目在modelnet40数据集上训练的模型

2、关键代码

2.1 dataloader

作者所提供的dataloader只能加载https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip 数据集,其所返回的tgt_cloud, src_cloud实质上是基于一个点云采样而来的。 其中的self.label2cat, self.cat2label, self.symmetric_labels等对象代码实际上是没有任何作用的。

import copy
import h5py
import math
import numpy as np
import os
import torchfrom torch.utils.data import Dataset
import sysBASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOR_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOR_DIR)
from utils import  random_select_points, shift_point_cloud, jitter_point_cloud, \generate_random_rotation_matrix, generate_random_tranlation_vector, \transform, random_crop, shuffle_pc, random_scale_point_cloud, flip_pchalf1 = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl','car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser','flower_pot', 'glass_box', 'guitar', 'keyboard', 'lamp']
half1_symmetric = ['bottle', 'bowl', 'cone', 'cup', 'flower_pot', 'lamp']half2 = ['laptop', 'mantel', 'monitor', 'night_stand', 'person', 'piano','plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 'stool','table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
half2_symmetric = ['tent', 'vase']class ModelNet40(Dataset):def __init__(self, root, split, npts, p_keep, noise, unseen, ao=False,normal=False):super(ModelNet40, self).__init__()self.single = False # for specific-class visualizationassert split in ['train', 'val', 'test']self.split = splitself.npts = nptsself.p_keep = p_keepself.noise = noiseself.unseen = unseenself.ao = ao # Asymmetric Objectsself.normal = normalself.half = half1 if split in 'train' else half2self.symmetric = half1_symmetric + half2_symmetricself.label2cat, self.cat2label = self.label2category(os.path.join(root, 'shape_names.txt'))self.half_labels = [self.cat2label[cat] for cat in self.half]self.symmetric_labels = [self.cat2label[cat] for cat in self.symmetric]files = [os.path.join(root, 'ply_data_train{}.h5'.format(i))for i in range(5)]if split == 'test':files = [os.path.join(root, 'ply_data_test{}.h5'.format(i))for i in range(2)]self.data, self.labels = self.decode_h5(files)print(f'split: {self.split}, unique_ids: {len(np.unique(self.labels))}')if self.split == 'train':self.Rs = [generate_random_rotation_matrix() for _ in range(len(self.data))]self.ts = [generate_random_tranlation_vector() for _ in range(len(self.data))]def label2category(self, file):with open(file, 'r') as f:label2cat = [category.strip() for category in f.readlines()]cat2label = {label2cat[i]: i for i in range(len(label2cat))}return label2cat, cat2labeldef decode_h5(self, files):points, normal, label = [], [], []for file in files:f = h5py.File(file, 'r')cur_points = f['data'][:].astype(np.float32)cur_normal = f['normal'][:].astype(np.float32)cur_label = f['label'][:].flatten().astype(np.int32)if self.unseen:idx = np.isin(cur_label, self.half_labels)cur_points = cur_points[idx]cur_normal = cur_normal[idx]cur_label = cur_label[idx]if self.ao and self.split in ['val', 'test']:idx = ~np.isin(cur_label, self.symmetric_labels)cur_points = cur_points[idx]cur_normal = cur_normal[idx]cur_label = cur_label[idx]if self.single:idx = np.isin(cur_label, [8])cur_points = cur_points[idx]cur_normal = cur_normal[idx]cur_label = cur_label[idx]points.append(cur_points)normal.append(cur_normal)label.append(cur_label)points = np.concatenate(points, axis=0)normal = np.concatenate(normal, axis=0)data = np.concatenate([points, normal], axis=-1).astype(np.float32)label = np.concatenate(label, axis=0)return data, labeldef compose(self, item, p_keep):tgt_cloud = self.data[item, ...]if self.split != 'train':np.random.seed(item)R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()else:tgt_cloud = flip_pc(tgt_cloud)R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()src_cloud = random_crop(copy.deepcopy(tgt_cloud), p_keep=p_keep[0])src_size = math.ceil(self.npts * p_keep[0])tgt_size = self.nptsif len(p_keep) > 1:tgt_cloud = random_crop(copy.deepcopy(tgt_cloud),p_keep=p_keep[1])tgt_size = math.ceil(self.npts * p_keep[1])src_cloud_points = transform(src_cloud[:, :3], R, t)src_cloud_normal = transform(src_cloud[:, 3:], R)src_cloud = np.concatenate([src_cloud_points, src_cloud_normal],axis=-1)src_cloud = random_select_points(src_cloud, m=src_size)tgt_cloud = random_select_points(tgt_cloud, m=tgt_size)if self.split == 'train' or self.noise:src_cloud[:, :3] = jitter_point_cloud(src_cloud[:, :3])tgt_cloud[:, :3] = jitter_point_cloud(tgt_cloud[:, :3])tgt_cloud, src_cloud = shuffle_pc(tgt_cloud), shuffle_pc(src_cloud)return src_cloud, tgt_cloud, R, tdef __getitem__(self, item):src_cloud, tgt_cloud, R, t = self.compose(item=item,p_keep=self.p_keep)if not self.normal:tgt_cloud, src_cloud = tgt_cloud[:, :3], src_cloud[:, :3]return tgt_cloud, src_cloud, R, tdef __len__(self):return len(self.data)

2.2 模型设计

模型设计如下:
在这里插入图片描述

2.3 loss设计

其主要包含Init_loss、Refine_loss和Ol_loss。
其中Init_loss是用于计算 预测点 云 0 预测点云_0 预测点0与目标点云的mse或mae loss,
Refine_loss用于计算 预测点 云 [ 1 : ] 预测点云_{[1:]} 预测点[1:]与目标点云的加权mae loss
Ol_loss用于计算两个输入点云输出的重叠分数,使两个点云对应点的重叠分数是一样的。
在这里插入图片描述

具体实现代码如上:


import math
import torch
import torch.nn as nn
from utils import square_distsdef Init_loss(gt_transformed_src, pred_transformed_src, loss_type='mae'):losses = {}num_iter = 1if loss_type == 'mse':criterion = nn.MSELoss(reduction='mean')for i in range(num_iter):losses['mse_{}'.format(i)] = criterion(pred_transformed_src[i],gt_transformed_src)elif loss_type == 'mae':criterion = nn.L1Loss(reduction='mean')for i in range(num_iter):losses['mae_{}'.format(i)] = criterion(pred_transformed_src[i],gt_transformed_src)else:raise NotImplementedErrortotal_losses = []for k in losses:total_losses.append(losses[k])losses = torch.sum(torch.stack(total_losses), dim=0)return lossesdef Refine_loss(gt_transformed_src, pred_transformed_src, weights=None, loss_type='mae'):losses = {}num_iter = len(pred_transformed_src)for i in range(num_iter):if weights is None:losses['mae_{}'.format(i)] = torch.mean(torch.abs(pred_transformed_src[i] - gt_transformed_src))else:losses['mae_{}'.format(i)] = torch.mean(torch.sum(weights * torch.mean(torch.abs(pred_transformed_src[i] -gt_transformed_src), dim=-1)/ (torch.sum(weights, dim=-1, keepdim=True) + 1e-8), dim=-1))total_losses = []for k in losses:total_losses.append(losses[k])losses = torch.sum(torch.stack(total_losses), dim=0)return lossesdef Ol_loss(x_ol, y_ol, dists):CELoss = nn.CrossEntropyLoss()x_ol_gt = (torch.min(dists, dim=-1)[0] < 0.05 * 0.05).long() # (B, N)y_ol_gt = (torch.min(dists, dim=1)[0] < 0.05 * 0.05).long() # (B, M)x_ol_loss = CELoss(x_ol, x_ol_gt)y_ol_loss = CELoss(y_ol, y_ol_gt)ol_loss = (x_ol_loss + y_ol_loss) / 2return ol_lossdef cal_loss(gt_transformed_src, pred_transformed_src, dists, x_ol, y_ol):losses = {}losses['init'] = Init_loss(gt_transformed_src,pred_transformed_src[0:1])if x_ol is not None:losses['ol'] = Ol_loss(x_ol, y_ol, dists)losses['refine'] = Refine_loss(gt_transformed_src,pred_transformed_src[1:],weights=None)alpha, beta, gamma = 1, 0.1, 1if x_ol is not None:losses['total'] = losses['init'] + beta * losses['ol'] + gamma * losses['refine']else:losses['total'] = losses['init'] + losses['refine']return losses

3、训练与预测

先进入src目录,并将modelnet40_ply_hdf5_2048.zip解压在src目录下
在这里插入图片描述

3.1 训练

训练命令及训练输出如下所示

python train.py --root modelnet40_ply_hdf5_2048/ --noise --unseen

python请添加图片描述
在训练过程中会在work_dirs\models\checkpoints目录下生成两个模型文件
在这里插入图片描述

3.2 验证

训练命令及训练输出如下所示

python eval.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --cuda --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

请添加图片描述

3.3 测试

测试训练数据的命令如下

python vis.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

具体配准效果如下所示,其中绿色点云为输入点云,红色点云为参考点云,蓝色点云为配准后的点云。可以看到蓝色点云基本与红色点云重合,可以确定其配准效果十分完好。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.4 处理自己的数据集

基于该项目训练并处理自己数据的教程后续会给出。

相关文章:

基于ROPNet项目训练modelnet40数据集进行3d点云的配置

项目地址&#xff1a; https://github.com/zhulf0804/ROPNet 在 MVP Registration Challenge (ICCV Workshop 2021)&#xff08;ICCV Workshop 2021&#xff09;中获得了第二名。项目可以在win10环境下运行。 论文地址&#xff1a; https://arxiv.org/abs/2107.02583 网络简介…...

力扣215. 数组中的第K个最大元素

堆排序 前言 面试中著名的 TopK 排序&#xff1b;常见的解法有冒泡排序、堆排序&#xff1b;更深入的思路可以参考&#xff1a;拜托&#xff0c;面试别再问我TopK了&#xff01;&#xff01;&#xff01;使用了堆排序的算法&#xff0c;关于堆可以参考&#xff1a;堆数据结构的…...

轻量封装WebGPU渲染系统示例<40>- 多层材质的Mask混合(源码)

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/MaskTextureEffect.ts 当前示例运行效果: 两层材质效果: 三层材质效果: 此示例基于此渲染系统实现&#xff0c;当前示例TypeScript源码如下&#xff1a; export c…...

程序员的实用网站导航与推荐

当你遇到问题时 Stack Overflow&#xff1a;订阅他们的每周新闻和任何你感兴趣的主题Google&#xff1a;全球最大搜索引擎必应&#xff1a;在你无法使用Google的时候CSDN&#xff1a;聊胜于无AI导航一号AI导航二号 新闻篇 OSCHINA&#xff1a;中文开源技术交流社区 针对初学…...

上午面了个腾讯拿 38K 出来的,让我见识到了基础的天花板

今年的校招基本已经进入大规模的开奖季了&#xff0c;很多小伙伴收获不错&#xff0c;拿到了心仪的 offer。 各大论坛和社区里也看见不少小伙伴慷慨地分享了常见的面试题和八股文&#xff0c;为此咱这里也统一做一次大整理和大归类&#xff0c;这也算是划重点了。 俗话说得好…...

【halcon】C# halcon 内存暴增

1 读取图片需要及时手动释放 一个6M的图片通过halcon进行加载&#xff0c;大约会消耗200M的内存&#xff0c;如果等待GC回收&#xff0c;而你又在不停的读取图片&#xff0c;你的内存占用&#xff0c;将在短时间内飙升。 2 halcon控件显示图片需要清空。 /// <summary>…...

LeetCode130. Surrounded Regions

文章目录 一、题目二、题解 一、题目 Given an m x n matrix board containing ‘X’ and ‘O’, capture all regions that are 4-directionally surrounded by ‘X’. A region is captured by flipping all O’s into X’s in that surrounded region. Example 1: Input…...

【实战教程】PHP如何轻松对接腾讯云COS,实现文件上传下载?

腾讯云提供了一系列丰富的云服务&#xff0c;其中包括对象存储&#xff08;Cloud Object Storage&#xff0c;简称COS&#xff09;&#xff0c;它是一种高可靠性、可扩展性强的云存储服务。本文将介绍如何使用PHP对接腾讯云COS存储服务&#xff0c;实现文件的上传和下载功能。 …...

pytorch学习10-网络模型的保存和加载

系列文章目录 pytorch学习1-数据加载以及Tensorboard可视化工具pytorch学习2-Transforms主要方法使用pytorch学习3-torchvisin和Dataloader的使用pytorch学习4-简易卷积实现pytorch学习5-最大池化层的使用pytorch学习6-非线性变换&#xff08;ReLU和sigmoid&#xff09;pytorc…...

SQL Server 2016(分离和附加数据库)

1、实验环境。 基于上一个实验《SQL Server&#xff08;创建数据库&#xff09;》 2、需求描述。 class数据库的数据文件和事务日志文件都位于C:\db_class目录下。现在需要把class数据库的数据文件和事务日志文件分开存放&#xff0c;数据文件class.mdf存放于原位置&#xff0…...

用友U8 Cloud RegisterServlet SQL注入漏洞复现

0x01 产品简介 用友U8 Cloud是用友推出的新一代云ERP,主要聚焦成长型、创新型企业,提供企业级云ERP整体解决方案。 0x02 漏洞概述 用友U8 Cloud RegisterServlet接口处存在SQL注入漏洞,未授权的攻击者可通过此漏洞获取数据库权限,从而盗取用户数据,造成用户信息泄露。 …...

coding创建远程分支。并拉取远程新分支+推送代码

进入coding ----项目----代码仓库---点击 下拉之后查看全部----创建分支 创建分支之后执行下面命令 git branch -a // 查看所有分支 这个时候发现自己创建的分支没有显示这是因为自己在远程创建了分支但是本地还没有分支 执行 git fetch命令 用于从远程仓库获取最新的提交…...

坚鹏:中国工商银行内蒙古分行数字化转型发展现状与成功案例培训

中国工商银行围绕“数字生态、数字资产、数字技术、数字基建、数字基因”五维布局&#xff0c;深入推进数字化转型&#xff0c;加快形成体系化、生态化实施路径&#xff0c;促进科技与业务加速融合&#xff0c;以“数字工行”建设推动“GBC”&#xff08;政务、企业、个人&…...

AIGC发展史

1 AIGC概况 1.1 AIGC定义 AIGC&#xff08;AI Generated Content&#xff09;是指利用人工智能技术生成的内容。它也被认为是继PGC,UGC之后的新型内容生产方式&#xff0c;AI绘画、AI写作等都属于AIGC的具体形式。2022年AIGC发展速度惊人&#xff0c;迭代速度更是呈现指数级发…...

面试题库之JAVA基础篇(二)

String 只读字符串。每次操作会隐式的在内存中new一个跟原字符串一样的StringBuilder对象&#xff0c;然后append号后面的字符串。 StringBuilder 可变字符串对象。线程不安全。 StringBuffer 可变字符串对象。线程安全。 数组 一种线性数据结构&#xff0c;使用连续的…...

[Rust] 可迭代类型, 迭代器, 如何正确的创建自定义可迭代类型

在 Rust 中, for 语句的执行依赖于类型对于 IntoIterator 的实现, 如果某类型实现了这个 trait, 那么它就可以直接使用 for 进行循环. 直接实现 在 Rust 中, 如果一个类型实现了 Iterator, 那么它会被同时实现 IntoIterator, 具体逻辑是返回自身, 因为自身就是迭代器. 但是如…...

MySQL中,text,mediumtext, 和 longtext字符类型

需求 由于项目需要&#xff0c;需要在mysql数据库&#xff0c;储存长文本&#xff0c;长文本格式可能为markdown也可能为html。 思路 测试存入html时&#xff0c;字符类型为varcar 255。很明显字符长度达不到要求。数据库抛错&#xff0c;修改字符类型 解决方案 将原本的字…...

网页开发 JS基础

目录 JS概述 基本语法 数据类型内置方法 DOM对象 查找标签 绑定事件 操作标签 jQuery 查找标签 绑定事件 操作标签 Ajax请求 数据接口 前后端分离 ajax的使用 JS概述 一门弱类型的编程语言,属于基于对象和基于原型的脚本语言. 1 直接编写<script>console…...

如何在财税行业查找批量客户?

现在市场上代记账公司也不算少&#xff0c;做过这行的都知道&#xff0c;最初呢行业竞争不强&#xff0c;都是靠地推、老客户转介绍&#xff0c;或者长期以往的蹲守各个地区的工商注册服务中心&#xff0c;找那些才注册企业的老板或者创业者。但是&#xff0c;随着市场经济的发…...

IntelliJ IDEA详细完整安装教程

IntelliJ IDEA 是一款强大的Java集成开发环境&#xff0c;以下是安装和使用教程&#xff1a; 1. 下载IntelliJ IDEA&#xff1a;访问JetBrains官网&#xff08;jetbrains.com&#xff09;&#xff0c;点击“Download”按钮&#xff0c;选择适合自己操作系统的版本进行下载。 2.…...

FFmpeg 低延迟同屏方案

引言 在实时互动需求激增的当下&#xff0c;无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作&#xff0c;还是游戏直播的画面实时传输&#xff0c;低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架&#xff0c;凭借其灵活的编解码、数据…...

线程同步:确保多线程程序的安全与高效!

全文目录&#xff1a; 开篇语前序前言第一部分&#xff1a;线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分&#xff1a;synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分&#xff…...

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…...

Linux离线(zip方式)安装docker

目录 基础信息操作系统信息docker信息 安装实例安装步骤示例 遇到的问题问题1&#xff1a;修改默认工作路径启动失败问题2 找不到对应组 基础信息 操作系统信息 OS版本&#xff1a;CentOS 7 64位 内核版本&#xff1a;3.10.0 相关命令&#xff1a; uname -rcat /etc/os-rele…...

20个超级好用的 CSS 动画库

分享 20 个最佳 CSS 动画库。 它们中的大多数将生成纯 CSS 代码&#xff0c;而不需要任何外部库。 1.Animate.css 一个开箱即用型的跨浏览器动画库&#xff0c;可供你在项目中使用。 2.Magic Animations CSS3 一组简单的动画&#xff0c;可以包含在你的网页或应用项目中。 3.An…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)

一、OpenBCI_GUI 项目概述 &#xff08;一&#xff09;项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台&#xff0c;其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言&#xff0c;首次接触 OpenBCI 设备时&#xff0c;往…...

django blank 与 null的区别

1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是&#xff0c;要注意以下几点&#xff1a; Django的表单验证与null无关&#xff1a;null参数控制的是数据库层面字段是否可以为NULL&#xff0c;而blank参数控制的是Django表单验证时字…...

Ubuntu Cursor升级成v1.0

0. 当前版本低 使用当前 Cursor v0.50时 GitHub Copilot Chat 打不开&#xff0c;快捷键也不好用&#xff0c;当看到 Cursor 升级后&#xff0c;还是蛮高兴的 1. 下载 Cursor 下载地址&#xff1a;https://www.cursor.com/cn/downloads 点击下载 Linux (x64) &#xff0c;…...

掌握 HTTP 请求:理解 cURL GET 语法

cURL 是一个强大的命令行工具&#xff0c;用于发送 HTTP 请求和与 Web 服务器交互。在 Web 开发和测试中&#xff0c;cURL 经常用于发送 GET 请求来获取服务器资源。本文将详细介绍 cURL GET 请求的语法和使用方法。 一、cURL 基本概念 cURL 是 "Client URL" 的缩写…...

深入浅出Diffusion模型:从原理到实践的全方位教程

I. 引言&#xff1a;生成式AI的黎明 – Diffusion模型是什么&#xff1f; 近年来&#xff0c;生成式人工智能&#xff08;Generative AI&#xff09;领域取得了爆炸性的进展&#xff0c;模型能够根据简单的文本提示创作出逼真的图像、连贯的文本&#xff0c;乃至更多令人惊叹的…...