LDMX Software
ONNXRuntime.h
1
2#ifndef TOOLS_ONNXRUNTIME_H
3#define TOOLS_ONNXRUNTIME_H
4
5#include <map>
6#include <memory>
7#include <string>
8#include <vector>
9
10#include "onnxruntime_cxx_api.h"
11
12namespace ldmx::Ort {
13
14typedef std::vector<std::vector<float>> FloatArrays;
15
21 public:
28 ONNXRuntime(const std::string& model_path,
29 const ::Ort::SessionOptions* session_options = nullptr);
30 ONNXRuntime(const ONNXRuntime&) = delete;
31 ONNXRuntime& operator=(const ONNXRuntime&) = delete;
33
47 FloatArrays run(const std::vector<std::string>& input_names,
48 FloatArrays& input_values,
49 const std::vector<std::string>& output_names = {},
50 int64_t batch_size = 1) const;
51
56 const std::vector<std::string>& getOutputNames() const;
57
64 const std::vector<int64_t>& getOutputShape(
65 const std::string& output_name) const;
66
67 private:
68 static ::Ort::Env env_;
69 std::unique_ptr<::Ort::Session> session_;
70
71 std::vector<std::string> input_node_strings_;
72 std::vector<const char*> input_node_names_;
73 std::map<std::string, std::vector<int64_t>> input_node_dims_;
74
75 std::vector<std::string> output_node_strings_;
76 std::vector<const char*> output_node_names_;
77 std::map<std::string, std::vector<int64_t>> output_node_dims_;
78};
79
80} // namespace ldmx::Ort
81
82#endif /* TOOLS_ONNXRUNTIME_H_ */
A convenience wrapper of the ONNXRuntime C++ API.
Definition ONNXRuntime.h:20
FloatArrays run(const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
Run model inference and get outputs.
ONNXRuntime(const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
Class constructor.
const std::vector< int64_t > & getOutputShape(const std::string &output_name) const
Get the shape of a output node.
const std::vector< std::string > & getOutputNames() const
Get the names of all the output nodes.