pytorch学习(6)
拼接与拆分
- cat
- stack
- split
- chunk
cat()
- 连接给定维度中给定的张量序列
- 所有张量必须具有相同的形状(拼接维度除外)或为空
- torch.cat() 可以看作是 torch.split() 和 torch.chunk() 的反运算
- torch.cat(inputs,dim=)
#正确的案例
import torch
a = torch.rand(3,32,8)
b = torch.rand(6,32,8) #b与a除拼接维度外具有相同的形状
c = torch.cat([a,b], dim=0)
print(a.shape)
print(b.shape)
print(c.shape)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 3 but got size 6 for tensor number 1 in the list.
stack()
- create new dim
- 沿着一个新维度对输入张量进行连接,属于扩张再拼接的函数
- 序列中所有张量都要有相同的形状
- torch.stack(sequence,dim=)
import torch
a = torch.rand(6,32,8)
b = torch.rand(6,32,8) #a与b序列应保持相同
c = torch.stack([a,b], dim=0)
print(a.shape)
print(b.shape)
print(c.shape)
torch.Size([6, 32, 8])
2
torch.Size([3, 32, 8])
torch.Size([3, 32, 8])
#案例二
import torch
a = torch.rand(7,32,8)
b = torch.split(a,3,0)
print(a.shape)
print(len(b))
print(b[0].shape)
print(b[1].shape)
print(b[2].shape)
torch.Size([6, 32, 8])
3
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])
运算与统计
- 基础四则运算
- 平方与开方
- 矩阵相乘
- 近似函数
- 数据裁剪函数
基础四则运算(需要满足广播机制)
- add 加法
- sub 减法
- mul 乘法
- div 除法
torch.add()
import torch
a = torch.rand(3,4)
b = torch.rand(4)
c = a+b
d = torch.add(a,b) #两种方式输出结果相同
print(c)
print(d)
tensor([[ 0.3418, -0.0872, -0.4209, -0.1290],
[ 0.3348, -0.0099, -0.2430, -0.3949],
[-0.1601, -0.1792, 0.1060, -0.5435]])
tensor([[ 0.3418, -0.0872, -0.4209, -0.1290],
[ 0.3348, -0.0099, -0.2430, -0.3949],
[-0.1601, -0.1792, 0.1060, -0.5435]])
torch.mul(input, value, out=None)
- 对输入张量 input 逐元素乘以 标量值/张量(value),并返回一个新的张量tensor
import torch
a = torch.rand(3,3)
b = torch.eye(3,3)
c = torch.mul(a,b)
print(a)
print(b)
print(c)
tensor([[2, 2],
[2, 2]])
tensor([[1., 1.],
[1., 1.]])
tensor([[1., 1.],
[1., 1.]])
平方与开方
- pow 函数 & **
- sqrt 函数(平方) & rsqrt 函数
pow 函数 & **
import torch
a = torch.full([2,2],2)
b = a.pow(2)
c = a**2
d = b.pow(0.5)
print(b)
print(c)
print(d)
tensor([[2, 2],
[2, 2]])
tensor([[1.4142, 1.4142],
[1.4142, 1.4142]])
tensor([[0.7071, 0.7071],
[0.7071, 0.7071]])
矩阵相乘
- matmul()
- torch.mm(mat1,mat2,out=None)
- torch.bmm(batch1,batch2,out=None)
- torch.matmul(tensor1,tensor2,out=None)
torch.mm() — 二维矩阵乘法
- mm只能进行矩阵乘法,也就是输入的两个tensor维度只能是 (n x m)和 (m x p)
- (n x m)和(m x p)通过矩阵乘法得到(n x p)
import torch
a = torch.rand(3,4)
b = torch.rand(4,5)
c = torch.mm(a,b)
print(a.shape)
print(b.shape)
print(c.shape)
torch.Size([3, 2, 4])
torch.Size([3, 4, 5])
torch.Size([3, 2, 5])
torch.matmul()
- matmul可以进行张量乘法,输入可以是高维
- 对矩阵mat1和mat2进行相乘
- @符号与matmul效果相同
- 例如:tensor1维度是(j x 1 x n x m),tensor2维度是(k x m x p),输出为(j x k x n x p)
import torch
a = torch.rand(3,1,2,4)
b = torch.rand(5,4,6)
c = torch.matmul(a,b)
d = a @ b
print(a.shape)
print(b.shape)
print(c.shape)
print(d.shape)
tensor(3.1416)
tensor(3.)
tensor(4.)
tensor(3.)
tensor(3.)
tensor(0.1416)
数据裁剪函数
- clamp函数 将输入input张量每个元素值约束到区间[min,max],并返回结果到一个新的tensor,也可以只设定min或只设定max
- torch.clamp(input,min,max)
`python
import torch
a = torch.rand(3,3)*20
b = torch.clamp(a,0,10)
c = torch.clamp(a,7,14)
print(a)
print(b)
print(c)
Original: https://www.cnblogs.com/311dih/p/16583856.html
Author: 叁_311
Title: Nine—pytorch学习—拼接与拆分/运算统计
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/643416/
转载文章受原作者版权保护。转载请注明原作者出处!