Run model inference and get outputs.
113 {
114 assert(input_names.size() == input_values.size());
115 assert(batch_size > 0);
116
117
118 std::vector<Value> input_tensors;
119 auto memory_info =
120 MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
121 for (const auto& name : input_node_strings_) {
122 auto iter = std::find(input_names.begin(), input_names.end(), name);
123 if (iter == input_names.end()) {
124 throw std::runtime_error("Input " + name + " is not provided!");
125 }
126 auto value = input_values.begin() + (iter - input_names.begin());
127 auto input_dims = input_node_dims_.at(name);
128 input_dims[0] = batch_size;
129 auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1,
130 std::multiplies<int64_t>());
131 if (expected_len != (int64_t)value->size()) {
132 throw std::runtime_error("Input array " + name + " has a wrong size of " +
133 std::to_string(value->size()) + ", expected " +
134 std::to_string(expected_len));
135 }
136 auto input_tensor =
137 Value::CreateTensor<float>(memory_info, value->data(), value->size(),
138 input_dims.data(), input_dims.size());
139 assert(input_tensor.IsTensor());
140 input_tensors.emplace_back(std::move(input_tensor));
141 }
142
143
144
145 std::vector<const char*> run_output_node_names;
146 if (output_names.empty()) {
147 run_output_node_names = output_node_names_;
148 } else {
149 for (const auto& name : output_names) {
150 run_output_node_names.push_back(name.c_str());
151 }
152 }
153
154
155 auto output_tensors =
156 session_->Run(RunOptions{nullptr}, input_node_names_.data(),
157 input_tensors.data(), input_tensors.size(),
158 run_output_node_names.data(), run_output_node_names.size());
159
160
161 FloatArrays outputs;
162 for (auto& output_tensor : output_tensors) {
163 assert(output_tensor.IsTensor());
164
165
166 auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
167 auto length = tensor_info.GetElementCount();
168
169 auto floatarr = output_tensor.GetTensorMutableData<float>();
170 outputs.emplace_back(floatarr, floatarr + length);
171 }
172 assert(outputs.size() == run_output_node_names.size());
173
174 return outputs;
175}