深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别

Flatten层用来将输入”压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。

就是把高纬度的数组按照 x轴或者y轴 进行拉伸,变成一维的数组

为了更好的理解Flatten层作用,我把这个神经网络进行可视化如下图:(来自网络)

深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别

flatten(),默认缺省参数为0,也就是说flatten()和flatte(0)效果一样。

python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。

比如一个数据的维度是

深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别,flatten(m)后的数据为深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别

案例程序如下:

import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("./data_CIFAR10", train=False,
                                       transform=torchvision.transforms.ToTensor(),download=True)

dataloader = DataLoader(dataset,batch_size=64)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.linear1 = Linear(196608,10)

    def forward(self,input):
        output = self.linear1(input)
        return output

tudui = Tudui()

for data in dataloader:
    imgs, targets = data
    print(imgs.shape)
    # output = torch.reshape(imgs,(1,1,1,-1))
    output = torch.flatten(imgs)
    print(output.shape)
    output = tudui(output)
    print(output.shape)

运行结果如下:

深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别

从上图可以看出,torch_size([64,3,32,32])是print(imgs.shape)打印得到的结果,表示batch_size=64,channel=3,高H=32,宽W=32

上面的结果通过flatten后得到的结果维度大小为torch_size([196608]),其中的196608=643232*3得到的

然后经过神经网络(Tudui)得到的结果维度大小是torch_size([10]),表示输出为10个类别。

如果把flatten改为reshape会出现什么结果呢,程序如下:

import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("./data_CIFAR10", train=False,
                                       transform=torchvision.transforms.ToTensor(),download=True)

dataloader = DataLoader(dataset,batch_size=64)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.linear1 = Linear(196608,10)

    def forward(self,input):
        output = self.linear1(input)
        return output

tudui = Tudui()

for data in dataloader:
    imgs, targets = data
    print(imgs.shape)
    output = torch.reshape(imgs,(1,1,1,-1))
    # output = torch.flatten(imgs)
    print(output.shape)
    output = tudui(output)
    print(output.shape)

运行结果如下:

深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别

我们发现,经过了reshape后,得到的结果尺寸维度是torch_size([1,1,1,196608]),其结果表示batch_size=1,channel=1,高H=1,宽W=196608

上面结果通过了神经网络(Tudui)后得到结果尺寸维度为torch_size([1,1,1,10]),表示输出为10个类别。

Original: https://blog.csdn.net/qq_42233059/article/details/126663501
Author: 清泉_流响
Title: 深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/708043/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球