04梯度下降算法比较
2023-12-13 12:40:18
?
BGD批量梯度下降
随机梯度下降
小批量梯度下降
import random
import numpy as np
import matplotlib.pyplot as plt
# BGD批量梯度下降
# # 用于生成一个形状为 (100, 1) 的随机数组。
# # 形状参数 (100, 1): 意味着生成的随机数组将是一个 100 行 1 列的二维数组
# X = np.random.rand(100, 1)
#
#
# # numpy.random.randint(low, high=None, size=None, dtype=int)
# # 生成一个形状为 (3, 2) 的二维数组,其中的随机整数范围在 [1, 10)
# # random_array = np.random.randint(1, 10, size=(3, 2))
# w, b = np.random.randint(1, 10, size=2)
#
# y = w * X + b + np.random.rand(100, 1)
#
# plt.scatter(X, y)
# # plt.show()
#
#
# # 在 X 数组的右侧添加一列全为 1 的列,
# # 通常是为了引入一个截距项,用于线性回归等模型中。
# # 这个操作使得 X 变成一个包含两列的数组,其中一列是原始的随机数据,另一列是全为 1 的列。
# # 这对于线性模型的训练很有用,因为它允许模型在学习时调整截距。
# # 在连接之前对X进行转置
# X = np.concatenate([X, np.full(shape=(100, 1), fill_value=1)], axis=1)
#
# # 循环次数
# epochs = 1000
#
# # 学习率
# eta = 0.01
#
# t0 = 5
# t1 = 1000
#
#
# # 定义学习率调度函数
# def learn_rate_shedule(t):
# return t0 / (t + t1)
#
#
# # 求解的系数
# theta = np.random.rand(2, 1)
#
# t = 0
#
# for i in range(epochs):
# g = X.T.dot(X.dot(theta) - y) # 计算梯度
# theta = theta - eta * g # 更新系数
# eta = learn_rate_shedule(t) # 更新学习率
# t = t + 1
#
# print('真实w, b ', w, b)
# print('预测w, b', theta)
# 随机梯度下降
# 生成随机数据
X = np.random.rand(100, 1)
w, b = np.random.randint(1, 10, size=2)
y = w * X + b + np.random.rand(100, 1)
# 添加截距项
X = np.concatenate([X, np.full_like(X, fill_value=1)], axis=1)
# 超参数设置
epochs = 1000
t0, t1 = 5, 1000
# 学习率调度函数
def learning_rate_schedule(t):
return t0 / (t + t1)
# 初始化权重
theta = np.random.randn(2, 1)
count = 0
for epoch in range(epochs):
# 随机打乱数据
index = np.arange(100)
np.random.shuffle(index)
X = X[index]
y = y[index]
for i in range(100):
X_i = X[[i]]
y_i = y[[i]]
# 计算梯度
g = X_i.T.dot(X_i.dot(theta) - y_i)
# 更新学习率
eta = learning_rate_schedule(count)
count += 1
# 更新权重
theta = theta - eta * g
# 打印结果
print('真实 w, b:', w, b)
print('预测 w, b:', theta)
# 小批量梯度下降
文章来源:https://blog.csdn.net/zzqtty/article/details/134968376
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!