pytorch tricks

pytorch tricks

社蕙 1463 2022-10-21

1. 存储方式

tensor的存储方式是行主序,这决定了reshape()的工作方式与matlab不同

tensor是有确定维数的,因此tensor(1)tensor([1])tensor([[1]])是不同的,不能像matlab一样用一维索引遍历。

>> a=1
a =
     1
>> a(1,1,1)
ans =
     1

可以使用unsqueeze()升维;上述三个向量都可以通过.item()方法取出唯一存储的标量(非唯一有方法.tolist

2. 构造与生成

2.1. 等距生成

对于matlab代码:

K>> 1:2:5
ans =
     1     3     5
K>> 1:2:6
ans =
     1     3     5

同样输入pytorch函数arange(),得到的东西完全不一样!因为python的最后一位是不取的。

>>> y=torch.arange(1,6,2)
>>> y
tensor([1, 3, 5])
>>> y=torch.arange(1,5,2)
>>> y
tensor([1, 3])

所以,一般的对于start:step:end,可以转换成torch.arange(start,end+1,step)

3. 索引、查找与赋值

matlab索引从1开始,正常的计算机语言从0开始,想必就不用多说了。

3.1. 查找元素索引

stackoverflow:How Pytorch Tensor get the index of specific value

不像matlab可以直接使用find(),pytorch没有办法做到查找元素返回索引,于是可以将tensor做逻辑运算编程boolTensor,在进行布尔运算后,通过nonzero()返回非零索引

>>> import torch
>>> xg = torch.linspace(0,1,20)
>>> xg
tensor([0.0000, 0.0526, 0.1053, 0.1579, 0.2105, 0.2632, 0.3158, 0.3684, 0.4211,
        0.4737, 0.5263, 0.5789, 0.6316, 0.6842, 0.7368, 0.7895, 0.8421, 0.8947,
        0.9474, 1.0000])
>>> xg == xg[3]
tensor([False, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
>>> (xg == xg[3]).nonzero()
tensor([[3]])
>>> (xg > 0.3)&(xg < 0.7)
tensor([False, False, False, False, False, False,  True,  True,  True,  True,
         True,  True,  True,  True, False, False, False, False, False, False])
>>> ((xg > 0.3)&(xg < 0.7)).nonzero()
tensor([[ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13]])

3.2. 多索引访问

# 一维的情况下:
>>> y=torch.arange(20)
>>> y[torch.tensor([6,7,8])]
tensor([6, 7, 8])

# 多维的情况下:
>>> y=y.reshape(4,5)
>>> y
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
>>> y[torch.tensor([2,3])]
tensor([[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
>>> y[(torch.tensor([2,3]),torch.tensor([0,1]))]
tensor([10, 16])
# 这就很变态了,tuple里每个tensor不是坐标,而是每个坐标的在该维度的分量
# 换句话说[2,3],[0,1]代表着坐标(2,0),(3,1)

3.3. 按索引赋值

matlab可以方便的在索引内用矩阵对分立的元素一次性赋值(这得益于matlab可以用一维索引访遍所有元素,以及和find()的配合),但pytorch不能搜索,就很恶心。

# 上一节的index
>>> index
tensor([[ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13]])
>>> index.squeeze()
tensor([ 6,  7,  8,  9, 10, 11, 12, 13])
>>> y=torch.arange(20)
# 直接传向量是不可行的
>>> y.index_put_(index.squeeze(),torch.ones(len(index_tuple)))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: index_put_(): argument 'indices' (position 1) must be tuple of Tensors, not Tensor
# 但是你他妈可以通过向量访问到元素
>>> y[torch.tensor([2,3])]=0
>>> y
tensor([ 0,  1,  0,  0,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
# 注意看这个逗号,折腾了很久这是关键,否则不会识别成tuple
>>> y.index_put_((index.squeeze(),),torch.ones(len(index_tuple),dtype=torch.long))
tensor([ 0,  1,  0,  0,  4,  5,  1,  1,  1,  1,  1,  1,  1,  1, 14, 15, 16, 17,
        18, 19])

3.4. 查找并替换

https://blog.csdn.net/qq_34372112/article/details/106219482

终于!可以摆脱判断后接index_put_()的逼操作了:

>>> a=torch.tensor([[1,1,4,5,1,4],[1,9,torch.nan,1,9,torch.nan]])
>>> a
tensor([[1., 1., 4., 5., 1., 4.],
        [1., 9., nan, 1., 9., nan]])
>>> torch.where(a.isnan(),torch.zeros_like(a),a)
tensor([[1., 1., 4., 5., 1., 4.],
        [1., 9., 0., 1., 9., 0.]])

4. 拼接、拓展与维度交换

4.1. 快速拼接

我们都知道matlab的[]运算符可以做到快速拼接矩阵,对于二维矩阵的拼接比较好理解,在更高维的情况下有:

K>> a = ones(2,3,4);
K>> b = ones(2,3,4);
K>> size([a b])% 实际上调用了horzcat
ans =
     2     6     4
K>> size([a,b])
ans =
     2     6     4
K>> size([a;b])% 实际上调用了vertcat
ans =
     4     3     4

则对于pytorch,cat或者concat函数在指定维度上连接(并增长尺寸)

  • horzcat(a,b) <=> cat((a,b),1)
  • vertcat(a,b) <=> cat((a,b),0)

4.2. 拼接并拓展维度

在matlab里面[a;b]还能做到将两个一维行向量接成两行矩阵,但是cat()做不到这一点,因为pytorch不认为存在一个新维度,但办法也是有的,stack()可以拓展再拼接:

>>> x = torch.randn(3, 4)
>>> torch.stack((x, x), 0).size()
torch.Size([2, 3, 4])
>>> torch.stack((x, x), 1).size()
torch.Size([3, 2, 4])
>>> torch.cat((x, x), 2).size()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

4.3. 维度的拓展与复制

在维度与拓展方向相等的情况下,pytorch的tile()与matlab的repmat()行为是一致的,但是在维度不同的情况下,它们的行为是不一致的,在matlab下面:

>> x=rand(1,10);
>> size(x)
ans =
     1    10
>> xn = repmat(x,[1,1,2]);
>> size(xn)
ans =
     1    10    2

对于pytorch:

>>> x=torch.rand(1,10)
>>> x.size()
torch.Size([1, 10])
>>> xn=x.tile(1,1,2)
>>> xn.size()
torch.Size([1, 1, 20])

这他妈就非常恶心了,matlab的repmat()在这方面更符合直觉,将输入矩阵(1,10)拓展为(1,10,1),迎合拓展尺寸;而pytorch将前面作为更高维,因此拓展为(1,1,10)。也就是说,新维度诞生的位置是不一样的。

在实操的时候一般遇到这个问题的时候是不同维数的矩阵相乘,在matlab中自动广播的机制做的比较好,但到了pytorch里面相对而言就比较拉跨,下面是实操中的例子:

# brdx refers to broadcast
# x_cor.shape = torch.Size([12, 9, 1])
# xn.shape = torch.Size([1, 1, 229])
# nnterm.shape = torch.Size([1, 1, 1, 42])

x_cor_brdx = (
    x_cor[i * seg_bound :, :, :]
    .unsqueeze(3) # 先手拓展到四维(在3位置处增加一维)
    .tile(1, 1, xn.size(2), nnterm.size(3)) # 适应其他维度的尺寸,即广播
)
xn_brdx = xn.unsqueeze(3).tile(
    x_cor_brdx.size(0), x_cor_brdx.size(1), 1, nnterm.size(3)
)
terms = (
    (-1) ** nnterm # nnterm的前三维不用广播也能顺利运算,但需要广播其他的张量的第四维
    * torch.sin((nnterm - 0.5) * pi * (x_cor_brdx - xn_brdx)) # 广播前三维就是为了这个减法
    * torch.exp(-((nnterm - 0.5) ** 2) * pi * pi * t)
).sum(3)
terms_out = torch.cat((terms_out, terms), 0)

4.4. 维度的交换

repmat()函数至今没有遇到使用问题,看来两者行为比较一致。唯一需要注意的是4.3. 维度的拓展与复制tile()的行为带来的维数打乱的问题。

4.5. 映射与变换

类似于map的,pytorch给了两个语焉不详的方法:TORCH.TENSOR.APPLY_TORCH.TENSOR.MAP_,看起来apply是一参数的,map是两参数的。注意他们是作用于自身的方法,如果是f:x->y的使用并且不想失去x,记得用clone()方法新开张量。

5. 模型与训练

5.1 模型的断点重训

参考PyTorch实现断点继续训练 - 知乎