在C++项目中集成机器学习能力时,TensorFlow和PyTorch都提供了官方C++ API,支持模型加载、推理、甚至简单的模型训练操作,适合对性能要求较高的场景。两种框架的C++接口设计思路不同,使用方式也存在差异,下面分别介绍具体的使用方法。

TensorFlow C++ API 使用流程
环境配置
首先需要下载TensorFlow C++的预编译库,或者从源码编译得到对应的头文件和动态库。编译时需要确保系统安装好Bazel构建工具,配置好C++编译环境,将TensorFlow的头文件路径和库文件路径添加到项目的编译配置中。
基础使用示例
以下代码演示了加载已保存的TensorFlow模型并执行推理的过程:
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/platform/env.h>
#include <iostream>
#include <vector>
int main() {
// 创建TensorFlow会话
tensorflow::Session* session;
tensorflow::Status status = tensorflow::NewSession(tensorflow::SessionOptions(), &session);
if (!status.ok()) {
std::cerr << "创建会话失败: " << status.ToString() << std::endl;
return -1;
}
// 加载已保存的模型文件
status = session->LoadSavedModel(
tensorflow::SessionOptions(),
tensorflow::RunOptions(),
"./saved_model", // 模型保存路径
{"serve"}, // 标签
&bundle
);
if (!status.ok()) {
std::cerr << "加载模型失败: " << status.ToString() << std::endl;
return -1;
}
// 准备输入数据,假设模型输入名为input,形状为[1,3]
tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 3}));
auto input_flat = input_tensor.flat<float>();
input_flat(0) = 1.0;
input_flat(1) = 2.0;
input_flat(2) = 3.0;
std::vector<std::pair<std::string, tensorflow::Tensor>> inputs = {
{"input", input_tensor}
};
std::vector<std::string> output_tensor_names = {"output"};
std::vector<tensorflow::Tensor> outputs;
// 执行推理
status = session->Run(inputs, output_tensor_names, {}, &outputs);
if (!status.ok()) {
std::cerr << "推理失败: " << status.ToString() << std::endl;
return -1;
}
// 输出结果
auto output_flat = outputs[0].flat<float>();
std::cout << "推理结果: " << output_flat(0) << std::endl;
// 关闭会话
session->Close();
delete session;
return 0;
}
PyTorch C++ API 使用流程
环境配置
PyTorch的C++ API被称为LibTorch,你可以从PyTorch官网下载对应系统的预编译LibTorch包,解压后得到头文件和库文件。项目中需要链接torch库、caffe2库等依赖,同时配置好C++17及以上的编译标准。
基础使用示例
以下代码演示了使用LibTorch加载模型并执行推理的过程:
#include <torch/script.h>
#include <iostream>
#include <vector>
int main() {
// 加载序列化的PyTorch模型
torch::jit::script::Module module;
try {
module = torch::jit::load("./model.pt"); // 模型保存路径
} catch (const c10::Error& e) {
std::cerr << "加载模型失败: " << e.what() << std::endl;
return -1;
}
// 准备输入数据,形状为[1,3]
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3}));
// 执行推理
auto output = module.forward(inputs).toTensor();
// 输出结果
std::cout << "推理结果: " << output << std::endl;
return 0;
}
两种API的对比与选择
| 对比维度 | TensorFlow C++ API | PyTorch C++ API(LibTorch) |
|---|---|---|
| 接口设计 | 偏向底层,配置项较多,灵活性高 | 风格接近PyTorch Python接口,易用性更好 |
| 模型兼容性 | 支持SavedModel格式,兼容性好 | 支持TorchScript序列化模型,需要提前将Python模型转成TorchScript |
| 适用场景 | 工业级部署、大规模分布式推理场景 | 快速原型开发、需要动态图特性的部署场景 |
注意事项
- 两种框架的C++ API都需要和编译时的版本严格匹配,否则可能出现运行时错误
- 模型输入的形状、数据类型需要和训练时的配置保持一致,否则推理结果会异常
- 如果需要在嵌入式设备上使用,建议选择轻量编译版本的库,减少资源占用
C++TensorFlow_C++_APIPyTorch_C++_API机器学习框架修改时间:2026-06-21 05:24:34