LDMX Software
Public Member Functions | Private Attributes | Static Private Attributes | List of all members
ldmx::Ort::ONNXRuntime Class Reference

A convenience wrapper of the ONNXRuntime C++ API. More...

#include <ONNXRuntime.h>

Public Member Functions

 ONNXRuntime (const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
 Class constructor.
 
 ONNXRuntime (const ONNXRuntime &)=delete
 
ONNXRuntimeoperator= (const ONNXRuntime &)=delete
 
FloatArrays run (const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
 Run model inference and get outputs.
 
const std::vector< std::string > & getOutputNames () const
 Get the names of all the output nodes.
 
const std::vector< int64_t > & getOutputShape (const std::string &output_name) const
 Get the shape of a output node.
 

Private Attributes

std::unique_ptr<::Ort::Session > session_
 
std::vector< std::string > input_node_strings_
 
std::vector< const char * > input_node_names_
 
std::map< std::string, std::vector< int64_t > > input_node_dims_
 
std::vector< std::string > output_node_strings_
 
std::vector< const char * > output_node_names_
 
std::map< std::string, std::vector< int64_t > > output_node_dims_
 

Static Private Attributes

::Ort::Env env_
 

Detailed Description

A convenience wrapper of the ONNXRuntime C++ API.

Definition at line 20 of file ONNXRuntime.h.

Constructor & Destructor Documentation

◆ ONNXRuntime()

ldmx::Ort::ONNXRuntime::ONNXRuntime ( const std::string &  model_path,
const ::Ort::SessionOptions *  session_options = nullptr 
)

Class constructor.

Parameters
model_pathPath to the ONNX model file.
session_optionsConfiguration options of the ONNXRuntime Session. Leave empty to use the default.

◆ ~ONNXRuntime()

ldmx::Ort::ONNXRuntime::~ONNXRuntime ( )

Definition at line 108 of file ONNXRuntime.cxx.

108{}

Member Function Documentation

◆ getOutputNames()

const std::vector< std::string > & ldmx::Ort::ONNXRuntime::getOutputNames ( ) const

Get the names of all the output nodes.

Returns
A list of names of all the output nodes.

Definition at line 177 of file ONNXRuntime.cxx.

177 {
178 if (session_) {
179 return output_node_strings_;
180 } else {
181 throw std::runtime_error("ONNXRuntime session is not initialized!");
182 }
183}

◆ getOutputShape()

const std::vector< int64_t > & ldmx::Ort::ONNXRuntime::getOutputShape ( const std::string &  output_name) const

Get the shape of a output node.

The 0th dim depends on the batch size, therefore is set to -1.

Parameters
output_nameName of the output node.
Returns
The shape of the output node as a vector of integers.

Definition at line 185 of file ONNXRuntime.cxx.

186 {
187 auto iter = output_node_dims_.find(output_name);
188 if (iter == output_node_dims_.end()) {
189 throw std::runtime_error("Output name " + output_name + " is invalid!");
190 } else {
191 return iter->second;
192 }
193}

◆ run()

FloatArrays ldmx::Ort::ONNXRuntime::run ( const std::vector< std::string > &  input_names,
FloatArrays &  input_values,
const std::vector< std::string > &  output_names = {},
int64_t  batch_size = 1 
) const

Run model inference and get outputs.

Parameters
input_namesList of the names of the input nodes.
input_valuesList of input arrays for each input node. The order of input_values must match input_names.
output_namesNames of the output nodes to get outputs from. Empty list means all output nodes.
batch_sizeNumber of samples in the batch. Each array in input_values must have a shape layout of (batch_size, ...).
Returns
A std::vector<std::vector<float>>, with the order matched to output_names. When output_names is empty, will return all outputs ordered as in getOutputNames().

Definition at line 110 of file ONNXRuntime.cxx.

113 {
114 assert(input_names.size() == input_values.size());
115 assert(batch_size > 0);
116
117 // create input tensor objects from data values
118 std::vector<Value> input_tensors;
119 auto memory_info =
120 MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
121 for (const auto& name : input_node_strings_) {
122 auto iter = std::find(input_names.begin(), input_names.end(), name);
123 if (iter == input_names.end()) {
124 throw std::runtime_error("Input " + name + " is not provided!");
125 }
126 auto value = input_values.begin() + (iter - input_names.begin());
127 auto input_dims = input_node_dims_.at(name);
128 input_dims[0] = batch_size;
129 auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1,
130 std::multiplies<int64_t>());
131 if (expected_len != (int64_t)value->size()) {
132 throw std::runtime_error("Input array " + name + " has a wrong size of " +
133 std::to_string(value->size()) + ", expected " +
134 std::to_string(expected_len));
135 }
136 auto input_tensor =
137 Value::CreateTensor<float>(memory_info, value->data(), value->size(),
138 input_dims.data(), input_dims.size());
139 assert(input_tensor.IsTensor());
140 input_tensors.emplace_back(std::move(input_tensor));
141 }
142
143 // set output node names; will get all outputs if `output_names` is not
144 // provided
145 std::vector<const char*> run_output_node_names;
146 if (output_names.empty()) {
147 run_output_node_names = output_node_names_;
148 } else {
149 for (const auto& name : output_names) {
150 run_output_node_names.push_back(name.c_str());
151 }
152 }
153
154 // run
155 auto output_tensors =
156 session_->Run(RunOptions{nullptr}, input_node_names_.data(),
157 input_tensors.data(), input_tensors.size(),
158 run_output_node_names.data(), run_output_node_names.size());
159
160 // convert output to floats
161 FloatArrays outputs;
162 for (auto& output_tensor : output_tensors) {
163 assert(output_tensor.IsTensor());
164
165 // get output shape
166 auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
167 auto length = tensor_info.GetElementCount();
168
169 auto floatarr = output_tensor.GetTensorMutableData<float>();
170 outputs.emplace_back(floatarr, floatarr + length);
171 }
172 assert(outputs.size() == run_output_node_names.size());
173
174 return outputs;
175}

Member Data Documentation

◆ env_

Env ldmx::Ort::ONNXRuntime::env_
staticprivate

Definition at line 68 of file ONNXRuntime.h.

◆ input_node_dims_

std::map<std::string, std::vector<int64_t> > ldmx::Ort::ONNXRuntime::input_node_dims_
private

Definition at line 73 of file ONNXRuntime.h.

◆ input_node_names_

std::vector<const char*> ldmx::Ort::ONNXRuntime::input_node_names_
private

Definition at line 72 of file ONNXRuntime.h.

◆ input_node_strings_

std::vector<std::string> ldmx::Ort::ONNXRuntime::input_node_strings_
private

Definition at line 71 of file ONNXRuntime.h.

◆ output_node_dims_

std::map<std::string, std::vector<int64_t> > ldmx::Ort::ONNXRuntime::output_node_dims_
private

Definition at line 77 of file ONNXRuntime.h.

◆ output_node_names_

std::vector<const char*> ldmx::Ort::ONNXRuntime::output_node_names_
private

Definition at line 76 of file ONNXRuntime.h.

◆ output_node_strings_

std::vector<std::string> ldmx::Ort::ONNXRuntime::output_node_strings_
private

Definition at line 75 of file ONNXRuntime.h.

◆ session_

std::unique_ptr<::Ort::Session> ldmx::Ort::ONNXRuntime::session_
private

Definition at line 69 of file ONNXRuntime.h.


The documentation for this class was generated from the following files: