pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

文章目录

*
前言
1. reshape()
2. view()

+ ① 1 阶变高阶
+
* 1 阶变 2 阶
* 1 阶变 3 阶
* 1 阶变 4 阶
* 1 阶变 m 阶
+ ② 2 阶变 m 阶
+ ③ 3 阶变 m 阶
+ ④ 4 阶变 m 阶
3. transpose()

+ ② 2 阶张量
+ ③ 3 阶张量
+ ④ 4 阶张量
4. permute()
结语

前言

view() 函数是进行张量维度重构的函数,permute() 和 transpose() 是进行张量维度转换的函数,高阶张量由若干低阶张量构成,如结构为 (n, c, h, w)的 4 阶张量由 n 个结构为 (c, h, w) 的 3 阶张量构成,结构为 (c, h, w)的 3 阶张量由 c 个结构为 (h, w) 的 2 阶张量构成,结构为 (h, w)的 2 阶张量又由 h 个长度为 w 的 1 阶张量构成,h 为行数,w 为列数。

1. reshape()

reshape() 函数与 view() 函数都是进行维度重组的函数,使用方法类似,区别在于 view() 函数只能对张量进行操作,而 reshape() 函数既可以对张量进行操作,还可以对 numpy 数组进行操作,代码示例如下,具体原理见 view() 函数。

x = np.array([1, 2, 3, 4, 5, 6])
y = torch.Tensor([1, 2, 3, 4, 5, 6])
print(x.reshape(2, 3))
print(y.reshape(2, 3))

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

2. view()

① 1 阶变高阶

1 阶变 2 阶

对于一个 1 阶张量 x,进行 view(h, w) 操作就是按照索引先后顺序每次从 x 中取出 w 个元素作为作为一行数据,共取 h 次,构成一个 (h, w) 结构的 2 阶张量,具体见示例。

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8])
print(x.view(4, 2))

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

1 阶变 3 阶

对于一个 1 阶张量 x,进行 view(c, h, w) 操作就是按照索引先后顺序每次从 x 中取出 hw 个元素,对这 hw 个元素按照 1 阶张量转 2 阶数张量的方法转为一个 (h, w) 结构的 2 阶张量,共取 c 次,构成一个 (c, h, w) 结构的 3 阶张量,具体见示例。

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
print(x.view(3, 2, 2))

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

1 阶变 4 阶

对于一个 1 阶张量 x,进行 view(n, c, h, w) 操作就是按照索引先后顺序每次从 x 中取出 chw 个元素,对这 chw 个元素按照 1 阶张量转 3 阶张量的方法转为一个 (c, h, w) 结构的 3 阶张量,共取 n 次,最终构成一个 (n, c, h, w) 结构的 4 阶张量,具体见示例。


x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                  13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
print(x.view(2, 2, 2, 3))

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

1 阶变 m 阶

对于一个 1 阶张量 x,进行 view(i n i_n i n ​, i n − 1 i_{n-1}i n −1 ​, ···, i 2 i_2 i 2 ​, i 1 i_1 i 1 ​) 操作就是按照索引先后顺序每次从 x 中取出 i n − 1 i_{n-1}i n −1 ​i n − 2 i_{n-2}i n −2 ​···i 2 i_2 i 2 ​i 1 i_1 i 1 ​ 个元素,对这 i n − 1 i_{n-1}i n −1 ​i n − 2 i_{n-2}i n −2 ​···i 2 i_2 i 2 ​i 1 i_1 i 1 ​ 个元素按照 1 阶张量转 m-1 阶张量的方法转为一个 (i n − 1 i_{n-1}i n −1 ​, ···, i 2 i_2 i 2 ​, i 1 i_1 i 1 ​) 结构的 m-1 阶张量,共取 m 次,最终构成一个 (i n i_n i n ​, i n − 1 i_{n-1}i n −1 ​, ···, i 2 i_2 i 2 ​, i 1 i_1 i 1 ​) 结构的 m 阶张量,其中 i n i_n i n ​ 代表张量第 n 个索引的值。

② 2 阶变 m 阶

对于一个 2 阶张量 x,结构为 (h, w),要变成一个 m 阶的新张量,首先将该 2 阶张量 按行展开成一个大小为 h*w 的 1 阶张量,再按照 1 阶变 m 阶的方法变为一个 m 阶张量,按行展开就是在 w 索引方向上进行拼接,2 阶张量变 3 阶张量的代码示例见下,用一个 1 阶张量来验证分析。

x = torch.Tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9],
                  [10, 11, 12]])
y = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
print(x.view(2, 2, 3))
print(y.view(2, 2, 3))

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

③ 3 阶变 m 阶

对于一个 3 阶张量 x,结构为 (c, h, w),要变成一个 m 阶的新张量,首先将该 3 阶张量 按行拼接得到一个结构为 (c*h, w) 的 2 阶张量,再按照 2 阶变 1 阶的方法转变为一个 1 阶张量,按行拼接就是在 h 索引方向上进行拼接,示例见图 1.1 和图 1.2。

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数
3 阶张量变 4 阶张量的代码示例见下, 用一个拼接后得到的 2 阶张量来验证前述分析。
x = torch.Tensor([[[1, 2, 3],
                   [4, 5, 6]],

                  [[7, 8, 9],
                   [10, 11, 12]],

                  [[13, 14, 15],
                   [16, 17, 18]],

                  [[19, 20, 21],
                   [22, 23, 24]]])
y = torch.Tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9],
                  [10, 11, 12],
                  [13, 14, 15],
                  [16, 17, 18],
                  [19, 20, 21],
                  [22, 23, 24]])
print((y.view(2, 2, 2, 3)).equal(x.view(2, 2, 2, 3)))
print(x.view(2, 2, 2, 3))

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

④ 4 阶变 m 阶

对于一个 4 阶张量 x,结构为 (n, c, h, w),要变成一个 m 阶的新张量,首先将该 m 阶张量在 c 索引方向进行拼接得到一个结构为 (n*c, h, w) 的 3 阶张量,再按照 3 阶张量变 1 阶张量的方法转变为一个 1 阶张量,最后再按照 1阶变 m 阶的方法得到 m 阶张量,4 阶张量在 c 索引方向进行拼接的示意图如图 2.1和图 2.2 所示。

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数
4 阶张量变 2 阶张量的代码示例见下,用一个拼接后的 3 阶张量验证前述分析。
x = torch.Tensor([[[[1, 2, 3],
                    [4, 5, 6]],

                   [[7, 8, 9],
                    [10, 11, 12]]],

                  [[[13, 14, 15],
                    [16, 17, 18]],

                   [[19, 20, 21],
                    [22, 23, 24]]]])
y = torch.Tensor([[[1, 2, 3],
                   [4, 5, 6]],

                   [[7, 8, 9],
                    [10, 11, 12]],

                   [[13, 14, 15],
                    [16, 17, 18]],

                   [[19, 20, 21],
                    [22, 23, 24]]])

print(f'x.size() = {x.size()}')
print(x.view(4, 6))
print(f'y.size() = {y.size()}')
print(y.view(4, 6))

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

3. transpose()

transpose() 函数一次进行两个维度的交换,参数是 0, 1, 2, 3, … ,随着待转换张量的阶数上升参数越来越多。

② 2 阶张量

对于一个 2 阶张量,结构为 (h, w),对应 transpose() 函数中的参数是 (0, 1) 两个索引,进行 transpose(0, 1) 操作就是在交换 h, w 两个维度,得到的结果与常见的矩阵转置相同,具体代码示例见下。

x = torch.Tensor([[1, 2],
                  [3, 4],
                  [5, 6]])
print(f'x.size() = {x.size()}')
y = x.transpose(0, 1)

print(f'y.size() = {y.size()}')
print(y)

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

③ 3 阶张量

对于一个 3 阶张量,结构为 (c, h, w),对应 transpose() 函数中的参数是 (0, 1, 2) 3 个索引,进行 transpose(0, 1) 操作就是在交换 c, h 两个维度,交换 c, h 两个维度的示意图见图 3.1 和图 3.2,其他维度的交换方式同理,实在不明白可以拿几本书放一起比划一下。

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数
3 阶张量交换 c, h 维度的代码示例见下,不难发现对 3 阶张量的 c, h 两个索引进行 transpose() 操作就是以 w 索引方向为轴在进行旋转。
x = torch.Tensor([[[1, 2, 3], [4, 5, 6]],
                  [[7, 8, 9], [10, 11, 12]],
                  [[13, 14, 15], [16, 17, 18]],
                  [[19, 20, 21], [22, 23, 24]]])
print(f'x.size() = {x.size()}')
print(x.transpose(0, 1))

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

④ 4 阶张量

对于一个 4 阶张量,结构为 (n,c, h, w),对应 transpose() 函数中的参数是 (0, 1, 2,3) 4 个索引,对应 transpose() 的操作相对复杂一些,为方便理解这里具体分为 transpose(0, 1),和 transpose(0, 3),和 transpose(1, 2) 三种,具体原因见以下分析。

3.4.1 transpose(0, 1) 操作就是交换 n, c 两个维度,交换 n, c 两个维度的示意图见图 4.1 和图 4.2,实在不明白可以拿几本书比划一下。

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

4 阶张量交换 n, c 维度的代码示例见下,其他维度交换同理,不难发现对 4 阶张量而言进行 transpose(0, 1) 操作就是 n 索引方向上进行通道重新分组,如下代码中原张量 n 索引方向上有 2 组,每组有 3 个通道,交换 n, c 维度后变为 3 组,每组有 2 个通道。

x = torch.Tensor([[[[1, 2], [3, 4]],
                   [[5, 6], [7, 8]],
                   [[9, 10], [11, 12]]],

                  [[[13, 14], [15, 16]],
                   [[17, 18], [19, 20]],
                   [[21, 22], [23, 24]]]])
print(f'x.size() = {x.size()}')
print(x.transpose(0, 1))

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数
3.4.2 transpose(0, 3) 操作就是交换 n, w 两个维度,交换 n, w 两个维度的示意图比较难表示,这里用代码解释一下, transpose(0, 2) 同理,需要说明的是这种变换方式很少会用到。

对于一个结构为 (2, 2, 2, 3) 的 4 阶张量 x,进行 transpose(0, 3) 操作即将原 4 阶张量变成一个结构为 (3, 2, 2, 2) 的新 4 阶张量,可以理解为在保证原 4 阶张量中元素 c , h 索引不变的情况下的将每一个元素的 n, w 进行交换,类似于坐标系变换。

x = torch.Tensor([[[[1, 2, 3], [4, 5, 6]],
                   [[7, 8, 9], [10, 11, 12]]],

                  [[[13, 14, 15], [16, 17, 18]],
                   [[19, 20, 21], [22, 23, 24]]]])
print(f'x.size() = {x.size()}')
y = x.transpose(0, 3)
print(f'y.size() = {y.size()}')
print(y)

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数
3.4.3 transpose(1, 2) 操作就是交换 c, h 两个维度,跟之前的 3 阶张量交换维度的操作一样,4 阶张量只是需要交换 n 个 3 阶张量的维度, transpose(1, 3)transpose(2, 3)同理。

4. permute()

permute() 函数一次可以进行多个维度的交换或者可以成为维度重新排列,参数是 0, 1, 2, 3, … ,随着待转换张量的阶数上升参数越来越多,本质上可以理解为多个 transpose() 操作的叠加,因此理解 permute() 函数的关键在于理解 transpose() 函数,代码示例如下。

x = torch.Tensor([[[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 10, 11, 12]],

                  [[13, 14, 15, 16],
                   [17, 18, 19, 20],
                   [21, 22, 23, 24]]])
print(f'x.size() = {x.size()}')
y = x.permute(2, 0, 1)
z = x.transpose(0, 1).transpose(0, 2)
print(y.equal(z))
print(f'z.size() = {z.size()}')
print(z)

pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

结语

通过以上分析可以得出结论, reshpe()view() 两个函数满足条件时可以根据需要设置维度,而 transpose()permute() 两个函数只能在已有的维度之间进行变换,另外 transpose() 函数在 pytorch 和 numpy 中略有不同,numpy 中的 transpose() 函数相当于 pytorch 中的 permute() 函数。

Original: https://blog.csdn.net/Wenyuanbo/article/details/119779521
Author: 听 风、
Title: pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/670925/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球