LDMX Software
ONNXRuntime.h
1
2#ifndef TOOLS_ONNXRUNTIME_H
3#define TOOLS_ONNXRUNTIME_H
4
5#include <algorithm>
6#include <cassert>
7#include <exception>
8#include <functional>
9#include <iostream>
10#include <map>
11#include <memory>
12#include <numeric>
13#include <string>
14#include <vector>
15
16#include "onnxruntime_cxx_api.h"
17
18namespace ldmx {
19namespace ort {
20
21typedef std::vector<std::vector<float>> FloatArrays;
22
28 public:
35 ONNXRuntime(const std::string& model_path,
36 const ::Ort::SessionOptions* session_options = nullptr);
37 ONNXRuntime(const ONNXRuntime&) = delete;
38 ONNXRuntime& operator=(const ONNXRuntime&) = delete;
39 ~ONNXRuntime() = default;
40
54 FloatArrays run(const std::vector<std::string>& input_names,
55 FloatArrays& input_values,
56 const std::vector<std::string>& output_names = {},
57 int64_t batch_size = 1) const;
58
63 const std::vector<std::string>& getOutputNames() const;
64
71 const std::vector<int64_t>& getOutputShape(
72 const std::string& output_name) const;
73
74 private:
75 static ::Ort::Env env;
76 std::unique_ptr<::Ort::Session> session_;
77
78 std::vector<std::string> input_node_strings_;
79 std::vector<const char*> input_node_names_;
80 std::map<std::string, std::vector<int64_t>> input_node_dims_;
81
82 std::vector<std::string> output_node_strings_;
83 std::vector<const char*> output_node_names_;
84 std::map<std::string, std::vector<int64_t>> output_node_dims_;
85};
86
87} // namespace ort
88} // namespace ldmx
89
90#endif /* TOOLS_ONNXRUNTIME_H_ */
A convenience wrapper of the ONNXRuntime C++ API.
Definition ONNXRuntime.h:27
const std::vector< int64_t > & getOutputShape(const std::string &output_name) const
Get the shape of a output node.
ONNXRuntime(const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
Class constructor.
FloatArrays run(const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
Run model inference and get outputs.
const std::vector< std::string > & getOutputNames() const
Get the names of all the output nodes.