目录
模型训练完准确率一直为0
- 计算准确率或者输出看loss、准确率时,需要把原来的tensor数据类型,转成普通的数字
即.item()进行转换
total_accuracy = (total_accuracy +accuracy).item()
最好转item数据类型,不然这个accuracy会是一个tensor的数据类型,tensor数据类型和一个普通的数据相除,结果一定是0
如果不转,就会输出:
如果Tensor数据类型没有转换,直接用 total_accuracy直接除以测试集10000这一数值,出来的结果会是0
转换后,结果正确:
完整的模型验证套路:test
import torchvision
from PIL import Image
from model import *
image_path ='dog.jpg'
用PIL读取图片
PIL_image = Image.open(image_path)
定义transform
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
PIL图片变成32*32,同时转tensor
tensor_image = transform(PIL_image)
reshape,加一个batchsize
tensor_image = torch.reshape(tensor_image,(1,3,32,32))
tensor_image = tensor_image.cuda()
print(tensor_image.shape)
读取模型
modle = torch.load('module_10.pth')
module1 = Module()
module1.load_state_dict(modle)
module1 = module1.cuda()
print(module1)
with torch.no_grad():
output = module1(tensor_image)
print(output)
output = (output.argmax(1)).item()
print(output)
reshape和reszie的区别
要注意reshape和resize的区别:
torchvision.transforms.Resize((32,32))
tensor_image= torch.reshape(tensor_image,(B,C,W,H))
一个是切割裁剪图片,一个是对图片的像素进行重新排列组合,添加BatchSize
debug方法
- 此时我们不记得5对应的是哪个类别,可以去代码中debug一下,看看数据集
- debug是运行到红点的上一行
完成训练!
附一个修狗图片
这玩意都能预测出来是青蛙,厉害了~
Original: https://blog.csdn.net/weixin_42934729/article/details/123686040
Author: 可基大萌萌哒的马鹿
Title: 模型训练完准确率为0的解决方法,以及模型验证方法(resize和reshape区别)
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/707740/
转载文章受原作者版权保护。转载请注明原作者出处!