Pytorch避坑之:RuntimeError: Input type(torch.cuda.FloatTensor) and weight type(torch.FloatTensor) shoul

问题分析

Pytorch避坑之:RuntimeError: Input type(torch.cuda.FloatTensor) and weight type(torch.FloatTensor) shoul
  • 就像是字面意思那样,这个错误是因为模型中的 weights 没有被转移到 cuda 上,而模型的数据转移到了 cuda 上而造成的
  • 但是造成这个问题的原因却没有那么简单。
  • 绝大多数时候,造成这个的原因是因为你定义好模型之后,没有对模型进行 to(device) 而造成的,但是,也有可能,是因为你的模型在定义的时候,没有定义好, *导致模型的一部分在加载的时候没有办法转移到 cuda上。

; 细节举例

  • 比如我现在定义了一个模型 A,B,它们的结构如下:

import torch.nn as nn
import torch
import torch.utils.data as Data
from tqdm import tqdm
from torchvision import transforms,datasets
import numpy as np
import torchvision
from torch.optim import lr_scheduler

class A(nn.Module):
    def __init__(self):
        super(A,self).__init__()
        self.conv = nn.Conv2d(in_channels=3
                              ,out_channels=8
                              ,kernel_size=3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self,x):
        out = self.conv(x)
        out = self.relu(out)
        B_model = B()
        out = B_model(out)
        return out

class B(nn.Module):
    def __init__(self):
        super(B,self).__init__()
        self.conv = nn.Conv2d(in_channels=8
                              ,out_channels=16
                              ,kernel_size=3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.relu(out)
        return out

Pytorch避坑之:RuntimeError: Input type(torch.cuda.FloatTensor) and weight type(torch.FloatTensor) shoul
  • 这个时候就会报错,而报错的原因,就是因为 torch 的流程是这样的:
  • 首先将所有的模型加载,先从 A 开始,进入 A 的 init 中把所有的内容加载,然后,通过 main 函数中的 to(device) 操作,就把加载的所有内容和网络定义都放到 cuda 上了,但是注意!!!
    Pytorch避坑之:RuntimeError: Input type(torch.cuda.FloatTensor) and weight type(torch.FloatTensor) shoul

改错思路

  • 将所有的内容都放到 cpu 上运行,即:
    Pytorch避坑之:RuntimeError: Input type(torch.cuda.FloatTensor) and weight type(torch.FloatTensor) shoul
  • 但显然这是个治标不治本的方法,我们就没有办法使用 gpu 训练了,因此我们选择把所有的网络层(只要有参数需要训练的网络层)都放到 init 里面去定义,只在 forward 中写运行时的逻辑,即:
class A(nn.Module):
    def __init__(self):
        super(A,self).__init__()
        self.conv = nn.Conv2d(in_channels=3
                              ,out_channels=8
                              ,kernel_size=3)
        self.relu = nn.ReLU(inplace=True)
        self.b_module = B()

    def forward(self,x):
        out = self.conv(x)
        out = self.relu(out)
        out = self.b_module(out)
        return out

class B(nn.Module):
    def __init__(self):
        super(B,self).__init__()
        self.conv = nn.Conv2d(in_channels=8
                              ,out_channels=16
                              ,kernel_size=3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.relu(out)
        return out

Original: https://blog.csdn.net/qq_42902997/article/details/122594017
Author: 暖仔会飞
Title: Pytorch避坑之:RuntimeError: Input type(torch.cuda.FloatTensor) and weight type(torch.FloatTensor) shoul

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/625550/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球