【stgcn】代码解读之主函数(一)

该代码是 python-torch 写的!
请看序列(一、二、三)

一、模型概述

文件分布

首先看文件的内容:STSGCN中包含两个文件夹:model,PeMSD7(M) 。model文件中包含:main.py,stsgcn.py ,utils.py三个文件。PeMSD7(M)中包含:矩阵文件 adj_mat.npy和特征数据node_values.npy两个文件。

【stgcn】代码解读之主函数(一)

; PeMSD7(M)中的数据

特征数据为:(34722,207,2)
矩阵数据为:(207,207)
表明共有207个节点,每个节点2个特征

【stgcn】代码解读之主函数(一)

二、导入库和参数设置–main.py

import os
import argparse
import pickle as pk
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from stgcn import STGCN
from utils import generate_dataset, load_metr_la_data, get_normalized_adj

use_gpu = False
num_timesteps_input = 12
num_timesteps_output = 3

epochs = 1000
batch_size = 50

parser = argparse.ArgumentParser(description='STGCN')
parser.add_argument('--enable-cuda', action='store_true',
                    help='Enable CUDA')

args = parser.parse_args()
args.device = None
if args.enable_cuda and torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

三、查看主函数–main.py

if __name__ == '__main__':
    torch.manual_seed(7)

    A, X, means, stds = load_metr_la_data()
    print("数据加载...")

    split_line1 = int(X.shape[2] * 0.6)
    split_line2 = int(X.shape[2] * 0.8)

    train_original_data = X[:, :, :split_line1]
    val_original_data = X[:, :, split_line1:split_line2]
    test_original_data = X[:, :, split_line2:]

    training_input, training_target = generate_dataset(train_original_data,
                                                       num_timesteps_input=num_timesteps_input,
                                                       num_timesteps_output=num_timesteps_output)
    val_input, val_target = generate_dataset(val_original_data,
                                             num_timesteps_input=num_timesteps_input,
                                             num_timesteps_output=num_timesteps_output)
    test_input, test_target = generate_dataset(test_original_data,
                                               num_timesteps_input=num_timesteps_input,
                                               num_timesteps_output=num_timesteps_output)
    print("数据生成器")

    A_wave = get_normalized_adj(A)
    A_wave = torch.from_numpy(A_wave)

    A_wave = A_wave.to(device=args.device)
    print("矩阵归一化后并存在设备上")

    net = STGCN(A_wave.shape[0],
                training_input.shape[3],
                num_timesteps_input,
                num_timesteps_output).to(device=args.device)
    print("模型实例化并且存放在设备上")
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    loss_criterion = nn.MSELoss()
    print("建立损失函数和优化器")
    training_losses = []
    validation_losses = []
    validation_maes = []
    for epoch in range(epochs):
        loss = train_epoch(training_input, training_target,
                           batch_size=batch_size)
        training_losses.append(loss)

        with torch.no_grad():

            net.eval()
            val_input = val_input.to(device=args.device)
            val_target = val_target.to(device=args.device)

            out = net(A_wave, val_input)
            val_loss = loss_criterion(out, val_target).to(device="cpu")
            validation_losses.append(np.asscalar(val_loss.detach().numpy()))

            out_unnormalized = out.detach().cpu().numpy()*stds[0]+means[0]
            target_unnormalized = val_target.detach().cpu().numpy()*stds[0]+means[0]
            mae = np.mean(np.absolute(out_unnormalized - target_unnormalized))
            validation_maes.append(mae)

            out = None
            val_input = val_input.to(device="cpu")
            val_target = val_target.to(device="cpu")

        print("Training loss: {}".format(training_losses[-1]))
        print("Validation loss: {}".format(validation_losses[-1]))
        print("Validation MAE: {}".format(validation_maes[-1]))
        plt.plot(training_losses, label="training loss")
        plt.plot(validation_losses, label="validation loss")
        plt.legend()
        plt.show()

        checkpoint_path = "checkpoints/"
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)
        with open("checkpoints/losses.pk", "wb") as fd:
            pk.dump((training_losses, validation_losses, validation_maes), fd)

注释:

  1. with torch.no_grad():见链接no_grad()
  2. net.eval():用于测试或者评估之前。原因见参考博客园链接1和简书链接2
  3. np.absolute= np.abs:对数组内每个元素求绝对值。在其上即求 mae

四、训练函数–main.py

def train_epoch(training_input, training_target, batch_size):
"""
    Trains one epoch with the given data.

    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).

    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).

    :param batch_size: Batch size to use during training.

    :return: Average loss for this epoch.

"""
    permutation = torch.randperm(training_input.shape[0])

    epoch_training_losses = []
    for i in range(0, training_input.shape[0], batch_size):
        net.train()
        optimizer.zero_grad()

        indices = permutation[i:i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=args.device)
        y_batch = y_batch.to(device=args.device)

        out = net(A_wave, X_batch)
        loss = loss_criterion(out, y_batch)
        loss.backward()
        optimizer.step()
        epoch_training_losses.append(loss.detach().cpu().numpy())
    return sum(epoch_training_losses)/len(epoch_training_losses)

注释:

  1. torch.randperm:返回一个0到n-1的数组。

torch.randperm(n, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False)
torch.randperm(4)
tensor([ 2, 1, 0, 3])

  1. optimizer.zero_grad()的原因:

下一篇:【stgcn】代码pytorch解读(二)

Original: https://blog.csdn.net/panbaoran913/article/details/123115022
Author: panbaoran913
Title: 【stgcn】代码解读之主函数(一)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/712914/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • fairseq笔记

    训练新模型 以机器翻译为例子开始 Fairseq 包含多个翻译数据集的示例预处理脚本:IWSLT 2014(德语-英语)、WMT 2014(英语-法语)和 WMT 2014(英语-…

    人工智能 2023年5月28日
    0109
  • 牛客网-《刷C语言百题》第三期

    ✅作者简介:嵌入式入坑者,与大家一起加油,希望文章能够帮助各位!!!!📃个人主页: @rivencode的个人主页🔥系列专栏: 《C语言入门必刷百题》💬推荐一款模拟面试、刷题神器,…

    人工智能 2023年6月26日
    079
  • 【案例实战】SpringBoot整合Redis连接池生成图形验证码

    回答1: Spring Boot 配置 可以通过以下步骤实现: 1. 在 pom.xml 文件中添加 相关依赖,例如: </p> <p> 2. 在 appl…

    人工智能 2023年6月29日
    072
  • Ubuntu 配置Python环境(包括Tensorflow)

    抵扣说明: 1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。 Original: https:…

    人工智能 2023年5月25日
    0121
  • 人工智能实验——八数码难题

    人工智能实验——八数码难题 人工智能实验——八数码难题 人工智能实验——八数码难题 * 八数码难题简介 八数码难题所用到的算法简介 代码实现解释 运行结果显示 代码附件 程序可视化…

    人工智能 2023年7月26日
    078
  • 爬虫——刘飞

    回答1: 使用pyquery可以通过CSS选择器或XPath表达式来查找HTML文档中的元素,从而提取所需的数据。具体步骤如下: 1. 导入pyquery库:from pyquer…

    人工智能 2023年7月8日
    095
  • 《深度学习之pytorch实战计算机视觉》第9章 多模型融合(代码可跑通)

    上一章《深度学习之pytorch实战计算机视觉》第8章 图像风格迁移实战(代码可跑通)讲了图像风格迁移实战,是个很有趣的应用。 多模型融合是一种”集百家之所长&#822…

    人工智能 2023年5月26日
    0138
  • Python创建/读取 Excel表

    学习心得: 一 通过pandas创建Excel 1.需要import pandas as pd 库来操作 2.通过 pandas.DataFrame()创建一个数据框datafra…

    人工智能 2023年7月6日
    0107
  • 从零手写RGBD SLAM

    刚学习完ORB-SLAM2框架,但苦于没有实战项目,总感觉心里没底。偶然间发现了高翔博士的一起做RGBD SLAM博客,简单看了一点就感觉对自己大有帮助,hh大佬就是大佬,完全理解…

    人工智能 2023年7月19日
    071
  • 图像平滑处理

    图像滤波是图像处理和计算机视觉中最常用、最基本的操作。主要是去除图像中的噪声,因为图像平滑处理过程中往往会使得图像变的模糊,因此又叫模糊处理。 基本原理 图像平滑的基本原理是,将噪…

    人工智能 2023年6月18日
    077
  • 知识图谱融合

    知识融合,即合并两个知识图谱(本体),基本的问题都是研究怎样将来自多个来源的关于同一个实体或概念的描述信息融合起来。知识图谱包含描述抽象知识的本体层和描述具有事实的实例层。本体层用…

    人工智能 2023年6月1日
    0132
  • 池化分类、作用简单总结

    池化分类 平均池化:对邻域内特征点求平均 正向传播:邻域内取平均 反向传递:梯度根据邻域大小被平均,然后传递给索引位置 参考链接:平均池化最大池化:对邻域内特征点求最大值 正向传播…

    人工智能 2023年7月2日
    0137
  • 机器学习笔记–2.1文本分类

    从分类算法层面来看,各类语言的文本分类技术大同小异,但从整个流程来考察,不同语言的文本处理所用到的技术还是有差别的。下面给出中文语言的文本分类技术和流程,主要包括以下几个步骤: (…

    人工智能 2023年5月28日
    078
  • 论文阅读 A Data-Driven Graph Generative Model for Temporal Interaction Networks

    13 A Data-Driven Graph Generative Model for Temporal Interaction Networks link:https://sch…

    人工智能 2023年6月4日
    0111
  • 【python数据分析】:数据预处理之连续数据离散化

    连续属性变换成分类属性,即连续数据离散化。 在数值的取值范围内设定若干个离散划分点,将取值范围划分为一些离散化的区间,最后用不同的符号或整数值代表每个子区间中的数据值。 连续数据离…

    人工智能 2023年7月6日
    0204
  • 卷积神经网络分类实战

    卷积神经网络分类实战 基于唐宇迪老师的神经网络课程,初步实现课堂上的神经网络构建。 1.torchvision torchvision是 pytorch的一个图形库,它服务于 Py…

    人工智能 2023年7月1日
    0114
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球