一、概述
机器学习模型的训练通常在Python环境下完成,而现实生产环境的复杂性和多样性使得模型的部署成为一个值得关注的重点。不同应用场景下有不同适应的实现方式,这里主要介绍通过一种通用中间格式——ONNX(Open Neural Network Exchange),来实现机器学习模型在C++平台的部署。
二、步骤
s1. Python环境中安装onnxruntime、skl2onnx工具模块;
s2. Python环境中训练机器学习模型;
s3. 将训练好的模型保存为.onnx格式的模型文件;
s4. C++环境中安装Microsoft.ML.OnnxRuntime程序包;
(Visual Studio 2022中可通过项目->管理NuGet程序包完成快捷安装)
S5. C++环境中加载模型文件,完成功能开发。
三、示例
使用 Python 训练一个线性回归模型并将其导出为 ONNX 格式的文件,在C++环境下完成对模型的部署和推理。
1.Python训练和导出
(环境:Python 3.11,scikit-learn 1.6.1,onnxruntime 1.22.0,skl2onnx 1.19.1)- import numpy as np
- import onnxruntime as ort
- from sklearn.datasets import make_regression
- from sklearn.linear_model import LinearRegression
- from skl2onnx import convert_sklearn
- from skl2onnx.common.data_types import FloatTensorType
- # 生成示例数据
- X, y = make_regression(n_samples=100, n_features=5, random_state=42)
- # 训练线性回归模型
- model = LinearRegression()
- model.fit(X, y)
- # 定义输入格式
- initial_type = [('input', FloatTensorType([None, 5]))]
- # 转换模型为 ONNX 格式
- onnx_model = convert_sklearn(model, initial_types=initial_type)
- # 保存 ONNX 模型
- with open("linear_regression.onnx", "wb") as f:
- f.write(onnx_model.SerializeToString())
- print("\n模型已保存为: linear_regression.onnx\n")
- # 测试导出的模型
- ort_session = ort.InferenceSession("linear_regression.onnx")
- input_name = ort_session.get_inputs()[0].name
- output_name = ort_session.get_outputs()[0].name
- # 创建一个测试样本
- test_input = np.array([0.1, 0.2, 0.3, 0.4, 0.5]).reshape(1,5).astype(np.float32)
- # 运行推理
- results = ort_session.run([output_name], {input_name: test_input})
- print(f"测试输入: {test_input}")
- print(f"预测结果: {results[0]}")
复制代码
2. C++ 部署和推理
(环境:C++ 14,Microsoft.ML.OnnxRuntime 1.22.0)
[code]#include #include #include #include #include int main() { // 初始化环境 Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXExample"); // 初始化会话选项 Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); // 加载模型 std::wstring model_path = L"linear_regression.onnx"; Ort::Session session(env, model_path.c_str(), session_options); // 获取输入信息 Ort::AllocatorWithDefaultOptions allocator; size_t num_inputs = session.GetInputCount(); size_t num_outputs = session.GetOutputCount(); // 假设只有一个输入和一个输出 if (num_inputs != 1 || num_outputs != 1) { std::cerr |