Nine—pytorch学习—拼接与拆分/运算统计

pytorch学习(6)

拼接与拆分

  • cat
  • stack
  • split
  • chunk

cat()

  • 连接给定维度中给定的张量序列
  • 所有张量必须具有相同的形状(拼接维度除外)或为空
  • torch.cat() 可以看作是 torch.split() 和 torch.chunk() 的反运算
  • torch.cat(inputs,dim=)
#正确的案例
import torch

a = torch.rand(3,32,8)
b = torch.rand(6,32,8) #b与a除拼接维度外具有相同的形状
c = torch.cat([a,b], dim=0)
print(a.shape)
print(b.shape)
print(c.shape)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 3 but got size 6 for tensor number 1 in the list.

stack()

  • create new dim
  • 沿着一个新维度对输入张量进行连接,属于扩张再拼接的函数
  • 序列中所有张量都要有相同的形状
  • torch.stack(sequence,dim=)
import torch

a = torch.rand(6,32,8)
b = torch.rand(6,32,8) #a与b序列应保持相同
c = torch.stack([a,b], dim=0)
print(a.shape)
print(b.shape)
print(c.shape)
torch.Size([6, 32, 8])
2
torch.Size([3, 32, 8])
torch.Size([3, 32, 8])
#案例二
import torch

a = torch.rand(7,32,8)
b = torch.split(a,3,0)
print(a.shape)
print(len(b))
print(b[0].shape)
print(b[1].shape)
print(b[2].shape)
torch.Size([6, 32, 8])
3
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])

运算与统计

  • 基础四则运算
  • 平方与开方
  • 矩阵相乘
  • 近似函数
  • 数据裁剪函数

基础四则运算(需要满足广播机制)

  • add 加法
  • sub 减法
  • mul 乘法
  • div 除法
torch.add()
import torch

a = torch.rand(3,4)
b = torch.rand(4)

c = a+b
d = torch.add(a,b) #两种方式输出结果相同

print(c)
print(d)
tensor([[ 0.3418, -0.0872, -0.4209, -0.1290],
[ 0.3348, -0.0099, -0.2430, -0.3949],
[-0.1601, -0.1792, 0.1060, -0.5435]])
tensor([[ 0.3418, -0.0872, -0.4209, -0.1290],
[ 0.3348, -0.0099, -0.2430, -0.3949],
[-0.1601, -0.1792, 0.1060, -0.5435]])

torch.mul(input, value, out=None)

  • 对输入张量 input 逐元素乘以 标量值/张量(value),并返回一个新的张量tensor
import torch

a = torch.rand(3,3)
b = torch.eye(3,3)

c = torch.mul(a,b)

print(a)
print(b)
print(c)
tensor([[2, 2],
[2, 2]])
tensor([[1., 1.],
[1., 1.]])
tensor([[1., 1.],
[1., 1.]])

平方与开方

  • pow 函数 & **
  • sqrt 函数(平方) & rsqrt 函数
pow 函数 & **
import torch

a = torch.full([2,2],2)
b = a.pow(2)
c = a**2
d = b.pow(0.5)

print(b)
print(c)
print(d)
tensor([[2, 2],
[2, 2]])
tensor([[1.4142, 1.4142],
[1.4142, 1.4142]])
tensor([[0.7071, 0.7071],
[0.7071, 0.7071]])

矩阵相乘

  • matmul()
  • torch.mm(mat1,mat2,out=None)
  • torch.bmm(batch1,batch2,out=None)
  • torch.matmul(tensor1,tensor2,out=None)
torch.mm() — 二维矩阵乘法
  • mm只能进行矩阵乘法,也就是输入的两个tensor维度只能是 (n x m)和 (m x p)
  • (n x m)和(m x p)通过矩阵乘法得到(n x p)
import torch

a = torch.rand(3,4)
b = torch.rand(4,5)

c = torch.mm(a,b)

print(a.shape)
print(b.shape)
print(c.shape)
torch.Size([3, 2, 4])
torch.Size([3, 4, 5])
torch.Size([3, 2, 5])
torch.matmul()
  • matmul可以进行张量乘法,输入可以是高维
  • 对矩阵mat1和mat2进行相乘
  • @符号与matmul效果相同
  • 例如:tensor1维度是(j x 1 x n x m),tensor2维度是(k x m x p),输出为(j x k x n x p)
import torch

a = torch.rand(3,1,2,4)
b = torch.rand(5,4,6)

c = torch.matmul(a,b)
d = a @ b
print(a.shape)
print(b.shape)
print(c.shape)
print(d.shape)
tensor(3.1416)
tensor(3.)
tensor(4.)
tensor(3.)
tensor(3.)
tensor(0.1416)

数据裁剪函数

  • clamp函数 将输入input张量每个元素值约束到区间[min,max],并返回结果到一个新的tensor,也可以只设定min或只设定max
  • torch.clamp(input,min,max)

`python
import torch

a = torch.rand(3,3)*20
b = torch.clamp(a,0,10)
c = torch.clamp(a,7,14)

print(a)
print(b)
print(c)

Original: https://www.cnblogs.com/311dih/p/16583856.html
Author: 叁_311
Title: Nine—pytorch学习—拼接与拆分/运算统计

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/643416/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球