当前位置: 首页 > news >正文

torch一些操作

Pytorch文档

  • Pytorch 官方文档

https://pytorch.org/docs/stable/index.html

  • pytorch 里的一些基础tensor操作讲的不错

https://blog.csdn.net/abc13526222160/category_8614343.html

  • 关于pytorch的Broadcast,合并与分割,数学运算,属性统计以及高阶操作

https://blog.csdn.net/abc13526222160/article/details/103520465

broadcast机制

对于涉及计算的两个tensor, 对于数量不匹配的dim, 可以进行自动重复拷贝,使得dim进行批评

Broadcast它能维度扩展和expand一样,它是自动扩展,并且不需要拷贝数据,能够节省内存。关键思想:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WMOpueEB-1692628397278)(attachment:image.png)]

import sys, osimport torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import pandas as pd
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt


1. Pytorch索引与切片以及维度变换

TODO: https://zhangkaifang.blog.csdn.net/article/details/103517970

index, view, reshape, squeeze, unsqueeze, transpose/t, permute , expand/repeat
transpose()函数表示矩阵的维度交换,接受的参数为要交换的哪两个维度。

permute: transpose()函数一次只能两两交换。【b, c, h, w】=> 【b, w, h, c】,比如原来一个人的图片,交换过后图片可能不是人了,我们还希望变成原来的样子,可以看成多维度交换,其中参数为新的维度顺序。同样的道理permute函数也会把内存的顺序给打乱,因此要是涉及contious这个错误的时候,需要额外添加.contiguous()函数,来把内存的顺序变得连续。

2.1 合并与分割 cat, stack, split, chunk repeat

  • stack一般叠加在新的维度, concatenate在已有的dim上进行扩展, vstack, dstack, hstack分别是在dim = [0, 1, 2]上进行叠加
a = torch.rand([5, 5, 3])print('-'*10 + 'stack' + '-'*10)
t1 = torch.stack([a, a], dim = 0)
print('pre', a.shape, 'post',t1.shape)t1 = torch.stack([a, a], dim = 1)
print('pre', a.shape, 'post',t1.shape)t1 = torch.stack([a, a], dim = 2)
print('pre', a.shape, 'post',t1.shape)t1 = torch.stack([a, a], dim = 3)
print('pre', a.shape, 'post',t1.shape)print('-'*10 + 'concatenate' + '-'*10)
t1 = torch.cat([a, a], dim = 0)
print('pre', a.shape, 'post',t1.shape)t1 = torch.cat([a, a], dim = 1)
print('pre', a.shape, 'post',t1.shape)t1 = torch.cat([a, a], dim = 2)
print('pre', a.shape, 'post',t1.shape)print('-'*10 + 'v-h-d stack' + '-'*10)t1 = torch.vstack([a, a])
print('pre', a.shape, 'post',t1.shape)t1 = torch.hstack([a, a])
print('pre', a.shape, 'post',t1.shape)t1 = torch.dstack([a, a])
print('pre', a.shape, 'post',t1.shape)
----------stack----------
pre torch.Size([5, 5, 3]) post torch.Size([2, 5, 5, 3])
pre torch.Size([5, 5, 3]) post torch.Size([5, 2, 5, 3])
pre torch.Size([5, 5, 3]) post torch.Size([5, 5, 2, 3])
pre torch.Size([5, 5, 3]) post torch.Size([5, 5, 3, 2])
----------concatenate----------
pre torch.Size([5, 5, 3]) post torch.Size([10, 5, 3])
pre torch.Size([5, 5, 3]) post torch.Size([5, 10, 3])
pre torch.Size([5, 5, 3]) post torch.Size([5, 5, 6])
----------v-h-d stack----------
pre torch.Size([5, 5, 3]) post torch.Size([10, 5, 3])
pre torch.Size([5, 5, 3]) post torch.Size([5, 10, 3])
pre torch.Size([5, 5, 3]) post torch.Size([5, 5, 6])
  • torch.split函数 按照长度进行切分

torch.split(tensor, split_size_or_sections, dim=0)

split_size_or_sections:需要切分的大小(int or list )
dim:切分维度

  • torch.chunk函数 按照数量进行切分
a = torch.rand([6, 256, 10, 10 ])
print('a', a.shape)t1 = torch.split(a, [3,3], dim = 0)
print('dim %d'%(0))
for ti in t1:print(ti.shape)print('dim %d'%(1))
t1 = torch.split(a, [128,128], dim = 1)
for ti in t1:print(ti.shape)print('dim %d'%(2))
t1 = torch.split(a, [5,5], dim = 2)
for ti in t1:print(ti.shape)
a torch.Size([6, 256, 10, 10])
dim 0
torch.Size([3, 256, 10, 10])
torch.Size([3, 256, 10, 10])
dim 1
torch.Size([6, 128, 10, 10])
torch.Size([6, 128, 10, 10])
dim 2
torch.Size([6, 256, 5, 10])
torch.Size([6, 256, 5, 10])
a = torch.rand([6, 256, 10, 10 ])
print('a', a.shape)t1 = torch.chunk(a, 2, dim = 0)
print('dim %d'%(0))
for ti in t1:print(ti.shape)print('dim %d'%(1))
t1 = torch.chunk(a, 2, dim = 1)
for ti in t1:print(ti.shape)print('dim %d'%(2))
t1 = torch.chunk(a, 2, dim = 2)
for ti in t1:print(ti.shape)
a torch.Size([6, 256, 10, 10])
dim 0
torch.Size([3, 256, 10, 10])
torch.Size([3, 256, 10, 10])
dim 1
torch.Size([6, 128, 10, 10])
torch.Size([6, 128, 10, 10])
dim 2
torch.Size([6, 256, 5, 10])
torch.Size([6, 256, 5, 10])
  • torch.repeat() & np.tile()

参数为沿着不同维度的扩展, 对于一个给定的tensor 给出其dims以及repeat的次数,类似于np.tile
由后往前进行匹配, 若超出当前的dims, 则新建一个dim, 进行repeat
当超出tensor 本身的 dims时,第一个dim 新起一个dim, 即沿着batch方向进行扩展

参考资料见
https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html?highlight=repeat#torch.Tensor.repeat

a = torch.tensor([1, 2, 3])
print(a.shape)
t1 = a.repeat(5, 2)
print('-'*10);print(t1.shape, a.shape); print(a,'\n', t1)t1 = a.repeat(2, 3, 2)
print('-'*10);print(t1.shape, a.shape); print(a,'\n', t1)a = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(a.shape)t1 = a.repeat([1, 1])
print('-'*10);print(t1.shape, a.shape); print(a,'\n', t1)t1 = a.repeat([2, 1]) # dim0 repeat 2
print('-'*10);print(t1.shape, a.shape); print(a,'\n', t1)t1 = a.repeat([1, 2]) # dim1 repeat 2
print('-'*10);print(t1.shape, a.shape); print(a,'\n', t1)t1 = a.repeat([2, 1, 1]) 
print('-'*10);print(t1.shape, a.shape); print(a,'\n', t1)
torch.Size([3])
----------
torch.Size([5, 6]) torch.Size([3])
tensor([1, 2, 3]) tensor([[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3]])
----------
torch.Size([2, 3, 6]) torch.Size([3])
tensor([1, 2, 3]) tensor([[[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3]],[[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3]]])
torch.Size([2, 3])
----------
torch.Size([2, 3]) torch.Size([2, 3])
tensor([[1, 2, 3],[4, 5, 6]]) tensor([[1, 2, 3],[4, 5, 6]])
----------
torch.Size([4, 3]) torch.Size([2, 3])
tensor([[1, 2, 3],[4, 5, 6]]) tensor([[1, 2, 3],[4, 5, 6],[1, 2, 3],[4, 5, 6]])
----------
torch.Size([2, 6]) torch.Size([2, 3])
tensor([[1, 2, 3],[4, 5, 6]]) tensor([[1, 2, 3, 1, 2, 3],[4, 5, 6, 4, 5, 6]])
----------
torch.Size([2, 2, 3]) torch.Size([2, 3])
tensor([[1, 2, 3],[4, 5, 6]]) tensor([[[1, 2, 3],[4, 5, 6]],[[1, 2, 3],[4, 5, 6]]])

2.2 torch生成grid函数

  • torch.arange 类似于 np.arange [left, right), not include right.

  • torch.linspace 包括 left, right, 且分割为若干个点

  • torch.meshgrid叠加后生成 x 和 y (opencv坐标系)下两个 h*w的tensor。为了组装为后期可以用的点,所以需要先翻转后使用。

xs = torch.linspace(-10, 10, 10)
ys = torch.linspace(-5, 5, 10)x, y = torch.meshgrid(xs, ys, indexing='ij')
z = x + yax = plt.axes(projection= '3d')
ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-m1idlV1z-1692628397282)(output_15_0.png)]
请添加图片描述

def coords_grid(batch, ht, wd, device, ifReverse = True):#ht,wd 55, 128coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))if ifReverse:#NOTE, coords 为 x, y (opencv 坐标系下), y, x (numpy 坐标系)coords = torch.stack(coords[::-1], dim=0).float() # coords shape [ 2, ht, wd]else:#NOTE coords 为 x, y (numpy 坐标系下)coords = torch.stack(coords, dim=0).float()# 通过[::-1]转变为 y, x(numpy 坐标系)#coords = torch.reshape(coords, [1, 2, ht, wd])coords = coords.repeat(batch, 1, 1, 1)return coords
DEVICE = 'cuda'
grid_opencv = coords_grid(2, 12, 5, DEVICE, ifReverse=True)
grid_numpy = coords_grid(2, 12, 5, DEVICE, ifReverse= False)print(grid_numpy.shape, grid_opencv.shape)def visualize_grid(grid):# sample 1, x & yplot_x = grid[0, 0, :, :].reshape([-1]).cpu().numpy()plot_y = grid[0, 1, :, :].reshape([-1]).cpu().numpy()return plot_x, plot_yplt.figure(figsize=(8,6))
for pici, grid in enumerate([grid_opencv, grid_numpy]):plt.subplot(2, 1, pici + 1)plt.title(['opencv', 'numpy'][pici])plot_x, plot_y = visualize_grid(grid)for i in range(len(plot_x)):plt.scatter([plot_x[i]], [plot_y[i]], c = 'grey', marker = 'o')plt.annotate(xy = (plot_x[i], plot_y[i]), s= str(i))plt.show()
/home/chwei/.local/lib/python3.6/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]torch.Size([2, 2, 12, 5]) torch.Size([2, 2, 12, 5])

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4XGp96GA-1692628397283)(output_17_2.png)]
请添加图片描述

3.1 数学运算

add/sub/mul/div加减乘除

a = torch.rand(3, 4)
b = torch.rand(4) # 这里基于pytorch 的broadcasting进行了dim 0的 broadcastc = a + bprint(a.shape, b.shape, c.shape)
#print(a+b)
#print(torch.add(a, b))print(torch.all(torch.eq(a-b, torch.sub(a, b))))
print(torch.all(torch.eq(a*b, torch.mul(a, b))))
print(torch.all(torch.eq(a/b, torch.div(a, b))))
torch.Size([3, 4]) torch.Size([4]) torch.Size([3, 4])
tensor(True)
tensor(True)
tensor(True)

3.2 torch.matmul函数

  • 基础case 1: 两个tensor 都是1 dim, If both tensors are 1-dimensional, the dot product (scalar) is returned.

  • 基础case 2: 两个tensor 都是2 dims, If both arguments are 2-dimensional, the matrix-matrix product is returned.

  • 基础case 3: 一个tensor 是1 dim, 另外一个 2dim. tensor with 1 dim进行扩展到2dim, 以适应 matrix-product
    If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
    If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.

  • 基础case 4: Batch Brocasted & batch matrix-product
    思路: 把 first dim (batch dim) 排除, 每个sample 进行 matrix-product or vector product
    若 有一个tensor 其维度不匹配, 沿着batch dim进行扩展broadcast之后计算

If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned.

If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.

If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after.

The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable).

For example, if input is a ( j × 1 × n × n ) ( j × 1 × n × n ) (j \times 1 \times n \times n)(j×1×n×n) (j×1×n×n)(j×1×n×n) tensor and other is a ( k × n × n ) ( k × n × n ) (k \times n \times n)(k×n×n) (k×n×n)(k×n×n) tensor, out will be a ( j × k × n × n ) ( j × k × n × n ) (j \times k \times n \times n)(j×k×n×n) (j×k×n×n)(j×k×n×n) tensor.

这里 j j j 是 batch dim, 而 matrix product 执行的位置在于$ (1 \times n \times n) \times (k \times n \times n) $

Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions.

For example, if input is a ( j × 1 × n × m ) ( j × 1 × n × m ) (j \times 1 \times n \times m)(j×1×n×m) (j×1×n×m)(j×1×n×m) tensor and other is a ( k × m × p ) ( k × m × p ) (k \times m \times p)(k×m×p) (k×m×p)(k×m×p) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different. out will be a ( j × k × n × p ) ( j × k × n × p ) (j \times k \times n \times p)(j×k×n×p) (j×k×n×p)(j×k×n×p) tensor.

# vector x vector case 1
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
out = torch.matmul(tensor1, tensor2)
print('input', tensor1.shape, tensor2.shape, 'output',out.shape)
# matrix x matrix case 2
tensor1 = torch.randn(3, 1)
tensor2 = torch.randn(1, 3)
out = torch.matmul(tensor1, tensor2)
print('input', tensor1.shape, tensor2.shape, 'output',out.shape)# matrix x vector case 3
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
out = torch.matmul(tensor1, tensor2)
print('input', tensor1.shape, tensor2.shape, 'output',out.shape)# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4) # first dim is batch dim
tensor2 = torch.randn(4)
out = torch.matmul(tensor1, tensor2)
print('input', tensor1.shape, tensor2.shape, 'output',out.shape)# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4) # first dim is batch dim
tensor2 = torch.randn(10, 4, 5) # first dim is batch dim
out = torch.matmul(tensor1, tensor2)
# matrix product (3 * 4 ) * (4 * 5) -> 3 * 5
print('input', tensor1.shape, tensor2.shape, 'output',out.shape) # batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4) # first dim is batch dim
tensor2 = torch.randn(4, 5)
out = torch.matmul(tensor1, tensor2)
# matrix product (3 * 4 ) * (4 * 5) -> 3 * 5
print('input', tensor1.shape, tensor2.shape, 'output',out.shape) 
input torch.Size([3]) torch.Size([3]) output torch.Size([])
input torch.Size([3, 1]) torch.Size([1, 3]) output torch.Size([3, 3])
input torch.Size([3, 4]) torch.Size([4]) output torch.Size([3])
input torch.Size([10, 3, 4]) torch.Size([4]) output torch.Size([10, 3])
input torch.Size([10, 3, 4]) torch.Size([10, 4, 5]) output torch.Size([10, 3, 5])
input torch.Size([10, 3, 4]) torch.Size([4, 5]) output torch.Size([10, 3, 5])

3.3 pow矩阵的次方以及sqrt/rsqrt/exp/log

a =torch.full([2, 2], 3) # 使用torch.full函数创建一个shape[2, 2],元素全部为3的张量
print(a.pow(2))
print(torch.pow(a, 2))
print(a**2)b = a**2
print(b.sqrt())
print(b.rsqrt())         # 平方根的导数print('=============================')
a = torch.exp(torch.ones(2, 2))
print(a)
print(torch.log(a))     # 默认以e为底,使用2为底或者其他的,自己设置.
tensor([[9, 9],[9, 9]])
tensor([[9, 9],[9, 9]])
tensor([[9, 9],[9, 9]])
tensor([[3., 3.],[3., 3.]])
tensor([[0.3333, 0.3333],[0.3333, 0.3333]])
=============================
tensor([[2.7183, 2.7183],[2.7183, 2.7183]])
tensor([[1., 1.],[1., 1.]])

3.4 round矩阵近似运算

.floor()向下取整,.ceil()向上取整,.trunc()截取整数,.frac截取小数。

a = torch.tensor(3.14)
# .floor()向下取整,.ceil()向上取整,.trunc()截取整数,.frac截取小数。
print(a.floor(), a.ceil(), a.trunc(), a.frac())print(a.round())
b = torch.tensor(3.5)
print(b.round())
tensor(3.) tensor(4.) tensor(3.) tensor(0.1400)
tensor(3.)
tensor(4.)

3.6. clamp(裁剪)用的多

主要用在梯度裁剪里面,梯度离散(不需要从网络层面解决,因为梯度非常小,接近0)和梯度爆炸(梯度非常大,100已经算是大的了)。因此在网络训练不稳定的时候,可以打印一下梯度的模看看,w.grad.norm(2)表示梯度的二范数(一般100,1000已经算是大的了,一般10以内算是合适的)。

a.clamp(min):表示tensor a中小于10的都赋值为10,表示最小值为10;
grad = torch.rand(2, 3)*15
print(grad)
print(grad.max(), grad.median(), grad.min())
print('============================================')
print(grad.clamp(10))     # 最小值限定为10,小于10的都变为10;print(grad.clamp(8, 15))
print(torch.clamp(grad, 8, 15))
tensor([[ 3.7078, 11.4988,  7.9875],[ 0.6747,  0.6269, 12.3761]])
tensor(12.3761) tensor(3.7078) tensor(0.6269)
============================================
tensor([[10.0000, 11.4988, 10.0000],[10.0000, 10.0000, 12.3761]])
tensor([[ 8.0000, 11.4988,  8.0000],[ 8.0000,  8.0000, 12.3761]])
tensor([[ 8.0000, 11.4988,  8.0000],[ 8.0000,  8.0000, 12.3761]])

4. 统计属性相关操作 TODO

4.1. norm范数,prod张量元素累乘(阶乘)

4.2. mean/sum/max/min/argmin/argmax

4.3. kthvalue()和topk()

这里: topk(3, dim=1)(最大的3个)返回结果如下图所示,如果把largest设置为False就是默认最小的几个。
这里: kthvalue(k,dim=1)表示第k小的(默认表示小的)。下面图中的一共10中可能,第8小就是表示第3大。

4.4. 比较运算符号>,>=,<,<=,!=,==

greater than表示大于等于。equal表示等于eq。

5. batch检索

torch.where, gather, index_select, masked_select, nonzero函数

TODO: https://www.cnblogs.com/liangjianli/p/13754817.html#3-gather%E5%87%BD%E6%95%B0

5.1 torch.where

高阶操作where和gather

5.2 torch.gather

TODO 依旧不是很理解,没有完全理解
gather 收集输入的特定维度指定位置的数值

注意,torch.gather里的index tensor必须是longtensor类型

  • 官方解释
    https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
a = torch.rand([4, 5])
indexes = torch.LongTensor([[4, 2],[3, 1],[2, 4],[1,2]])# 这个在q-learning中的batch actions中常见,注意一下
a_ = torch.gather(a, dim = 1, index = indexes)print(a.shape, a_.shape, '\n', a, '\n', a_)indexes = torch.LongTensor([[ 0, 2, 2], [1, 2, 3]])
a_ = torch.gather(a, dim = 0, index = indexes)
print(a.shape, a_.shape, '\n', a, '\n', a_)
torch.Size([4, 5]) torch.Size([4, 2]) tensor([[0.1330, 0.2177, 0.2177, 0.3287, 0.0809],[0.1803, 0.4438, 0.9827, 0.6353, 0.3548],[0.4980, 0.6792, 0.3885, 0.6338, 0.4985],[0.8367, 0.4682, 0.0805, 0.6161, 0.4861]]) tensor([[0.0809, 0.2177],[0.6353, 0.4438],[0.3885, 0.4985],[0.4682, 0.0805]])
torch.Size([4, 5]) torch.Size([2, 3]) tensor([[0.1330, 0.2177, 0.2177, 0.3287, 0.0809],[0.1803, 0.4438, 0.9827, 0.6353, 0.3548],[0.4980, 0.6792, 0.3885, 0.6338, 0.4985],[0.8367, 0.4682, 0.0805, 0.6161, 0.4861]]) tensor([[0.1330, 0.6792, 0.3885],[0.1803, 0.6792, 0.0805]])

10. torch vision下的一些操作

10.1 F.grid_sample

根据位置 对 batch 里的 feature flow 进行采样,可以理解为index, 但本函数允许index里为float类型(采用bilinear插值实现)
函数

torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)
iven an input and a flow-field grid, computes the output using input values and pixel locations from grid.

input: (N, C, Hin, Win)
grid: (N, Hout, Wout, 2) , 其中最后一维为opencv 坐标系下的x & y,

output: (N, C, Hout, Wout)

NOTE,输入的grid,需根据 input 的Hin, Win进行归一化到 -1 ~ 1,

grid specifies the sampling pixel locations normalized by the input spatial dimensions.

Therefore, it should have most values in the range of [-1, 1]. For example, values x = -1, y = -1 is the left-top pixel of input, and values x = 1, y = 1 is the right-bottom pixel of input.

参考资料:https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html?highlight=grid_sample#torch.nn.functional.grid_sample
以下通过图片展示如何使用:

DEVICE = 'cuda'def load_image(imfile):img = np.array(Image.open(imfile)).astype(np.uint8)img = torch.from_numpy(img).permute(2, 0, 1).float() # h, w, dim -> dim, h, wreturn img.to(DEVICE)def visualize(img):#input imgshape: 3, w, h# NOTE, dims convert, device convert, numpy convert, and dtype convertreturn img.permute(1, 2, 0).cpu().numpy().astype(np.uint8)def coords_grid(batch, ht, wd, device, ifReverse = True):#ht,wd 55, 128coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))if ifReverse:#NOTE, coords 为 x, y (opencv 坐标系下), y, x (numpy 坐标系)coords = torch.stack(coords[::-1], dim=0).float() # coords shape [ 2, ht, wd]else:#NOTE coords 为 x, y (numpy 坐标系下)coords = torch.stack(coords, dim=0).float()# 通过[::-1]转变为 y, x(numpy 坐标系)#coords = torch.reshape(coords, [1, 2, ht, wd])coords = coords.repeat(batch, 1, 1, 1)return coords
# 载入图片并显示
img1, img2 = load_image('demo-frames/frame_0016.png'), load_image('demo-frames/frame_0017.png')inputImg = torch.cat([img1.unsqueeze(0), img2.unsqueeze(0)], 0)
print(img1.shape, img2.shape, inputImg.shape)#------ visualize -----------
img1_, img2_ = visualize(img1), visualize(img2)plt.subplot(2, 1, 1)
plt.imshow(img1_)plt.subplot(2, 1, 2)
plt.imshow(img2_)
plt.show()
torch.Size([3, 436, 1024]) torch.Size([3, 436, 1024]) torch.Size([2, 3, 436, 1024])

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zA8xwrlK-1692628397284)(output_36_1.png)]

请添加图片描述

# 生成grid
hImg, wImg = img1.shape[-2:]grid =  coords_grid(batch = 2, ht = 100, wd = 100, device = 'cuda', ifReverse = True)
print(grid.shape) # dim 1: y & x(numpy 坐标系), x & y (opencv坐标系)
torch.Size([2, 2, 100, 100])
# grid归一化操作 
grid = grid.permute(0, 2, 3, 1)
ygrid, xgrid = grid.split([1, 1], dim = -1) # 这里的xgrid, ygrid指的是 numpy 坐标系
print(grid.shape, xgrid.shape, ygrid.shape)ygrid = 2 * ygrid/(wImg - 1) -1
xgrid = 2 * xgrid/(hImg - 1) -1grid = torch.cat([ygrid, xgrid], dim=-1)
grid = grid.to(DEVICE)
# 输入的grid 最后一维为 opencv坐标系下的x & yinputImg_ = F.grid_sample(inputImg, grid, align_corners=True) # size 7040, 1, 9,9
torch.Size([2, 100, 100, 2]) torch.Size([2, 100, 100, 1]) torch.Size([2, 100, 100, 1])
plt.subplot(1, 2, 1)
img_i = visualize(inputImg[0, :, :, :])
plt.imshow(img_i)
plt.subplot(1, 2, 2)img_o = visualize(inputImg_[0, :, :, :])
plt.imshow(img_o)
plt.show()print(inputImg.shape, img_i.shape, img_o.shape)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-l9dvPqBJ-1692628397285)(output_39_0.png)]
请添加图片描述

torch.Size([2, 3, 436, 1024]) (436, 1024, 3) (100, 100, 3)

上/下采样 方法1

F.interpolate

功能:利用插值方法,对输入的张量数组进行上\下采样操作,换句话说就是科学合理地改变数组的尺寸大小,尽量保持数据完整。

官方文档: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html?highlight=f%20interpolate#torch.nn.functional.interpolate

img1, img2 = load_image('demo-frames/frame_0016.png'), load_image('demo-frames/frame_0017.png')
inputImg = torch.cat([img1.unsqueeze(0), img2.unsqueeze(0)], 0)
print(img1.shape, img2.shape, inputImg.shape)
torch.Size([3, 436, 1024]) torch.Size([3, 436, 1024]) torch.Size([2, 3, 436, 1024])
new_size = (2* inputImg.shape[2], 2 * inputImg.shape[3])
mode = 'bilinear'
inputImg_ = F.interpolate(inputImg, size=new_size, mode=mode, align_corners=True)
plt.subplot(2, 1, 1)
img_i = visualize(inputImg[0, :, :, :])
plt.imshow(img_i)
plt.subplot(2, 1, 2)
img_o = visualize(inputImg_[0, :, :, :])
plt.imshow(img_o)
plt.show()print(inputImg.shape, img_i.shape, img_o.shape)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CNtrY6FN-1692628397286)(output_43_0.png)]

请添加图片描述

torch.Size([2, 3, 436, 1024]) (436, 1024, 3) (872, 2048, 3)








相关文章:

torch一些操作

Pytorch文档 Pytorch 官方文档 https://pytorch.org/docs/stable/index.html pytorch 里的一些基础tensor操作讲的不错 https://blog.csdn.net/abc13526222160/category_8614343.html 关于pytorch的Broadcast,合并与分割,数学运算,属性统计以及高阶操作 https://blog.csd…...

ICCV23 | Ada3D:利用动态推理挖掘3D感知任务中数据冗余性

​ 论文地址&#xff1a;https://arxiv.org/abs/2307.08209 项目主页&#xff1a;https://a-suozhang.xyz/ada3d.github.io/ 01. 背景与动因 3D检测(3D Detection)任务是自动驾驶任务中的重要任务。由于自动驾驶任务的安全性至关重要(safety-critic)&#xff0c;对感知算法的延…...

软件工程模型-架构师之路(四)

软件工程模型 敏捷开发&#xff1a; 个体和交互 胜过 过程和工具、可以工作的软件 胜过 面面俱到的文件、客户合作胜过合同谈判、响应变化 胜过 循序计划。&#xff08;适应需求变化&#xff0c;积极响应&#xff09; 敏捷开发与其他结构化方法区别特点&#xff1a;面向人的…...

ubuntu20.04共享文件夹—— /mnt/hgfs里没有共享文件夹

参考文章&#xff1a;https://blog.csdn.net/Edwinwzy/article/details/129580636 虚拟机启用共享文件夹后&#xff0c;/mnt/hgfs下面为空&#xff0c;使用 vmware-hgfsclient 查看设置的共享文件夹名字也是为空。 解决方法&#xff1a; 1. 重新安装vmware tools. 在菜单…...

Redis中的有序集合及其底层跳表

前言 本文着重介绍Redis中的有序集合的底层实现中的跳表 有序集合 Sorted Set Redis中的Sorted Set 是一个有序的无重复值的集合&#xff0c;他底层是使用压缩列表和跳表实现的&#xff0c;和Java中的HashMap底层数据结构&#xff08;1.8&#xff09;链表红黑树异曲同工之妙…...

js 小程序限流函数 return闭包函数执行不了

问题&#xff1a; 调用限流 &#xff0c;没走闭包的函数&#xff1a; checkBalanceReq&#xff08;&#xff09; loadsh.js // 限流 const throttle (fn, context, interval) > {console.log(">>>>cmm throttle", context, interval)let canRun…...

【数据结构】堆的初始化——如何初始化一个大根堆?

文章目录 源码是如何插入的&#xff1f;扩容向上调整实现大根堆代码&#xff1a; 源码是如何插入的&#xff1f; 扩容 在扩容的时候&#xff0c;如果容量小于64&#xff0c;那就2倍多2的扩容&#xff1b;如果大于64&#xff0c;那就1.5倍扩容。 还会进行溢出的判断&#xff0c…...

【韩顺平 零基础30天学会Java】程序流程控制(2days)

day1 程序流程控制&#xff1a;顺序控制、分支控制、循环控制 顺序控制&#xff1a;从上到下逐行地执行&#xff0c;中间没有任何判断和跳转。 Java中定义变量时要采用合法的前向引用。 分支控制if-else&#xff1a;单分支、双分支和多分支。 单分支 import java.util.Scann…...

从入门到精通Python隧道代理的使用与优化

哈喽&#xff0c;Python爬虫小伙伴们&#xff01;今天我们来聊聊如何从入门到精通地使用和优化Python隧道代理&#xff0c;让我们的爬虫程序更加稳定、高效&#xff01;今天我们将对使用和优化进行一个简单的梳理&#xff0c;并且会提供相应的代码示例。 1. 什么是隧道代理&…...

19万字智慧城市总体规划与设计方案WORD

导读&#xff1a;原文《19万字智慧城市总体规划与设计方案WORD》&#xff08;获取来源见文尾&#xff09;&#xff0c;本文精选其中精华及架构部分&#xff0c;逻辑清晰、内容完整&#xff0c;为快速形成售前方案提供参考。 感知基础设施 感知基础设施架构由感知范围、感知手…...

[赛博昆仑] 腾讯QQ_PC端,逻辑漏洞导致RCE漏洞

简介 !! 内容仅供学习,请不要进行非法网络活动,网络不是法外之地!! 赛博昆仑是国内一家较为知名的网络安全公司&#xff0c;该公司今日报告称 Windows 版腾讯 QQ 桌面客户端出现高危安全漏洞&#xff0c;据称“黑客利用难度极低、危害较大”&#xff0c;腾讯刚刚已经紧急发布…...

python Requests

Requests概述 官方文档&#xff1a;http://cn.python-requests.org/zh_CN/latest/,Requests是python的HTTP的库&#xff0c;我们可以安全的使用 Requests安装 pip install Requests -i https://pypi.tuna.tsinghua.edu.cn/simple Requests的使用 Respose的属性 属性说明url响…...

【深入解析:数据结构栈的魅力与应用】

本章重点 栈的概念及结构 栈的实现方式 数组实现栈接口 栈面试题目 概念选择题 一、栈的概念及结构 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端 称为栈顶&#xff0c;另一端称为栈底。栈中的数…...

安卓机显示屏的硬件结构

显示屏的硬件结构 显示屏的硬件结构主要由背光源、液晶面板和驱动电路构成。可以将液晶面板看成一个三明治的结构&#xff0c;即在两片偏振方向互相垂直的偏光片系统中夹着一层液晶层。自然光源通过起偏器&#xff08;偏光片之一&#xff09;后&#xff0c;变成了垂直方向的偏…...

基于swing的超市管理系统java仓库库存进销存jsp源代码mysql

本项目为前几天收费帮学妹做的一个项目&#xff0c;Java EE JSP项目&#xff0c;在工作环境中基本使用不到&#xff0c;但是很多学校把这个当作编程入门的项目来做&#xff0c;故分享出本项目供初学者参考。 一、项目描述 基于swing的超市管理系统 系统有3权限&#xff1a;管…...

常用系统命令

重定向 cat aa.txt > bbb.txt 将输出定向到bbb.txt cat aaa.txt >> bbb.txt 输出并追加查看进程 ps ps -ef 显示所有进程 例⼦&#xff1a;ps -ef | grep mysql |&#xff1a;管道符 kill pid 结束进程&#xff0c; 如 kill 3732&#xff1b;根据进程名结束进程可以先…...

【Spring专题】Spring之Bean生命周期源码解析——阶段四(Bean销毁)(拓展,了解就好)

目录 前言阅读建议 课程内容一、Bean什么时候销毁二、实现自定义的Bean销毁逻辑2.1 实现DisposableBean或者AutoCloseable接口2.2 使用PreDestroy注解2.3 其他方式&#xff08;手动指定销毁方法名字&#xff09; 三、注册销毁Bean过程及方法详解3.1 AbstractBeanFactory#requir…...

配置Docker,漏洞复现

目录 配置Docker 漏洞复现 配置Docker Docker的配置在Linux系统中相对简单&#xff0c;以下是详细步骤&#xff1a; 1.安装Docker&#xff1a;打开终端&#xff0c;运行以下命令以安装Docker。 sudo apt update sudo apt install docker.io 2.启动Docker服务&#xff1a;运…...

微信小程序 游戏水平评估系统的设计与实现_pzbe0

近年来&#xff0c;随着互联网的蓬勃发展&#xff0c;游戏公司对信息的管理提出了更高的要求。传统的管理方式已无法满足现代人们的需求。为了迎合时代需求&#xff0c;优化管理效率&#xff0c;各种各样的管理系统应运而生&#xff0c;随着各行业的不断发展&#xff0c;使命召…...

moba登录不进去提示修改问题问题解决方式

问题&#xff1a; 安装moba后&#xff0c;运行时运行不起来&#xff0c;提示输入密码&#xff0c;安装、卸载多个版本都不行 方法&#xff1a; 使用ResetMasterPassword工具进行重置主密码 官网下载地址&#xff1a; MobaXterm Xserver and tabbed SSH client - resetmaster…...

Unsafe upfileupload

文章目录 client checkMIME Typegetimagesize 文件上传功能在web应用系统很常见&#xff0c;比如很多网站注册的时候需要上传头像、上传附件等等。当用户点击上传按钮后&#xff0c;后台会对上传的文件进行判断 比如是否是指定的类型、后缀名、大小等等&#xff0c;然后将其按…...

机器人制作开源方案 | 滑板助力器

我们可以用一块废滑板做些什么呢&#xff1f; 如今&#xff0c;越来越多的人选择电动滑板作为代步工具或娱乐方式&#xff0c;市场上也涌现出越来越多的电动滑板产品。 &#xff08;图片来源&#xff1a;Backfire Zealot X Belt Drive Electric Skateboard– Backfire Board…...

飞机打方块(二)游戏界面制作

一、背景 1.新建bg节点 二、飞机节点功能实现 1.移动 1.新建plane节点 2.新建脚本GameController.ts,并绑定Canvas GameControll.ts const { ccclass, property } cc._decorator;ccclass export default class NewClass extends cc.Component {property(cc.Node)canvas:…...

自我理解:精度(precision)和召回(recall)

1、精度(precision) 精度是用于评估分类模型的一个重要指标。它反映了模型预测为正例的样本中&#xff0c;实际真正为正例样本的比例。 【注】正例样本指在二分类问题中&#xff0c;被标注为正类的样本。 例如&#xff1a;在垃圾邮件分类任务中&#xff0c;正例样本就是真实的…...

Nginx 使用 HTTPS(准备证书和私钥)

文章目录 Nginx生成自签名证书和配置Nginx HTTPS&#xff08;准备证书和私钥&#xff09;准备证书和私钥 Nginx生成自签名证书和配置Nginx HTTPS&#xff08;准备证书和私钥&#xff09; 准备证书和私钥 生成私钥 openssl genrsa -des3 -out server.key 2048这会生成一个加密…...

Java:集合框架:Set集合、LinkedSet集合、TreeSet集合、哈希值、HashSet的底层原理

Set集合 创建一个Set集合对象&#xff0c;因为Set是一个接口不能直接new一个对象&#xff0c;所以要用一个实现类来接 HashSet来接 无序性只有一次&#xff0c;只要第一次运行出来后&#xff0c;之后再运行的顺序还是第一次的顺序。 用LinkedSet来接 有序 不重复 无索引 用Tree…...

自定义Taro的navBar的宽度和高度

本方法是计算自定义navbar的宽度和高度&#xff0c;输出的参数有 navBarHeight, menuBottom,menuHeight, menuRectWidth,windowWidth, windowHeight,具体代码如下&#xff1a; export function getCustomNavBarRect():| {navBarHeight: number;menuBottom: number;menuHeight:…...

用Python编程实现百度自然语言处理接口的对接,助力你开发智能化处理程序

用Python编程实现百度自然语言处理接口的对接&#xff0c;助力你开发智能化处理程序 随着人工智能的不断进步&#xff0c;自然语言处理&#xff08;Natural Language Processing&#xff0c;NLP&#xff09;成为了解决文本处理问题的重要工具。百度自然语言处理接口提供了一系…...

系统架构设计专业技能 · 系统工程与系统性能

系列文章目录 系统架构设计专业技能 网络技术&#xff08;三&#xff09; 系统架构设计专业技能 系统安全分析与设计&#xff08;四&#xff09;【系统架构设计师】 系统架构设计高级技能 软件架构设计&#xff08;一&#xff09;【系统架构设计师】 系统架构设计高级技能 …...

初识网络原理(笔记)

目录 ​编辑局域网 网络通信基础 IP 地址 端口号 协议 协议分层 TCP / IP 五层网络模型 网络数据传输的基本流程 发送方的情况&#xff1a; 接收方的情况 局域网 搭建网络的时候&#xff0c;需要用到 交换机 和 路由器 路由器上&#xff0c;有 lan 口 和 wan 口 虽…...

嵌入式C语言基本操作方法之经典

C语言一经出现就以其功能丰富、表达能力强、灵活方便、应用面广等特点迅速在全世界普及和推广。 C语言不但执行效率高而且可移植性好&#xff0c;可以用来开发应用软件、驱动、操作系统等。 C语言也是其它众多高级语言的鼻祖语言&#xff0c;所以说学习C语言是进入编程世界的必…...

postgresql \watch实用的使用方法

文章目录 1.介绍2.语法3.实用的使用方法3.1 慢sql监控3.2 长wait事件3.3 日志输出量3.3结合pg_stat_database使用3.4 结合pg_stat_bgwriter使用3.5 其他 1.介绍 \watch Postgres 9.3 版带来的一个有用的命令&#xff0c;与linux watch指令类似&#xff0c;可以帮我们在指定间隔…...

Cocos2d 项目问题记录

环境搭建 正常运行 Android 端的 Cocos2d 项目&#xff0c;本机至少需要 Android SDK、NDK 环境、Android Studio 项目报错总结 CMake Error: CMake was unable to find a build program corresponding to "Ninja" 默认创建工程的 gradle.tools 版本为 3.1.0&…...

系统架构合理性的思考 | 京东云技术团队

最近牵头在梳理部门的系统架构合理性&#xff0c;开始工作之前&#xff0c;我首先想到的是如何定义架构合理性&#xff1f; 从研发的角度来看如果系统上下文清晰、应用架构设计简单、应用拆分合理应该称之为架构合理。 基于以上的定义可以从以下三个方面来梳理评估&#xff1…...

Amelia预订插件:WordPress企业级预约系统

并非所有WordPress预订插件都像他们所设计的那样。其中一些缺乏运行高效预约操作所需的功能&#xff0c;而其他一些则看起来陈旧过时。您不需要其中任何一个&#xff0c;但Amelia预订插件似乎希望确保所有用户都对功能和风格感到满意。 在这篇Amelia企业级预约系统插件评测中&…...

共享门店模式:线下门店的商家如何利用它增加客户

随着数字化时代的到来&#xff0c;商业模式正在不断创新与演变&#xff0c;而共享经济正成为引领这一变革的重要力量。在这个大背景下&#xff0c;共享门店模式作为共享经济的一种体现&#xff0c;正在逐渐走进人们的生活&#xff0c;并为商家和消费者带来了新的商机和体验。 共…...

实现矩阵地图与rviz地图重合

文章目录 一、rviz地图转换矩形地图(只能用于全局规划)二、在rviz上显示地图边界信息,可视化调整,实现重合(只能用于局部规划)一、rviz地图转换矩形地图(只能用于全局规划) 此方法矩形地图可能会与rviz地图不重合,通过改变偏移量x_offset,y_offset接近地图 可以将矩…...

设计模式十九:备忘录模式(Memento Pattern)

备忘录模式是一种行为型设计模式&#xff0c;它允许对象在不暴露其内部状态的情况下捕获和恢复其状态。该模式的主要目标是在不破坏封装性的前提下&#xff0c;实现对象状态的备份和恢复。备忘录模式常用于需要保存对象历史状态、撤销操作或者实现快照功能的情况。 备忘录模式…...

【题解】二叉搜索树与双向链表

二叉搜索树与双向链表 题目链接&#xff1a;二叉搜索树与双向链表 解题思路1&#xff1a;递归中序遍历 首先题目最后要求的是一个的递增的双向链表&#xff0c;而二叉搜索树也是一类非常有特色的树&#xff0c;它的根节点大于所有左侧的节点&#xff0c;同时又小于所有右侧的…...

【真实案例】解决后端接口调用偶尔超时问题

文章目录 背景分析代码分析二次日志分析排查Gateway服务解决解决办法1:添加重试机制解决办法2:优化网关内存分配解决办法3:调整OOM策略背景 项目从虚拟机迁移到k8s云原生平台(RainBond)后,发现偶尔会出现接口调用超时的问题。 统计了一下从上线到现在近一个月的调用失败…...

操作符详解(1)

1. 操作符分类&#xff1a; 算术操作符 移位操作符 位操作符 赋值操作符 单目操作符 关系操作符 逻辑操作符 条件操作符 逗号表达式 下标引用、函数调用和结构成员 2. 算术操作符 - * / % 1. 除了 % 操作符之外&#xff0c;其他的几个操作符可以作用于整数和浮点数。 2. 对…...

<指针进阶>指针数组和数组指针傻傻分不清?

✨Blog&#xff1a;&#x1f970;不会敲代码的小张:)&#x1f970; &#x1f251;推荐专栏&#xff1a;C语言&#x1f92a;、Cpp&#x1f636;‍&#x1f32b;️、数据结构初阶&#x1f480; &#x1f4bd;座右铭&#xff1a;“記住&#xff0c;每一天都是一個新的開始&#x1…...

无代码集成飞书连接更多应用

场景描述&#xff1a; 基于飞书开放平台能力&#xff0c;无代码集成飞书连接更多应用&#xff0c;打通数据孤岛。通过Aboter可轻松搭建业务自动化流程&#xff0c;实现多个应用之间的数据连接。 支持包括飞书事件监听和接口调用的能力&#xff1a; 事件监听&#xff1a; 用…...

三分钟解决AE缓存预览渲染错误、暂停、卡顿问题

一、清除RAM缓存&#xff08;内存&#xff09; 你应该做的第一件事是清除你的RAM。这将清除当前存储在内存中的所有临时缓存文件。要执行此操作&#xff0c;请导航到编辑>清除>所有内存。这将从头开始重置RAM缓存 二、清空磁盘缓存 您也可以尝试清空磁盘缓存。执行此操作…...

朴实无华的数据增强然后训练一下应用在电网异物检测领域,好像有自己的数据集就能发文了

RCNN-based foreign object detection for securing power transmission lines (RCNN4SPTL) Abstract 本文提出了一种新的深度学习网络——RCNN4SPTL (RCNN -based Foreign Object Detection for Securing Power Transmission lines)&#xff0c;该网络适用于检测输电线路上的…...

【使用教程】在Ubuntu下运行CANopen通信PMM伺服电机使用教程(NimServoSDK_V2.0.0)

本教程将指导您在Ubuntu操作系统下使用NimServoSDK_V2.0.0来运行CANopen通信的PMM系列一体化伺服电机。我们将介绍必要的步骤和命令&#xff0c;以确保您能够成功地配置和控制PMM系列一体化伺服电机。 NimServoSDK_V2.0.0是一款用于PMM一体化伺服电机的软件开发工具包。它提供了…...

vue3+ts+vite项目页面初始化loading加载效果

简介 一分钟实现 vue-pure-admin 同款项目加载时的 loading 效果 一、先看效果 1.1 静态效果 1.2 动态效果 二、上代码 核心代码在body里面&#xff0c;代码中已标明。找到你项目的 index.html &#xff0c;复制粘贴进去即可 <!DOCTYPE html> <html lang"en…...

ElasticSearch 数据聚合、自动补全(自定义分词器)、数据同步

文章目录 数据聚合一、聚合的种类二、DSL实现聚合1、Bucket&#xff08;桶&#xff09;聚合2、Metrics&#xff08;度量&#xff09;聚合 三、RestAPI实现聚合 自动补全一、拼音分词器二、自定义分词器三、自动补全查询四、实现搜索款自动补全&#xff08;例酒店信息&#xff0…...

神经网络基础-神经网络补充概念-18-多个样本的向量化

概念 多个样本的向量化通常涉及将一组样本数据组织成矩阵形式&#xff0c;其中每一行代表一个样本&#xff0c;每一列代表样本的特征。这种向量化可以使你更有效地处理和操作多个样本&#xff0c;特别是在机器学习和数据分析中。 代码实现 import numpy as np# 多个样本的数…...

*看门狗1

//while部分是我们在项目中具体需要写的代码&#xff0c;这部分的程序可以用独立看门狗来监控 //如果我们知道这部分代码的执行时间&#xff0c;比如是500ms&#xff0c;那么我们可以设置独立看门狗的 //溢出时间是600ms&#xff0c;比500ms多一点&#xff0c;如果要被监控的程…...