LDMX Software
ONNXRuntime.cxx
1
2#include "Tools/ONNXRuntime.h"
3
4#include <algorithm>
5#include <cassert>
6#include <exception>
7#include <functional>
8#include <iostream>
9#include <numeric>
10
11namespace ldmx::Ort {
12
13using namespace ::Ort;
14
15#if ORT_API_VERSION == 2
16// version used when first integrated onnx into ldmx-sw
17// and version downloaded by cmake infrastructure
18// only support x86_64 architectures
19std::string get_input_name(std::unique_ptr<Session>& s, size_t i,
20 AllocatorWithDefaultOptions a) {
21 return s->GetInputName(i, a);
22}
23std::string get_output_name(std::unique_ptr<Session>& s, size_t i,
24 AllocatorWithDefaultOptions a) {
25 return s->GetOutputName(i, a);
26}
27#else
28// latest version with prebuilds for both x86_64 and arm64
29// architectures but contains a slight API change
30std::string get_input_name(std::unique_ptr<Session>& s, size_t i,
31 AllocatorWithDefaultOptions a) {
32 return s->GetInputNameAllocated(i, a).get();
33}
34std::string get_output_name(std::unique_ptr<Session>& s, size_t i,
35 AllocatorWithDefaultOptions a) {
36 return s->GetOutputNameAllocated(i, a).get();
37}
38#if ORT_API_VERSION != 15
39#pragma warning( \
40 "Untested ONNX version, not certain of API, assuming API version 15.")
41#endif
42#endif
43
44Env ONNXRuntime::env_(ORT_LOGGING_LEVEL_WARNING, "");
45
46ONNXRuntime::ONNXRuntime(const std::string& model_path,
47 const SessionOptions* session_options) {
48 // create session
49 if (session_options) {
50 session_.reset(new Session(env_, model_path.c_str(), *session_options));
51 } else {
52 SessionOptions sess_opts;
53 sess_opts.SetIntraOpNumThreads(1);
54 session_.reset(new Session(env_, model_path.c_str(), sess_opts));
55 }
56 AllocatorWithDefaultOptions allocator;
57
58 // get input names and shapes
59 size_t num_input_nodes = session_->GetInputCount();
60 input_node_strings_.resize(num_input_nodes);
61 input_node_names_.resize(num_input_nodes);
62 input_node_dims_.clear();
63
64 for (size_t i = 0; i < num_input_nodes; i++) {
65 // get input node names
66 std::string input_name(get_input_name(session_, i, allocator));
67 input_node_strings_[i] = input_name;
68 input_node_names_[i] = input_node_strings_[i].c_str();
69
70 // get input shapes
71 auto type_info = session_->GetInputTypeInfo(i);
72 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
73 size_t num_dims = tensor_info.GetDimensionsCount();
74 input_node_dims_[input_name].resize(num_dims);
75 const auto input_shape = tensor_info.GetShape();
76 std::copy(input_shape.begin(), input_shape.end(),
77 input_node_dims_[input_name].begin());
78
79 // set the batch size to 1 by default
80 input_node_dims_[input_name].at(0) = 1;
81 }
82
83 size_t num_output_nodes = session_->GetOutputCount();
84 output_node_strings_.resize(num_output_nodes);
85 output_node_names_.resize(num_output_nodes);
86 output_node_dims_.clear();
87
88 for (size_t i = 0; i < num_output_nodes; i++) {
89 // get output node names
90 std::string output_name(get_output_name(session_, i, allocator));
91 output_node_strings_[i] = output_name;
92 output_node_names_[i] = output_node_strings_[i].c_str();
93
94 // get output node types
95 auto type_info = session_->GetOutputTypeInfo(i);
96 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
97 size_t num_dims = tensor_info.GetDimensionsCount();
98 output_node_dims_[output_name].resize(num_dims);
99 const auto output_shape = tensor_info.GetShape();
100 std::copy(output_shape.begin(), output_shape.end(),
101 output_node_dims_[output_name].begin());
102
103 // the 0th dim depends on the batch size
104 output_node_dims_[output_name].at(0) = -1;
105 }
106}
107
108ONNXRuntime::~ONNXRuntime() {}
109
110FloatArrays ONNXRuntime::run(const std::vector<std::string>& input_names,
111 FloatArrays& input_values,
112 const std::vector<std::string>& output_names,
113 int64_t batch_size) const {
114 assert(input_names.size() == input_values.size());
115 assert(batch_size > 0);
116
117 // create input tensor objects from data values
118 std::vector<Value> input_tensors;
119 auto memory_info =
120 MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
121 for (const auto& name : input_node_strings_) {
122 auto iter = std::find(input_names.begin(), input_names.end(), name);
123 if (iter == input_names.end()) {
124 throw std::runtime_error("Input " + name + " is not provided!");
125 }
126 auto value = input_values.begin() + (iter - input_names.begin());
127 auto input_dims = input_node_dims_.at(name);
128 input_dims[0] = batch_size;
129 auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1,
130 std::multiplies<int64_t>());
131 if (expected_len != (int64_t)value->size()) {
132 throw std::runtime_error("Input array " + name + " has a wrong size of " +
133 std::to_string(value->size()) + ", expected " +
134 std::to_string(expected_len));
135 }
136 auto input_tensor =
137 Value::CreateTensor<float>(memory_info, value->data(), value->size(),
138 input_dims.data(), input_dims.size());
139 assert(input_tensor.IsTensor());
140 input_tensors.emplace_back(std::move(input_tensor));
141 }
142
143 // set output node names; will get all outputs if `output_names` is not
144 // provided
145 std::vector<const char*> run_output_node_names;
146 if (output_names.empty()) {
147 run_output_node_names = output_node_names_;
148 } else {
149 for (const auto& name : output_names) {
150 run_output_node_names.push_back(name.c_str());
151 }
152 }
153
154 // run
155 auto output_tensors =
156 session_->Run(RunOptions{nullptr}, input_node_names_.data(),
157 input_tensors.data(), input_tensors.size(),
158 run_output_node_names.data(), run_output_node_names.size());
159
160 // convert output to floats
161 FloatArrays outputs;
162 for (auto& output_tensor : output_tensors) {
163 assert(output_tensor.IsTensor());
164
165 // get output shape
166 auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
167 auto length = tensor_info.GetElementCount();
168
169 auto floatarr = output_tensor.GetTensorMutableData<float>();
170 outputs.emplace_back(floatarr, floatarr + length);
171 }
172 assert(outputs.size() == run_output_node_names.size());
173
174 return outputs;
175}
176
177const std::vector<std::string>& ONNXRuntime::getOutputNames() const {
178 if (session_) {
179 return output_node_strings_;
180 } else {
181 throw std::runtime_error("ONNXRuntime session is not initialized!");
182 }
183}
184
185const std::vector<int64_t>& ONNXRuntime::getOutputShape(
186 const std::string& output_name) const {
187 auto iter = output_node_dims_.find(output_name);
188 if (iter == output_node_dims_.end()) {
189 throw std::runtime_error("Output name " + output_name + " is invalid!");
190 } else {
191 return iter->second;
192 }
193}
194
195} /* namespace ldmx::Ort */
FloatArrays run(const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
Run model inference and get outputs.
ONNXRuntime(const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
Class constructor.
const std::vector< int64_t > & getOutputShape(const std::string &output_name) const
Get the shape of a output node.
const std::vector< std::string > & getOutputNames() const
Get the names of all the output nodes.