Pytorch(2):关于torch.cat()与torch.stack()

用法:

torch.cat(): 用于连接两个相同大小的张量

torch.stack(): 用于连接两个相同大小的张量,并扩展维度

实例:

import torch
a = torch.tensor(torch.arange(10)).reshape(3, 3)
b = torch.tensor(torch.arange(10, 100, 10)).reshape(3, 3)

print(a)
Out[7]: 
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

print(b)
Out[10]: 
tensor([[10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]])

torch.cat()

使用不同的参数,输出的结果不同,首先填入一个会返回错误的参数:从返回报错原因可以看到,参数的返回必须是在[-2, 1]之间。

d3 = torch.cat((a, b), dim=2)

# 返回输出如下
Traceback (most recent call last):
  File "/home/franklinpan/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "", line 1, in 
    d3 = torch.cat((a, b), dim=2)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

设置dim=-1,得到如下结果,当参数为-1时,与dim=1的返回结果相同

d_1= torch.cat((a, b), dim=-1)

print(d_1)
Out[25]: 
tensor([[ 1,  2,  3, 10, 20, 30],
        [ 4,  5,  6, 40, 50, 60],
        [ 7,  8,  9, 70, 80, 90]])

d1 = torch.cat((a, b), dim=1)
print(d1)
Out[22]: 
tensor([[ 1,  2,  3, 10, 20, 30],
        [ 4,  5,  6, 40, 50, 60],
        [ 7,  8,  9, 70, 80, 90]])

设置dim=-2,得到如下结果,与dim=0结果相同

d_2= torch.cat((a, b), dim=-2)
print(d_2)
Out[27]: 
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]])

d1 = torch.cat((a, b), dim=0)
print(d1)
Out[20]: 
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]])

可以看到,采用不同的参数,输出的张量维度仍然与原来张量的维度保持一致。

若输入参数的维度不一样,会产生什么结果呢?

当输出张量保持一个维度一致时,若在相同维度的方向进行连接torch.cat操作,则仍然可以张量的合并操作,若在维度不同的方向进行连接操作,会报错。(torch.cat操作没有广播机制)

b2 = torch.tensor(torch.arange(10, 70, 10)).reshape(2, 3)
d2 = torch.cat((a, b2))
Out[32]: 
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 20, 30],
        [40, 50, 60]])

d3 = torch.cat((a, b2), dim=1)
Traceback (most recent call last):
  File "/home/franklinpan/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "", line 1, in 
    d3 = torch.cat((a, b2), dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 3 but got size 2 for tensor number 1 in the list.

torch.stack()

先使用超出范围的参数,得到反馈结果如下:参数范围应该在[-3, 2]之间(此处的参数范围跟输入张量的维度有关

c4 = torch.stack((a, b), dim=3)
Traceback (most recent call last):
  File "/home/franklinpan/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "", line 1, in 
    c4 = torch.stack((a, b), dim=3)
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

此处不再重复dim=-3 or -2等操作,当dim=0时

c1 = torch.stack((a, b), dim=0)

Out[12]: 
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],
        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])

 当dim=1时

c2 = torch.stack((a, b), dim=1)

Out[15]: 
tensor([[[ 1,  2,  3],
         [10, 20, 30]],
        [[ 4,  5,  6],
         [40, 50, 60]],
        [[ 7,  8,  9],
         [70, 80, 90]]])

当 dim=2时

c3 = torch.stack((a, b), dim=2)

Out[17]: 
tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],
        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]],
        [[ 7, 70],
         [ 8, 80],
         [ 9, 90]]])

若在torch.stack中使用不同维度的输入,得到反馈如下:

c5 = torch.stack((a, b2))
Traceback (most recent call last):
  File "/home/franklinpan/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "", line 1, in 
    c5 = torch.stack((a, b2))
RuntimeError: stack expects each tensor to be equal size, but got [3, 3] at entry 0 and [2, 3] at entry 1

从实例可见,torch.stack操作将会增加合并后张量的维度。

各张量维度如下所示:

Pytorch(2):关于torch.cat()与torch.stack()_第1张图片

 总结:

torch.cat()与torch.stack()操作都是对张量进行拼接操作,不同点如下:

torch.stack()将对张量维度进行扩张

torch.cat()可以对只有一个方向维度相同的张量进行合并,而torch.stack()要求输入张量的维度必须一样。

你可能感兴趣的