pytorch学习(3)
索引与切片
- 普通索引
- 冒号索引(切片)
- index_select 选择特定索引
- masked_select 选择符合条件的索引
- take 索引
普通索引
- index(有负索引)
import torch
a = torch.Tensor(2,3,32,32)
print(a.shape)
print(a[0].shape)
print(a[0][0].shape)
print(a[0][0][0][0].shape)
print(a[0][0][0][0])
torch.Size([1, 3, 32, 32])
torch.Size([2, 3, 10, 10])
torch.Size([2, 3, 16, 16])
torch.Size([2, 3, 16, 16])
torch.Size([2, 3, 32, 32])
index_select 选择特定索引
- torch.index_select(x, 维度, torch.tensor([a,b]))
- x代表目标张量,[a,b]代表从a到b
import torch
a = torch.linspace(0, 12, steps=12)
#创建一个列表从0到12的浮点型
print(a)
c = a.view(3,4) #将a进行维度转换变为三行四列的二维张量
b = torch.index_select(c, 0, torch.tensor([0,2]))
#索引张量c的1维,是行,即为索引第0行以及第2行
print(b)
d = torch.index_select(c, 1, torch.tensor([1,3]))
#索引c的2维,是列,即为索引第1列和第3列
print(d)
tensor([[ 1.1909, 0.2912, -0.1066],
[ 0.9496, 0.6031, 1.5957],
[-0.2447, 0.0101, 2.4906]])
tensor([[ True, False, False],
[False, True, False],
[False, False, True]])
tensor([1.1909, 0.6031, 2.4906])
take索引
- torch.take(input, index) take索引是在原来tensor的shape基础上打平,然后按照index在打平后的tensor上索取对应位置的元素
`python
import torch
a = torch.randn(3,3) #创建一个三行三列的正态分布的矩阵
print(a)
c = torch.take(a, torch.tensor([0,2,4,6,8]))
print(c)
Original: https://www.cnblogs.com/311dih/p/16583850.html
Author: 叁_311
Title: Six—pytorch学习—索引与切片
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/643687/
转载文章受原作者版权保护。转载请注明原作者出处!