是什么?它

优化器(Optimizer)和损失函数(Loss)介绍

优化器(Optimizer)

在深度学习中,优化器用于更新模型参数以最小化损失函数。PyTorch提供了多种优化器,其中常用的有随机梯度下降(SGD)、Adam、Adadelta、Adagrad、AdamW等。

随机梯度下降(SGD)优化器

算法原理

随机梯度下降(SGD)是最基本的优化算法之一。它通过计算每个样本的梯度来更新参数。具体算法原理如下:

  1. 初始化模型参数。
  2. 对每个样本:
  3. 计算模型的输出。
  4. 计算损失函数对输出的梯度。
  5. 根据梯度更新模型参数。
  6. 重复步骤2直到所有样本均处理完毕。
公式推导

SGD的参数更新公式如下所示:
$$
\theta_{t+1}=\theta_t-\alpha\frac{\partial L}{\partial \theta_t}
$$
其中,$\theta_{t}$表示第t个参数的值,$\alpha$为学习率,$L$为损失函数。

计算步骤
  1. 遍历训练数据集。
  2. 对于每个样本,计算模型的输出和损失函数。
  3. 对损失函数进行反向传播,计算梯度。
  4. 根据梯度和学习率更新参数。
  5. 重复步骤2-4直到所有样本均处理完毕。
Python代码示例

下面是一个简单的使用SGD优化器训练线性回归模型的示例代码,代码中包含了详细的注释解释代码细节。

import torch
import torch.nn as nn
import torch.optim as optim

# 创建模型
class LinearRegression(nn.Module):
 def __init__(self):
 super(LinearRegression, self).__init__()
 self.linear = nn.Linear(1, 1) # 输入和输出的维度都为1

 def forward(self, x):
 return self.linear(x)

# 定义训练数据
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]])

# 创建模型实例
model = LinearRegression()

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 进行训练
for epoch in range(100):
 # 前向传播
 y_pred = model(x_train)

 # 计算损失
 loss = criterion(y_pred, y_train)

 # 反向传播和参数更新
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()

 # 打印训练结果
 print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 100, loss.item()))

# 输出模型参数
print('Learned parameters:')
for name, param in model.named_parameters():
 print(name, param.data)
代码细节解释
  • nn.Linear(1, 1) 创建了一个线性层,输入维度为1,输出维度为1,即一个线性回归模型。
  • nn.MSELoss() 创建了一个均方误差(Mean Squared Error)损失函数。
  • optim.SGD(model.parameters(), lr=0.01) 创建了一个SGD优化器,学习率为0.01。
  • optimizer.zero_grad() 清除梯度。
  • loss.backward() 反向传播,计算梯度。
  • optimizer.step() 根据梯度和学习率更新模型参数。
  • loss.item()获取当前损失的数值。

总结

本文介绍了PyTorch提供的优化器和损失函数。其中,SGD优化器是最基本的优化算法之一,可以通过计算每个样本的梯度来更新参数。通过使用PyTorch提供的优化器和损失函数,可以方便地进行模型训练,并根据实际需求选择合适的优化器和损失函数。以上示例代码展示了使用SGD优化器训练线性回归模型的完整流程,并解释了代码中的细节。

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/823483/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球