Run model inference and get outputs.
103 {
104 assert(input_names.size() == input_values.size());
105 assert(batch_size > 0);
106
107
108 std::vector<Value> input_tensors;
109 auto memory_info =
110 MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
111 for (const auto& name : input_node_strings_) {
112 auto iter = std::find(input_names.begin(), input_names.end(), name);
113 if (iter == input_names.end()) {
114 throw std::runtime_error("Input '" + name + "' is not provided!");
115 }
116 auto value = input_values.begin() + (iter - input_names.begin());
117 auto input_dims = input_node_dims_.at(name);
118 if (input_dims.size() > 0) {
119 input_dims[0] = batch_size;
120 }
121 auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1,
122 std::multiplies<int64_t>());
123 if (expected_len != (int64_t)value->size()) {
124 throw std::runtime_error("Input array '" + name +
125 "' has a wrong size of " +
126 std::to_string(value->size()) + ", expected " +
127 std::to_string(expected_len));
128 }
129 auto input_tensor =
130 Value::CreateTensor<float>(memory_info, value->data(), value->size(),
131 input_dims.data(), input_dims.size());
132 assert(input_tensor.IsTensor());
133 input_tensors.emplace_back(std::move(input_tensor));
134 }
135
136
137
138 std::vector<const char*> run_output_node_names;
139 if (output_names.empty()) {
140 run_output_node_names = output_node_names_;
141 } else {
142 for (const auto& name : output_names) {
143 run_output_node_names.push_back(name.c_str());
144 }
145 }
146
147
148 auto output_tensors =
149 session_->Run(RunOptions{nullptr}, input_node_names_.data(),
150 input_tensors.data(), input_tensors.size(),
151 run_output_node_names.data(), run_output_node_names.size());
152
153
154 FloatArrays outputs;
155 for (auto& output_tensor : output_tensors) {
156 assert(output_tensor.IsTensor());
157
158
159 auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
160 auto length = tensor_info.GetElementCount();
161
162 auto floatarr = output_tensor.GetTensorMutableData<float>();
163 outputs.emplace_back(floatarr, floatarr + length);
164 }
165 assert(outputs.size() == run_output_node_names.size());
166
167 return outputs;
168}