Run model inference and get outputs.
103 {
104 assert(input_names.size() == input_values.size());
105 assert(batch_size > 0);
106
107
108 std::vector<Value> input_tensors;
109 auto memory_info =
110 MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
111 for (const auto& name : input_node_strings_) {
112 auto iter = std::find(input_names.begin(), input_names.end(), name);
113 if (iter == input_names.end()) {
114 throw std::runtime_error("Input '" + name + "' is not provided!");
115 }
116 auto value = input_values.begin() + (iter - input_names.begin());
117 auto input_dims = input_node_dims_.at(name);
118 input_dims[0] = batch_size;
119 auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1,
120 std::multiplies<int64_t>());
121 if (expected_len != (int64_t)value->size()) {
122 throw std::runtime_error("Input array '" + name +
123 "' has a wrong size of " +
124 std::to_string(value->size()) + ", expected " +
125 std::to_string(expected_len));
126 }
127 auto input_tensor =
128 Value::CreateTensor<float>(memory_info, value->data(), value->size(),
129 input_dims.data(), input_dims.size());
130 assert(input_tensor.IsTensor());
131 input_tensors.emplace_back(std::move(input_tensor));
132 }
133
134
135
136 std::vector<const char*> run_output_node_names;
137 if (output_names.empty()) {
138 run_output_node_names = output_node_names_;
139 } else {
140 for (const auto& name : output_names) {
141 run_output_node_names.push_back(name.c_str());
142 }
143 }
144
145
146 auto output_tensors =
147 session_->Run(RunOptions{nullptr}, input_node_names_.data(),
148 input_tensors.data(), input_tensors.size(),
149 run_output_node_names.data(), run_output_node_names.size());
150
151
152 FloatArrays outputs;
153 for (auto& output_tensor : output_tensors) {
154 assert(output_tensor.IsTensor());
155
156
157 auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
158 auto length = tensor_info.GetElementCount();
159
160 auto floatarr = output_tensor.GetTensorMutableData<float>();
161 outputs.emplace_back(floatarr, floatarr + length);
162 }
163 assert(outputs.size() == run_output_node_names.size());
164
165 return outputs;
166}