[Python] 如何把scikit-learn的线性回归模型导出为onnx格式,并使用onnx模型文件进行预测
2024-01-09 11:26:19
什么是Scikit-learn?
Scikit-learn是一个用于Python编程语言的机器学习库。它提供了各种监督和无监督学习算法,包括分类、回归、聚类、降维等。Scikit-learn易于使用且功能强大,可以处理大型数据集,并且具有很好的可扩展性。它还提供了许多方便的工具,如数据预处理、模型选择、评估和可视化等。Scikit-learn是许多机器学习项目中使用的首选库之一。
什么是ONNX?
ONNX(Open Neural Network Exchange)是一个开放的生态系统,旨在使不同的深度学习框架之间能够互操作。它定义了一个通用的模型表示格式,使得在不同的深度学习框架之间进行模型转换和部署变得更加容易。ONNX模型可以被多种工具和平台所支持,包括ONNX Runtime、TensorFlow、PyTorch、Caffe2等。通过使用ONNX,开发者可以轻松地将一个深度学习模型转换为另一个框架所需的格式,从而实现模型的重用和加速。
什么是ONNX Runtime?
ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exchange (ONNX) models. For more information on ONNX Runtime, please see?aka.ms/onnxruntime?or the?Github project.
什么是skl2onnx库?
sklearn-onnx 1.16.0 documentation?
?
安装onnx onnxruntime skl2onnx库
pip install onnx onnxruntime skl2onnx
scikit-learn的线性回归模型导出为onnx格式
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from skl2onnx import to_onnx
# 创建数据集
np.random.seed(0)
x = np.random.rand(100, 1)
x = x.astype(np.float32)
y = 2 + 3 * x + np.random.rand(100, 1)
y = y.astype(np.float32)
# 将数据集分为训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)
# 训练模型
model = LinearRegression()
model.fit(x_train, y_train)
# 保存模型到文件
onx = to_onnx(model, x[:1])
with open("linear_regression_model.onnx", "wb") as f:
f.write(onx.SerializeToString())
?使用onnxruntime来加载onnx模型进行预测
import onnxruntime as rt
from sklearn.metrics import mean_squared_error, r2_score
# 从文件中加载模型
sess = rt.InferenceSession("linear_regression_model.onnx", providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
print('input_name:', input_name)
print('label_name:', label_name)
# 使用加载的模型进行预测
# # 当前模型只有一个输入和一个输出,所以我们只需要通过[0]获取第一个输出,即为预测值
y_pred = sess.run([label_name], {input_name: x_test.astype(np.float32)})[0]
print(y_pred)
# 评估模型的性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print('均方误差:', mse)
print('R2分数:', r2)
输出结果:?
文章来源:https://blog.csdn.net/u011775793/article/details/135443392
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!