LDMX Software
ONNXRuntime.cxx
1
2#include "Tools/ONNXRuntime.h"
3
4namespace ldmx {
5namespace ort {
6using namespace ::Ort;
7#if ORT_API_VERSION == 2
8// version used when first integrated onnx into ldmx-sw
9// and version downloaded by cmake infrastructure
10// only support x86_64 architectures
11std::string get_input_name(std::unique_ptr<Session>& s, size_t i,
12 AllocatorWithDefaultOptions a) {
13 return s->GetInputName(i, a);
14}
15std::string get_output_name(std::unique_ptr<Session>& s, size_t i,
16 AllocatorWithDefaultOptions a) {
17 return s->GetOutputName(i, a);
18}
19#else
20// latest version with prebuilds for both x86_64 and arm64
21// architectures but contains a slight API change
22std::string getInputName(std::unique_ptr<Session>& s, size_t i,
23 AllocatorWithDefaultOptions a) {
24 return s->GetInputNameAllocated(i, a).get();
25}
26std::string getOutputName(std::unique_ptr<Session>& s, size_t i,
27 AllocatorWithDefaultOptions a) {
28 return s->GetOutputNameAllocated(i, a).get();
29}
30#if ORT_API_VERSION != 15
31#pragma warning( \
32 "Untested ONNX version, not certain of API, assuming API version 15.")
33#endif
34#endif
35
36Env ONNXRuntime::env(ORT_LOGGING_LEVEL_WARNING, "");
37
38ONNXRuntime::ONNXRuntime(const std::string& model_path,
39 const SessionOptions* session_options) {
40 // create session
41 if (session_options) {
42 session_.reset(new Session(env, model_path.c_str(), *session_options));
43 } else {
44 SessionOptions sess_opts;
45 sess_opts.SetIntraOpNumThreads(1);
46 session_.reset(new Session(env, model_path.c_str(), sess_opts));
47 }
48 AllocatorWithDefaultOptions allocator;
49
50 // get input names and shapes
51 size_t num_input_nodes = session_->GetInputCount();
52 input_node_strings_.resize(num_input_nodes);
53 input_node_names_.resize(num_input_nodes);
54 input_node_dims_.clear();
55
56 for (size_t i = 0; i < num_input_nodes; i++) {
57 // get input node names
58 std::string input_name(getInputName(session_, i, allocator));
59 input_node_strings_[i] = input_name;
60 input_node_names_[i] = input_node_strings_[i].c_str();
61
62 // get input shapes
63 auto type_info = session_->GetInputTypeInfo(i);
64 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
65 size_t num_dims = tensor_info.GetDimensionsCount();
66 input_node_dims_[input_name].resize(num_dims);
67 const auto input_shape = tensor_info.GetShape();
68 std::copy(input_shape.begin(), input_shape.end(),
69 input_node_dims_[input_name].begin());
70
71 // set the batch size to 1 by default
72 input_node_dims_[input_name].at(0) = 1;
73 }
74
75 size_t num_output_nodes = session_->GetOutputCount();
76 output_node_strings_.resize(num_output_nodes);
77 output_node_names_.resize(num_output_nodes);
78 output_node_dims_.clear();
79
80 for (size_t i = 0; i < num_output_nodes; i++) {
81 // get output node names
82 std::string output_name(getOutputName(session_, i, allocator));
83 output_node_strings_[i] = output_name;
84 output_node_names_[i] = output_node_strings_[i].c_str();
85
86 // get output node types
87 auto type_info = session_->GetOutputTypeInfo(i);
88 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
89 size_t num_dims = tensor_info.GetDimensionsCount();
90 output_node_dims_[output_name].resize(num_dims);
91 const auto output_shape = tensor_info.GetShape();
92 std::copy(output_shape.begin(), output_shape.end(),
93 output_node_dims_[output_name].begin());
94
95 // the 0th dim depends on the batch size
96 output_node_dims_[output_name].at(0) = -1;
97 }
98}
99
100FloatArrays ONNXRuntime::run(const std::vector<std::string>& input_names,
101 FloatArrays& input_values,
102 const std::vector<std::string>& output_names,
103 int64_t batch_size) const {
104 assert(input_names.size() == input_values.size());
105 assert(batch_size > 0);
106
107 // create input tensor objects from data values
108 std::vector<Value> input_tensors;
109 auto memory_info =
110 MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
111 for (const auto& name : input_node_strings_) {
112 auto iter = std::find(input_names.begin(), input_names.end(), name);
113 if (iter == input_names.end()) {
114 throw std::runtime_error("Input '" + name + "' is not provided!");
115 }
116 auto value = input_values.begin() + (iter - input_names.begin());
117 auto input_dims = input_node_dims_.at(name);
118 input_dims[0] = batch_size;
119 auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1,
120 std::multiplies<int64_t>());
121 if (expected_len != (int64_t)value->size()) {
122 throw std::runtime_error("Input array '" + name +
123 "' has a wrong size of " +
124 std::to_string(value->size()) + ", expected " +
125 std::to_string(expected_len));
126 }
127 auto input_tensor =
128 Value::CreateTensor<float>(memory_info, value->data(), value->size(),
129 input_dims.data(), input_dims.size());
130 assert(input_tensor.IsTensor());
131 input_tensors.emplace_back(std::move(input_tensor));
132 }
133
134 // set output node names; will get all outputs if `output_names` is not
135 // provided
136 std::vector<const char*> run_output_node_names;
137 if (output_names.empty()) {
138 run_output_node_names = output_node_names_;
139 } else {
140 for (const auto& name : output_names) {
141 run_output_node_names.push_back(name.c_str());
142 }
143 }
144
145 // run
146 auto output_tensors =
147 session_->Run(RunOptions{nullptr}, input_node_names_.data(),
148 input_tensors.data(), input_tensors.size(),
149 run_output_node_names.data(), run_output_node_names.size());
150
151 // convert output to floats
152 FloatArrays outputs;
153 for (auto& output_tensor : output_tensors) {
154 assert(output_tensor.IsTensor());
155
156 // get output shape
157 auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
158 auto length = tensor_info.GetElementCount();
159
160 auto floatarr = output_tensor.GetTensorMutableData<float>();
161 outputs.emplace_back(floatarr, floatarr + length);
162 }
163 assert(outputs.size() == run_output_node_names.size());
164
165 return outputs;
166}
167
168const std::vector<std::string>& ONNXRuntime::getOutputNames() const {
169 if (session_) {
170 return output_node_strings_;
171 } else {
172 throw std::runtime_error("ONNXRuntime session is not initialized!");
173 }
174}
175
176const std::vector<int64_t>& ONNXRuntime::getOutputShape(
177 const std::string& output_name) const {
178 auto iter = output_node_dims_.find(output_name);
179 if (iter == output_node_dims_.end()) {
180 throw std::runtime_error("Output name '" + output_name + "' is invalid!");
181 } else {
182 return iter->second;
183 }
184}
185
186} // namespace ort
187} // namespace ldmx
const std::vector< int64_t > & getOutputShape(const std::string &output_name) const
Get the shape of a output node.
ONNXRuntime(const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
Class constructor.
FloatArrays run(const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
Run model inference and get outputs.
const std::vector< std::string > & getOutputNames() const
Get the names of all the output nodes.