1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle //
9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter.
10abfd1a8bSRiver Riddle //
11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
12abfd1a8bSRiver Riddle 
13abfd1a8bSRiver Riddle #include "ByteCode.h"
14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h"
15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17e66c2e25SRiver Riddle #include "mlir/IR/BuiltinOps.h"
18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h"
19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h"
20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h"
21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h"
23abfd1a8bSRiver Riddle 
24abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode"
25abfd1a8bSRiver Riddle 
26abfd1a8bSRiver Riddle using namespace mlir;
27abfd1a8bSRiver Riddle using namespace mlir::detail;
28abfd1a8bSRiver Riddle 
29abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
30abfd1a8bSRiver Riddle // PDLByteCodePattern
31abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
32abfd1a8bSRiver Riddle 
33abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
34abfd1a8bSRiver Riddle                                               ByteCodeAddr rewriterAddr) {
35abfd1a8bSRiver Riddle   SmallVector<StringRef, 8> generatedOps;
36abfd1a8bSRiver Riddle   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
37abfd1a8bSRiver Riddle     generatedOps =
38abfd1a8bSRiver Riddle         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
39abfd1a8bSRiver Riddle 
40abfd1a8bSRiver Riddle   PatternBenefit benefit = matchOp.benefit();
41abfd1a8bSRiver Riddle   MLIRContext *ctx = matchOp.getContext();
42abfd1a8bSRiver Riddle 
43abfd1a8bSRiver Riddle   // Check to see if this is pattern matches a specific operation type.
44abfd1a8bSRiver Riddle   if (Optional<StringRef> rootKind = matchOp.rootKind())
45abfd1a8bSRiver Riddle     return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
46abfd1a8bSRiver Riddle                               ctx);
47abfd1a8bSRiver Riddle   return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
48abfd1a8bSRiver Riddle                             MatchAnyOpTypeTag());
49abfd1a8bSRiver Riddle }
50abfd1a8bSRiver Riddle 
51abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
52abfd1a8bSRiver Riddle // PDLByteCodeMutableState
53abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
54abfd1a8bSRiver Riddle 
55abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
56abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by
57abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`.
58abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
59abfd1a8bSRiver Riddle                                                    PatternBenefit benefit) {
60abfd1a8bSRiver Riddle   currentPatternBenefits[patternIndex] = benefit;
61abfd1a8bSRiver Riddle }
62abfd1a8bSRiver Riddle 
63abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
64abfd1a8bSRiver Riddle // Bytecode OpCodes
65abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
66abfd1a8bSRiver Riddle 
67abfd1a8bSRiver Riddle namespace {
68abfd1a8bSRiver Riddle enum OpCode : ByteCodeField {
69abfd1a8bSRiver Riddle   /// Apply an externally registered constraint.
70abfd1a8bSRiver Riddle   ApplyConstraint,
71abfd1a8bSRiver Riddle   /// Apply an externally registered rewrite.
72abfd1a8bSRiver Riddle   ApplyRewrite,
73abfd1a8bSRiver Riddle   /// Check if two generic values are equal.
74abfd1a8bSRiver Riddle   AreEqual,
75abfd1a8bSRiver Riddle   /// Unconditional branch.
76abfd1a8bSRiver Riddle   Branch,
77abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a constant.
78abfd1a8bSRiver Riddle   CheckOperandCount,
79abfd1a8bSRiver Riddle   /// Compare the name of an operation with a constant.
80abfd1a8bSRiver Riddle   CheckOperationName,
81abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a constant.
82abfd1a8bSRiver Riddle   CheckResultCount,
83abfd1a8bSRiver Riddle   /// Invoke a native creation method.
84abfd1a8bSRiver Riddle   CreateNative,
85abfd1a8bSRiver Riddle   /// Create an operation.
86abfd1a8bSRiver Riddle   CreateOperation,
87abfd1a8bSRiver Riddle   /// Erase an operation.
88abfd1a8bSRiver Riddle   EraseOp,
89abfd1a8bSRiver Riddle   /// Terminate a matcher or rewrite sequence.
90abfd1a8bSRiver Riddle   Finalize,
91abfd1a8bSRiver Riddle   /// Get a specific attribute of an operation.
92abfd1a8bSRiver Riddle   GetAttribute,
93abfd1a8bSRiver Riddle   /// Get the type of an attribute.
94abfd1a8bSRiver Riddle   GetAttributeType,
95abfd1a8bSRiver Riddle   /// Get the defining operation of a value.
96abfd1a8bSRiver Riddle   GetDefiningOp,
97abfd1a8bSRiver Riddle   /// Get a specific operand of an operation.
98abfd1a8bSRiver Riddle   GetOperand0,
99abfd1a8bSRiver Riddle   GetOperand1,
100abfd1a8bSRiver Riddle   GetOperand2,
101abfd1a8bSRiver Riddle   GetOperand3,
102abfd1a8bSRiver Riddle   GetOperandN,
103abfd1a8bSRiver Riddle   /// Get a specific result of an operation.
104abfd1a8bSRiver Riddle   GetResult0,
105abfd1a8bSRiver Riddle   GetResult1,
106abfd1a8bSRiver Riddle   GetResult2,
107abfd1a8bSRiver Riddle   GetResult3,
108abfd1a8bSRiver Riddle   GetResultN,
109abfd1a8bSRiver Riddle   /// Get the type of a value.
110abfd1a8bSRiver Riddle   GetValueType,
111abfd1a8bSRiver Riddle   /// Check if a generic value is not null.
112abfd1a8bSRiver Riddle   IsNotNull,
113abfd1a8bSRiver Riddle   /// Record a successful pattern match.
114abfd1a8bSRiver Riddle   RecordMatch,
115abfd1a8bSRiver Riddle   /// Replace an operation.
116abfd1a8bSRiver Riddle   ReplaceOp,
117abfd1a8bSRiver Riddle   /// Compare an attribute with a set of constants.
118abfd1a8bSRiver Riddle   SwitchAttribute,
119abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a set of constants.
120abfd1a8bSRiver Riddle   SwitchOperandCount,
121abfd1a8bSRiver Riddle   /// Compare the name of an operation with a set of constants.
122abfd1a8bSRiver Riddle   SwitchOperationName,
123abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a set of constants.
124abfd1a8bSRiver Riddle   SwitchResultCount,
125abfd1a8bSRiver Riddle   /// Compare a type with a set of constants.
126abfd1a8bSRiver Riddle   SwitchType,
127abfd1a8bSRiver Riddle };
128abfd1a8bSRiver Riddle 
129abfd1a8bSRiver Riddle enum class PDLValueKind { Attribute, Operation, Type, Value };
130abfd1a8bSRiver Riddle } // end anonymous namespace
131abfd1a8bSRiver Riddle 
132abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
133abfd1a8bSRiver Riddle // ByteCode Generation
134abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
135abfd1a8bSRiver Riddle 
136abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
137abfd1a8bSRiver Riddle // Generator
138abfd1a8bSRiver Riddle 
139abfd1a8bSRiver Riddle namespace {
140abfd1a8bSRiver Riddle struct ByteCodeWriter;
141abfd1a8bSRiver Riddle 
142abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode.
143abfd1a8bSRiver Riddle class Generator {
144abfd1a8bSRiver Riddle public:
145abfd1a8bSRiver Riddle   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
146abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &matcherByteCode,
147abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
148abfd1a8bSRiver Riddle             SmallVectorImpl<PDLByteCodePattern> &patterns,
149abfd1a8bSRiver Riddle             ByteCodeField &maxValueMemoryIndex,
150abfd1a8bSRiver Riddle             llvm::StringMap<PDLConstraintFunction> &constraintFns,
151abfd1a8bSRiver Riddle             llvm::StringMap<PDLCreateFunction> &createFns,
152abfd1a8bSRiver Riddle             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
153abfd1a8bSRiver Riddle       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
154abfd1a8bSRiver Riddle         rewriterByteCode(rewriterByteCode), patterns(patterns),
155abfd1a8bSRiver Riddle         maxValueMemoryIndex(maxValueMemoryIndex) {
156abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(constraintFns))
157abfd1a8bSRiver Riddle       constraintToMemIndex.try_emplace(it.value().first(), it.index());
158abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(createFns))
159abfd1a8bSRiver Riddle       nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
160abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(rewriteFns))
161abfd1a8bSRiver Riddle       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
162abfd1a8bSRiver Riddle   }
163abfd1a8bSRiver Riddle 
164abfd1a8bSRiver Riddle   /// Generate the bytecode for the given PDL interpreter module.
165abfd1a8bSRiver Riddle   void generate(ModuleOp module);
166abfd1a8bSRiver Riddle 
167abfd1a8bSRiver Riddle   /// Return the memory index to use for the given value.
168abfd1a8bSRiver Riddle   ByteCodeField &getMemIndex(Value value) {
169abfd1a8bSRiver Riddle     assert(valueToMemIndex.count(value) &&
170abfd1a8bSRiver Riddle            "expected memory index to be assigned");
171abfd1a8bSRiver Riddle     return valueToMemIndex[value];
172abfd1a8bSRiver Riddle   }
173abfd1a8bSRiver Riddle 
174abfd1a8bSRiver Riddle   /// Return an index to use when referring to the given data that is uniqued in
175abfd1a8bSRiver Riddle   /// the MLIR context.
176abfd1a8bSRiver Riddle   template <typename T>
177abfd1a8bSRiver Riddle   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
178abfd1a8bSRiver Riddle   getMemIndex(T val) {
179abfd1a8bSRiver Riddle     const void *opaqueVal = val.getAsOpaquePointer();
180abfd1a8bSRiver Riddle 
181abfd1a8bSRiver Riddle     // Get or insert a reference to this value.
182abfd1a8bSRiver Riddle     auto it = uniquedDataToMemIndex.try_emplace(
183abfd1a8bSRiver Riddle         opaqueVal, maxValueMemoryIndex + uniquedData.size());
184abfd1a8bSRiver Riddle     if (it.second)
185abfd1a8bSRiver Riddle       uniquedData.push_back(opaqueVal);
186abfd1a8bSRiver Riddle     return it.first->second;
187abfd1a8bSRiver Riddle   }
188abfd1a8bSRiver Riddle 
189abfd1a8bSRiver Riddle private:
190abfd1a8bSRiver Riddle   /// Allocate memory indices for the results of operations within the matcher
191abfd1a8bSRiver Riddle   /// and rewriters.
192abfd1a8bSRiver Riddle   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
193abfd1a8bSRiver Riddle 
194abfd1a8bSRiver Riddle   /// Generate the bytecode for the given operation.
195abfd1a8bSRiver Riddle   void generate(Operation *op, ByteCodeWriter &writer);
196abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
197abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
198abfd1a8bSRiver Riddle   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
199abfd1a8bSRiver Riddle   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
200abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
201abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
202abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
203abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
204abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
205abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
206abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
207abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
208abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
209abfd1a8bSRiver Riddle   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
210abfd1a8bSRiver Riddle   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
211abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
212abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
213abfd1a8bSRiver Riddle   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
214abfd1a8bSRiver Riddle   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
215abfd1a8bSRiver Riddle   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
216abfd1a8bSRiver Riddle   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
217abfd1a8bSRiver Riddle   void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
218abfd1a8bSRiver Riddle   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
219abfd1a8bSRiver Riddle   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
220abfd1a8bSRiver Riddle   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
221abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
222abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
223abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
224abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
225abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
226abfd1a8bSRiver Riddle 
227abfd1a8bSRiver Riddle   /// Mapping from value to its corresponding memory index.
228abfd1a8bSRiver Riddle   DenseMap<Value, ByteCodeField> valueToMemIndex;
229abfd1a8bSRiver Riddle 
230abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered rewrite to its index in
231abfd1a8bSRiver Riddle   /// the bytecode registry.
232abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
233abfd1a8bSRiver Riddle 
234abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered constraint to its index
235abfd1a8bSRiver Riddle   /// in the bytecode registry.
236abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> constraintToMemIndex;
237abfd1a8bSRiver Riddle 
238abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered creation method to its
239abfd1a8bSRiver Riddle   /// index in the bytecode registry.
240abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
241abfd1a8bSRiver Riddle 
242abfd1a8bSRiver Riddle   /// Mapping from rewriter function name to the bytecode address of the
243abfd1a8bSRiver Riddle   /// rewriter function in byte.
244abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
245abfd1a8bSRiver Riddle 
246abfd1a8bSRiver Riddle   /// Mapping from a uniqued storage object to its memory index within
247abfd1a8bSRiver Riddle   /// `uniquedData`.
248abfd1a8bSRiver Riddle   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
249abfd1a8bSRiver Riddle 
250abfd1a8bSRiver Riddle   /// The current MLIR context.
251abfd1a8bSRiver Riddle   MLIRContext *ctx;
252abfd1a8bSRiver Riddle 
253abfd1a8bSRiver Riddle   /// Data of the ByteCode class to be populated.
254abfd1a8bSRiver Riddle   std::vector<const void *> &uniquedData;
255abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &matcherByteCode;
256abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
257abfd1a8bSRiver Riddle   SmallVectorImpl<PDLByteCodePattern> &patterns;
258abfd1a8bSRiver Riddle   ByteCodeField &maxValueMemoryIndex;
259abfd1a8bSRiver Riddle };
260abfd1a8bSRiver Riddle 
261abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream.
262abfd1a8bSRiver Riddle struct ByteCodeWriter {
263abfd1a8bSRiver Riddle   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
264abfd1a8bSRiver Riddle       : bytecode(bytecode), generator(generator) {}
265abfd1a8bSRiver Riddle 
266abfd1a8bSRiver Riddle   /// Append a field to the bytecode.
267abfd1a8bSRiver Riddle   void append(ByteCodeField field) { bytecode.push_back(field); }
268fa20ab7bSRiver Riddle   void append(OpCode opCode) { bytecode.push_back(opCode); }
269abfd1a8bSRiver Riddle 
270abfd1a8bSRiver Riddle   /// Append an address to the bytecode.
271abfd1a8bSRiver Riddle   void append(ByteCodeAddr field) {
272abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
273abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
274abfd1a8bSRiver Riddle 
275abfd1a8bSRiver Riddle     ByteCodeField fieldParts[2];
276abfd1a8bSRiver Riddle     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
277abfd1a8bSRiver Riddle     bytecode.append({fieldParts[0], fieldParts[1]});
278abfd1a8bSRiver Riddle   }
279abfd1a8bSRiver Riddle 
280abfd1a8bSRiver Riddle   /// Append a successor range to the bytecode, the exact address will need to
281abfd1a8bSRiver Riddle   /// be resolved later.
282abfd1a8bSRiver Riddle   void append(SuccessorRange successors) {
283abfd1a8bSRiver Riddle     // Add back references to the any successors so that the address can be
284abfd1a8bSRiver Riddle     // resolved later.
285abfd1a8bSRiver Riddle     for (Block *successor : successors) {
286abfd1a8bSRiver Riddle       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
287abfd1a8bSRiver Riddle       append(ByteCodeAddr(0));
288abfd1a8bSRiver Riddle     }
289abfd1a8bSRiver Riddle   }
290abfd1a8bSRiver Riddle 
291abfd1a8bSRiver Riddle   /// Append a range of values that will be read as generic PDLValues.
292abfd1a8bSRiver Riddle   void appendPDLValueList(OperandRange values) {
293abfd1a8bSRiver Riddle     bytecode.push_back(values.size());
294abfd1a8bSRiver Riddle     for (Value value : values) {
295abfd1a8bSRiver Riddle       // Append the type of the value in addition to the value itself.
296abfd1a8bSRiver Riddle       PDLValueKind kind =
297abfd1a8bSRiver Riddle           TypeSwitch<Type, PDLValueKind>(value.getType())
298abfd1a8bSRiver Riddle               .Case<pdl::AttributeType>(
299abfd1a8bSRiver Riddle                   [](Type) { return PDLValueKind::Attribute; })
300abfd1a8bSRiver Riddle               .Case<pdl::OperationType>(
301abfd1a8bSRiver Riddle                   [](Type) { return PDLValueKind::Operation; })
302abfd1a8bSRiver Riddle               .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
303abfd1a8bSRiver Riddle               .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
304abfd1a8bSRiver Riddle       bytecode.push_back(static_cast<ByteCodeField>(kind));
305abfd1a8bSRiver Riddle       append(value);
306abfd1a8bSRiver Riddle     }
307abfd1a8bSRiver Riddle   }
308abfd1a8bSRiver Riddle 
309abfd1a8bSRiver Riddle   /// Check if the given class `T` has an iterator type.
310abfd1a8bSRiver Riddle   template <typename T, typename... Args>
311abfd1a8bSRiver Riddle   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
312abfd1a8bSRiver Riddle 
313abfd1a8bSRiver Riddle   /// Append a value that will be stored in a memory slot and not inline within
314abfd1a8bSRiver Riddle   /// the bytecode.
315abfd1a8bSRiver Riddle   template <typename T>
316abfd1a8bSRiver Riddle   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
317abfd1a8bSRiver Riddle                    std::is_pointer<T>::value>
318abfd1a8bSRiver Riddle   append(T value) {
319abfd1a8bSRiver Riddle     bytecode.push_back(generator.getMemIndex(value));
320abfd1a8bSRiver Riddle   }
321abfd1a8bSRiver Riddle 
322abfd1a8bSRiver Riddle   /// Append a range of values.
323abfd1a8bSRiver Riddle   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
324abfd1a8bSRiver Riddle   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
325abfd1a8bSRiver Riddle   append(T range) {
326abfd1a8bSRiver Riddle     bytecode.push_back(llvm::size(range));
327abfd1a8bSRiver Riddle     for (auto it : range)
328abfd1a8bSRiver Riddle       append(it);
329abfd1a8bSRiver Riddle   }
330abfd1a8bSRiver Riddle 
331abfd1a8bSRiver Riddle   /// Append a variadic number of fields to the bytecode.
332abfd1a8bSRiver Riddle   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
333abfd1a8bSRiver Riddle   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
334abfd1a8bSRiver Riddle     append(field);
335abfd1a8bSRiver Riddle     append(field2, fields...);
336abfd1a8bSRiver Riddle   }
337abfd1a8bSRiver Riddle 
338abfd1a8bSRiver Riddle   /// Successor references in the bytecode that have yet to be resolved.
339abfd1a8bSRiver Riddle   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
340abfd1a8bSRiver Riddle 
341abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
342abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &bytecode;
343abfd1a8bSRiver Riddle 
344abfd1a8bSRiver Riddle   /// The main generator producing PDL.
345abfd1a8bSRiver Riddle   Generator &generator;
346abfd1a8bSRiver Riddle };
347abfd1a8bSRiver Riddle } // end anonymous namespace
348abfd1a8bSRiver Riddle 
349abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) {
350abfd1a8bSRiver Riddle   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
351abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
352abfd1a8bSRiver Riddle   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
353abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getRewriterModuleName());
354abfd1a8bSRiver Riddle   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
355abfd1a8bSRiver Riddle 
356abfd1a8bSRiver Riddle   // Allocate memory indices for the results of operations within the matcher
357abfd1a8bSRiver Riddle   // and rewriters.
358abfd1a8bSRiver Riddle   allocateMemoryIndices(matcherFunc, rewriterModule);
359abfd1a8bSRiver Riddle 
360abfd1a8bSRiver Riddle   // Generate code for the rewriter functions.
361abfd1a8bSRiver Riddle   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
362abfd1a8bSRiver Riddle   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
363abfd1a8bSRiver Riddle     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
364abfd1a8bSRiver Riddle     for (Operation &op : rewriterFunc.getOps())
365abfd1a8bSRiver Riddle       generate(&op, rewriterByteCodeWriter);
366abfd1a8bSRiver Riddle   }
367abfd1a8bSRiver Riddle   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
368abfd1a8bSRiver Riddle          "unexpected branches in rewriter function");
369abfd1a8bSRiver Riddle 
370abfd1a8bSRiver Riddle   // Generate code for the matcher function.
371abfd1a8bSRiver Riddle   DenseMap<Block *, ByteCodeAddr> blockToAddr;
372abfd1a8bSRiver Riddle   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
373abfd1a8bSRiver Riddle   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
374abfd1a8bSRiver Riddle   for (Block *block : rpot) {
375abfd1a8bSRiver Riddle     // Keep track of where this block begins within the matcher function.
376abfd1a8bSRiver Riddle     blockToAddr.try_emplace(block, matcherByteCode.size());
377abfd1a8bSRiver Riddle     for (Operation &op : *block)
378abfd1a8bSRiver Riddle       generate(&op, matcherByteCodeWriter);
379abfd1a8bSRiver Riddle   }
380abfd1a8bSRiver Riddle 
381abfd1a8bSRiver Riddle   // Resolve successor references in the matcher.
382abfd1a8bSRiver Riddle   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
383abfd1a8bSRiver Riddle     ByteCodeAddr addr = blockToAddr[it.first];
384abfd1a8bSRiver Riddle     for (unsigned offsetToFix : it.second)
385abfd1a8bSRiver Riddle       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
386abfd1a8bSRiver Riddle   }
387abfd1a8bSRiver Riddle }
388abfd1a8bSRiver Riddle 
389abfd1a8bSRiver Riddle void Generator::allocateMemoryIndices(FuncOp matcherFunc,
390abfd1a8bSRiver Riddle                                       ModuleOp rewriterModule) {
391abfd1a8bSRiver Riddle   // Rewriters use simplistic allocation scheme that simply assigns an index to
392abfd1a8bSRiver Riddle   // each result.
393abfd1a8bSRiver Riddle   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
394abfd1a8bSRiver Riddle     ByteCodeField index = 0;
395abfd1a8bSRiver Riddle     for (BlockArgument arg : rewriterFunc.getArguments())
396abfd1a8bSRiver Riddle       valueToMemIndex.try_emplace(arg, index++);
397abfd1a8bSRiver Riddle     rewriterFunc.getBody().walk([&](Operation *op) {
398abfd1a8bSRiver Riddle       for (Value result : op->getResults())
399abfd1a8bSRiver Riddle         valueToMemIndex.try_emplace(result, index++);
400abfd1a8bSRiver Riddle     });
401abfd1a8bSRiver Riddle     if (index > maxValueMemoryIndex)
402abfd1a8bSRiver Riddle       maxValueMemoryIndex = index;
403abfd1a8bSRiver Riddle   }
404abfd1a8bSRiver Riddle 
405abfd1a8bSRiver Riddle   // The matcher function uses a more sophisticated numbering that tries to
406abfd1a8bSRiver Riddle   // minimize the number of memory indices assigned. This is done by determining
407abfd1a8bSRiver Riddle   // a live range of the values within the matcher, then the allocation is just
408abfd1a8bSRiver Riddle   // finding the minimal number of overlapping live ranges. This is essentially
409abfd1a8bSRiver Riddle   // a simplified form of register allocation where we don't necessarily have a
410abfd1a8bSRiver Riddle   // limited number of registers, but we still want to minimize the number used.
411abfd1a8bSRiver Riddle   DenseMap<Operation *, ByteCodeField> opToIndex;
412abfd1a8bSRiver Riddle   matcherFunc.getBody().walk([&](Operation *op) {
413abfd1a8bSRiver Riddle     opToIndex.insert(std::make_pair(op, opToIndex.size()));
414abfd1a8bSRiver Riddle   });
415abfd1a8bSRiver Riddle 
416abfd1a8bSRiver Riddle   // Liveness info for each of the defs within the matcher.
417abfd1a8bSRiver Riddle   using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
418abfd1a8bSRiver Riddle   LivenessSet::Allocator allocator;
419abfd1a8bSRiver Riddle   DenseMap<Value, LivenessSet> valueDefRanges;
420abfd1a8bSRiver Riddle 
421abfd1a8bSRiver Riddle   // Assign the root operation being matched to slot 0.
422abfd1a8bSRiver Riddle   BlockArgument rootOpArg = matcherFunc.getArgument(0);
423abfd1a8bSRiver Riddle   valueToMemIndex[rootOpArg] = 0;
424abfd1a8bSRiver Riddle 
425abfd1a8bSRiver Riddle   // Walk each of the blocks, computing the def interval that the value is used.
426abfd1a8bSRiver Riddle   Liveness matcherLiveness(matcherFunc);
427abfd1a8bSRiver Riddle   for (Block &block : matcherFunc.getBody()) {
428abfd1a8bSRiver Riddle     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
429abfd1a8bSRiver Riddle     assert(info && "expected liveness info for block");
430abfd1a8bSRiver Riddle     auto processValue = [&](Value value, Operation *firstUseOrDef) {
431abfd1a8bSRiver Riddle       // We don't need to process the root op argument, this value is always
432abfd1a8bSRiver Riddle       // assigned to the first memory slot.
433abfd1a8bSRiver Riddle       if (value == rootOpArg)
434abfd1a8bSRiver Riddle         return;
435abfd1a8bSRiver Riddle 
436abfd1a8bSRiver Riddle       // Set indices for the range of this block that the value is used.
437abfd1a8bSRiver Riddle       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
438abfd1a8bSRiver Riddle       defRangeIt->second.insert(
439abfd1a8bSRiver Riddle           opToIndex[firstUseOrDef],
440abfd1a8bSRiver Riddle           opToIndex[info->getEndOperation(value, firstUseOrDef)],
441abfd1a8bSRiver Riddle           /*dummyValue*/ 0);
442abfd1a8bSRiver Riddle     };
443abfd1a8bSRiver Riddle 
444abfd1a8bSRiver Riddle     // Process the live-ins of this block.
445abfd1a8bSRiver Riddle     for (Value liveIn : info->in())
446abfd1a8bSRiver Riddle       processValue(liveIn, &block.front());
447abfd1a8bSRiver Riddle 
448abfd1a8bSRiver Riddle     // Process any new defs within this block.
449abfd1a8bSRiver Riddle     for (Operation &op : block)
450abfd1a8bSRiver Riddle       for (Value result : op.getResults())
451abfd1a8bSRiver Riddle         processValue(result, &op);
452abfd1a8bSRiver Riddle   }
453abfd1a8bSRiver Riddle 
454abfd1a8bSRiver Riddle   // Greedily allocate memory slots using the computed def live ranges.
455abfd1a8bSRiver Riddle   std::vector<LivenessSet> allocatedIndices;
456abfd1a8bSRiver Riddle   for (auto &defIt : valueDefRanges) {
457abfd1a8bSRiver Riddle     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
458abfd1a8bSRiver Riddle     LivenessSet &defSet = defIt.second;
459abfd1a8bSRiver Riddle 
460abfd1a8bSRiver Riddle     // Try to allocate to an existing index.
461abfd1a8bSRiver Riddle     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
462abfd1a8bSRiver Riddle       LivenessSet &existingIndex = existingIndexIt.value();
463abfd1a8bSRiver Riddle       llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
464abfd1a8bSRiver Riddle           defIt.second, existingIndex);
465abfd1a8bSRiver Riddle       if (overlaps.valid())
466abfd1a8bSRiver Riddle         continue;
467abfd1a8bSRiver Riddle       // Union the range of the def within the existing index.
468abfd1a8bSRiver Riddle       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
469abfd1a8bSRiver Riddle         existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
470abfd1a8bSRiver Riddle       memIndex = existingIndexIt.index() + 1;
471abfd1a8bSRiver Riddle     }
472abfd1a8bSRiver Riddle 
473abfd1a8bSRiver Riddle     // If no existing index could be used, add a new one.
474abfd1a8bSRiver Riddle     if (memIndex == 0) {
475abfd1a8bSRiver Riddle       allocatedIndices.emplace_back(allocator);
476abfd1a8bSRiver Riddle       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
477abfd1a8bSRiver Riddle         allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
478abfd1a8bSRiver Riddle       memIndex = allocatedIndices.size();
479abfd1a8bSRiver Riddle     }
480abfd1a8bSRiver Riddle   }
481abfd1a8bSRiver Riddle 
482abfd1a8bSRiver Riddle   // Update the max number of indices.
483abfd1a8bSRiver Riddle   ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
484abfd1a8bSRiver Riddle   if (numMatcherIndices > maxValueMemoryIndex)
485abfd1a8bSRiver Riddle     maxValueMemoryIndex = numMatcherIndices;
486abfd1a8bSRiver Riddle }
487abfd1a8bSRiver Riddle 
488abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) {
489abfd1a8bSRiver Riddle   TypeSwitch<Operation *>(op)
490abfd1a8bSRiver Riddle       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
491abfd1a8bSRiver Riddle             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
492abfd1a8bSRiver Riddle             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
493abfd1a8bSRiver Riddle             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
494abfd1a8bSRiver Riddle             pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
495abfd1a8bSRiver Riddle             pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
496abfd1a8bSRiver Riddle             pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
497abfd1a8bSRiver Riddle             pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
498abfd1a8bSRiver Riddle             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
499abfd1a8bSRiver Riddle             pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
500abfd1a8bSRiver Riddle             pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
501abfd1a8bSRiver Riddle             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
502abfd1a8bSRiver Riddle             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
503abfd1a8bSRiver Riddle             pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
504abfd1a8bSRiver Riddle             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
505abfd1a8bSRiver Riddle           [&](auto interpOp) { this->generate(interpOp, writer); })
506abfd1a8bSRiver Riddle       .Default([](Operation *) {
507abfd1a8bSRiver Riddle         llvm_unreachable("unknown `pdl_interp` operation");
508abfd1a8bSRiver Riddle       });
509abfd1a8bSRiver Riddle }
510abfd1a8bSRiver Riddle 
511abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op,
512abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
513abfd1a8bSRiver Riddle   assert(constraintToMemIndex.count(op.name()) &&
514abfd1a8bSRiver Riddle          "expected index for constraint function");
515abfd1a8bSRiver Riddle   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
516abfd1a8bSRiver Riddle                 op.constParamsAttr());
517abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
518abfd1a8bSRiver Riddle   writer.append(op.getSuccessors());
519abfd1a8bSRiver Riddle }
520abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op,
521abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
522abfd1a8bSRiver Riddle   assert(externalRewriterToMemIndex.count(op.name()) &&
523abfd1a8bSRiver Riddle          "expected index for rewrite function");
524abfd1a8bSRiver Riddle   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
525abfd1a8bSRiver Riddle                 op.constParamsAttr(), op.root());
526abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
527abfd1a8bSRiver Riddle }
528abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
529abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
530abfd1a8bSRiver Riddle }
531abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
5328affe881SRiver Riddle   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
533abfd1a8bSRiver Riddle }
534abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op,
535abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
536abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
537abfd1a8bSRiver Riddle                 op.getSuccessors());
538abfd1a8bSRiver Riddle }
539abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op,
540abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
541abfd1a8bSRiver Riddle   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
542abfd1a8bSRiver Riddle                 op.getSuccessors());
543abfd1a8bSRiver Riddle }
544abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op,
545abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
546abfd1a8bSRiver Riddle   writer.append(OpCode::CheckOperationName, op.operation(),
547abfd1a8bSRiver Riddle                 OperationName(op.name(), ctx), op.getSuccessors());
548abfd1a8bSRiver Riddle }
549abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op,
550abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
551abfd1a8bSRiver Riddle   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
552abfd1a8bSRiver Riddle                 op.getSuccessors());
553abfd1a8bSRiver Riddle }
554abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
555abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
556abfd1a8bSRiver Riddle }
557abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op,
558abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
559abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
560abfd1a8bSRiver Riddle   getMemIndex(op.attribute()) = getMemIndex(op.value());
561abfd1a8bSRiver Riddle }
562abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateNativeOp op,
563abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
564abfd1a8bSRiver Riddle   assert(nativeCreateToMemIndex.count(op.name()) &&
565abfd1a8bSRiver Riddle          "expected index for creation function");
566abfd1a8bSRiver Riddle   writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
567abfd1a8bSRiver Riddle                 op.result(), op.constParamsAttr());
568abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
569abfd1a8bSRiver Riddle }
570abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op,
571abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
572abfd1a8bSRiver Riddle   writer.append(OpCode::CreateOperation, op.operation(),
573abfd1a8bSRiver Riddle                 OperationName(op.name(), ctx), op.operands());
574abfd1a8bSRiver Riddle 
575abfd1a8bSRiver Riddle   // Add the attributes.
576abfd1a8bSRiver Riddle   OperandRange attributes = op.attributes();
577abfd1a8bSRiver Riddle   writer.append(static_cast<ByteCodeField>(attributes.size()));
578abfd1a8bSRiver Riddle   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
579abfd1a8bSRiver Riddle     writer.append(
580abfd1a8bSRiver Riddle         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
581abfd1a8bSRiver Riddle         std::get<1>(it));
582abfd1a8bSRiver Riddle   }
583abfd1a8bSRiver Riddle   writer.append(op.types());
584abfd1a8bSRiver Riddle }
585abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
586abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
587abfd1a8bSRiver Riddle   getMemIndex(op.result()) = getMemIndex(op.value());
588abfd1a8bSRiver Riddle }
589abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
590abfd1a8bSRiver Riddle   writer.append(OpCode::EraseOp, op.operation());
591abfd1a8bSRiver Riddle }
592abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
593abfd1a8bSRiver Riddle   writer.append(OpCode::Finalize);
594abfd1a8bSRiver Riddle }
595abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op,
596abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
597abfd1a8bSRiver Riddle   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
598abfd1a8bSRiver Riddle                 Identifier::get(op.name(), ctx));
599abfd1a8bSRiver Riddle }
600abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op,
601abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
602abfd1a8bSRiver Riddle   writer.append(OpCode::GetAttributeType, op.result(), op.value());
603abfd1a8bSRiver Riddle }
604abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op,
605abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
606abfd1a8bSRiver Riddle   writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
607abfd1a8bSRiver Riddle }
608abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
609abfd1a8bSRiver Riddle   uint32_t index = op.index();
610abfd1a8bSRiver Riddle   if (index < 4)
611abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
612abfd1a8bSRiver Riddle   else
613abfd1a8bSRiver Riddle     writer.append(OpCode::GetOperandN, index);
614abfd1a8bSRiver Riddle   writer.append(op.operation(), op.value());
615abfd1a8bSRiver Riddle }
616abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
617abfd1a8bSRiver Riddle   uint32_t index = op.index();
618abfd1a8bSRiver Riddle   if (index < 4)
619abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
620abfd1a8bSRiver Riddle   else
621abfd1a8bSRiver Riddle     writer.append(OpCode::GetResultN, index);
622abfd1a8bSRiver Riddle   writer.append(op.operation(), op.value());
623abfd1a8bSRiver Riddle }
624abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op,
625abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
626abfd1a8bSRiver Riddle   writer.append(OpCode::GetValueType, op.result(), op.value());
627abfd1a8bSRiver Riddle }
628abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::InferredTypeOp op,
629abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
630abfd1a8bSRiver Riddle   // InferType maps to a null type as a marker for inferring a result type.
631abfd1a8bSRiver Riddle   getMemIndex(op.type()) = getMemIndex(Type());
632abfd1a8bSRiver Riddle }
633abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
634abfd1a8bSRiver Riddle   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
635abfd1a8bSRiver Riddle }
636abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
637abfd1a8bSRiver Riddle   ByteCodeField patternIndex = patterns.size();
638abfd1a8bSRiver Riddle   patterns.emplace_back(PDLByteCodePattern::create(
639abfd1a8bSRiver Riddle       op, rewriterToAddr[op.rewriter().getLeafReference()]));
6408affe881SRiver Riddle   writer.append(OpCode::RecordMatch, patternIndex,
6418affe881SRiver Riddle                 SuccessorRange(op.getOperation()), op.matchedOps(),
6428affe881SRiver Riddle                 op.inputs());
643abfd1a8bSRiver Riddle }
644abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
645abfd1a8bSRiver Riddle   writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
646abfd1a8bSRiver Riddle }
647abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op,
648abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
649abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
650abfd1a8bSRiver Riddle                 op.getSuccessors());
651abfd1a8bSRiver Riddle }
652abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op,
653abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
654abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
655abfd1a8bSRiver Riddle                 op.getSuccessors());
656abfd1a8bSRiver Riddle }
657abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op,
658abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
659abfd1a8bSRiver Riddle   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
660abfd1a8bSRiver Riddle     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
661abfd1a8bSRiver Riddle   });
662abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
663abfd1a8bSRiver Riddle                 op.getSuccessors());
664abfd1a8bSRiver Riddle }
665abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op,
666abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
667abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
668abfd1a8bSRiver Riddle                 op.getSuccessors());
669abfd1a8bSRiver Riddle }
670abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
671abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
672abfd1a8bSRiver Riddle                 op.getSuccessors());
673abfd1a8bSRiver Riddle }
674abfd1a8bSRiver Riddle 
675abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
676abfd1a8bSRiver Riddle // PDLByteCode
677abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
678abfd1a8bSRiver Riddle 
679abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module,
680abfd1a8bSRiver Riddle                          llvm::StringMap<PDLConstraintFunction> constraintFns,
681abfd1a8bSRiver Riddle                          llvm::StringMap<PDLCreateFunction> createFns,
682abfd1a8bSRiver Riddle                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
683abfd1a8bSRiver Riddle   Generator generator(module.getContext(), uniquedData, matcherByteCode,
684abfd1a8bSRiver Riddle                       rewriterByteCode, patterns, maxValueMemoryIndex,
685abfd1a8bSRiver Riddle                       constraintFns, createFns, rewriteFns);
686abfd1a8bSRiver Riddle   generator.generate(module);
687abfd1a8bSRiver Riddle 
688abfd1a8bSRiver Riddle   // Initialize the external functions.
689abfd1a8bSRiver Riddle   for (auto &it : constraintFns)
690abfd1a8bSRiver Riddle     constraintFunctions.push_back(std::move(it.second));
691abfd1a8bSRiver Riddle   for (auto &it : createFns)
692abfd1a8bSRiver Riddle     createFunctions.push_back(std::move(it.second));
693abfd1a8bSRiver Riddle   for (auto &it : rewriteFns)
694abfd1a8bSRiver Riddle     rewriteFunctions.push_back(std::move(it.second));
695abfd1a8bSRiver Riddle }
696abfd1a8bSRiver Riddle 
697abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current
698abfd1a8bSRiver Riddle /// bytecode.
699abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
700abfd1a8bSRiver Riddle   state.memory.resize(maxValueMemoryIndex, nullptr);
701abfd1a8bSRiver Riddle   state.currentPatternBenefits.reserve(patterns.size());
702abfd1a8bSRiver Riddle   for (const PDLByteCodePattern &pattern : patterns)
703abfd1a8bSRiver Riddle     state.currentPatternBenefits.push_back(pattern.getBenefit());
704abfd1a8bSRiver Riddle }
705abfd1a8bSRiver Riddle 
706abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
707abfd1a8bSRiver Riddle // ByteCode Execution
708abfd1a8bSRiver Riddle 
709abfd1a8bSRiver Riddle namespace {
710abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream.
711abfd1a8bSRiver Riddle class ByteCodeExecutor {
712abfd1a8bSRiver Riddle public:
713abfd1a8bSRiver Riddle   ByteCodeExecutor(const ByteCodeField *curCodeIt,
714abfd1a8bSRiver Riddle                    MutableArrayRef<const void *> memory,
715abfd1a8bSRiver Riddle                    ArrayRef<const void *> uniquedMemory,
716abfd1a8bSRiver Riddle                    ArrayRef<ByteCodeField> code,
717abfd1a8bSRiver Riddle                    ArrayRef<PatternBenefit> currentPatternBenefits,
718abfd1a8bSRiver Riddle                    ArrayRef<PDLByteCodePattern> patterns,
719abfd1a8bSRiver Riddle                    ArrayRef<PDLConstraintFunction> constraintFunctions,
720abfd1a8bSRiver Riddle                    ArrayRef<PDLCreateFunction> createFunctions,
721abfd1a8bSRiver Riddle                    ArrayRef<PDLRewriteFunction> rewriteFunctions)
722abfd1a8bSRiver Riddle       : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
723abfd1a8bSRiver Riddle         code(code), currentPatternBenefits(currentPatternBenefits),
724abfd1a8bSRiver Riddle         patterns(patterns), constraintFunctions(constraintFunctions),
725abfd1a8bSRiver Riddle         createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
726abfd1a8bSRiver Riddle 
727abfd1a8bSRiver Riddle   /// Start executing the code at the current bytecode index. `matches` is an
728abfd1a8bSRiver Riddle   /// optional field provided when this function is executed in a matching
729abfd1a8bSRiver Riddle   /// context.
730abfd1a8bSRiver Riddle   void execute(PatternRewriter &rewriter,
731abfd1a8bSRiver Riddle                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
732abfd1a8bSRiver Riddle                Optional<Location> mainRewriteLoc = {});
733abfd1a8bSRiver Riddle 
734abfd1a8bSRiver Riddle private:
735*154cabe7SRiver Riddle   /// Internal implementation of executing each of the bytecode commands.
736*154cabe7SRiver Riddle   void executeApplyConstraint(PatternRewriter &rewriter);
737*154cabe7SRiver Riddle   void executeApplyRewrite(PatternRewriter &rewriter);
738*154cabe7SRiver Riddle   void executeAreEqual();
739*154cabe7SRiver Riddle   void executeBranch();
740*154cabe7SRiver Riddle   void executeCheckOperandCount();
741*154cabe7SRiver Riddle   void executeCheckOperationName();
742*154cabe7SRiver Riddle   void executeCheckResultCount();
743*154cabe7SRiver Riddle   void executeCreateNative(PatternRewriter &rewriter);
744*154cabe7SRiver Riddle   void executeCreateOperation(PatternRewriter &rewriter,
745*154cabe7SRiver Riddle                               Location mainRewriteLoc);
746*154cabe7SRiver Riddle   void executeEraseOp(PatternRewriter &rewriter);
747*154cabe7SRiver Riddle   void executeGetAttribute();
748*154cabe7SRiver Riddle   void executeGetAttributeType();
749*154cabe7SRiver Riddle   void executeGetDefiningOp();
750*154cabe7SRiver Riddle   void executeGetOperand(unsigned index);
751*154cabe7SRiver Riddle   void executeGetResult(unsigned index);
752*154cabe7SRiver Riddle   void executeGetValueType();
753*154cabe7SRiver Riddle   void executeIsNotNull();
754*154cabe7SRiver Riddle   void executeRecordMatch(PatternRewriter &rewriter,
755*154cabe7SRiver Riddle                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
756*154cabe7SRiver Riddle   void executeReplaceOp(PatternRewriter &rewriter);
757*154cabe7SRiver Riddle   void executeSwitchAttribute();
758*154cabe7SRiver Riddle   void executeSwitchOperandCount();
759*154cabe7SRiver Riddle   void executeSwitchOperationName();
760*154cabe7SRiver Riddle   void executeSwitchResultCount();
761*154cabe7SRiver Riddle   void executeSwitchType();
762*154cabe7SRiver Riddle 
763abfd1a8bSRiver Riddle   /// Read a value from the bytecode buffer, optionally skipping a certain
764abfd1a8bSRiver Riddle   /// number of prefix values. These methods always update the buffer to point
765abfd1a8bSRiver Riddle   /// to the next field after the read data.
766abfd1a8bSRiver Riddle   template <typename T = ByteCodeField>
767abfd1a8bSRiver Riddle   T read(size_t skipN = 0) {
768abfd1a8bSRiver Riddle     curCodeIt += skipN;
769abfd1a8bSRiver Riddle     return readImpl<T>();
770abfd1a8bSRiver Riddle   }
771abfd1a8bSRiver Riddle   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
772abfd1a8bSRiver Riddle 
773abfd1a8bSRiver Riddle   /// Read a list of values from the bytecode buffer.
774abfd1a8bSRiver Riddle   template <typename ValueT, typename T>
775abfd1a8bSRiver Riddle   void readList(SmallVectorImpl<T> &list) {
776abfd1a8bSRiver Riddle     list.clear();
777abfd1a8bSRiver Riddle     for (unsigned i = 0, e = read(); i != e; ++i)
778abfd1a8bSRiver Riddle       list.push_back(read<ValueT>());
779abfd1a8bSRiver Riddle   }
780abfd1a8bSRiver Riddle 
781abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a predicate value.
782abfd1a8bSRiver Riddle   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
783abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a destination index.
784abfd1a8bSRiver Riddle   void selectJump(size_t destIndex) {
785abfd1a8bSRiver Riddle     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
786abfd1a8bSRiver Riddle   }
787abfd1a8bSRiver Riddle 
788abfd1a8bSRiver Riddle   /// Handle a switch operation with the provided value and cases.
789abfd1a8bSRiver Riddle   template <typename T, typename RangeT>
790abfd1a8bSRiver Riddle   void handleSwitch(const T &value, RangeT &&cases) {
791abfd1a8bSRiver Riddle     LLVM_DEBUG({
792abfd1a8bSRiver Riddle       llvm::dbgs() << "  * Value: " << value << "\n"
793abfd1a8bSRiver Riddle                    << "  * Cases: ";
794abfd1a8bSRiver Riddle       llvm::interleaveComma(cases, llvm::dbgs());
795*154cabe7SRiver Riddle       llvm::dbgs() << "\n";
796abfd1a8bSRiver Riddle     });
797abfd1a8bSRiver Riddle 
798abfd1a8bSRiver Riddle     // Check to see if the attribute value is within the case list. Jump to
799abfd1a8bSRiver Riddle     // the correct successor index based on the result.
800f80b6304SRiver Riddle     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
801f80b6304SRiver Riddle       if (*it == value)
802f80b6304SRiver Riddle         return selectJump(size_t((it - cases.begin()) + 1));
803f80b6304SRiver Riddle     selectJump(size_t(0));
804abfd1a8bSRiver Riddle   }
805abfd1a8bSRiver Riddle 
806abfd1a8bSRiver Riddle   /// Internal implementation of reading various data types from the bytecode
807abfd1a8bSRiver Riddle   /// stream.
808abfd1a8bSRiver Riddle   template <typename T>
809abfd1a8bSRiver Riddle   const void *readFromMemory() {
810abfd1a8bSRiver Riddle     size_t index = *curCodeIt++;
811abfd1a8bSRiver Riddle 
812abfd1a8bSRiver Riddle     // If this type is an SSA value, it can only be stored in non-const memory.
813abfd1a8bSRiver Riddle     if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
814abfd1a8bSRiver Riddle       return memory[index];
815abfd1a8bSRiver Riddle 
816abfd1a8bSRiver Riddle     // Otherwise, if this index is not inbounds it is uniqued.
817abfd1a8bSRiver Riddle     return uniquedMemory[index - memory.size()];
818abfd1a8bSRiver Riddle   }
819abfd1a8bSRiver Riddle   template <typename T>
820abfd1a8bSRiver Riddle   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
821abfd1a8bSRiver Riddle     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
822abfd1a8bSRiver Riddle   }
823abfd1a8bSRiver Riddle   template <typename T>
824abfd1a8bSRiver Riddle   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
825abfd1a8bSRiver Riddle                    T>
826abfd1a8bSRiver Riddle   readImpl() {
827abfd1a8bSRiver Riddle     return T(T::getFromOpaquePointer(readFromMemory<T>()));
828abfd1a8bSRiver Riddle   }
829abfd1a8bSRiver Riddle   template <typename T>
830abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
831abfd1a8bSRiver Riddle     switch (static_cast<PDLValueKind>(read())) {
832abfd1a8bSRiver Riddle     case PDLValueKind::Attribute:
833abfd1a8bSRiver Riddle       return read<Attribute>();
834abfd1a8bSRiver Riddle     case PDLValueKind::Operation:
835abfd1a8bSRiver Riddle       return read<Operation *>();
836abfd1a8bSRiver Riddle     case PDLValueKind::Type:
837abfd1a8bSRiver Riddle       return read<Type>();
838abfd1a8bSRiver Riddle     case PDLValueKind::Value:
839abfd1a8bSRiver Riddle       return read<Value>();
840abfd1a8bSRiver Riddle     }
8417dadcd02SMehdi Amini     llvm_unreachable("unhandled PDLValueKind");
842abfd1a8bSRiver Riddle   }
843abfd1a8bSRiver Riddle   template <typename T>
844abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
845abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
846abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
847abfd1a8bSRiver Riddle     ByteCodeAddr result;
848abfd1a8bSRiver Riddle     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
849abfd1a8bSRiver Riddle     curCodeIt += 2;
850abfd1a8bSRiver Riddle     return result;
851abfd1a8bSRiver Riddle   }
852abfd1a8bSRiver Riddle   template <typename T>
853abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
854abfd1a8bSRiver Riddle     return *curCodeIt++;
855abfd1a8bSRiver Riddle   }
856abfd1a8bSRiver Riddle 
857abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
858abfd1a8bSRiver Riddle   const ByteCodeField *curCodeIt;
859abfd1a8bSRiver Riddle 
860abfd1a8bSRiver Riddle   /// The current execution memory.
861abfd1a8bSRiver Riddle   MutableArrayRef<const void *> memory;
862abfd1a8bSRiver Riddle 
863abfd1a8bSRiver Riddle   /// References to ByteCode data necessary for execution.
864abfd1a8bSRiver Riddle   ArrayRef<const void *> uniquedMemory;
865abfd1a8bSRiver Riddle   ArrayRef<ByteCodeField> code;
866abfd1a8bSRiver Riddle   ArrayRef<PatternBenefit> currentPatternBenefits;
867abfd1a8bSRiver Riddle   ArrayRef<PDLByteCodePattern> patterns;
868abfd1a8bSRiver Riddle   ArrayRef<PDLConstraintFunction> constraintFunctions;
869abfd1a8bSRiver Riddle   ArrayRef<PDLCreateFunction> createFunctions;
870abfd1a8bSRiver Riddle   ArrayRef<PDLRewriteFunction> rewriteFunctions;
871abfd1a8bSRiver Riddle };
872abfd1a8bSRiver Riddle } // end anonymous namespace
873abfd1a8bSRiver Riddle 
874*154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
875abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
876abfd1a8bSRiver Riddle   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
877abfd1a8bSRiver Riddle   ArrayAttr constParams = read<ArrayAttr>();
878abfd1a8bSRiver Riddle   SmallVector<PDLValue, 16> args;
879abfd1a8bSRiver Riddle   readList<PDLValue>(args);
880*154cabe7SRiver Riddle 
881abfd1a8bSRiver Riddle   LLVM_DEBUG({
882abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Arguments: ";
883abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
884*154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
885abfd1a8bSRiver Riddle   });
886abfd1a8bSRiver Riddle 
887abfd1a8bSRiver Riddle   // Invoke the constraint and jump to the proper destination.
888abfd1a8bSRiver Riddle   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
889abfd1a8bSRiver Riddle }
890*154cabe7SRiver Riddle 
891*154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
892abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
893abfd1a8bSRiver Riddle   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
894abfd1a8bSRiver Riddle   ArrayAttr constParams = read<ArrayAttr>();
895abfd1a8bSRiver Riddle   Operation *root = read<Operation *>();
896abfd1a8bSRiver Riddle   SmallVector<PDLValue, 16> args;
897abfd1a8bSRiver Riddle   readList<PDLValue>(args);
898abfd1a8bSRiver Riddle 
899abfd1a8bSRiver Riddle   LLVM_DEBUG({
900*154cabe7SRiver Riddle     llvm::dbgs() << "  * Root: " << *root << "\n  * Arguments: ";
901abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
902*154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
903abfd1a8bSRiver Riddle   });
904*154cabe7SRiver Riddle 
905*154cabe7SRiver Riddle   // Invoke the native rewrite function.
906abfd1a8bSRiver Riddle   rewriteFn(root, args, constParams, rewriter);
907abfd1a8bSRiver Riddle }
908*154cabe7SRiver Riddle 
909*154cabe7SRiver Riddle void ByteCodeExecutor::executeAreEqual() {
910abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
911abfd1a8bSRiver Riddle   const void *lhs = read<const void *>();
912abfd1a8bSRiver Riddle   const void *rhs = read<const void *>();
913abfd1a8bSRiver Riddle 
914*154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
915abfd1a8bSRiver Riddle   selectJump(lhs == rhs);
916abfd1a8bSRiver Riddle }
917*154cabe7SRiver Riddle 
918*154cabe7SRiver Riddle void ByteCodeExecutor::executeBranch() {
919*154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
920abfd1a8bSRiver Riddle   curCodeIt = &code[read<ByteCodeAddr>()];
921abfd1a8bSRiver Riddle }
922*154cabe7SRiver Riddle 
923*154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperandCount() {
924abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
925abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
926abfd1a8bSRiver Riddle   uint32_t expectedCount = read<uint32_t>();
927abfd1a8bSRiver Riddle 
928abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
929*154cabe7SRiver Riddle                           << "  * Expected: " << expectedCount << "\n");
930abfd1a8bSRiver Riddle   selectJump(op->getNumOperands() == expectedCount);
931abfd1a8bSRiver Riddle }
932*154cabe7SRiver Riddle 
933*154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperationName() {
934abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
935abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
936abfd1a8bSRiver Riddle   OperationName expectedName = read<OperationName>();
937abfd1a8bSRiver Riddle 
938*154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
939*154cabe7SRiver Riddle                           << "  * Expected: \"" << expectedName << "\"\n");
940abfd1a8bSRiver Riddle   selectJump(op->getName() == expectedName);
941abfd1a8bSRiver Riddle }
942*154cabe7SRiver Riddle 
943*154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckResultCount() {
944abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
945abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
946abfd1a8bSRiver Riddle   uint32_t expectedCount = read<uint32_t>();
947abfd1a8bSRiver Riddle 
948abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
949*154cabe7SRiver Riddle                           << "  * Expected: " << expectedCount << "\n");
950abfd1a8bSRiver Riddle   selectJump(op->getNumResults() == expectedCount);
951abfd1a8bSRiver Riddle }
952*154cabe7SRiver Riddle 
953*154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) {
954abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
955abfd1a8bSRiver Riddle   const PDLCreateFunction &createFn = createFunctions[read()];
956abfd1a8bSRiver Riddle   ByteCodeField resultIndex = read();
957abfd1a8bSRiver Riddle   ArrayAttr constParams = read<ArrayAttr>();
958abfd1a8bSRiver Riddle   SmallVector<PDLValue, 16> args;
959abfd1a8bSRiver Riddle   readList<PDLValue>(args);
960abfd1a8bSRiver Riddle 
961abfd1a8bSRiver Riddle   LLVM_DEBUG({
962abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Arguments: ";
963abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
964abfd1a8bSRiver Riddle     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
965abfd1a8bSRiver Riddle   });
966abfd1a8bSRiver Riddle 
967abfd1a8bSRiver Riddle   PDLValue result = createFn(args, constParams, rewriter);
968abfd1a8bSRiver Riddle   memory[resultIndex] = result.getAsOpaquePointer();
969abfd1a8bSRiver Riddle 
970*154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
971abfd1a8bSRiver Riddle }
972*154cabe7SRiver Riddle 
973*154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
974*154cabe7SRiver Riddle                                               Location mainRewriteLoc) {
975abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
976abfd1a8bSRiver Riddle 
977abfd1a8bSRiver Riddle   unsigned memIndex = read();
978*154cabe7SRiver Riddle   OperationState state(mainRewriteLoc, read<OperationName>());
979abfd1a8bSRiver Riddle   readList<Value>(state.operands);
980abfd1a8bSRiver Riddle   for (unsigned i = 0, e = read(); i != e; ++i) {
981abfd1a8bSRiver Riddle     Identifier name = read<Identifier>();
982abfd1a8bSRiver Riddle     if (Attribute attr = read<Attribute>())
983abfd1a8bSRiver Riddle       state.addAttribute(name, attr);
984abfd1a8bSRiver Riddle   }
985abfd1a8bSRiver Riddle 
986abfd1a8bSRiver Riddle   bool hasInferredTypes = false;
987abfd1a8bSRiver Riddle   for (unsigned i = 0, e = read(); i != e; ++i) {
988abfd1a8bSRiver Riddle     Type resultType = read<Type>();
989abfd1a8bSRiver Riddle     hasInferredTypes |= !resultType;
990abfd1a8bSRiver Riddle     state.types.push_back(resultType);
991abfd1a8bSRiver Riddle   }
992abfd1a8bSRiver Riddle 
993abfd1a8bSRiver Riddle   // Handle the case where the operation has inferred types.
994abfd1a8bSRiver Riddle   if (hasInferredTypes) {
995abfd1a8bSRiver Riddle     InferTypeOpInterface::Concept *concept =
996*154cabe7SRiver Riddle         state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
997abfd1a8bSRiver Riddle 
998abfd1a8bSRiver Riddle     // TODO: Handle failure.
999abfd1a8bSRiver Riddle     SmallVector<Type, 2> inferredTypes;
1000abfd1a8bSRiver Riddle     if (failed(concept->inferReturnTypes(
1001abfd1a8bSRiver Riddle             state.getContext(), state.location, state.operands,
1002*154cabe7SRiver Riddle             state.attributes.getDictionary(state.getContext()), state.regions,
1003*154cabe7SRiver Riddle             inferredTypes)))
1004abfd1a8bSRiver Riddle       return;
1005abfd1a8bSRiver Riddle 
1006abfd1a8bSRiver Riddle     for (unsigned i = 0, e = state.types.size(); i != e; ++i)
1007abfd1a8bSRiver Riddle       if (!state.types[i])
1008abfd1a8bSRiver Riddle         state.types[i] = inferredTypes[i];
1009abfd1a8bSRiver Riddle   }
1010abfd1a8bSRiver Riddle   Operation *resultOp = rewriter.createOperation(state);
1011abfd1a8bSRiver Riddle   memory[memIndex] = resultOp;
1012abfd1a8bSRiver Riddle 
1013abfd1a8bSRiver Riddle   LLVM_DEBUG({
1014abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Attributes: "
1015abfd1a8bSRiver Riddle                  << state.attributes.getDictionary(state.getContext())
1016abfd1a8bSRiver Riddle                  << "\n  * Operands: ";
1017abfd1a8bSRiver Riddle     llvm::interleaveComma(state.operands, llvm::dbgs());
1018abfd1a8bSRiver Riddle     llvm::dbgs() << "\n  * Result Types: ";
1019abfd1a8bSRiver Riddle     llvm::interleaveComma(state.types, llvm::dbgs());
1020*154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1021abfd1a8bSRiver Riddle   });
1022abfd1a8bSRiver Riddle }
1023*154cabe7SRiver Riddle 
1024*154cabe7SRiver Riddle void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1025abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1026abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1027abfd1a8bSRiver Riddle 
1028*154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1029abfd1a8bSRiver Riddle   rewriter.eraseOp(op);
1030abfd1a8bSRiver Riddle }
1031*154cabe7SRiver Riddle 
1032*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttribute() {
1033abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1034abfd1a8bSRiver Riddle   unsigned memIndex = read();
1035abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1036abfd1a8bSRiver Riddle   Identifier attrName = read<Identifier>();
1037abfd1a8bSRiver Riddle   Attribute attr = op->getAttr(attrName);
1038abfd1a8bSRiver Riddle 
1039abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1040abfd1a8bSRiver Riddle                           << "  * Attribute: " << attrName << "\n"
1041*154cabe7SRiver Riddle                           << "  * Result: " << attr << "\n");
1042abfd1a8bSRiver Riddle   memory[memIndex] = attr.getAsOpaquePointer();
1043abfd1a8bSRiver Riddle }
1044*154cabe7SRiver Riddle 
1045*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttributeType() {
1046abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1047abfd1a8bSRiver Riddle   unsigned memIndex = read();
1048abfd1a8bSRiver Riddle   Attribute attr = read<Attribute>();
1049*154cabe7SRiver Riddle   Type type = attr ? attr.getType() : Type();
1050abfd1a8bSRiver Riddle 
1051abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1052*154cabe7SRiver Riddle                           << "  * Result: " << type << "\n");
1053*154cabe7SRiver Riddle   memory[memIndex] = type.getAsOpaquePointer();
1054abfd1a8bSRiver Riddle }
1055*154cabe7SRiver Riddle 
1056*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetDefiningOp() {
1057abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1058abfd1a8bSRiver Riddle   unsigned memIndex = read();
1059abfd1a8bSRiver Riddle   Value value = read<Value>();
1060abfd1a8bSRiver Riddle   Operation *op = value ? value.getDefiningOp() : nullptr;
1061abfd1a8bSRiver Riddle 
1062abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1063*154cabe7SRiver Riddle                           << "  * Result: " << *op << "\n");
1064abfd1a8bSRiver Riddle   memory[memIndex] = op;
1065abfd1a8bSRiver Riddle }
1066*154cabe7SRiver Riddle 
1067*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetOperand(unsigned index) {
1068abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1069abfd1a8bSRiver Riddle   unsigned memIndex = read();
1070abfd1a8bSRiver Riddle   Value operand =
1071abfd1a8bSRiver Riddle       index < op->getNumOperands() ? op->getOperand(index) : Value();
1072abfd1a8bSRiver Riddle 
1073abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1074abfd1a8bSRiver Riddle                           << "  * Index: " << index << "\n"
1075*154cabe7SRiver Riddle                           << "  * Result: " << operand << "\n");
1076abfd1a8bSRiver Riddle   memory[memIndex] = operand.getAsOpaquePointer();
1077abfd1a8bSRiver Riddle }
1078*154cabe7SRiver Riddle 
1079*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetResult(unsigned index) {
1080abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1081abfd1a8bSRiver Riddle   unsigned memIndex = read();
1082abfd1a8bSRiver Riddle   OpResult result =
1083abfd1a8bSRiver Riddle       index < op->getNumResults() ? op->getResult(index) : OpResult();
1084abfd1a8bSRiver Riddle 
1085abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1086abfd1a8bSRiver Riddle                           << "  * Index: " << index << "\n"
1087*154cabe7SRiver Riddle                           << "  * Result: " << result << "\n");
1088abfd1a8bSRiver Riddle   memory[memIndex] = result.getAsOpaquePointer();
1089abfd1a8bSRiver Riddle }
1090*154cabe7SRiver Riddle 
1091*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetValueType() {
1092abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1093abfd1a8bSRiver Riddle   unsigned memIndex = read();
1094abfd1a8bSRiver Riddle   Value value = read<Value>();
1095*154cabe7SRiver Riddle   Type type = value ? value.getType() : Type();
1096abfd1a8bSRiver Riddle 
1097abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1098*154cabe7SRiver Riddle                           << "  * Result: " << type << "\n");
1099*154cabe7SRiver Riddle   memory[memIndex] = type.getAsOpaquePointer();
1100abfd1a8bSRiver Riddle }
1101*154cabe7SRiver Riddle 
1102*154cabe7SRiver Riddle void ByteCodeExecutor::executeIsNotNull() {
1103abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1104abfd1a8bSRiver Riddle   const void *value = read<const void *>();
1105abfd1a8bSRiver Riddle 
1106*154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1107abfd1a8bSRiver Riddle   selectJump(value != nullptr);
1108abfd1a8bSRiver Riddle }
1109*154cabe7SRiver Riddle 
1110*154cabe7SRiver Riddle void ByteCodeExecutor::executeRecordMatch(
1111*154cabe7SRiver Riddle     PatternRewriter &rewriter,
1112*154cabe7SRiver Riddle     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1113abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1114abfd1a8bSRiver Riddle   unsigned patternIndex = read();
1115abfd1a8bSRiver Riddle   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1116abfd1a8bSRiver Riddle   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1117abfd1a8bSRiver Riddle 
1118abfd1a8bSRiver Riddle   // If the benefit of the pattern is impossible, skip the processing of the
1119abfd1a8bSRiver Riddle   // rest of the pattern.
1120abfd1a8bSRiver Riddle   if (benefit.isImpossibleToMatch()) {
1121*154cabe7SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1122abfd1a8bSRiver Riddle     curCodeIt = dest;
1123*154cabe7SRiver Riddle     return;
1124abfd1a8bSRiver Riddle   }
1125abfd1a8bSRiver Riddle 
1126abfd1a8bSRiver Riddle   // Create a fused location containing the locations of each of the
1127abfd1a8bSRiver Riddle   // operations used in the match. This will be used as the location for
1128abfd1a8bSRiver Riddle   // created operations during the rewrite that don't already have an
1129abfd1a8bSRiver Riddle   // explicit location set.
1130abfd1a8bSRiver Riddle   unsigned numMatchLocs = read();
1131abfd1a8bSRiver Riddle   SmallVector<Location, 4> matchLocs;
1132abfd1a8bSRiver Riddle   matchLocs.reserve(numMatchLocs);
1133abfd1a8bSRiver Riddle   for (unsigned i = 0; i != numMatchLocs; ++i)
1134abfd1a8bSRiver Riddle     matchLocs.push_back(read<Operation *>()->getLoc());
1135abfd1a8bSRiver Riddle   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1136abfd1a8bSRiver Riddle 
1137abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1138*154cabe7SRiver Riddle                           << "  * Location: " << matchLoc << "\n");
1139*154cabe7SRiver Riddle   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1140*154cabe7SRiver Riddle   readList<const void *>(matches.back().values);
1141abfd1a8bSRiver Riddle   curCodeIt = dest;
1142abfd1a8bSRiver Riddle }
1143*154cabe7SRiver Riddle 
1144*154cabe7SRiver Riddle void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1145abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1146abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1147abfd1a8bSRiver Riddle   SmallVector<Value, 16> args;
1148abfd1a8bSRiver Riddle   readList<Value>(args);
1149abfd1a8bSRiver Riddle 
1150abfd1a8bSRiver Riddle   LLVM_DEBUG({
1151abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Operation: " << *op << "\n"
1152abfd1a8bSRiver Riddle                  << "  * Values: ";
1153abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1154*154cabe7SRiver Riddle     llvm::dbgs() << "\n";
1155abfd1a8bSRiver Riddle   });
1156abfd1a8bSRiver Riddle   rewriter.replaceOp(op, args);
1157abfd1a8bSRiver Riddle }
1158*154cabe7SRiver Riddle 
1159*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchAttribute() {
1160abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1161abfd1a8bSRiver Riddle   Attribute value = read<Attribute>();
1162abfd1a8bSRiver Riddle   ArrayAttr cases = read<ArrayAttr>();
1163abfd1a8bSRiver Riddle   handleSwitch(value, cases);
1164abfd1a8bSRiver Riddle }
1165*154cabe7SRiver Riddle 
1166*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperandCount() {
1167abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1168abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1169abfd1a8bSRiver Riddle   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1170abfd1a8bSRiver Riddle 
1171abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1172abfd1a8bSRiver Riddle   handleSwitch(op->getNumOperands(), cases);
1173abfd1a8bSRiver Riddle }
1174*154cabe7SRiver Riddle 
1175*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperationName() {
1176abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1177abfd1a8bSRiver Riddle   OperationName value = read<Operation *>()->getName();
1178abfd1a8bSRiver Riddle   size_t caseCount = read();
1179abfd1a8bSRiver Riddle 
1180abfd1a8bSRiver Riddle   // The operation names are stored in-line, so to print them out for
1181abfd1a8bSRiver Riddle   // debugging purposes we need to read the array before executing the
1182abfd1a8bSRiver Riddle   // switch so that we can display all of the possible values.
1183abfd1a8bSRiver Riddle   LLVM_DEBUG({
1184abfd1a8bSRiver Riddle     const ByteCodeField *prevCodeIt = curCodeIt;
1185abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Value: " << value << "\n"
1186abfd1a8bSRiver Riddle                  << "  * Cases: ";
1187abfd1a8bSRiver Riddle     llvm::interleaveComma(
1188abfd1a8bSRiver Riddle         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1189*154cabe7SRiver Riddle                         [&](size_t) { return read<OperationName>(); }),
1190abfd1a8bSRiver Riddle         llvm::dbgs());
1191*154cabe7SRiver Riddle     llvm::dbgs() << "\n";
1192abfd1a8bSRiver Riddle     curCodeIt = prevCodeIt;
1193abfd1a8bSRiver Riddle   });
1194abfd1a8bSRiver Riddle 
1195abfd1a8bSRiver Riddle   // Try to find the switch value within any of the cases.
1196abfd1a8bSRiver Riddle   for (size_t i = 0; i != caseCount; ++i) {
1197abfd1a8bSRiver Riddle     if (read<OperationName>() == value) {
1198abfd1a8bSRiver Riddle       curCodeIt += (caseCount - i - 1);
1199*154cabe7SRiver Riddle       return selectJump(i + 1);
1200abfd1a8bSRiver Riddle     }
1201abfd1a8bSRiver Riddle   }
1202*154cabe7SRiver Riddle   selectJump(size_t(0));
1203abfd1a8bSRiver Riddle }
1204*154cabe7SRiver Riddle 
1205*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchResultCount() {
1206abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1207abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1208abfd1a8bSRiver Riddle   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1209abfd1a8bSRiver Riddle 
1210abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1211abfd1a8bSRiver Riddle   handleSwitch(op->getNumResults(), cases);
1212abfd1a8bSRiver Riddle }
1213*154cabe7SRiver Riddle 
1214*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchType() {
1215abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1216abfd1a8bSRiver Riddle   Type value = read<Type>();
1217abfd1a8bSRiver Riddle   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1218abfd1a8bSRiver Riddle   handleSwitch(value, cases);
1219*154cabe7SRiver Riddle }
1220*154cabe7SRiver Riddle 
1221*154cabe7SRiver Riddle void ByteCodeExecutor::execute(
1222*154cabe7SRiver Riddle     PatternRewriter &rewriter,
1223*154cabe7SRiver Riddle     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1224*154cabe7SRiver Riddle     Optional<Location> mainRewriteLoc) {
1225*154cabe7SRiver Riddle   while (true) {
1226*154cabe7SRiver Riddle     OpCode opCode = static_cast<OpCode>(read());
1227*154cabe7SRiver Riddle     switch (opCode) {
1228*154cabe7SRiver Riddle     case ApplyConstraint:
1229*154cabe7SRiver Riddle       executeApplyConstraint(rewriter);
1230*154cabe7SRiver Riddle       break;
1231*154cabe7SRiver Riddle     case ApplyRewrite:
1232*154cabe7SRiver Riddle       executeApplyRewrite(rewriter);
1233*154cabe7SRiver Riddle       break;
1234*154cabe7SRiver Riddle     case AreEqual:
1235*154cabe7SRiver Riddle       executeAreEqual();
1236*154cabe7SRiver Riddle       break;
1237*154cabe7SRiver Riddle     case Branch:
1238*154cabe7SRiver Riddle       executeBranch();
1239*154cabe7SRiver Riddle       break;
1240*154cabe7SRiver Riddle     case CheckOperandCount:
1241*154cabe7SRiver Riddle       executeCheckOperandCount();
1242*154cabe7SRiver Riddle       break;
1243*154cabe7SRiver Riddle     case CheckOperationName:
1244*154cabe7SRiver Riddle       executeCheckOperationName();
1245*154cabe7SRiver Riddle       break;
1246*154cabe7SRiver Riddle     case CheckResultCount:
1247*154cabe7SRiver Riddle       executeCheckResultCount();
1248*154cabe7SRiver Riddle       break;
1249*154cabe7SRiver Riddle     case CreateNative:
1250*154cabe7SRiver Riddle       executeCreateNative(rewriter);
1251*154cabe7SRiver Riddle       break;
1252*154cabe7SRiver Riddle     case CreateOperation:
1253*154cabe7SRiver Riddle       executeCreateOperation(rewriter, *mainRewriteLoc);
1254*154cabe7SRiver Riddle       break;
1255*154cabe7SRiver Riddle     case EraseOp:
1256*154cabe7SRiver Riddle       executeEraseOp(rewriter);
1257*154cabe7SRiver Riddle       break;
1258*154cabe7SRiver Riddle     case Finalize:
1259*154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1260*154cabe7SRiver Riddle       return;
1261*154cabe7SRiver Riddle     case GetAttribute:
1262*154cabe7SRiver Riddle       executeGetAttribute();
1263*154cabe7SRiver Riddle       break;
1264*154cabe7SRiver Riddle     case GetAttributeType:
1265*154cabe7SRiver Riddle       executeGetAttributeType();
1266*154cabe7SRiver Riddle       break;
1267*154cabe7SRiver Riddle     case GetDefiningOp:
1268*154cabe7SRiver Riddle       executeGetDefiningOp();
1269*154cabe7SRiver Riddle       break;
1270*154cabe7SRiver Riddle     case GetOperand0:
1271*154cabe7SRiver Riddle     case GetOperand1:
1272*154cabe7SRiver Riddle     case GetOperand2:
1273*154cabe7SRiver Riddle     case GetOperand3: {
1274*154cabe7SRiver Riddle       unsigned index = opCode - GetOperand0;
1275*154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
1276*154cabe7SRiver Riddle       executeGetOperand(opCode - GetOperand0);
1277abfd1a8bSRiver Riddle       break;
1278abfd1a8bSRiver Riddle     }
1279*154cabe7SRiver Riddle     case GetOperandN:
1280*154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
1281*154cabe7SRiver Riddle       executeGetOperand(read<uint32_t>());
1282*154cabe7SRiver Riddle       break;
1283*154cabe7SRiver Riddle     case GetResult0:
1284*154cabe7SRiver Riddle     case GetResult1:
1285*154cabe7SRiver Riddle     case GetResult2:
1286*154cabe7SRiver Riddle     case GetResult3: {
1287*154cabe7SRiver Riddle       unsigned index = opCode - GetResult0;
1288*154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
1289*154cabe7SRiver Riddle       executeGetResult(opCode - GetResult0);
1290*154cabe7SRiver Riddle       break;
1291abfd1a8bSRiver Riddle     }
1292*154cabe7SRiver Riddle     case GetResultN:
1293*154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
1294*154cabe7SRiver Riddle       executeGetResult(read<uint32_t>());
1295*154cabe7SRiver Riddle       break;
1296*154cabe7SRiver Riddle     case GetValueType:
1297*154cabe7SRiver Riddle       executeGetValueType();
1298*154cabe7SRiver Riddle       break;
1299*154cabe7SRiver Riddle     case IsNotNull:
1300*154cabe7SRiver Riddle       executeIsNotNull();
1301*154cabe7SRiver Riddle       break;
1302*154cabe7SRiver Riddle     case RecordMatch:
1303*154cabe7SRiver Riddle       assert(matches &&
1304*154cabe7SRiver Riddle              "expected matches to be provided when executing the matcher");
1305*154cabe7SRiver Riddle       executeRecordMatch(rewriter, *matches);
1306*154cabe7SRiver Riddle       break;
1307*154cabe7SRiver Riddle     case ReplaceOp:
1308*154cabe7SRiver Riddle       executeReplaceOp(rewriter);
1309*154cabe7SRiver Riddle       break;
1310*154cabe7SRiver Riddle     case SwitchAttribute:
1311*154cabe7SRiver Riddle       executeSwitchAttribute();
1312*154cabe7SRiver Riddle       break;
1313*154cabe7SRiver Riddle     case SwitchOperandCount:
1314*154cabe7SRiver Riddle       executeSwitchOperandCount();
1315*154cabe7SRiver Riddle       break;
1316*154cabe7SRiver Riddle     case SwitchOperationName:
1317*154cabe7SRiver Riddle       executeSwitchOperationName();
1318*154cabe7SRiver Riddle       break;
1319*154cabe7SRiver Riddle     case SwitchResultCount:
1320*154cabe7SRiver Riddle       executeSwitchResultCount();
1321*154cabe7SRiver Riddle       break;
1322*154cabe7SRiver Riddle     case SwitchType:
1323*154cabe7SRiver Riddle       executeSwitchType();
1324*154cabe7SRiver Riddle       break;
1325*154cabe7SRiver Riddle     }
1326*154cabe7SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "\n");
1327abfd1a8bSRiver Riddle   }
1328abfd1a8bSRiver Riddle }
1329abfd1a8bSRiver Riddle 
1330abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched
1331abfd1a8bSRiver Riddle /// patterns in `matches`.
1332abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1333abfd1a8bSRiver Riddle                         SmallVectorImpl<MatchResult> &matches,
1334abfd1a8bSRiver Riddle                         PDLByteCodeMutableState &state) const {
1335abfd1a8bSRiver Riddle   // The first memory slot is always the root operation.
1336abfd1a8bSRiver Riddle   state.memory[0] = op;
1337abfd1a8bSRiver Riddle 
1338abfd1a8bSRiver Riddle   // The matcher function always starts at code address 0.
1339abfd1a8bSRiver Riddle   ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
1340abfd1a8bSRiver Riddle                             matcherByteCode, state.currentPatternBenefits,
1341abfd1a8bSRiver Riddle                             patterns, constraintFunctions, createFunctions,
1342abfd1a8bSRiver Riddle                             rewriteFunctions);
1343abfd1a8bSRiver Riddle   executor.execute(rewriter, &matches);
1344abfd1a8bSRiver Riddle 
1345abfd1a8bSRiver Riddle   // Order the found matches by benefit.
1346abfd1a8bSRiver Riddle   std::stable_sort(matches.begin(), matches.end(),
1347abfd1a8bSRiver Riddle                    [](const MatchResult &lhs, const MatchResult &rhs) {
1348abfd1a8bSRiver Riddle                      return lhs.benefit > rhs.benefit;
1349abfd1a8bSRiver Riddle                    });
1350abfd1a8bSRiver Riddle }
1351abfd1a8bSRiver Riddle 
1352abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`.
1353abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1354abfd1a8bSRiver Riddle                           PDLByteCodeMutableState &state) const {
1355abfd1a8bSRiver Riddle   // The arguments of the rewrite function are stored at the start of the
1356abfd1a8bSRiver Riddle   // memory buffer.
1357abfd1a8bSRiver Riddle   llvm::copy(match.values, state.memory.begin());
1358abfd1a8bSRiver Riddle 
1359abfd1a8bSRiver Riddle   ByteCodeExecutor executor(
1360abfd1a8bSRiver Riddle       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1361abfd1a8bSRiver Riddle       uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
1362abfd1a8bSRiver Riddle       constraintFunctions, createFunctions, rewriteFunctions);
1363abfd1a8bSRiver Riddle   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1364abfd1a8bSRiver Riddle }
1365