1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle //
9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter.
10abfd1a8bSRiver Riddle //
11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
12abfd1a8bSRiver Riddle 
13abfd1a8bSRiver Riddle #include "ByteCode.h"
14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h"
15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17abfd1a8bSRiver Riddle #include "mlir/IR/Function.h"
18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h"
19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h"
20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h"
21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h"
23abfd1a8bSRiver Riddle 
24abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode"
25abfd1a8bSRiver Riddle 
26abfd1a8bSRiver Riddle using namespace mlir;
27abfd1a8bSRiver Riddle using namespace mlir::detail;
28abfd1a8bSRiver Riddle 
29abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
30abfd1a8bSRiver Riddle // PDLByteCodePattern
31abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
32abfd1a8bSRiver Riddle 
33abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
34abfd1a8bSRiver Riddle                                               ByteCodeAddr rewriterAddr) {
35abfd1a8bSRiver Riddle   SmallVector<StringRef, 8> generatedOps;
36abfd1a8bSRiver Riddle   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
37abfd1a8bSRiver Riddle     generatedOps =
38abfd1a8bSRiver Riddle         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
39abfd1a8bSRiver Riddle 
40abfd1a8bSRiver Riddle   PatternBenefit benefit = matchOp.benefit();
41abfd1a8bSRiver Riddle   MLIRContext *ctx = matchOp.getContext();
42abfd1a8bSRiver Riddle 
43abfd1a8bSRiver Riddle   // Check to see if this is pattern matches a specific operation type.
44abfd1a8bSRiver Riddle   if (Optional<StringRef> rootKind = matchOp.rootKind())
45abfd1a8bSRiver Riddle     return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
46abfd1a8bSRiver Riddle                               ctx);
47abfd1a8bSRiver Riddle   return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
48abfd1a8bSRiver Riddle                             MatchAnyOpTypeTag());
49abfd1a8bSRiver Riddle }
50abfd1a8bSRiver Riddle 
51abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
52abfd1a8bSRiver Riddle // PDLByteCodeMutableState
53abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
54abfd1a8bSRiver Riddle 
55abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
56abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by
57abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`.
58abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
59abfd1a8bSRiver Riddle                                                    PatternBenefit benefit) {
60abfd1a8bSRiver Riddle   currentPatternBenefits[patternIndex] = benefit;
61abfd1a8bSRiver Riddle }
62abfd1a8bSRiver Riddle 
63abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
64abfd1a8bSRiver Riddle // Bytecode OpCodes
65abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
66abfd1a8bSRiver Riddle 
67abfd1a8bSRiver Riddle namespace {
68abfd1a8bSRiver Riddle enum OpCode : ByteCodeField {
69abfd1a8bSRiver Riddle   /// Apply an externally registered constraint.
70abfd1a8bSRiver Riddle   ApplyConstraint,
71abfd1a8bSRiver Riddle   /// Apply an externally registered rewrite.
72abfd1a8bSRiver Riddle   ApplyRewrite,
73abfd1a8bSRiver Riddle   /// Check if two generic values are equal.
74abfd1a8bSRiver Riddle   AreEqual,
75abfd1a8bSRiver Riddle   /// Unconditional branch.
76abfd1a8bSRiver Riddle   Branch,
77abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a constant.
78abfd1a8bSRiver Riddle   CheckOperandCount,
79abfd1a8bSRiver Riddle   /// Compare the name of an operation with a constant.
80abfd1a8bSRiver Riddle   CheckOperationName,
81abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a constant.
82abfd1a8bSRiver Riddle   CheckResultCount,
83abfd1a8bSRiver Riddle   /// Invoke a native creation method.
84abfd1a8bSRiver Riddle   CreateNative,
85abfd1a8bSRiver Riddle   /// Create an operation.
86abfd1a8bSRiver Riddle   CreateOperation,
87abfd1a8bSRiver Riddle   /// Erase an operation.
88abfd1a8bSRiver Riddle   EraseOp,
89abfd1a8bSRiver Riddle   /// Terminate a matcher or rewrite sequence.
90abfd1a8bSRiver Riddle   Finalize,
91abfd1a8bSRiver Riddle   /// Get a specific attribute of an operation.
92abfd1a8bSRiver Riddle   GetAttribute,
93abfd1a8bSRiver Riddle   /// Get the type of an attribute.
94abfd1a8bSRiver Riddle   GetAttributeType,
95abfd1a8bSRiver Riddle   /// Get the defining operation of a value.
96abfd1a8bSRiver Riddle   GetDefiningOp,
97abfd1a8bSRiver Riddle   /// Get a specific operand of an operation.
98abfd1a8bSRiver Riddle   GetOperand0,
99abfd1a8bSRiver Riddle   GetOperand1,
100abfd1a8bSRiver Riddle   GetOperand2,
101abfd1a8bSRiver Riddle   GetOperand3,
102abfd1a8bSRiver Riddle   GetOperandN,
103abfd1a8bSRiver Riddle   /// Get a specific result of an operation.
104abfd1a8bSRiver Riddle   GetResult0,
105abfd1a8bSRiver Riddle   GetResult1,
106abfd1a8bSRiver Riddle   GetResult2,
107abfd1a8bSRiver Riddle   GetResult3,
108abfd1a8bSRiver Riddle   GetResultN,
109abfd1a8bSRiver Riddle   /// Get the type of a value.
110abfd1a8bSRiver Riddle   GetValueType,
111abfd1a8bSRiver Riddle   /// Check if a generic value is not null.
112abfd1a8bSRiver Riddle   IsNotNull,
113abfd1a8bSRiver Riddle   /// Record a successful pattern match.
114abfd1a8bSRiver Riddle   RecordMatch,
115abfd1a8bSRiver Riddle   /// Replace an operation.
116abfd1a8bSRiver Riddle   ReplaceOp,
117abfd1a8bSRiver Riddle   /// Compare an attribute with a set of constants.
118abfd1a8bSRiver Riddle   SwitchAttribute,
119abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a set of constants.
120abfd1a8bSRiver Riddle   SwitchOperandCount,
121abfd1a8bSRiver Riddle   /// Compare the name of an operation with a set of constants.
122abfd1a8bSRiver Riddle   SwitchOperationName,
123abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a set of constants.
124abfd1a8bSRiver Riddle   SwitchResultCount,
125abfd1a8bSRiver Riddle   /// Compare a type with a set of constants.
126abfd1a8bSRiver Riddle   SwitchType,
127abfd1a8bSRiver Riddle };
128abfd1a8bSRiver Riddle 
129abfd1a8bSRiver Riddle enum class PDLValueKind { Attribute, Operation, Type, Value };
130abfd1a8bSRiver Riddle } // end anonymous namespace
131abfd1a8bSRiver Riddle 
132abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
133abfd1a8bSRiver Riddle // ByteCode Generation
134abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
135abfd1a8bSRiver Riddle 
136abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
137abfd1a8bSRiver Riddle // Generator
138abfd1a8bSRiver Riddle 
139abfd1a8bSRiver Riddle namespace {
140abfd1a8bSRiver Riddle struct ByteCodeWriter;
141abfd1a8bSRiver Riddle 
142abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode.
143abfd1a8bSRiver Riddle class Generator {
144abfd1a8bSRiver Riddle public:
145abfd1a8bSRiver Riddle   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
146abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &matcherByteCode,
147abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
148abfd1a8bSRiver Riddle             SmallVectorImpl<PDLByteCodePattern> &patterns,
149abfd1a8bSRiver Riddle             ByteCodeField &maxValueMemoryIndex,
150abfd1a8bSRiver Riddle             llvm::StringMap<PDLConstraintFunction> &constraintFns,
151abfd1a8bSRiver Riddle             llvm::StringMap<PDLCreateFunction> &createFns,
152abfd1a8bSRiver Riddle             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
153abfd1a8bSRiver Riddle       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
154abfd1a8bSRiver Riddle         rewriterByteCode(rewriterByteCode), patterns(patterns),
155abfd1a8bSRiver Riddle         maxValueMemoryIndex(maxValueMemoryIndex) {
156abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(constraintFns))
157abfd1a8bSRiver Riddle       constraintToMemIndex.try_emplace(it.value().first(), it.index());
158abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(createFns))
159abfd1a8bSRiver Riddle       nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
160abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(rewriteFns))
161abfd1a8bSRiver Riddle       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
162abfd1a8bSRiver Riddle   }
163abfd1a8bSRiver Riddle 
164abfd1a8bSRiver Riddle   /// Generate the bytecode for the given PDL interpreter module.
165abfd1a8bSRiver Riddle   void generate(ModuleOp module);
166abfd1a8bSRiver Riddle 
167abfd1a8bSRiver Riddle   /// Return the memory index to use for the given value.
168abfd1a8bSRiver Riddle   ByteCodeField &getMemIndex(Value value) {
169abfd1a8bSRiver Riddle     assert(valueToMemIndex.count(value) &&
170abfd1a8bSRiver Riddle            "expected memory index to be assigned");
171abfd1a8bSRiver Riddle     return valueToMemIndex[value];
172abfd1a8bSRiver Riddle   }
173abfd1a8bSRiver Riddle 
174abfd1a8bSRiver Riddle   /// Return an index to use when referring to the given data that is uniqued in
175abfd1a8bSRiver Riddle   /// the MLIR context.
176abfd1a8bSRiver Riddle   template <typename T>
177abfd1a8bSRiver Riddle   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
178abfd1a8bSRiver Riddle   getMemIndex(T val) {
179abfd1a8bSRiver Riddle     const void *opaqueVal = val.getAsOpaquePointer();
180abfd1a8bSRiver Riddle 
181abfd1a8bSRiver Riddle     // Get or insert a reference to this value.
182abfd1a8bSRiver Riddle     auto it = uniquedDataToMemIndex.try_emplace(
183abfd1a8bSRiver Riddle         opaqueVal, maxValueMemoryIndex + uniquedData.size());
184abfd1a8bSRiver Riddle     if (it.second)
185abfd1a8bSRiver Riddle       uniquedData.push_back(opaqueVal);
186abfd1a8bSRiver Riddle     return it.first->second;
187abfd1a8bSRiver Riddle   }
188abfd1a8bSRiver Riddle 
189abfd1a8bSRiver Riddle private:
190abfd1a8bSRiver Riddle   /// Allocate memory indices for the results of operations within the matcher
191abfd1a8bSRiver Riddle   /// and rewriters.
192abfd1a8bSRiver Riddle   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
193abfd1a8bSRiver Riddle 
194abfd1a8bSRiver Riddle   /// Generate the bytecode for the given operation.
195abfd1a8bSRiver Riddle   void generate(Operation *op, ByteCodeWriter &writer);
196abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
197abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
198abfd1a8bSRiver Riddle   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
199abfd1a8bSRiver Riddle   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
200abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
201abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
202abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
203abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
204abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
205abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
206abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
207abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
208abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
209abfd1a8bSRiver Riddle   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
210abfd1a8bSRiver Riddle   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
211abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
212abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
213abfd1a8bSRiver Riddle   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
214abfd1a8bSRiver Riddle   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
215abfd1a8bSRiver Riddle   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
216abfd1a8bSRiver Riddle   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
217abfd1a8bSRiver Riddle   void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
218abfd1a8bSRiver Riddle   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
219abfd1a8bSRiver Riddle   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
220abfd1a8bSRiver Riddle   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
221abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
222abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
223abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
224abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
225abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
226abfd1a8bSRiver Riddle 
227abfd1a8bSRiver Riddle   /// Mapping from value to its corresponding memory index.
228abfd1a8bSRiver Riddle   DenseMap<Value, ByteCodeField> valueToMemIndex;
229abfd1a8bSRiver Riddle 
230abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered rewrite to its index in
231abfd1a8bSRiver Riddle   /// the bytecode registry.
232abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
233abfd1a8bSRiver Riddle 
234abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered constraint to its index
235abfd1a8bSRiver Riddle   /// in the bytecode registry.
236abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> constraintToMemIndex;
237abfd1a8bSRiver Riddle 
238abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered creation method to its
239abfd1a8bSRiver Riddle   /// index in the bytecode registry.
240abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
241abfd1a8bSRiver Riddle 
242abfd1a8bSRiver Riddle   /// Mapping from rewriter function name to the bytecode address of the
243abfd1a8bSRiver Riddle   /// rewriter function in byte.
244abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
245abfd1a8bSRiver Riddle 
246abfd1a8bSRiver Riddle   /// Mapping from a uniqued storage object to its memory index within
247abfd1a8bSRiver Riddle   /// `uniquedData`.
248abfd1a8bSRiver Riddle   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
249abfd1a8bSRiver Riddle 
250abfd1a8bSRiver Riddle   /// The current MLIR context.
251abfd1a8bSRiver Riddle   MLIRContext *ctx;
252abfd1a8bSRiver Riddle 
253abfd1a8bSRiver Riddle   /// Data of the ByteCode class to be populated.
254abfd1a8bSRiver Riddle   std::vector<const void *> &uniquedData;
255abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &matcherByteCode;
256abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
257abfd1a8bSRiver Riddle   SmallVectorImpl<PDLByteCodePattern> &patterns;
258abfd1a8bSRiver Riddle   ByteCodeField &maxValueMemoryIndex;
259abfd1a8bSRiver Riddle };
260abfd1a8bSRiver Riddle 
261abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream.
262abfd1a8bSRiver Riddle struct ByteCodeWriter {
263abfd1a8bSRiver Riddle   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
264abfd1a8bSRiver Riddle       : bytecode(bytecode), generator(generator) {}
265abfd1a8bSRiver Riddle 
266abfd1a8bSRiver Riddle   /// Append a field to the bytecode.
267abfd1a8bSRiver Riddle   void append(ByteCodeField field) { bytecode.push_back(field); }
268*fa20ab7bSRiver Riddle   void append(OpCode opCode) { bytecode.push_back(opCode); }
269abfd1a8bSRiver Riddle 
270abfd1a8bSRiver Riddle   /// Append an address to the bytecode.
271abfd1a8bSRiver Riddle   void append(ByteCodeAddr field) {
272abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
273abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
274abfd1a8bSRiver Riddle 
275abfd1a8bSRiver Riddle     ByteCodeField fieldParts[2];
276abfd1a8bSRiver Riddle     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
277abfd1a8bSRiver Riddle     bytecode.append({fieldParts[0], fieldParts[1]});
278abfd1a8bSRiver Riddle   }
279abfd1a8bSRiver Riddle 
280abfd1a8bSRiver Riddle   /// Append a successor range to the bytecode, the exact address will need to
281abfd1a8bSRiver Riddle   /// be resolved later.
282abfd1a8bSRiver Riddle   void append(SuccessorRange successors) {
283abfd1a8bSRiver Riddle     // Add back references to the any successors so that the address can be
284abfd1a8bSRiver Riddle     // resolved later.
285abfd1a8bSRiver Riddle     for (Block *successor : successors) {
286abfd1a8bSRiver Riddle       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
287abfd1a8bSRiver Riddle       append(ByteCodeAddr(0));
288abfd1a8bSRiver Riddle     }
289abfd1a8bSRiver Riddle   }
290abfd1a8bSRiver Riddle 
291abfd1a8bSRiver Riddle   /// Append a range of values that will be read as generic PDLValues.
292abfd1a8bSRiver Riddle   void appendPDLValueList(OperandRange values) {
293abfd1a8bSRiver Riddle     bytecode.push_back(values.size());
294abfd1a8bSRiver Riddle     for (Value value : values) {
295abfd1a8bSRiver Riddle       // Append the type of the value in addition to the value itself.
296abfd1a8bSRiver Riddle       PDLValueKind kind =
297abfd1a8bSRiver Riddle           TypeSwitch<Type, PDLValueKind>(value.getType())
298abfd1a8bSRiver Riddle               .Case<pdl::AttributeType>(
299abfd1a8bSRiver Riddle                   [](Type) { return PDLValueKind::Attribute; })
300abfd1a8bSRiver Riddle               .Case<pdl::OperationType>(
301abfd1a8bSRiver Riddle                   [](Type) { return PDLValueKind::Operation; })
302abfd1a8bSRiver Riddle               .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
303abfd1a8bSRiver Riddle               .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
304abfd1a8bSRiver Riddle       bytecode.push_back(static_cast<ByteCodeField>(kind));
305abfd1a8bSRiver Riddle       append(value);
306abfd1a8bSRiver Riddle     }
307abfd1a8bSRiver Riddle   }
308abfd1a8bSRiver Riddle 
309abfd1a8bSRiver Riddle   /// Check if the given class `T` has an iterator type.
310abfd1a8bSRiver Riddle   template <typename T, typename... Args>
311abfd1a8bSRiver Riddle   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
312abfd1a8bSRiver Riddle 
313abfd1a8bSRiver Riddle   /// Append a value that will be stored in a memory slot and not inline within
314abfd1a8bSRiver Riddle   /// the bytecode.
315abfd1a8bSRiver Riddle   template <typename T>
316abfd1a8bSRiver Riddle   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
317abfd1a8bSRiver Riddle                    std::is_pointer<T>::value>
318abfd1a8bSRiver Riddle   append(T value) {
319abfd1a8bSRiver Riddle     bytecode.push_back(generator.getMemIndex(value));
320abfd1a8bSRiver Riddle   }
321abfd1a8bSRiver Riddle 
322abfd1a8bSRiver Riddle   /// Append a range of values.
323abfd1a8bSRiver Riddle   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
324abfd1a8bSRiver Riddle   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
325abfd1a8bSRiver Riddle   append(T range) {
326abfd1a8bSRiver Riddle     bytecode.push_back(llvm::size(range));
327abfd1a8bSRiver Riddle     for (auto it : range)
328abfd1a8bSRiver Riddle       append(it);
329abfd1a8bSRiver Riddle   }
330abfd1a8bSRiver Riddle 
331abfd1a8bSRiver Riddle   /// Append a variadic number of fields to the bytecode.
332abfd1a8bSRiver Riddle   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
333abfd1a8bSRiver Riddle   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
334abfd1a8bSRiver Riddle     append(field);
335abfd1a8bSRiver Riddle     append(field2, fields...);
336abfd1a8bSRiver Riddle   }
337abfd1a8bSRiver Riddle 
338abfd1a8bSRiver Riddle   /// Successor references in the bytecode that have yet to be resolved.
339abfd1a8bSRiver Riddle   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
340abfd1a8bSRiver Riddle 
341abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
342abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &bytecode;
343abfd1a8bSRiver Riddle 
344abfd1a8bSRiver Riddle   /// The main generator producing PDL.
345abfd1a8bSRiver Riddle   Generator &generator;
346abfd1a8bSRiver Riddle };
347abfd1a8bSRiver Riddle } // end anonymous namespace
348abfd1a8bSRiver Riddle 
349abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) {
350abfd1a8bSRiver Riddle   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
351abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
352abfd1a8bSRiver Riddle   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
353abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getRewriterModuleName());
354abfd1a8bSRiver Riddle   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
355abfd1a8bSRiver Riddle 
356abfd1a8bSRiver Riddle   // Allocate memory indices for the results of operations within the matcher
357abfd1a8bSRiver Riddle   // and rewriters.
358abfd1a8bSRiver Riddle   allocateMemoryIndices(matcherFunc, rewriterModule);
359abfd1a8bSRiver Riddle 
360abfd1a8bSRiver Riddle   // Generate code for the rewriter functions.
361abfd1a8bSRiver Riddle   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
362abfd1a8bSRiver Riddle   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
363abfd1a8bSRiver Riddle     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
364abfd1a8bSRiver Riddle     for (Operation &op : rewriterFunc.getOps())
365abfd1a8bSRiver Riddle       generate(&op, rewriterByteCodeWriter);
366abfd1a8bSRiver Riddle   }
367abfd1a8bSRiver Riddle   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
368abfd1a8bSRiver Riddle          "unexpected branches in rewriter function");
369abfd1a8bSRiver Riddle 
370abfd1a8bSRiver Riddle   // Generate code for the matcher function.
371abfd1a8bSRiver Riddle   DenseMap<Block *, ByteCodeAddr> blockToAddr;
372abfd1a8bSRiver Riddle   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
373abfd1a8bSRiver Riddle   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
374abfd1a8bSRiver Riddle   for (Block *block : rpot) {
375abfd1a8bSRiver Riddle     // Keep track of where this block begins within the matcher function.
376abfd1a8bSRiver Riddle     blockToAddr.try_emplace(block, matcherByteCode.size());
377abfd1a8bSRiver Riddle     for (Operation &op : *block)
378abfd1a8bSRiver Riddle       generate(&op, matcherByteCodeWriter);
379abfd1a8bSRiver Riddle   }
380abfd1a8bSRiver Riddle 
381abfd1a8bSRiver Riddle   // Resolve successor references in the matcher.
382abfd1a8bSRiver Riddle   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
383abfd1a8bSRiver Riddle     ByteCodeAddr addr = blockToAddr[it.first];
384abfd1a8bSRiver Riddle     for (unsigned offsetToFix : it.second)
385abfd1a8bSRiver Riddle       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
386abfd1a8bSRiver Riddle   }
387abfd1a8bSRiver Riddle }
388abfd1a8bSRiver Riddle 
389abfd1a8bSRiver Riddle void Generator::allocateMemoryIndices(FuncOp matcherFunc,
390abfd1a8bSRiver Riddle                                       ModuleOp rewriterModule) {
391abfd1a8bSRiver Riddle   // Rewriters use simplistic allocation scheme that simply assigns an index to
392abfd1a8bSRiver Riddle   // each result.
393abfd1a8bSRiver Riddle   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
394abfd1a8bSRiver Riddle     ByteCodeField index = 0;
395abfd1a8bSRiver Riddle     for (BlockArgument arg : rewriterFunc.getArguments())
396abfd1a8bSRiver Riddle       valueToMemIndex.try_emplace(arg, index++);
397abfd1a8bSRiver Riddle     rewriterFunc.getBody().walk([&](Operation *op) {
398abfd1a8bSRiver Riddle       for (Value result : op->getResults())
399abfd1a8bSRiver Riddle         valueToMemIndex.try_emplace(result, index++);
400abfd1a8bSRiver Riddle     });
401abfd1a8bSRiver Riddle     if (index > maxValueMemoryIndex)
402abfd1a8bSRiver Riddle       maxValueMemoryIndex = index;
403abfd1a8bSRiver Riddle   }
404abfd1a8bSRiver Riddle 
405abfd1a8bSRiver Riddle   // The matcher function uses a more sophisticated numbering that tries to
406abfd1a8bSRiver Riddle   // minimize the number of memory indices assigned. This is done by determining
407abfd1a8bSRiver Riddle   // a live range of the values within the matcher, then the allocation is just
408abfd1a8bSRiver Riddle   // finding the minimal number of overlapping live ranges. This is essentially
409abfd1a8bSRiver Riddle   // a simplified form of register allocation where we don't necessarily have a
410abfd1a8bSRiver Riddle   // limited number of registers, but we still want to minimize the number used.
411abfd1a8bSRiver Riddle   DenseMap<Operation *, ByteCodeField> opToIndex;
412abfd1a8bSRiver Riddle   matcherFunc.getBody().walk([&](Operation *op) {
413abfd1a8bSRiver Riddle     opToIndex.insert(std::make_pair(op, opToIndex.size()));
414abfd1a8bSRiver Riddle   });
415abfd1a8bSRiver Riddle 
416abfd1a8bSRiver Riddle   // Liveness info for each of the defs within the matcher.
417abfd1a8bSRiver Riddle   using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
418abfd1a8bSRiver Riddle   LivenessSet::Allocator allocator;
419abfd1a8bSRiver Riddle   DenseMap<Value, LivenessSet> valueDefRanges;
420abfd1a8bSRiver Riddle 
421abfd1a8bSRiver Riddle   // Assign the root operation being matched to slot 0.
422abfd1a8bSRiver Riddle   BlockArgument rootOpArg = matcherFunc.getArgument(0);
423abfd1a8bSRiver Riddle   valueToMemIndex[rootOpArg] = 0;
424abfd1a8bSRiver Riddle 
425abfd1a8bSRiver Riddle   // Walk each of the blocks, computing the def interval that the value is used.
426abfd1a8bSRiver Riddle   Liveness matcherLiveness(matcherFunc);
427abfd1a8bSRiver Riddle   for (Block &block : matcherFunc.getBody()) {
428abfd1a8bSRiver Riddle     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
429abfd1a8bSRiver Riddle     assert(info && "expected liveness info for block");
430abfd1a8bSRiver Riddle     auto processValue = [&](Value value, Operation *firstUseOrDef) {
431abfd1a8bSRiver Riddle       // We don't need to process the root op argument, this value is always
432abfd1a8bSRiver Riddle       // assigned to the first memory slot.
433abfd1a8bSRiver Riddle       if (value == rootOpArg)
434abfd1a8bSRiver Riddle         return;
435abfd1a8bSRiver Riddle 
436abfd1a8bSRiver Riddle       // Set indices for the range of this block that the value is used.
437abfd1a8bSRiver Riddle       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
438abfd1a8bSRiver Riddle       defRangeIt->second.insert(
439abfd1a8bSRiver Riddle           opToIndex[firstUseOrDef],
440abfd1a8bSRiver Riddle           opToIndex[info->getEndOperation(value, firstUseOrDef)],
441abfd1a8bSRiver Riddle           /*dummyValue*/ 0);
442abfd1a8bSRiver Riddle     };
443abfd1a8bSRiver Riddle 
444abfd1a8bSRiver Riddle     // Process the live-ins of this block.
445abfd1a8bSRiver Riddle     for (Value liveIn : info->in())
446abfd1a8bSRiver Riddle       processValue(liveIn, &block.front());
447abfd1a8bSRiver Riddle 
448abfd1a8bSRiver Riddle     // Process any new defs within this block.
449abfd1a8bSRiver Riddle     for (Operation &op : block)
450abfd1a8bSRiver Riddle       for (Value result : op.getResults())
451abfd1a8bSRiver Riddle         processValue(result, &op);
452abfd1a8bSRiver Riddle   }
453abfd1a8bSRiver Riddle 
454abfd1a8bSRiver Riddle   // Greedily allocate memory slots using the computed def live ranges.
455abfd1a8bSRiver Riddle   std::vector<LivenessSet> allocatedIndices;
456abfd1a8bSRiver Riddle   for (auto &defIt : valueDefRanges) {
457abfd1a8bSRiver Riddle     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
458abfd1a8bSRiver Riddle     LivenessSet &defSet = defIt.second;
459abfd1a8bSRiver Riddle 
460abfd1a8bSRiver Riddle     // Try to allocate to an existing index.
461abfd1a8bSRiver Riddle     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
462abfd1a8bSRiver Riddle       LivenessSet &existingIndex = existingIndexIt.value();
463abfd1a8bSRiver Riddle       llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
464abfd1a8bSRiver Riddle           defIt.second, existingIndex);
465abfd1a8bSRiver Riddle       if (overlaps.valid())
466abfd1a8bSRiver Riddle         continue;
467abfd1a8bSRiver Riddle       // Union the range of the def within the existing index.
468abfd1a8bSRiver Riddle       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
469abfd1a8bSRiver Riddle         existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
470abfd1a8bSRiver Riddle       memIndex = existingIndexIt.index() + 1;
471abfd1a8bSRiver Riddle     }
472abfd1a8bSRiver Riddle 
473abfd1a8bSRiver Riddle     // If no existing index could be used, add a new one.
474abfd1a8bSRiver Riddle     if (memIndex == 0) {
475abfd1a8bSRiver Riddle       allocatedIndices.emplace_back(allocator);
476abfd1a8bSRiver Riddle       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
477abfd1a8bSRiver Riddle         allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
478abfd1a8bSRiver Riddle       memIndex = allocatedIndices.size();
479abfd1a8bSRiver Riddle     }
480abfd1a8bSRiver Riddle   }
481abfd1a8bSRiver Riddle 
482abfd1a8bSRiver Riddle   // Update the max number of indices.
483abfd1a8bSRiver Riddle   ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
484abfd1a8bSRiver Riddle   if (numMatcherIndices > maxValueMemoryIndex)
485abfd1a8bSRiver Riddle     maxValueMemoryIndex = numMatcherIndices;
486abfd1a8bSRiver Riddle }
487abfd1a8bSRiver Riddle 
488abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) {
489abfd1a8bSRiver Riddle   TypeSwitch<Operation *>(op)
490abfd1a8bSRiver Riddle       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
491abfd1a8bSRiver Riddle             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
492abfd1a8bSRiver Riddle             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
493abfd1a8bSRiver Riddle             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
494abfd1a8bSRiver Riddle             pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
495abfd1a8bSRiver Riddle             pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
496abfd1a8bSRiver Riddle             pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
497abfd1a8bSRiver Riddle             pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
498abfd1a8bSRiver Riddle             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
499abfd1a8bSRiver Riddle             pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
500abfd1a8bSRiver Riddle             pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
501abfd1a8bSRiver Riddle             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
502abfd1a8bSRiver Riddle             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
503abfd1a8bSRiver Riddle             pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
504abfd1a8bSRiver Riddle             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
505abfd1a8bSRiver Riddle           [&](auto interpOp) { this->generate(interpOp, writer); })
506abfd1a8bSRiver Riddle       .Default([](Operation *) {
507abfd1a8bSRiver Riddle         llvm_unreachable("unknown `pdl_interp` operation");
508abfd1a8bSRiver Riddle       });
509abfd1a8bSRiver Riddle }
510abfd1a8bSRiver Riddle 
511abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op,
512abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
513abfd1a8bSRiver Riddle   assert(constraintToMemIndex.count(op.name()) &&
514abfd1a8bSRiver Riddle          "expected index for constraint function");
515abfd1a8bSRiver Riddle   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
516abfd1a8bSRiver Riddle                 op.constParamsAttr());
517abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
518abfd1a8bSRiver Riddle   writer.append(op.getSuccessors());
519abfd1a8bSRiver Riddle }
520abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op,
521abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
522abfd1a8bSRiver Riddle   assert(externalRewriterToMemIndex.count(op.name()) &&
523abfd1a8bSRiver Riddle          "expected index for rewrite function");
524abfd1a8bSRiver Riddle   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
525abfd1a8bSRiver Riddle                 op.constParamsAttr(), op.root());
526abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
527abfd1a8bSRiver Riddle }
528abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
529abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
530abfd1a8bSRiver Riddle }
531abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
532abfd1a8bSRiver Riddle   writer.append(OpCode::Branch, SuccessorRange(op));
533abfd1a8bSRiver Riddle }
534abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op,
535abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
536abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
537abfd1a8bSRiver Riddle                 op.getSuccessors());
538abfd1a8bSRiver Riddle }
539abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op,
540abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
541abfd1a8bSRiver Riddle   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
542abfd1a8bSRiver Riddle                 op.getSuccessors());
543abfd1a8bSRiver Riddle }
544abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op,
545abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
546abfd1a8bSRiver Riddle   writer.append(OpCode::CheckOperationName, op.operation(),
547abfd1a8bSRiver Riddle                 OperationName(op.name(), ctx), op.getSuccessors());
548abfd1a8bSRiver Riddle }
549abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op,
550abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
551abfd1a8bSRiver Riddle   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
552abfd1a8bSRiver Riddle                 op.getSuccessors());
553abfd1a8bSRiver Riddle }
554abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
555abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
556abfd1a8bSRiver Riddle }
557abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op,
558abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
559abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
560abfd1a8bSRiver Riddle   getMemIndex(op.attribute()) = getMemIndex(op.value());
561abfd1a8bSRiver Riddle }
562abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateNativeOp op,
563abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
564abfd1a8bSRiver Riddle   assert(nativeCreateToMemIndex.count(op.name()) &&
565abfd1a8bSRiver Riddle          "expected index for creation function");
566abfd1a8bSRiver Riddle   writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
567abfd1a8bSRiver Riddle                 op.result(), op.constParamsAttr());
568abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
569abfd1a8bSRiver Riddle }
570abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op,
571abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
572abfd1a8bSRiver Riddle   writer.append(OpCode::CreateOperation, op.operation(),
573abfd1a8bSRiver Riddle                 OperationName(op.name(), ctx), op.operands());
574abfd1a8bSRiver Riddle 
575abfd1a8bSRiver Riddle   // Add the attributes.
576abfd1a8bSRiver Riddle   OperandRange attributes = op.attributes();
577abfd1a8bSRiver Riddle   writer.append(static_cast<ByteCodeField>(attributes.size()));
578abfd1a8bSRiver Riddle   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
579abfd1a8bSRiver Riddle     writer.append(
580abfd1a8bSRiver Riddle         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
581abfd1a8bSRiver Riddle         std::get<1>(it));
582abfd1a8bSRiver Riddle   }
583abfd1a8bSRiver Riddle   writer.append(op.types());
584abfd1a8bSRiver Riddle }
585abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
586abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
587abfd1a8bSRiver Riddle   getMemIndex(op.result()) = getMemIndex(op.value());
588abfd1a8bSRiver Riddle }
589abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
590abfd1a8bSRiver Riddle   writer.append(OpCode::EraseOp, op.operation());
591abfd1a8bSRiver Riddle }
592abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
593abfd1a8bSRiver Riddle   writer.append(OpCode::Finalize);
594abfd1a8bSRiver Riddle }
595abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op,
596abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
597abfd1a8bSRiver Riddle   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
598abfd1a8bSRiver Riddle                 Identifier::get(op.name(), ctx));
599abfd1a8bSRiver Riddle }
600abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op,
601abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
602abfd1a8bSRiver Riddle   writer.append(OpCode::GetAttributeType, op.result(), op.value());
603abfd1a8bSRiver Riddle }
604abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op,
605abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
606abfd1a8bSRiver Riddle   writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
607abfd1a8bSRiver Riddle }
608abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
609abfd1a8bSRiver Riddle   uint32_t index = op.index();
610abfd1a8bSRiver Riddle   if (index < 4)
611abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
612abfd1a8bSRiver Riddle   else
613abfd1a8bSRiver Riddle     writer.append(OpCode::GetOperandN, index);
614abfd1a8bSRiver Riddle   writer.append(op.operation(), op.value());
615abfd1a8bSRiver Riddle }
616abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
617abfd1a8bSRiver Riddle   uint32_t index = op.index();
618abfd1a8bSRiver Riddle   if (index < 4)
619abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
620abfd1a8bSRiver Riddle   else
621abfd1a8bSRiver Riddle     writer.append(OpCode::GetResultN, index);
622abfd1a8bSRiver Riddle   writer.append(op.operation(), op.value());
623abfd1a8bSRiver Riddle }
624abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op,
625abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
626abfd1a8bSRiver Riddle   writer.append(OpCode::GetValueType, op.result(), op.value());
627abfd1a8bSRiver Riddle }
628abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::InferredTypeOp op,
629abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
630abfd1a8bSRiver Riddle   // InferType maps to a null type as a marker for inferring a result type.
631abfd1a8bSRiver Riddle   getMemIndex(op.type()) = getMemIndex(Type());
632abfd1a8bSRiver Riddle }
633abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
634abfd1a8bSRiver Riddle   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
635abfd1a8bSRiver Riddle }
636abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
637abfd1a8bSRiver Riddle   ByteCodeField patternIndex = patterns.size();
638abfd1a8bSRiver Riddle   patterns.emplace_back(PDLByteCodePattern::create(
639abfd1a8bSRiver Riddle       op, rewriterToAddr[op.rewriter().getLeafReference()]));
640abfd1a8bSRiver Riddle   writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op),
641abfd1a8bSRiver Riddle                 op.matchedOps(), op.inputs());
642abfd1a8bSRiver Riddle }
643abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
644abfd1a8bSRiver Riddle   writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
645abfd1a8bSRiver Riddle }
646abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op,
647abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
648abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
649abfd1a8bSRiver Riddle                 op.getSuccessors());
650abfd1a8bSRiver Riddle }
651abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op,
652abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
653abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
654abfd1a8bSRiver Riddle                 op.getSuccessors());
655abfd1a8bSRiver Riddle }
656abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op,
657abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
658abfd1a8bSRiver Riddle   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
659abfd1a8bSRiver Riddle     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
660abfd1a8bSRiver Riddle   });
661abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
662abfd1a8bSRiver Riddle                 op.getSuccessors());
663abfd1a8bSRiver Riddle }
664abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op,
665abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
666abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
667abfd1a8bSRiver Riddle                 op.getSuccessors());
668abfd1a8bSRiver Riddle }
669abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
670abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
671abfd1a8bSRiver Riddle                 op.getSuccessors());
672abfd1a8bSRiver Riddle }
673abfd1a8bSRiver Riddle 
674abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
675abfd1a8bSRiver Riddle // PDLByteCode
676abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
677abfd1a8bSRiver Riddle 
678abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module,
679abfd1a8bSRiver Riddle                          llvm::StringMap<PDLConstraintFunction> constraintFns,
680abfd1a8bSRiver Riddle                          llvm::StringMap<PDLCreateFunction> createFns,
681abfd1a8bSRiver Riddle                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
682abfd1a8bSRiver Riddle   Generator generator(module.getContext(), uniquedData, matcherByteCode,
683abfd1a8bSRiver Riddle                       rewriterByteCode, patterns, maxValueMemoryIndex,
684abfd1a8bSRiver Riddle                       constraintFns, createFns, rewriteFns);
685abfd1a8bSRiver Riddle   generator.generate(module);
686abfd1a8bSRiver Riddle 
687abfd1a8bSRiver Riddle   // Initialize the external functions.
688abfd1a8bSRiver Riddle   for (auto &it : constraintFns)
689abfd1a8bSRiver Riddle     constraintFunctions.push_back(std::move(it.second));
690abfd1a8bSRiver Riddle   for (auto &it : createFns)
691abfd1a8bSRiver Riddle     createFunctions.push_back(std::move(it.second));
692abfd1a8bSRiver Riddle   for (auto &it : rewriteFns)
693abfd1a8bSRiver Riddle     rewriteFunctions.push_back(std::move(it.second));
694abfd1a8bSRiver Riddle }
695abfd1a8bSRiver Riddle 
696abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current
697abfd1a8bSRiver Riddle /// bytecode.
698abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
699abfd1a8bSRiver Riddle   state.memory.resize(maxValueMemoryIndex, nullptr);
700abfd1a8bSRiver Riddle   state.currentPatternBenefits.reserve(patterns.size());
701abfd1a8bSRiver Riddle   for (const PDLByteCodePattern &pattern : patterns)
702abfd1a8bSRiver Riddle     state.currentPatternBenefits.push_back(pattern.getBenefit());
703abfd1a8bSRiver Riddle }
704abfd1a8bSRiver Riddle 
705abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
706abfd1a8bSRiver Riddle // ByteCode Execution
707abfd1a8bSRiver Riddle 
708abfd1a8bSRiver Riddle namespace {
709abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream.
710abfd1a8bSRiver Riddle class ByteCodeExecutor {
711abfd1a8bSRiver Riddle public:
712abfd1a8bSRiver Riddle   ByteCodeExecutor(const ByteCodeField *curCodeIt,
713abfd1a8bSRiver Riddle                    MutableArrayRef<const void *> memory,
714abfd1a8bSRiver Riddle                    ArrayRef<const void *> uniquedMemory,
715abfd1a8bSRiver Riddle                    ArrayRef<ByteCodeField> code,
716abfd1a8bSRiver Riddle                    ArrayRef<PatternBenefit> currentPatternBenefits,
717abfd1a8bSRiver Riddle                    ArrayRef<PDLByteCodePattern> patterns,
718abfd1a8bSRiver Riddle                    ArrayRef<PDLConstraintFunction> constraintFunctions,
719abfd1a8bSRiver Riddle                    ArrayRef<PDLCreateFunction> createFunctions,
720abfd1a8bSRiver Riddle                    ArrayRef<PDLRewriteFunction> rewriteFunctions)
721abfd1a8bSRiver Riddle       : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
722abfd1a8bSRiver Riddle         code(code), currentPatternBenefits(currentPatternBenefits),
723abfd1a8bSRiver Riddle         patterns(patterns), constraintFunctions(constraintFunctions),
724abfd1a8bSRiver Riddle         createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
725abfd1a8bSRiver Riddle 
726abfd1a8bSRiver Riddle   /// Start executing the code at the current bytecode index. `matches` is an
727abfd1a8bSRiver Riddle   /// optional field provided when this function is executed in a matching
728abfd1a8bSRiver Riddle   /// context.
729abfd1a8bSRiver Riddle   void execute(PatternRewriter &rewriter,
730abfd1a8bSRiver Riddle                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
731abfd1a8bSRiver Riddle                Optional<Location> mainRewriteLoc = {});
732abfd1a8bSRiver Riddle 
733abfd1a8bSRiver Riddle private:
734abfd1a8bSRiver Riddle   /// Read a value from the bytecode buffer, optionally skipping a certain
735abfd1a8bSRiver Riddle   /// number of prefix values. These methods always update the buffer to point
736abfd1a8bSRiver Riddle   /// to the next field after the read data.
737abfd1a8bSRiver Riddle   template <typename T = ByteCodeField>
738abfd1a8bSRiver Riddle   T read(size_t skipN = 0) {
739abfd1a8bSRiver Riddle     curCodeIt += skipN;
740abfd1a8bSRiver Riddle     return readImpl<T>();
741abfd1a8bSRiver Riddle   }
742abfd1a8bSRiver Riddle   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
743abfd1a8bSRiver Riddle 
744abfd1a8bSRiver Riddle   /// Read a list of values from the bytecode buffer.
745abfd1a8bSRiver Riddle   template <typename ValueT, typename T>
746abfd1a8bSRiver Riddle   void readList(SmallVectorImpl<T> &list) {
747abfd1a8bSRiver Riddle     list.clear();
748abfd1a8bSRiver Riddle     for (unsigned i = 0, e = read(); i != e; ++i)
749abfd1a8bSRiver Riddle       list.push_back(read<ValueT>());
750abfd1a8bSRiver Riddle   }
751abfd1a8bSRiver Riddle 
752abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a predicate value.
753abfd1a8bSRiver Riddle   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
754abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a destination index.
755abfd1a8bSRiver Riddle   void selectJump(size_t destIndex) {
756abfd1a8bSRiver Riddle     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
757abfd1a8bSRiver Riddle   }
758abfd1a8bSRiver Riddle 
759abfd1a8bSRiver Riddle   /// Handle a switch operation with the provided value and cases.
760abfd1a8bSRiver Riddle   template <typename T, typename RangeT>
761abfd1a8bSRiver Riddle   void handleSwitch(const T &value, RangeT &&cases) {
762abfd1a8bSRiver Riddle     LLVM_DEBUG({
763abfd1a8bSRiver Riddle       llvm::dbgs() << "  * Value: " << value << "\n"
764abfd1a8bSRiver Riddle                    << "  * Cases: ";
765abfd1a8bSRiver Riddle       llvm::interleaveComma(cases, llvm::dbgs());
766abfd1a8bSRiver Riddle       llvm::dbgs() << "\n\n";
767abfd1a8bSRiver Riddle     });
768abfd1a8bSRiver Riddle 
769abfd1a8bSRiver Riddle     // Check to see if the attribute value is within the case list. Jump to
770abfd1a8bSRiver Riddle     // the correct successor index based on the result.
771abfd1a8bSRiver Riddle     auto it = llvm::find(cases, value);
772abfd1a8bSRiver Riddle     selectJump(it == cases.end() ? size_t(0) : ((it - cases.begin()) + 1));
773abfd1a8bSRiver Riddle   }
774abfd1a8bSRiver Riddle 
775abfd1a8bSRiver Riddle   /// Internal implementation of reading various data types from the bytecode
776abfd1a8bSRiver Riddle   /// stream.
777abfd1a8bSRiver Riddle   template <typename T>
778abfd1a8bSRiver Riddle   const void *readFromMemory() {
779abfd1a8bSRiver Riddle     size_t index = *curCodeIt++;
780abfd1a8bSRiver Riddle 
781abfd1a8bSRiver Riddle     // If this type is an SSA value, it can only be stored in non-const memory.
782abfd1a8bSRiver Riddle     if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
783abfd1a8bSRiver Riddle       return memory[index];
784abfd1a8bSRiver Riddle 
785abfd1a8bSRiver Riddle     // Otherwise, if this index is not inbounds it is uniqued.
786abfd1a8bSRiver Riddle     return uniquedMemory[index - memory.size()];
787abfd1a8bSRiver Riddle   }
788abfd1a8bSRiver Riddle   template <typename T>
789abfd1a8bSRiver Riddle   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
790abfd1a8bSRiver Riddle     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
791abfd1a8bSRiver Riddle   }
792abfd1a8bSRiver Riddle   template <typename T>
793abfd1a8bSRiver Riddle   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
794abfd1a8bSRiver Riddle                    T>
795abfd1a8bSRiver Riddle   readImpl() {
796abfd1a8bSRiver Riddle     return T(T::getFromOpaquePointer(readFromMemory<T>()));
797abfd1a8bSRiver Riddle   }
798abfd1a8bSRiver Riddle   template <typename T>
799abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
800abfd1a8bSRiver Riddle     switch (static_cast<PDLValueKind>(read())) {
801abfd1a8bSRiver Riddle     case PDLValueKind::Attribute:
802abfd1a8bSRiver Riddle       return read<Attribute>();
803abfd1a8bSRiver Riddle     case PDLValueKind::Operation:
804abfd1a8bSRiver Riddle       return read<Operation *>();
805abfd1a8bSRiver Riddle     case PDLValueKind::Type:
806abfd1a8bSRiver Riddle       return read<Type>();
807abfd1a8bSRiver Riddle     case PDLValueKind::Value:
808abfd1a8bSRiver Riddle       return read<Value>();
809abfd1a8bSRiver Riddle     }
810abfd1a8bSRiver Riddle   }
811abfd1a8bSRiver Riddle   template <typename T>
812abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
813abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
814abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
815abfd1a8bSRiver Riddle     ByteCodeAddr result;
816abfd1a8bSRiver Riddle     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
817abfd1a8bSRiver Riddle     curCodeIt += 2;
818abfd1a8bSRiver Riddle     return result;
819abfd1a8bSRiver Riddle   }
820abfd1a8bSRiver Riddle   template <typename T>
821abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
822abfd1a8bSRiver Riddle     return *curCodeIt++;
823abfd1a8bSRiver Riddle   }
824abfd1a8bSRiver Riddle 
825abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
826abfd1a8bSRiver Riddle   const ByteCodeField *curCodeIt;
827abfd1a8bSRiver Riddle 
828abfd1a8bSRiver Riddle   /// The current execution memory.
829abfd1a8bSRiver Riddle   MutableArrayRef<const void *> memory;
830abfd1a8bSRiver Riddle 
831abfd1a8bSRiver Riddle   /// References to ByteCode data necessary for execution.
832abfd1a8bSRiver Riddle   ArrayRef<const void *> uniquedMemory;
833abfd1a8bSRiver Riddle   ArrayRef<ByteCodeField> code;
834abfd1a8bSRiver Riddle   ArrayRef<PatternBenefit> currentPatternBenefits;
835abfd1a8bSRiver Riddle   ArrayRef<PDLByteCodePattern> patterns;
836abfd1a8bSRiver Riddle   ArrayRef<PDLConstraintFunction> constraintFunctions;
837abfd1a8bSRiver Riddle   ArrayRef<PDLCreateFunction> createFunctions;
838abfd1a8bSRiver Riddle   ArrayRef<PDLRewriteFunction> rewriteFunctions;
839abfd1a8bSRiver Riddle };
840abfd1a8bSRiver Riddle } // end anonymous namespace
841abfd1a8bSRiver Riddle 
842abfd1a8bSRiver Riddle void ByteCodeExecutor::execute(
843abfd1a8bSRiver Riddle     PatternRewriter &rewriter,
844abfd1a8bSRiver Riddle     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
845abfd1a8bSRiver Riddle     Optional<Location> mainRewriteLoc) {
846abfd1a8bSRiver Riddle   while (true) {
847abfd1a8bSRiver Riddle     OpCode opCode = static_cast<OpCode>(read());
848abfd1a8bSRiver Riddle     switch (opCode) {
849abfd1a8bSRiver Riddle     case ApplyConstraint: {
850abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
851abfd1a8bSRiver Riddle       const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
852abfd1a8bSRiver Riddle       ArrayAttr constParams = read<ArrayAttr>();
853abfd1a8bSRiver Riddle       SmallVector<PDLValue, 16> args;
854abfd1a8bSRiver Riddle       readList<PDLValue>(args);
855abfd1a8bSRiver Riddle       LLVM_DEBUG({
856abfd1a8bSRiver Riddle         llvm::dbgs() << "  * Arguments: ";
857abfd1a8bSRiver Riddle         llvm::interleaveComma(args, llvm::dbgs());
858abfd1a8bSRiver Riddle         llvm::dbgs() << "\n  * Parameters: " << constParams << "\n\n";
859abfd1a8bSRiver Riddle       });
860abfd1a8bSRiver Riddle 
861abfd1a8bSRiver Riddle       // Invoke the constraint and jump to the proper destination.
862abfd1a8bSRiver Riddle       selectJump(succeeded(constraintFn(args, constParams, rewriter)));
863abfd1a8bSRiver Riddle       break;
864abfd1a8bSRiver Riddle     }
865abfd1a8bSRiver Riddle     case ApplyRewrite: {
866abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
867abfd1a8bSRiver Riddle       const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
868abfd1a8bSRiver Riddle       ArrayAttr constParams = read<ArrayAttr>();
869abfd1a8bSRiver Riddle       Operation *root = read<Operation *>();
870abfd1a8bSRiver Riddle       SmallVector<PDLValue, 16> args;
871abfd1a8bSRiver Riddle       readList<PDLValue>(args);
872abfd1a8bSRiver Riddle 
873abfd1a8bSRiver Riddle       LLVM_DEBUG({
874abfd1a8bSRiver Riddle         llvm::dbgs() << "  * Root: " << *root << "\n"
875abfd1a8bSRiver Riddle                      << "  * Arguments: ";
876abfd1a8bSRiver Riddle         llvm::interleaveComma(args, llvm::dbgs());
877abfd1a8bSRiver Riddle         llvm::dbgs() << "\n  * Parameters: " << constParams << "\n\n";
878abfd1a8bSRiver Riddle       });
879abfd1a8bSRiver Riddle       rewriteFn(root, args, constParams, rewriter);
880abfd1a8bSRiver Riddle       break;
881abfd1a8bSRiver Riddle     }
882abfd1a8bSRiver Riddle     case AreEqual: {
883abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
884abfd1a8bSRiver Riddle       const void *lhs = read<const void *>();
885abfd1a8bSRiver Riddle       const void *rhs = read<const void *>();
886abfd1a8bSRiver Riddle 
887abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
888abfd1a8bSRiver Riddle       selectJump(lhs == rhs);
889abfd1a8bSRiver Riddle       break;
890abfd1a8bSRiver Riddle     }
891abfd1a8bSRiver Riddle     case Branch: {
892abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
893abfd1a8bSRiver Riddle       curCodeIt = &code[read<ByteCodeAddr>()];
894abfd1a8bSRiver Riddle       break;
895abfd1a8bSRiver Riddle     }
896abfd1a8bSRiver Riddle     case CheckOperandCount: {
897abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
898abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
899abfd1a8bSRiver Riddle       uint32_t expectedCount = read<uint32_t>();
900abfd1a8bSRiver Riddle 
901abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
902abfd1a8bSRiver Riddle                               << "  * Expected: " << expectedCount << "\n\n");
903abfd1a8bSRiver Riddle       selectJump(op->getNumOperands() == expectedCount);
904abfd1a8bSRiver Riddle       break;
905abfd1a8bSRiver Riddle     }
906abfd1a8bSRiver Riddle     case CheckOperationName: {
907abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
908abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
909abfd1a8bSRiver Riddle       OperationName expectedName = read<OperationName>();
910abfd1a8bSRiver Riddle 
911abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs()
912abfd1a8bSRiver Riddle                  << "  * Found: \"" << op->getName() << "\"\n"
913abfd1a8bSRiver Riddle                  << "  * Expected: \"" << expectedName << "\"\n\n");
914abfd1a8bSRiver Riddle       selectJump(op->getName() == expectedName);
915abfd1a8bSRiver Riddle       break;
916abfd1a8bSRiver Riddle     }
917abfd1a8bSRiver Riddle     case CheckResultCount: {
918abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
919abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
920abfd1a8bSRiver Riddle       uint32_t expectedCount = read<uint32_t>();
921abfd1a8bSRiver Riddle 
922abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
923abfd1a8bSRiver Riddle                               << "  * Expected: " << expectedCount << "\n\n");
924abfd1a8bSRiver Riddle       selectJump(op->getNumResults() == expectedCount);
925abfd1a8bSRiver Riddle       break;
926abfd1a8bSRiver Riddle     }
927abfd1a8bSRiver Riddle     case CreateNative: {
928abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
929abfd1a8bSRiver Riddle       const PDLCreateFunction &createFn = createFunctions[read()];
930abfd1a8bSRiver Riddle       ByteCodeField resultIndex = read();
931abfd1a8bSRiver Riddle       ArrayAttr constParams = read<ArrayAttr>();
932abfd1a8bSRiver Riddle       SmallVector<PDLValue, 16> args;
933abfd1a8bSRiver Riddle       readList<PDLValue>(args);
934abfd1a8bSRiver Riddle 
935abfd1a8bSRiver Riddle       LLVM_DEBUG({
936abfd1a8bSRiver Riddle         llvm::dbgs() << "  * Arguments: ";
937abfd1a8bSRiver Riddle         llvm::interleaveComma(args, llvm::dbgs());
938abfd1a8bSRiver Riddle         llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
939abfd1a8bSRiver Riddle       });
940abfd1a8bSRiver Riddle 
941abfd1a8bSRiver Riddle       PDLValue result = createFn(args, constParams, rewriter);
942abfd1a8bSRiver Riddle       memory[resultIndex] = result.getAsOpaquePointer();
943abfd1a8bSRiver Riddle 
944abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n\n");
945abfd1a8bSRiver Riddle       break;
946abfd1a8bSRiver Riddle     }
947abfd1a8bSRiver Riddle     case CreateOperation: {
948abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
949abfd1a8bSRiver Riddle       assert(mainRewriteLoc && "expected rewrite loc to be provided when "
950abfd1a8bSRiver Riddle                                "executing the rewriter bytecode");
951abfd1a8bSRiver Riddle 
952abfd1a8bSRiver Riddle       unsigned memIndex = read();
953abfd1a8bSRiver Riddle       OperationState state(*mainRewriteLoc, read<OperationName>());
954abfd1a8bSRiver Riddle       readList<Value>(state.operands);
955abfd1a8bSRiver Riddle       for (unsigned i = 0, e = read(); i != e; ++i) {
956abfd1a8bSRiver Riddle         Identifier name = read<Identifier>();
957abfd1a8bSRiver Riddle         if (Attribute attr = read<Attribute>())
958abfd1a8bSRiver Riddle           state.addAttribute(name, attr);
959abfd1a8bSRiver Riddle       }
960abfd1a8bSRiver Riddle 
961abfd1a8bSRiver Riddle       bool hasInferredTypes = false;
962abfd1a8bSRiver Riddle       for (unsigned i = 0, e = read(); i != e; ++i) {
963abfd1a8bSRiver Riddle         Type resultType = read<Type>();
964abfd1a8bSRiver Riddle         hasInferredTypes |= !resultType;
965abfd1a8bSRiver Riddle         state.types.push_back(resultType);
966abfd1a8bSRiver Riddle       }
967abfd1a8bSRiver Riddle 
968abfd1a8bSRiver Riddle       // Handle the case where the operation has inferred types.
969abfd1a8bSRiver Riddle       if (hasInferredTypes) {
970abfd1a8bSRiver Riddle         InferTypeOpInterface::Concept *concept =
971abfd1a8bSRiver Riddle             state.name.getAbstractOperation()
972abfd1a8bSRiver Riddle                 ->getInterface<InferTypeOpInterface>();
973abfd1a8bSRiver Riddle 
974abfd1a8bSRiver Riddle         // TODO: Handle failure.
975abfd1a8bSRiver Riddle         SmallVector<Type, 2> inferredTypes;
976abfd1a8bSRiver Riddle         if (failed(concept->inferReturnTypes(
977abfd1a8bSRiver Riddle                 state.getContext(), state.location, state.operands,
978abfd1a8bSRiver Riddle                 state.attributes.getDictionary(state.getContext()),
979abfd1a8bSRiver Riddle                 state.regions, inferredTypes)))
980abfd1a8bSRiver Riddle           return;
981abfd1a8bSRiver Riddle 
982abfd1a8bSRiver Riddle         for (unsigned i = 0, e = state.types.size(); i != e; ++i)
983abfd1a8bSRiver Riddle           if (!state.types[i])
984abfd1a8bSRiver Riddle             state.types[i] = inferredTypes[i];
985abfd1a8bSRiver Riddle       }
986abfd1a8bSRiver Riddle       Operation *resultOp = rewriter.createOperation(state);
987abfd1a8bSRiver Riddle       memory[memIndex] = resultOp;
988abfd1a8bSRiver Riddle 
989abfd1a8bSRiver Riddle       LLVM_DEBUG({
990abfd1a8bSRiver Riddle         llvm::dbgs() << "  * Attributes: "
991abfd1a8bSRiver Riddle                      << state.attributes.getDictionary(state.getContext())
992abfd1a8bSRiver Riddle                      << "\n  * Operands: ";
993abfd1a8bSRiver Riddle         llvm::interleaveComma(state.operands, llvm::dbgs());
994abfd1a8bSRiver Riddle         llvm::dbgs() << "\n  * Result Types: ";
995abfd1a8bSRiver Riddle         llvm::interleaveComma(state.types, llvm::dbgs());
996abfd1a8bSRiver Riddle         llvm::dbgs() << "\n  * Result: " << *resultOp << "\n\n";
997abfd1a8bSRiver Riddle       });
998abfd1a8bSRiver Riddle       break;
999abfd1a8bSRiver Riddle     }
1000abfd1a8bSRiver Riddle     case EraseOp: {
1001abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1002abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
1003abfd1a8bSRiver Riddle 
1004abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n\n");
1005abfd1a8bSRiver Riddle       rewriter.eraseOp(op);
1006abfd1a8bSRiver Riddle       break;
1007abfd1a8bSRiver Riddle     }
1008abfd1a8bSRiver Riddle     case Finalize: {
1009abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1010abfd1a8bSRiver Riddle       return;
1011abfd1a8bSRiver Riddle     }
1012abfd1a8bSRiver Riddle     case GetAttribute: {
1013abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1014abfd1a8bSRiver Riddle       unsigned memIndex = read();
1015abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
1016abfd1a8bSRiver Riddle       Identifier attrName = read<Identifier>();
1017abfd1a8bSRiver Riddle       Attribute attr = op->getAttr(attrName);
1018abfd1a8bSRiver Riddle 
1019abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1020abfd1a8bSRiver Riddle                               << "  * Attribute: " << attrName << "\n"
1021abfd1a8bSRiver Riddle                               << "  * Result: " << attr << "\n\n");
1022abfd1a8bSRiver Riddle       memory[memIndex] = attr.getAsOpaquePointer();
1023abfd1a8bSRiver Riddle       break;
1024abfd1a8bSRiver Riddle     }
1025abfd1a8bSRiver Riddle     case GetAttributeType: {
1026abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1027abfd1a8bSRiver Riddle       unsigned memIndex = read();
1028abfd1a8bSRiver Riddle       Attribute attr = read<Attribute>();
1029abfd1a8bSRiver Riddle 
1030abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1031abfd1a8bSRiver Riddle                               << "  * Result: " << attr.getType() << "\n\n");
1032abfd1a8bSRiver Riddle       memory[memIndex] = attr.getType().getAsOpaquePointer();
1033abfd1a8bSRiver Riddle       break;
1034abfd1a8bSRiver Riddle     }
1035abfd1a8bSRiver Riddle     case GetDefiningOp: {
1036abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1037abfd1a8bSRiver Riddle       unsigned memIndex = read();
1038abfd1a8bSRiver Riddle       Value value = read<Value>();
1039abfd1a8bSRiver Riddle       Operation *op = value ? value.getDefiningOp() : nullptr;
1040abfd1a8bSRiver Riddle 
1041abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1042abfd1a8bSRiver Riddle                               << "  * Result: " << *op << "\n\n");
1043abfd1a8bSRiver Riddle       memory[memIndex] = op;
1044abfd1a8bSRiver Riddle       break;
1045abfd1a8bSRiver Riddle     }
1046abfd1a8bSRiver Riddle     case GetOperand0:
1047abfd1a8bSRiver Riddle     case GetOperand1:
1048abfd1a8bSRiver Riddle     case GetOperand2:
1049abfd1a8bSRiver Riddle     case GetOperand3:
1050abfd1a8bSRiver Riddle     case GetOperandN: {
1051abfd1a8bSRiver Riddle       LLVM_DEBUG({
1052abfd1a8bSRiver Riddle         llvm::dbgs() << "Executing GetOperand"
1053abfd1a8bSRiver Riddle                      << (opCode == GetOperandN ? Twine("N")
1054abfd1a8bSRiver Riddle                                                : Twine(opCode - GetOperand0))
1055abfd1a8bSRiver Riddle                      << ":\n";
1056abfd1a8bSRiver Riddle       });
1057abfd1a8bSRiver Riddle       unsigned index =
1058abfd1a8bSRiver Riddle           opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
1059abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
1060abfd1a8bSRiver Riddle       unsigned memIndex = read();
1061abfd1a8bSRiver Riddle       Value operand =
1062abfd1a8bSRiver Riddle           index < op->getNumOperands() ? op->getOperand(index) : Value();
1063abfd1a8bSRiver Riddle 
1064abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1065abfd1a8bSRiver Riddle                               << "  * Index: " << index << "\n"
1066abfd1a8bSRiver Riddle                               << "  * Result: " << operand << "\n\n");
1067abfd1a8bSRiver Riddle       memory[memIndex] = operand.getAsOpaquePointer();
1068abfd1a8bSRiver Riddle       break;
1069abfd1a8bSRiver Riddle     }
1070abfd1a8bSRiver Riddle     case GetResult0:
1071abfd1a8bSRiver Riddle     case GetResult1:
1072abfd1a8bSRiver Riddle     case GetResult2:
1073abfd1a8bSRiver Riddle     case GetResult3:
1074abfd1a8bSRiver Riddle     case GetResultN: {
1075abfd1a8bSRiver Riddle       LLVM_DEBUG({
1076abfd1a8bSRiver Riddle         llvm::dbgs() << "Executing GetResult"
1077abfd1a8bSRiver Riddle                      << (opCode == GetResultN ? Twine("N")
1078abfd1a8bSRiver Riddle                                               : Twine(opCode - GetResult0))
1079abfd1a8bSRiver Riddle                      << ":\n";
1080abfd1a8bSRiver Riddle       });
1081abfd1a8bSRiver Riddle       unsigned index =
1082abfd1a8bSRiver Riddle           opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
1083abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
1084abfd1a8bSRiver Riddle       unsigned memIndex = read();
1085abfd1a8bSRiver Riddle       OpResult result =
1086abfd1a8bSRiver Riddle           index < op->getNumResults() ? op->getResult(index) : OpResult();
1087abfd1a8bSRiver Riddle 
1088abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1089abfd1a8bSRiver Riddle                               << "  * Index: " << index << "\n"
1090abfd1a8bSRiver Riddle                               << "  * Result: " << result << "\n\n");
1091abfd1a8bSRiver Riddle       memory[memIndex] = result.getAsOpaquePointer();
1092abfd1a8bSRiver Riddle       break;
1093abfd1a8bSRiver Riddle     }
1094abfd1a8bSRiver Riddle     case GetValueType: {
1095abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1096abfd1a8bSRiver Riddle       unsigned memIndex = read();
1097abfd1a8bSRiver Riddle       Value value = read<Value>();
1098abfd1a8bSRiver Riddle 
1099abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1100abfd1a8bSRiver Riddle                               << "  * Result: " << value.getType() << "\n\n");
1101abfd1a8bSRiver Riddle       memory[memIndex] = value.getType().getAsOpaquePointer();
1102abfd1a8bSRiver Riddle       break;
1103abfd1a8bSRiver Riddle     }
1104abfd1a8bSRiver Riddle     case IsNotNull: {
1105abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1106abfd1a8bSRiver Riddle       const void *value = read<const void *>();
1107abfd1a8bSRiver Riddle 
1108abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n\n");
1109abfd1a8bSRiver Riddle       selectJump(value != nullptr);
1110abfd1a8bSRiver Riddle       break;
1111abfd1a8bSRiver Riddle     }
1112abfd1a8bSRiver Riddle     case RecordMatch: {
1113abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1114abfd1a8bSRiver Riddle       assert(matches &&
1115abfd1a8bSRiver Riddle              "expected matches to be provided when executing the matcher");
1116abfd1a8bSRiver Riddle       unsigned patternIndex = read();
1117abfd1a8bSRiver Riddle       PatternBenefit benefit = currentPatternBenefits[patternIndex];
1118abfd1a8bSRiver Riddle       const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1119abfd1a8bSRiver Riddle 
1120abfd1a8bSRiver Riddle       // If the benefit of the pattern is impossible, skip the processing of the
1121abfd1a8bSRiver Riddle       // rest of the pattern.
1122abfd1a8bSRiver Riddle       if (benefit.isImpossibleToMatch()) {
1123abfd1a8bSRiver Riddle         LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n\n");
1124abfd1a8bSRiver Riddle         curCodeIt = dest;
1125abfd1a8bSRiver Riddle         break;
1126abfd1a8bSRiver Riddle       }
1127abfd1a8bSRiver Riddle 
1128abfd1a8bSRiver Riddle       // Create a fused location containing the locations of each of the
1129abfd1a8bSRiver Riddle       // operations used in the match. This will be used as the location for
1130abfd1a8bSRiver Riddle       // created operations during the rewrite that don't already have an
1131abfd1a8bSRiver Riddle       // explicit location set.
1132abfd1a8bSRiver Riddle       unsigned numMatchLocs = read();
1133abfd1a8bSRiver Riddle       SmallVector<Location, 4> matchLocs;
1134abfd1a8bSRiver Riddle       matchLocs.reserve(numMatchLocs);
1135abfd1a8bSRiver Riddle       for (unsigned i = 0; i != numMatchLocs; ++i)
1136abfd1a8bSRiver Riddle         matchLocs.push_back(read<Operation *>()->getLoc());
1137abfd1a8bSRiver Riddle       Location matchLoc = rewriter.getFusedLoc(matchLocs);
1138abfd1a8bSRiver Riddle 
1139abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1140abfd1a8bSRiver Riddle                               << "  * Location: " << matchLoc << "\n\n");
1141abfd1a8bSRiver Riddle       matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
1142abfd1a8bSRiver Riddle       readList<const void *>(matches->back().values);
1143abfd1a8bSRiver Riddle       curCodeIt = dest;
1144abfd1a8bSRiver Riddle       break;
1145abfd1a8bSRiver Riddle     }
1146abfd1a8bSRiver Riddle     case ReplaceOp: {
1147abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1148abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
1149abfd1a8bSRiver Riddle       SmallVector<Value, 16> args;
1150abfd1a8bSRiver Riddle       readList<Value>(args);
1151abfd1a8bSRiver Riddle 
1152abfd1a8bSRiver Riddle       LLVM_DEBUG({
1153abfd1a8bSRiver Riddle         llvm::dbgs() << "  * Operation: " << *op << "\n"
1154abfd1a8bSRiver Riddle                      << "  * Values: ";
1155abfd1a8bSRiver Riddle         llvm::interleaveComma(args, llvm::dbgs());
1156abfd1a8bSRiver Riddle         llvm::dbgs() << "\n\n";
1157abfd1a8bSRiver Riddle       });
1158abfd1a8bSRiver Riddle       rewriter.replaceOp(op, args);
1159abfd1a8bSRiver Riddle       break;
1160abfd1a8bSRiver Riddle     }
1161abfd1a8bSRiver Riddle     case SwitchAttribute: {
1162abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1163abfd1a8bSRiver Riddle       Attribute value = read<Attribute>();
1164abfd1a8bSRiver Riddle       ArrayAttr cases = read<ArrayAttr>();
1165abfd1a8bSRiver Riddle       handleSwitch(value, cases);
1166abfd1a8bSRiver Riddle       break;
1167abfd1a8bSRiver Riddle     }
1168abfd1a8bSRiver Riddle     case SwitchOperandCount: {
1169abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1170abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
1171abfd1a8bSRiver Riddle       auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1172abfd1a8bSRiver Riddle 
1173abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1174abfd1a8bSRiver Riddle       handleSwitch(op->getNumOperands(), cases);
1175abfd1a8bSRiver Riddle       break;
1176abfd1a8bSRiver Riddle     }
1177abfd1a8bSRiver Riddle     case SwitchOperationName: {
1178abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1179abfd1a8bSRiver Riddle       OperationName value = read<Operation *>()->getName();
1180abfd1a8bSRiver Riddle       size_t caseCount = read();
1181abfd1a8bSRiver Riddle 
1182abfd1a8bSRiver Riddle       // The operation names are stored in-line, so to print them out for
1183abfd1a8bSRiver Riddle       // debugging purposes we need to read the array before executing the
1184abfd1a8bSRiver Riddle       // switch so that we can display all of the possible values.
1185abfd1a8bSRiver Riddle       LLVM_DEBUG({
1186abfd1a8bSRiver Riddle         const ByteCodeField *prevCodeIt = curCodeIt;
1187abfd1a8bSRiver Riddle         llvm::dbgs() << "  * Value: " << value << "\n"
1188abfd1a8bSRiver Riddle                      << "  * Cases: ";
1189abfd1a8bSRiver Riddle         llvm::interleaveComma(
1190abfd1a8bSRiver Riddle             llvm::map_range(llvm::seq<size_t>(0, caseCount),
1191abfd1a8bSRiver Riddle                             [&](size_t i) { return read<OperationName>(); }),
1192abfd1a8bSRiver Riddle             llvm::dbgs());
1193abfd1a8bSRiver Riddle         llvm::dbgs() << "\n\n";
1194abfd1a8bSRiver Riddle         curCodeIt = prevCodeIt;
1195abfd1a8bSRiver Riddle       });
1196abfd1a8bSRiver Riddle 
1197abfd1a8bSRiver Riddle       // Try to find the switch value within any of the cases.
1198abfd1a8bSRiver Riddle       size_t jumpDest = 0;
1199abfd1a8bSRiver Riddle       for (size_t i = 0; i != caseCount; ++i) {
1200abfd1a8bSRiver Riddle         if (read<OperationName>() == value) {
1201abfd1a8bSRiver Riddle           curCodeIt += (caseCount - i - 1);
1202abfd1a8bSRiver Riddle           jumpDest = i + 1;
1203abfd1a8bSRiver Riddle           break;
1204abfd1a8bSRiver Riddle         }
1205abfd1a8bSRiver Riddle       }
1206abfd1a8bSRiver Riddle       selectJump(jumpDest);
1207abfd1a8bSRiver Riddle       break;
1208abfd1a8bSRiver Riddle     }
1209abfd1a8bSRiver Riddle     case SwitchResultCount: {
1210abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1211abfd1a8bSRiver Riddle       Operation *op = read<Operation *>();
1212abfd1a8bSRiver Riddle       auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1213abfd1a8bSRiver Riddle 
1214abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1215abfd1a8bSRiver Riddle       handleSwitch(op->getNumResults(), cases);
1216abfd1a8bSRiver Riddle       break;
1217abfd1a8bSRiver Riddle     }
1218abfd1a8bSRiver Riddle     case SwitchType: {
1219abfd1a8bSRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1220abfd1a8bSRiver Riddle       Type value = read<Type>();
1221abfd1a8bSRiver Riddle       auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1222abfd1a8bSRiver Riddle       handleSwitch(value, cases);
1223abfd1a8bSRiver Riddle       break;
1224abfd1a8bSRiver Riddle     }
1225abfd1a8bSRiver Riddle     }
1226abfd1a8bSRiver Riddle   }
1227abfd1a8bSRiver Riddle }
1228abfd1a8bSRiver Riddle 
1229abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched
1230abfd1a8bSRiver Riddle /// patterns in `matches`.
1231abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1232abfd1a8bSRiver Riddle                         SmallVectorImpl<MatchResult> &matches,
1233abfd1a8bSRiver Riddle                         PDLByteCodeMutableState &state) const {
1234abfd1a8bSRiver Riddle   // The first memory slot is always the root operation.
1235abfd1a8bSRiver Riddle   state.memory[0] = op;
1236abfd1a8bSRiver Riddle 
1237abfd1a8bSRiver Riddle   // The matcher function always starts at code address 0.
1238abfd1a8bSRiver Riddle   ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
1239abfd1a8bSRiver Riddle                             matcherByteCode, state.currentPatternBenefits,
1240abfd1a8bSRiver Riddle                             patterns, constraintFunctions, createFunctions,
1241abfd1a8bSRiver Riddle                             rewriteFunctions);
1242abfd1a8bSRiver Riddle   executor.execute(rewriter, &matches);
1243abfd1a8bSRiver Riddle 
1244abfd1a8bSRiver Riddle   // Order the found matches by benefit.
1245abfd1a8bSRiver Riddle   std::stable_sort(matches.begin(), matches.end(),
1246abfd1a8bSRiver Riddle                    [](const MatchResult &lhs, const MatchResult &rhs) {
1247abfd1a8bSRiver Riddle                      return lhs.benefit > rhs.benefit;
1248abfd1a8bSRiver Riddle                    });
1249abfd1a8bSRiver Riddle }
1250abfd1a8bSRiver Riddle 
1251abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`.
1252abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1253abfd1a8bSRiver Riddle                           PDLByteCodeMutableState &state) const {
1254abfd1a8bSRiver Riddle   // The arguments of the rewrite function are stored at the start of the
1255abfd1a8bSRiver Riddle   // memory buffer.
1256abfd1a8bSRiver Riddle   llvm::copy(match.values, state.memory.begin());
1257abfd1a8bSRiver Riddle 
1258abfd1a8bSRiver Riddle   ByteCodeExecutor executor(
1259abfd1a8bSRiver Riddle       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1260abfd1a8bSRiver Riddle       uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
1261abfd1a8bSRiver Riddle       constraintFunctions, createFunctions, rewriteFunctions);
1262abfd1a8bSRiver Riddle   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1263abfd1a8bSRiver Riddle }
1264