LDMX Software
DNNEcalVetoProcessor.cxx
1#include "Ecal/DNNEcalVetoProcessor.h"
2
3// LDMX
4#include <algorithm>
5
6#include "Ecal/Event/EcalHit.h"
7
8namespace ecal {
9
10const std::vector<std::string> DNNEcalVetoProcessor::input_names_{"coordinates",
11 "features"};
12const std::vector<unsigned int> DNNEcalVetoProcessor::input_sizes_{
13 n_coordinate_dim_ * max_num_hits_, n_feature_dim_* max_num_hits_};
14
15DNNEcalVetoProcessor::DNNEcalVetoProcessor(const std::string& name,
16 framework::Process& process)
17 : Producer(name, process) {
18 for (const auto& s : input_sizes_) {
19 data_.emplace_back(s, 0);
20 }
21}
22
23void DNNEcalVetoProcessor::configure(
25 disc_cut_ = parameters.getParameter<double>("disc_cut");
26 rt_ = std::make_unique<ldmx::Ort::ONNXRuntime>(
27 parameters.getParameter<std::string>("model_path"));
28
29 // debug mode
30 debug_ = parameters.getParameter<bool>("debug");
31
32 // Set the collection name as defined in the configuration
33 collectionName_ = parameters.getParameter<std::string>("collection_name");
34}
35
36void DNNEcalVetoProcessor::produce(framework::Event& event) {
38
39 // Get the Ecal Geometry
40 const auto& ecal_geometry = getCondition<ldmx::EcalGeometry>(
41 ldmx::EcalGeometry::CONDITIONS_OBJECT_NAME);
42
43 // Get the collection of digitized Ecal hits from the event.
44 const auto ecalRecHits = event.getCollection<ldmx::EcalHit>("EcalRecHits");
45 auto nhits = std::count_if(
46 ecalRecHits.begin(), ecalRecHits.end(),
47 [](const ldmx::EcalHit& hit) { return hit.getEnergy() > 0; });
48
49 if (nhits < max_num_hits_) {
50 // make inputs
51 make_inputs(ecal_geometry, ecalRecHits);
52 // run the DNN
53 auto outputs = rt_->run(input_names_, data_)[0];
54 result.setDiscValue(outputs.at(1));
55 } else {
56 result.setDiscValue(-99);
57 }
58
59 if (debug_) {
60 std::cout << "... disc_val = " << result.getDisc() << std::endl;
61 }
62
63 result.setVetoResult(result.getDisc() > disc_cut_);
64
65 // If the event passes the veto, keep it. Otherwise, drop the event.
66 if (result.passesVeto()) {
67 setStorageHint(framework::hint_shouldKeep);
68 } else {
69 setStorageHint(framework::hint_shouldDrop);
70 }
71
72 event.add(collectionName_, result);
73}
74
75void DNNEcalVetoProcessor::make_inputs(
76 const ldmx::EcalGeometry& geom,
77 const std::vector<ldmx::EcalHit>& ecalRecHits) {
78 // clear data
79 for (auto& v : data_) {
80 std::fill(v.begin(), v.end(), 0);
81 }
82
83 unsigned idx = 0;
84 for (const auto& hit : ecalRecHits) {
85 if (hit.getEnergy() <= 0) continue;
86 ldmx::EcalID id(hit.getID());
87 auto [x, y, z] = geom.getPosition(id);
88
89 data_[0].at(coordinate_x_offset_ + idx) = x;
90 data_[0].at(coordinate_y_offset_ + idx) = y;
91 data_[0].at(coordinate_z_offset_ + idx) = z;
92
93 data_[1].at(feature_x_offset_ + idx) = x;
94 data_[1].at(feature_y_offset_ + idx) = y;
95 data_[1].at(feature_z_offset_ + idx) = z;
96 data_[1].at(feature_layerid_offset_ + idx) = id.layer();
97 data_[1].at(feature_energy_offset_ + idx) = std::log(hit.getEnergy());
98
99 ++idx;
100 }
101
102 if (debug_) {
103 for (unsigned iname = 0; iname < input_names_.size(); ++iname) {
104 std::cout << "=== " << input_names_[iname] << " ===" << std::endl;
105 for (unsigned i = 0; i < input_sizes_[iname]; ++i) {
106 std::cout << data_[iname].at(i) << ", ";
107 if ((i + 1) % max_num_hits_ == 0) {
108 std::cout << std::endl;
109 }
110 }
111 }
112 } // debug
113}
114
115} // namespace ecal
116
117DECLARE_PRODUCER_NS(ecal, DNNEcalVetoProcessor);
#define DECLARE_PRODUCER_NS(NS, CLASS)
Macro which allows the framework to construct a producer given its name during configuration.
Implements an event buffer system for storing event data.
Definition Event.h:41
Class which represents the process under execution.
Definition Process.h:36
Class encapsulating parameters for configuring a processor.
Definition Parameters.h:27
T getParameter(const std::string &name) const
Retrieve the parameter of the given name.
Definition Parameters.h:89
Translation between real-space positions and cell IDs within the ECal.
std::tuple< double, double, double > getPosition(EcalID id) const
Get a cell's position from its ID number.
Stores reconstructed hit information from the ECAL.
Definition EcalHit.h:19
Extension of DetectorID providing access to ECal layers and cell numbers in a hex grid.
Definition EcalID.h:20
bool passesVeto() const
Checks if the event passes the Ecal veto.
constexpr StorageControl::Hint hint_shouldKeep
storage control hint alias for backwards compatibility
constexpr StorageControl::Hint hint_shouldDrop
storage control hint alias for backwards compatibility