PyTorch 基础篇(2):线性回归(Linear Regression)

2023-12-13 16:29:40

  
  
  1. # 包
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  
  
  1. # 超参数设置
  2. input_size = 1
  3. output_size = 1
  4. num_epochs = 60
  5. learning_rate = 0.001
  6. ?
  7. # Toy dataset
  8. # 玩具资料:小数据集
  9. x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
  10. [9.779], [6.182], [7.59], [2.167], [7.042],
  11. [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
  12. ?
  13. y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
  14. [3.366], [2.596], [2.53], [1.221], [2.827],
  15. [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
  16. ?
  17. # 线性回归模型
  18. model = nn.Linear(input_size, output_size)
  19. ?
  20. # 损失函数和优化器
  21. criterion = nn.MSELoss()
  22. optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
  
  
  1. # 训练模型
  2. for epoch in range(num_epochs):
  3. # 将Numpy数组转换为torch张量
  4. inputs = torch.from_numpy(x_train)
  5. targets = torch.from_numpy(y_train)
  6. ?
  7. # 前向传播
  8. outputs = model(inputs)
  9. loss = criterion(outputs, targets)
  10. # 反向传播和优化
  11. optimizer.zero_grad()
  12. loss.backward()
  13. optimizer.step()
  14. if (epoch 1) % 5 == 0:
  15. print (‘Epoch [{}/{}], Loss: {:.4f}’.format(epoch 1, num_epochs, loss.item()))
  
  
  1. Epoch [5/60], Loss: 7.7737
  2. Epoch [10/60], Loss: 3.2548
  3. Epoch [15/60], Loss: 1.4241
  4. Epoch [20/60], Loss: 0.6824
  5. Epoch [25/60], Loss: 0.3820
  6. Epoch [30/60], Loss: 0.2602
  7. Epoch [35/60], Loss: 0.2109
  8. Epoch [40/60], Loss: 0.1909
  9. Epoch [45/60], Loss: 0.1828
  10. Epoch [50/60], Loss: 0.1795
  11. Epoch [55/60], Loss: 0.1781
  12. Epoch [60/60], Loss: 0.1776
  
  
  1. # 绘制图形
  2. # torch.from_numpy(x_train)将X_train转换为Tensor
  3. # model()根据输入和模型,得到输出
  4. # detach().numpy()预测结结果转换为numpy数组
  5. predicted = model(torch.from_numpy(x_train)).detach().numpy()
  6. plt.plot(x_train, y_train, ‘ro’, label=‘Original data’)
  7. plt.plot(x_train, predicted, label=‘Fitted line’)
  8. plt.legend()
  9. plt.show()

文章来源:https://blog.csdn.net/qq_15719613/article/details/134830515
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。