Python使用训练数据拟合模型
# 假设“满意度”是因变量,其他的是自变量
# 提取自变量(特征)和因变量(目标)
X = df_filtered_cleaned[['Bonus', 'Enhancement', 'Time_in_seconds']]
y = df_filtered_cleaned['Satisfaction']
# 分割为训练集和测试集。
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(X_train) + '-' + (X_test)
class MultivariateLinearRegression:
? ? def __init__(self):
? ? ? ? self.coefficients = None
? ? ? ? self.intercept = None
? ? def fit(self, X, y):
? ? ? ? """
? ? ? ? 使用普通最小二乘法的闭式解来拟合模型。
? ? ? ? X:形状为(n_samples,n_features)的numpy数组
? ? ? ? y:形状为(n_samples,)的numpy数组? ? ? ? """
? ? ? ? # 在输入特征中添加偏置列
? ? ? ? X_b = np.c_[np.ones((X.shape[0], 1)), X]
? ? ? ? # 使用正规方程计算最佳参数
? ? ? ? theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
? ? ? ? self.intercept = theta_best[0]
? ? ? ? self.coefficients = theta_best[1:]
? ? def predict(self, X):
? ? ? ? """
? ? ? ? 使用训练好的模型进行预测。
? ? ? ? X: numpy数组的形状(n_samples, n_features)
? ? ? ? """
? ? ? ? # 在输入特征中添加偏置列
? ? ? ? X_b = np.c_[np.ones((X.shape[0], 1)), X]
? ? ? ? # 计算预测
? ? ? ? return X_b.dot(np.r_[self.intercept, self.coefficients])
# 初始化线性回归模型
model_custom = MultivariateLinearRegression()
# 使用训练数据拟合模型
model_custom.fit(X_train, y_train)
# 预测测试集的满意度分数
y_pred_custom = model_custom.predict(X_test)
# 计算预测结果的均方误差(MSE)
mse_custom = mean_squared_error(y_test, y_pred_custom)
# 输出模型的系数和均方误差
coefficients_custom = model_custom.coefficients
intercept_custom = model_custom.intercept
# 格式化模型的方程式
model_equation = f"Satisfaction = {intercept_custom:.4f}"
for i, coef in enumerate(coefficients_custom):
? ? model_equation += f" + ({coef:.4f}) * X{i+1}"
print( mse_custom)
# 由于我们处理的是多个特征,无法绘制一条单独的线。
# 然而,我们可以绘制真实值与预测值之间的图表来观察模型性能。
# 对整个数据集进行满意度评分预测,并将其与真实值进行比较。
y_pred_entire_dataset = model_custom.predict(X)
print(y_pred_entire_dataset)
plt.figure(figsize=(10, 6))
plt.scatter(y, y_pred_entire_dataset, alpha=0.5)
plt.plot([min(y), max(y)], [min(y), max(y)], 'r--')
plt.title('True vs Predicted Satisfaction')
plt.xlabel('True Values')
plt.ylabel('Predicted Values')
plt.show()
plt.savefig('test5.jpg')
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!