【Matlab算法】随机梯度下降法 (Stochastic Gradient Descent,SGD) (附MATLAB完整代码)
前言
随机梯度下降法 (Stochastic Gradient Descent,SGD) 是一种梯度下降法的变种,用于优化损失函数并更新模型参数。与传统的梯度下降法不同,SGD每次只使用一个样本来计算梯度和更新参数,而不是使用整个数据集。这种随机性使得SGD在大型数据集上更加高效,因为它在每次迭代中只需要处理一个样本。
以下是关于随机梯度下降法的详细描述:
- 初姶化参数:与梯度下降法类似,首先需要初始化模型的参数,通常使用随机的初始值。
- 选代过程:
- 对于每个训练样本 i i i :
- 计算损失函数关于当前参数的梯度,即 ? f i ( θ ) \nabla f_i(\theta) ?fi?(θ) ,其中 f i ( θ ) f_i(\theta) fi?(θ) 是针对第 i i i 个样本的损失。
- 使用计算得到的梯度来更新模型参数: θ = θ ? η ? ? f i ( θ ) \theta=\theta-\eta \cdot \nabla f_i(\theta) θ=θ?η??fi?(θ) ,其中 η \eta η 是学习率。
- 重复迭代: 重复以上过程,直到达到预定的迭代次数或满足停止条件(例如梯度的范数足够小)。
相比于传统的梯度下降法,SGD的优点包括:
- 高效:特别适用于大型数据集,因为每次迭代只使用一个样本。
- 在线学习: 可以用于在线学习,即在接收到新数据时立即更新模型。
然而,由于随机性的引入,SGD的参数更新可能会更加不稳定,因此学习率的选择变得尤为重要。为了解决这个问题,有一些SGD的变种,如Mini-batch SGD,它在每次迭代中使用小批量的样本来计算梯度。这样可以在保持高效性的同时减小参数更新的方差。
正文
对于给出的函数
f
(
x
)
f(x)
f(x) :
f
(
x
)
=
x
(
1
)
2
+
x
(
2
)
2
?
2
?
x
(
1
)
?
x
(
2
)
+
sin
?
(
x
(
1
)
)
+
cos
?
(
x
(
2
)
)
f(x)=x(1)^2+x(2)^2-2 \cdot x(1) \cdot x(2)+\sin (x(1))+\cos (x(2))
f(x)=x(1)2+x(2)2?2?x(1)?x(2)+sin(x(1))+cos(x(2))
- 初始化参数: 随机选择初始参数 x x x ,通常使用某种随机的初始值。
- 选择学习率: 选择一个适当的学习率 η \eta η ,这是一个重要的超参数,影响着参数更新的步长。
- 设置迭代次数和停止条件: 确定迭代次数的上限或设置停止条件,例如当梯度的范数小于某个容许误差时停止迭代。
- 随机梯度下降选代:
- 对于每次迭代 t t t ,从训练集中随机选择一个样本 i i i 。
- 计算该样本的梯度: ? f i ( x ( t ) ) \nabla f_i\left(x^{(t)}\right) ?fi?(x(t))
- 使用梯度更新参数: x ( t + 1 ) = x ( t ) ? η ? ? f i ( x ( t ) ) x^{(t+1)}=x^{(t)}-\eta \cdot \nabla f_i\left(x^{(t)}\right) x(t+1)=x(t)?η??fi?(x(t))
- 检查是否满足停止条件。如果满足,停止迭代;否则,继续下一次迭代。
- 输出结果: 输出最终的参数 x x x ,以及在最优点的目标函数值 f ( x ) f(x) f(x) 。
代码实现
可运行代码
% 定义目标函数
f = @(x) x(1)^2 + x(2)^2 - 2*x(1)*x(2) + sin(x(1)) + cos(x(2));
% 定义目标函数的梯度
grad_f = @(x) [2*x(1) - 2*x(2) + cos(x(1)); 2*x(2) - 2*x(1) - sin(x(2))];
% 设置参数
learning_rate = 0.01;
max_iterations = 1000;
tolerance = 1e-6;
% 初始化起始点
x = [0; 0];
% 随机梯度下降
for iteration = 1:max_iterations
% 随机选择一个样本
i = randi(2);
% 计算梯度
gradient = grad_f(x);
% 更新参数
x = x - learning_rate * gradient;
% 检查收敛性
if norm(gradient) < tolerance
break;
end
end
% 显示结果
fprintf('Optimal solution: x = [%f, %f]\n', x(1), x(2));
fprintf('Optimal value of f(x): %f\n', f(x));
fprintf('Number of iterations: %d\n', iteration);
结果
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!