pytorch实现T-GCN
- 参考代码
- 完整 二维TCN实现代码
- 相关语法
* - Tensor.contiguous()
- torch.transpose() 和 torch.permute()比较
- 基础模型实验
* - 1、nn.Cov1D(input,output,kernal_size,padding)
- 2、nn.Cov2D(input,output,kernal_size,padding)
参考代码
TCN(Temporal Convolutional Network)由Shaojie Bai et al.提出,
https://arxiv.org/pdf/1803.01271.pdf
原始代码来自github:https://github.com/locuslab/TCN
完整 二维TCN实现代码
class delect_padding(nn.Module):
def __init__(self, chomp_size):
super(time_padding, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
"""
其实就是一个裁剪的模块,裁剪多出来的padding
"""
return x[:, :,:, :-self.chomp_size].contiguous()
class TCN(nn.Module):
def __init__(self, n_inputs, n_outputs, kernel_size, padding, dropout=0.2):
"""
:param n_inputs: int, 输入通道数
:param n_outputs: int, 输出通道数
:param kernel_size: int, 卷积核尺寸
:param padding: int, 填充系数
:param dropout: float, dropout比率
"""
super(TCN, self).__init__()
self.conv1 = weight_norm(nn.Conv2d(n_inputs, n_outputs, (1,kernel_size), padding=(0,padding)))
self.delect_pad = delect_padding(padding)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)
self.net = nn.Sequential(self.conv1, self.delect_pad, self.relu1, self.dropout1)
self.relu = nn.ReLU()
self.init_weights()
def init_weights(self):
"""
参数初始化
"""
self.conv1.weight.data.normal_(0, 0.01)
def forward(self, x):
out = self.net(x)
return self.relu(out)
相关语法
Tensor.contiguous()
Tensor.contiguous(memory_format=torch.contiguous_format)
保存为一个连续格式tensor,一般用于 transpose/permute 后和 view 前
torch.transpose() 和 torch.permute()比较
a=torch.tensor([[[[0,1,2],[2,3,4]],
[[1,1,1],[4,3,1]],
[[2,1,1],[2,2,2]],
[[1,3,1],[2,1,1]]]])
print(a.shape)
test_x_1=a.transpose(1,3)
test_x_2 = a.permute(0,3,2,1)
print(test_x_1)
print(test_x_2)
比较test_x_1和test_x_2结果:
基础模型实验
1、nn.Cov1D(input,output,kernal_size,padding)
输入:[N, C1, H]
输出:[N, C2, H]
import torch
import torch.nn as nn
test_x=torch.randn(16,2,6)
text=nn.Conv1d(2,16,3,padding=0)
text_y=text(test_x)
普通1D-CNN实验,kernal=3, padding=0,输出时间片减少2个
2、nn.Cov2D(input,output,kernal_size,padding)
输入:[N, C1, H, W]
输出:[N, C2, H, W]
Original: https://blog.csdn.net/weixin_48635857/article/details/122301193
Author: 长生不老的咖啡精
Title: 二维 TCN pytorch实现完整代码和语法记录
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/709466/
转载文章受原作者版权保护。转载请注明原作者出处!