Pytorch 是一个机器深度学习框架,易于上手,个人感觉比 tensorflow要友好。
Pytorch的深度学习程序分三个模块,实现三个功能,分别是取数据、建模型、运行程序。一般是分三个.py文件写,当然也可以写在一个文件里。我喜欢写成三个文件,这样看着比较方便点,而且Pytorch把这三个功能都写的挺好的,自己用的时候继承稍微改一下就好了。
其实深度学习的最终目标,就像求 y = f ( x ) y = f(x)y =f (x ) 这个公式中 f ( x ) f(x)f (x ) 的最佳参数一样:
继承 Dataset
就可以了,直接上代码
from torch.utils.data import Dataset
class DataSet_h(Dataset):
def __init__(self):
super(DataSet_h, self).__init__()
self.Arr = [(x1, y1), (x2, y2)...]
def __len__(self):
return len(self.Arr)
def __getitem__(self, item):
x = self.Arr[item][0]
y = self.Arr[item][1]
return x, y
继承 nn.Module
就可以了,直接上代码
import torch.nn as nn
class Model_h(nn.Module):
def __init__(self):
super(Model_h, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, inputs):
return self.fc(inputs)
直接上代码
from torch.utils.data import DataLoader
from Model import Model_h
import DataSet_h
model = Model_h()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
trainDataSet = DataSet_h.DataSet_h()
trainDataLoader = DataLoader(trainDataSet, batch_size=2)
for i, batch in enumerate(trainDataLoader):
x, y = batch
y_pre = model(x)
loss = y_pre - y
optimizer.zero_grad()
loss.backward()
optimizer.step()
最基本的意思就是这样,我还没试代码能不能跑,不过想法通了,代码的小问题都不是事儿了。
Original: https://blog.csdn.net/u010095372/article/details/120671413
Author: 赫凯
Title: Pytorch 深度学习运行代码简单教程
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/690954/
转载文章受原作者版权保护。转载请注明原作者出处!