PyTorch 介绍 | 快速开始

本节介绍有关机器学习常见任务重的API。请参阅每一节的链接以深入了解。

Working with data

PyTorch有两个有关数据工作的原型torch.utils.data.DataLoadertorch.utils.data.DatasetDataset 存储了样本及其对应的标签,而 DataLoaderDataset 生成了一个迭代器。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

torchvision.datasets 模块包含多种真实世界的视觉数据集 Dataset 对象,如CIFAR、COCO(full list here)。本教程中,我们使用FashionMNIST数据集。每个TorchVision Dataset 均包括两个参数: transformtarget_transform分别用于修改样本和标签。

从公开数据集上下载训练数据
training_data = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor(),
)

Download test data from open datasets.

test_data = datasets.FashionMNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor(),
)

输出:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

我们将 Dataset 作为参数传递给 DataLoader。这为我们的dataset包装了一个迭代器,并支持自动生成batch、抽样、打乱和多进程数据加载。这里定义了一个大小为64的batch,即,dataloader迭代的每一个元素将返回一个包含64个样本及对应标签的batch。

batch_size = 64

Create data loaders
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtpe)
    break

输出:

Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])
Shape of y:  torch.Size([64]) torch.int64

创建模型

为了在PyTorch定义模型,我们创建了一个类,继承自nn.Module。我们在 __init__函数中定义是网络的layers,并在 forward 函数中指定data如何通过网络。为加快神经网络中的操作,若GPU可用,则把其移动到GPU上。

Get cpu or gpu device for training
device = 'cuda' if torch.cuda.is_availabel() else "cpu"
print(f"Using {device} device")

Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)

        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

输出:

Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

优化模型参数

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

在单个训练循环中,模型在训练集上作出预测(分批喂给模型),并且反向传播预测误差来调整模型参数。

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(x)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

我们还可以检查模型在测试集上的性能,确保模型是在学习。

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            # 第一维度是batch,第二维度是预测值
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")

训练过程由几次迭代( epochs)组成。每一个epoch,模型学习参数,作出更好的预测。我们在每次epoch都打印了模型的准确率和损失,我们希望看到随着每次epoch,准确率升高,而损失降低。

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-----------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

输出:

点击查看代码

`
Epoch 1
loss: 2.169147 [ 0/60000]
loss: 2.165468 [ 6400/60000]
loss: 2.118014 [12800/60000]
loss: 2.129221 [19200/60000]
loss: 2.074899 [25600/60000]
loss: 2.022606 [32000/60000]
loss: 2.033795 [38400/60000]
loss: 1.976709 [44800/60000]
loss: 1.982757 [51200/60000]
loss: 1.881978 [57600/60000]
Test Error:
Accuracy: 57.4%, Avg loss: 1.902724

Epoch 3
loss: 1.592845 [ 0/60000]
loss: 1.556097 [ 6400/60000]
loss: 1.417763 [12800/60000]
loss: 1.478243 [19200/60000]
loss: 1.357680 [25600/60000]
loss: 1.356057 [32000/60000]
loss: 1.360733 [38400/60000]
loss: 1.298324 [44800/60000]
loss: 1.329920 [51200/60000]
loss: 1.219030 [57600/60000]
Test Error:
Accuracy: 63.4%, Avg loss: 1.250318

Epoch 5

Original: https://www.cnblogs.com/DeepRS/p/15727075.html
Author: Deep_RS
Title: PyTorch 介绍 | 快速开始

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/621340/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • 真香警告!JitPack 开源库集成平台

    前言: 请各大网友尊重本人原创知识分享,谨记本人博客:南国以南i 简介 官方介绍: JitPack 是一个用于 JVM 和 Android 项目的新颖的包存储库。它按需构建 Git…

    Linux 2023年6月14日
    0124
  • Linux常用系统管理命令详解

    ps ps命令用于查看系统中的进程状态。 命令格式: ps [参数] 命令参数说明: 参数 作用 -a 显示现行终端机下的所有程序,包括…

    Linux 2023年5月27日
    0118
  • deepin安装Redis步骤以及简单配置

    一、安装Redis 安装完成之后,Redis服务器会自动启动 二、检查Redis服务器系统进程(非必要) 三、查看Redis端口状态(非必要) 四、输入redis-cli进入命令模…

    Linux 2023年5月28日
    0112
  • 如何写出健壮可靠的shell脚本

    1 脚本失败时即退出 ; set -e 例子: 可以在脚本的开头设置如下set -e 2 打印脚本执行过程 sh -x test.sh #整个过程执行了哪些命令或者在开头加上set…

    Linux 2023年5月28日
    0101
  • 3.20 什么是环境变量,Linux环境变量有哪些?

    变量是计算机系统用于保存可变值的数据类型,我们可以直接通过变量名称来提取到对应的变量值。在 Linux 系统中,环境变量是用来定义系统运行环境的一些参数,比如每个用户不同的家目录(…

    Linux 2023年6月7日
    0110
  • MacOS设置终端代理

    前言 国内的开发者或多或少都会因为网络而烦恼,因为一些特殊原因有时候网络不好的时候需要使用代理才能完成对应的操作。原来我一直都是使用斐讯路由器然后刷了梅林的固件,直接在路由器层面设…

    Linux 2023年6月14日
    0101
  • npm 和 maven 使用 Nexus3 私服 | 前后端一起学

    前文《Docker 搭建 Nexus3 私服 》介绍了在 docker 环境下安装 nexus3 以及 nexus3 的基本操作和管理,本文分别介绍 npm(前端)和 maven(…

    Linux 2023年6月7日
    077
  • @Aspect

    AOP是指在程序运行期间动态地将某段代码切入到指定位置并运行的编程方式。 AOP详解可参考:https://blog.csdn.net/javazejian/article/det…

    Linux 2023年6月8日
    0111
  • 应用实战:从Redis到Aerospike,我们踩了这些坑

    博客园 :当前访问的博文已被密码保护 请输入阅读密码: Original: https://www.cnblogs.com/duanxz/p/15878002.htmlAuthor…

    Linux 2023年5月28日
    0108
  • LINUX系统虚拟机环境的安装

    安装VM和Centos Step 1 去BIOS里修改设置开启虚拟化设备支持 设置BIOS: 1.开机按F2 、F12 、DEL 、ESC 等进入BIOS ,一般来说可以看屏幕的左…

    Linux 2023年6月7日
    090
  • GCC 内联汇编基础

    GCC 内联汇编 在 MIT6.828的实验中,有几处用到了很底层的函数,都以内联汇编的形式存在,例如 static inline uint32_t read_esp(void) …

    Linux 2023年6月8日
    091
  • Ubuntu更换镜像源

    当修改 sources.list文件时,我们需要将下面任意一个镜像源的代码 复制粘贴到该文件中。 阿里源 阿里镜像源 deb http://mirrors.aliyun.com/u…

    Linux 2023年6月14日
    095
  • 壁纸爬取——协程应用

    (协程)壁纸爬取 一、 算法解析 1.1 进入爬取壁纸的网站(表层网页) 彼岸桌面壁纸-二次元 少爬涩图,健康生活! 1.2 获取显示单张壁纸的页面(深层网页)地址 选择网页元素:…

    Linux 2023年6月14日
    0185
  • Linux系统编程—信号捕捉

    前面我们学习了信号产生的几种方式,而对于信号的处理有如下几种方式: 默认处理方式; 忽略; 捕捉。 信号的捕捉,说白了就是抓到一个信号后,执行我们指定的函数,或者执行我们指定的动作…

    Linux 2023年6月14日
    0127
  • 部署office在线预览服务器(Office Web Apps Server)

    引言为方便在web端方便的使用office。 简介 Office Online Server (OOS,下文简写为OOS ) 提供基于浏览器的 Word、PowerPoint、Ex…

    Linux 2023年6月14日
    0137
  • 网络设备配置–10、利用ACL配置访问控制

    一、前言 同系列前几篇:网络设备配置–1、配置交换机enable、console、telnet密码网络设备配置–2、通过交换机划分vlan网络设备配置&#8…

    Linux 2023年6月8日
    0121
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球