04梯度下降算法比较

2023-12-13 12:40:18

?

BGD批量梯度下降

随机梯度下降

小批量梯度下降

import random
import numpy as np
import matplotlib.pyplot as plt

# BGD批量梯度下降

# # 用于生成一个形状为 (100, 1) 的随机数组。
# # 形状参数 (100, 1): 意味着生成的随机数组将是一个 100 行 1 列的二维数组
# X = np.random.rand(100, 1)
#
#
# # numpy.random.randint(low, high=None, size=None, dtype=int)
# # 生成一个形状为 (3, 2) 的二维数组,其中的随机整数范围在 [1, 10)
# # random_array = np.random.randint(1, 10, size=(3, 2))
# w, b = np.random.randint(1, 10, size=2)
#
# y = w * X + b + np.random.rand(100, 1)
#
# plt.scatter(X, y)
# # plt.show()
#
#
# # 在 X 数组的右侧添加一列全为 1 的列,
# # 通常是为了引入一个截距项,用于线性回归等模型中。
# # 这个操作使得 X 变成一个包含两列的数组,其中一列是原始的随机数据,另一列是全为 1 的列。
# # 这对于线性模型的训练很有用,因为它允许模型在学习时调整截距。
# # 在连接之前对X进行转置
# X = np.concatenate([X, np.full(shape=(100, 1), fill_value=1)], axis=1)
#
# # 循环次数
# epochs = 1000
#
# # 学习率
# eta = 0.01
#
# t0 = 5
# t1 = 1000
#
#
# # 定义学习率调度函数
# def learn_rate_shedule(t):
#     return t0 / (t + t1)
#
#
# # 求解的系数
# theta = np.random.rand(2, 1)
#
# t = 0
#
# for i in range(epochs):
#     g = X.T.dot(X.dot(theta) - y)  # 计算梯度
#     theta = theta - eta * g  # 更新系数
#     eta = learn_rate_shedule(t)  # 更新学习率
#     t = t + 1
#
# print('真实w, b ', w, b)
# print('预测w, b', theta)


# 随机梯度下降
# 生成随机数据
X = np.random.rand(100, 1)
w, b = np.random.randint(1, 10, size=2)
y = w * X + b + np.random.rand(100, 1)

# 添加截距项
X = np.concatenate([X, np.full_like(X, fill_value=1)], axis=1)

# 超参数设置
epochs = 1000
t0, t1 = 5, 1000

# 学习率调度函数
def learning_rate_schedule(t):
    return t0 / (t + t1)

# 初始化权重
theta = np.random.randn(2, 1)

count = 0
for epoch in range(epochs):
    # 随机打乱数据
    index = np.arange(100)
    np.random.shuffle(index)

    X = X[index]
    y = y[index]

    for i in range(100):
        X_i = X[[i]]
        y_i = y[[i]]

        # 计算梯度
        g = X_i.T.dot(X_i.dot(theta) - y_i)

        # 更新学习率
        eta = learning_rate_schedule(count)
        count += 1

        # 更新权重
        theta = theta - eta * g

# 打印结果
print('真实 w, b:', w, b)
print('预测 w, b:', theta)


# 小批量梯度下降

文章来源:https://blog.csdn.net/zzqtty/article/details/134968376
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。