1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle //
9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter.
10abfd1a8bSRiver Riddle //
11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
12abfd1a8bSRiver Riddle 
13abfd1a8bSRiver Riddle #include "ByteCode.h"
14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h"
15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17e66c2e25SRiver Riddle #include "mlir/IR/BuiltinOps.h"
18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h"
19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h"
20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h"
21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h"
23*85ab413bSRiver Riddle #include "llvm/Support/Format.h"
24*85ab413bSRiver Riddle #include "llvm/Support/FormatVariadic.h"
25*85ab413bSRiver Riddle #include <numeric>
26abfd1a8bSRiver Riddle 
27abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode"
28abfd1a8bSRiver Riddle 
29abfd1a8bSRiver Riddle using namespace mlir;
30abfd1a8bSRiver Riddle using namespace mlir::detail;
31abfd1a8bSRiver Riddle 
32abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
33abfd1a8bSRiver Riddle // PDLByteCodePattern
34abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
35abfd1a8bSRiver Riddle 
36abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37abfd1a8bSRiver Riddle                                               ByteCodeAddr rewriterAddr) {
38abfd1a8bSRiver Riddle   SmallVector<StringRef, 8> generatedOps;
39abfd1a8bSRiver Riddle   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40abfd1a8bSRiver Riddle     generatedOps =
41abfd1a8bSRiver Riddle         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42abfd1a8bSRiver Riddle 
43abfd1a8bSRiver Riddle   PatternBenefit benefit = matchOp.benefit();
44abfd1a8bSRiver Riddle   MLIRContext *ctx = matchOp.getContext();
45abfd1a8bSRiver Riddle 
46abfd1a8bSRiver Riddle   // Check to see if this is pattern matches a specific operation type.
47abfd1a8bSRiver Riddle   if (Optional<StringRef> rootKind = matchOp.rootKind())
48abfd1a8bSRiver Riddle     return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
49abfd1a8bSRiver Riddle                               ctx);
50abfd1a8bSRiver Riddle   return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
51abfd1a8bSRiver Riddle                             MatchAnyOpTypeTag());
52abfd1a8bSRiver Riddle }
53abfd1a8bSRiver Riddle 
54abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
55abfd1a8bSRiver Riddle // PDLByteCodeMutableState
56abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
57abfd1a8bSRiver Riddle 
58abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by
60abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`.
61abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62abfd1a8bSRiver Riddle                                                    PatternBenefit benefit) {
63abfd1a8bSRiver Riddle   currentPatternBenefits[patternIndex] = benefit;
64abfd1a8bSRiver Riddle }
65abfd1a8bSRiver Riddle 
66*85ab413bSRiver Riddle /// Cleanup any allocated state after a full match/rewrite has been completed.
67*85ab413bSRiver Riddle /// This method should be called irregardless of whether the match+rewrite was a
68*85ab413bSRiver Riddle /// success or not.
69*85ab413bSRiver Riddle void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70*85ab413bSRiver Riddle   allocatedTypeRangeMemory.clear();
71*85ab413bSRiver Riddle   allocatedValueRangeMemory.clear();
72*85ab413bSRiver Riddle }
73*85ab413bSRiver Riddle 
74abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
75abfd1a8bSRiver Riddle // Bytecode OpCodes
76abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
77abfd1a8bSRiver Riddle 
78abfd1a8bSRiver Riddle namespace {
79abfd1a8bSRiver Riddle enum OpCode : ByteCodeField {
80abfd1a8bSRiver Riddle   /// Apply an externally registered constraint.
81abfd1a8bSRiver Riddle   ApplyConstraint,
82abfd1a8bSRiver Riddle   /// Apply an externally registered rewrite.
83abfd1a8bSRiver Riddle   ApplyRewrite,
84abfd1a8bSRiver Riddle   /// Check if two generic values are equal.
85abfd1a8bSRiver Riddle   AreEqual,
86*85ab413bSRiver Riddle   /// Check if two ranges are equal.
87*85ab413bSRiver Riddle   AreRangesEqual,
88abfd1a8bSRiver Riddle   /// Unconditional branch.
89abfd1a8bSRiver Riddle   Branch,
90abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a constant.
91abfd1a8bSRiver Riddle   CheckOperandCount,
92abfd1a8bSRiver Riddle   /// Compare the name of an operation with a constant.
93abfd1a8bSRiver Riddle   CheckOperationName,
94abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a constant.
95abfd1a8bSRiver Riddle   CheckResultCount,
96*85ab413bSRiver Riddle   /// Compare a range of types to a constant range of types.
97*85ab413bSRiver Riddle   CheckTypes,
98abfd1a8bSRiver Riddle   /// Create an operation.
99abfd1a8bSRiver Riddle   CreateOperation,
100*85ab413bSRiver Riddle   /// Create a range of types.
101*85ab413bSRiver Riddle   CreateTypes,
102abfd1a8bSRiver Riddle   /// Erase an operation.
103abfd1a8bSRiver Riddle   EraseOp,
104abfd1a8bSRiver Riddle   /// Terminate a matcher or rewrite sequence.
105abfd1a8bSRiver Riddle   Finalize,
106abfd1a8bSRiver Riddle   /// Get a specific attribute of an operation.
107abfd1a8bSRiver Riddle   GetAttribute,
108abfd1a8bSRiver Riddle   /// Get the type of an attribute.
109abfd1a8bSRiver Riddle   GetAttributeType,
110abfd1a8bSRiver Riddle   /// Get the defining operation of a value.
111abfd1a8bSRiver Riddle   GetDefiningOp,
112abfd1a8bSRiver Riddle   /// Get a specific operand of an operation.
113abfd1a8bSRiver Riddle   GetOperand0,
114abfd1a8bSRiver Riddle   GetOperand1,
115abfd1a8bSRiver Riddle   GetOperand2,
116abfd1a8bSRiver Riddle   GetOperand3,
117abfd1a8bSRiver Riddle   GetOperandN,
118*85ab413bSRiver Riddle   /// Get a specific operand group of an operation.
119*85ab413bSRiver Riddle   GetOperands,
120abfd1a8bSRiver Riddle   /// Get a specific result of an operation.
121abfd1a8bSRiver Riddle   GetResult0,
122abfd1a8bSRiver Riddle   GetResult1,
123abfd1a8bSRiver Riddle   GetResult2,
124abfd1a8bSRiver Riddle   GetResult3,
125abfd1a8bSRiver Riddle   GetResultN,
126*85ab413bSRiver Riddle   /// Get a specific result group of an operation.
127*85ab413bSRiver Riddle   GetResults,
128abfd1a8bSRiver Riddle   /// Get the type of a value.
129abfd1a8bSRiver Riddle   GetValueType,
130*85ab413bSRiver Riddle   /// Get the types of a value range.
131*85ab413bSRiver Riddle   GetValueRangeTypes,
132abfd1a8bSRiver Riddle   /// Check if a generic value is not null.
133abfd1a8bSRiver Riddle   IsNotNull,
134abfd1a8bSRiver Riddle   /// Record a successful pattern match.
135abfd1a8bSRiver Riddle   RecordMatch,
136abfd1a8bSRiver Riddle   /// Replace an operation.
137abfd1a8bSRiver Riddle   ReplaceOp,
138abfd1a8bSRiver Riddle   /// Compare an attribute with a set of constants.
139abfd1a8bSRiver Riddle   SwitchAttribute,
140abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a set of constants.
141abfd1a8bSRiver Riddle   SwitchOperandCount,
142abfd1a8bSRiver Riddle   /// Compare the name of an operation with a set of constants.
143abfd1a8bSRiver Riddle   SwitchOperationName,
144abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a set of constants.
145abfd1a8bSRiver Riddle   SwitchResultCount,
146abfd1a8bSRiver Riddle   /// Compare a type with a set of constants.
147abfd1a8bSRiver Riddle   SwitchType,
148*85ab413bSRiver Riddle   /// Compare a range of types with a set of constants.
149*85ab413bSRiver Riddle   SwitchTypes,
150abfd1a8bSRiver Riddle };
151abfd1a8bSRiver Riddle } // end anonymous namespace
152abfd1a8bSRiver Riddle 
153abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
154abfd1a8bSRiver Riddle // ByteCode Generation
155abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
156abfd1a8bSRiver Riddle 
157abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
158abfd1a8bSRiver Riddle // Generator
159abfd1a8bSRiver Riddle 
160abfd1a8bSRiver Riddle namespace {
161abfd1a8bSRiver Riddle struct ByteCodeWriter;
162abfd1a8bSRiver Riddle 
163abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode.
164abfd1a8bSRiver Riddle class Generator {
165abfd1a8bSRiver Riddle public:
166abfd1a8bSRiver Riddle   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
167abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &matcherByteCode,
168abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
169abfd1a8bSRiver Riddle             SmallVectorImpl<PDLByteCodePattern> &patterns,
170abfd1a8bSRiver Riddle             ByteCodeField &maxValueMemoryIndex,
171*85ab413bSRiver Riddle             ByteCodeField &maxTypeRangeMemoryIndex,
172*85ab413bSRiver Riddle             ByteCodeField &maxValueRangeMemoryIndex,
173abfd1a8bSRiver Riddle             llvm::StringMap<PDLConstraintFunction> &constraintFns,
174abfd1a8bSRiver Riddle             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
175abfd1a8bSRiver Riddle       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
176abfd1a8bSRiver Riddle         rewriterByteCode(rewriterByteCode), patterns(patterns),
177*85ab413bSRiver Riddle         maxValueMemoryIndex(maxValueMemoryIndex),
178*85ab413bSRiver Riddle         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
179*85ab413bSRiver Riddle         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
180abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(constraintFns))
181abfd1a8bSRiver Riddle       constraintToMemIndex.try_emplace(it.value().first(), it.index());
182abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(rewriteFns))
183abfd1a8bSRiver Riddle       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
184abfd1a8bSRiver Riddle   }
185abfd1a8bSRiver Riddle 
186abfd1a8bSRiver Riddle   /// Generate the bytecode for the given PDL interpreter module.
187abfd1a8bSRiver Riddle   void generate(ModuleOp module);
188abfd1a8bSRiver Riddle 
189abfd1a8bSRiver Riddle   /// Return the memory index to use for the given value.
190abfd1a8bSRiver Riddle   ByteCodeField &getMemIndex(Value value) {
191abfd1a8bSRiver Riddle     assert(valueToMemIndex.count(value) &&
192abfd1a8bSRiver Riddle            "expected memory index to be assigned");
193abfd1a8bSRiver Riddle     return valueToMemIndex[value];
194abfd1a8bSRiver Riddle   }
195abfd1a8bSRiver Riddle 
196*85ab413bSRiver Riddle   /// Return the range memory index used to store the given range value.
197*85ab413bSRiver Riddle   ByteCodeField &getRangeStorageIndex(Value value) {
198*85ab413bSRiver Riddle     assert(valueToRangeIndex.count(value) &&
199*85ab413bSRiver Riddle            "expected range index to be assigned");
200*85ab413bSRiver Riddle     return valueToRangeIndex[value];
201*85ab413bSRiver Riddle   }
202*85ab413bSRiver Riddle 
203abfd1a8bSRiver Riddle   /// Return an index to use when referring to the given data that is uniqued in
204abfd1a8bSRiver Riddle   /// the MLIR context.
205abfd1a8bSRiver Riddle   template <typename T>
206abfd1a8bSRiver Riddle   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
207abfd1a8bSRiver Riddle   getMemIndex(T val) {
208abfd1a8bSRiver Riddle     const void *opaqueVal = val.getAsOpaquePointer();
209abfd1a8bSRiver Riddle 
210abfd1a8bSRiver Riddle     // Get or insert a reference to this value.
211abfd1a8bSRiver Riddle     auto it = uniquedDataToMemIndex.try_emplace(
212abfd1a8bSRiver Riddle         opaqueVal, maxValueMemoryIndex + uniquedData.size());
213abfd1a8bSRiver Riddle     if (it.second)
214abfd1a8bSRiver Riddle       uniquedData.push_back(opaqueVal);
215abfd1a8bSRiver Riddle     return it.first->second;
216abfd1a8bSRiver Riddle   }
217abfd1a8bSRiver Riddle 
218abfd1a8bSRiver Riddle private:
219abfd1a8bSRiver Riddle   /// Allocate memory indices for the results of operations within the matcher
220abfd1a8bSRiver Riddle   /// and rewriters.
221abfd1a8bSRiver Riddle   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
222abfd1a8bSRiver Riddle 
223abfd1a8bSRiver Riddle   /// Generate the bytecode for the given operation.
224abfd1a8bSRiver Riddle   void generate(Operation *op, ByteCodeWriter &writer);
225abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
226abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
227abfd1a8bSRiver Riddle   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
228abfd1a8bSRiver Riddle   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
229abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
230abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
231abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
232abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
233abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
234*85ab413bSRiver Riddle   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
235abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
236abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
237abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
238*85ab413bSRiver Riddle   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
239abfd1a8bSRiver Riddle   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
240abfd1a8bSRiver Riddle   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
241abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
242abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
243abfd1a8bSRiver Riddle   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
244abfd1a8bSRiver Riddle   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
245*85ab413bSRiver Riddle   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
246abfd1a8bSRiver Riddle   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
247*85ab413bSRiver Riddle   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
248abfd1a8bSRiver Riddle   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
2493a833a0eSRiver Riddle   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
250abfd1a8bSRiver Riddle   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
251abfd1a8bSRiver Riddle   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
252abfd1a8bSRiver Riddle   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
253abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
254abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
255*85ab413bSRiver Riddle   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
256abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
257abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
258abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
259abfd1a8bSRiver Riddle 
260abfd1a8bSRiver Riddle   /// Mapping from value to its corresponding memory index.
261abfd1a8bSRiver Riddle   DenseMap<Value, ByteCodeField> valueToMemIndex;
262abfd1a8bSRiver Riddle 
263*85ab413bSRiver Riddle   /// Mapping from a range value to its corresponding range storage index.
264*85ab413bSRiver Riddle   DenseMap<Value, ByteCodeField> valueToRangeIndex;
265*85ab413bSRiver Riddle 
266abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered rewrite to its index in
267abfd1a8bSRiver Riddle   /// the bytecode registry.
268abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
269abfd1a8bSRiver Riddle 
270abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered constraint to its index
271abfd1a8bSRiver Riddle   /// in the bytecode registry.
272abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> constraintToMemIndex;
273abfd1a8bSRiver Riddle 
274abfd1a8bSRiver Riddle   /// Mapping from rewriter function name to the bytecode address of the
275abfd1a8bSRiver Riddle   /// rewriter function in byte.
276abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
277abfd1a8bSRiver Riddle 
278abfd1a8bSRiver Riddle   /// Mapping from a uniqued storage object to its memory index within
279abfd1a8bSRiver Riddle   /// `uniquedData`.
280abfd1a8bSRiver Riddle   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
281abfd1a8bSRiver Riddle 
282abfd1a8bSRiver Riddle   /// The current MLIR context.
283abfd1a8bSRiver Riddle   MLIRContext *ctx;
284abfd1a8bSRiver Riddle 
285abfd1a8bSRiver Riddle   /// Data of the ByteCode class to be populated.
286abfd1a8bSRiver Riddle   std::vector<const void *> &uniquedData;
287abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &matcherByteCode;
288abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
289abfd1a8bSRiver Riddle   SmallVectorImpl<PDLByteCodePattern> &patterns;
290abfd1a8bSRiver Riddle   ByteCodeField &maxValueMemoryIndex;
291*85ab413bSRiver Riddle   ByteCodeField &maxTypeRangeMemoryIndex;
292*85ab413bSRiver Riddle   ByteCodeField &maxValueRangeMemoryIndex;
293abfd1a8bSRiver Riddle };
294abfd1a8bSRiver Riddle 
295abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream.
296abfd1a8bSRiver Riddle struct ByteCodeWriter {
297abfd1a8bSRiver Riddle   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
298abfd1a8bSRiver Riddle       : bytecode(bytecode), generator(generator) {}
299abfd1a8bSRiver Riddle 
300abfd1a8bSRiver Riddle   /// Append a field to the bytecode.
301abfd1a8bSRiver Riddle   void append(ByteCodeField field) { bytecode.push_back(field); }
302fa20ab7bSRiver Riddle   void append(OpCode opCode) { bytecode.push_back(opCode); }
303abfd1a8bSRiver Riddle 
304abfd1a8bSRiver Riddle   /// Append an address to the bytecode.
305abfd1a8bSRiver Riddle   void append(ByteCodeAddr field) {
306abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
307abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
308abfd1a8bSRiver Riddle 
309abfd1a8bSRiver Riddle     ByteCodeField fieldParts[2];
310abfd1a8bSRiver Riddle     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
311abfd1a8bSRiver Riddle     bytecode.append({fieldParts[0], fieldParts[1]});
312abfd1a8bSRiver Riddle   }
313abfd1a8bSRiver Riddle 
314abfd1a8bSRiver Riddle   /// Append a successor range to the bytecode, the exact address will need to
315abfd1a8bSRiver Riddle   /// be resolved later.
316abfd1a8bSRiver Riddle   void append(SuccessorRange successors) {
317abfd1a8bSRiver Riddle     // Add back references to the any successors so that the address can be
318abfd1a8bSRiver Riddle     // resolved later.
319abfd1a8bSRiver Riddle     for (Block *successor : successors) {
320abfd1a8bSRiver Riddle       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
321abfd1a8bSRiver Riddle       append(ByteCodeAddr(0));
322abfd1a8bSRiver Riddle     }
323abfd1a8bSRiver Riddle   }
324abfd1a8bSRiver Riddle 
325abfd1a8bSRiver Riddle   /// Append a range of values that will be read as generic PDLValues.
326abfd1a8bSRiver Riddle   void appendPDLValueList(OperandRange values) {
327abfd1a8bSRiver Riddle     bytecode.push_back(values.size());
328*85ab413bSRiver Riddle     for (Value value : values)
329*85ab413bSRiver Riddle       appendPDLValue(value);
330*85ab413bSRiver Riddle   }
331*85ab413bSRiver Riddle 
332*85ab413bSRiver Riddle   /// Append a value as a PDLValue.
333*85ab413bSRiver Riddle   void appendPDLValue(Value value) {
334*85ab413bSRiver Riddle     appendPDLValueKind(value);
335abfd1a8bSRiver Riddle     append(value);
336abfd1a8bSRiver Riddle   }
337*85ab413bSRiver Riddle 
338*85ab413bSRiver Riddle   /// Append the PDLValue::Kind of the given value.
339*85ab413bSRiver Riddle   void appendPDLValueKind(Value value) {
340*85ab413bSRiver Riddle     // Append the type of the value in addition to the value itself.
341*85ab413bSRiver Riddle     PDLValue::Kind kind =
342*85ab413bSRiver Riddle         TypeSwitch<Type, PDLValue::Kind>(value.getType())
343*85ab413bSRiver Riddle             .Case<pdl::AttributeType>(
344*85ab413bSRiver Riddle                 [](Type) { return PDLValue::Kind::Attribute; })
345*85ab413bSRiver Riddle             .Case<pdl::OperationType>(
346*85ab413bSRiver Riddle                 [](Type) { return PDLValue::Kind::Operation; })
347*85ab413bSRiver Riddle             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
348*85ab413bSRiver Riddle               if (rangeTy.getElementType().isa<pdl::TypeType>())
349*85ab413bSRiver Riddle                 return PDLValue::Kind::TypeRange;
350*85ab413bSRiver Riddle               return PDLValue::Kind::ValueRange;
351*85ab413bSRiver Riddle             })
352*85ab413bSRiver Riddle             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
353*85ab413bSRiver Riddle             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
354*85ab413bSRiver Riddle     bytecode.push_back(static_cast<ByteCodeField>(kind));
355abfd1a8bSRiver Riddle   }
356abfd1a8bSRiver Riddle 
357abfd1a8bSRiver Riddle   /// Check if the given class `T` has an iterator type.
358abfd1a8bSRiver Riddle   template <typename T, typename... Args>
359abfd1a8bSRiver Riddle   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
360abfd1a8bSRiver Riddle 
361abfd1a8bSRiver Riddle   /// Append a value that will be stored in a memory slot and not inline within
362abfd1a8bSRiver Riddle   /// the bytecode.
363abfd1a8bSRiver Riddle   template <typename T>
364abfd1a8bSRiver Riddle   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
365abfd1a8bSRiver Riddle                    std::is_pointer<T>::value>
366abfd1a8bSRiver Riddle   append(T value) {
367abfd1a8bSRiver Riddle     bytecode.push_back(generator.getMemIndex(value));
368abfd1a8bSRiver Riddle   }
369abfd1a8bSRiver Riddle 
370abfd1a8bSRiver Riddle   /// Append a range of values.
371abfd1a8bSRiver Riddle   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
372abfd1a8bSRiver Riddle   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
373abfd1a8bSRiver Riddle   append(T range) {
374abfd1a8bSRiver Riddle     bytecode.push_back(llvm::size(range));
375abfd1a8bSRiver Riddle     for (auto it : range)
376abfd1a8bSRiver Riddle       append(it);
377abfd1a8bSRiver Riddle   }
378abfd1a8bSRiver Riddle 
379abfd1a8bSRiver Riddle   /// Append a variadic number of fields to the bytecode.
380abfd1a8bSRiver Riddle   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
381abfd1a8bSRiver Riddle   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
382abfd1a8bSRiver Riddle     append(field);
383abfd1a8bSRiver Riddle     append(field2, fields...);
384abfd1a8bSRiver Riddle   }
385abfd1a8bSRiver Riddle 
386abfd1a8bSRiver Riddle   /// Successor references in the bytecode that have yet to be resolved.
387abfd1a8bSRiver Riddle   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
388abfd1a8bSRiver Riddle 
389abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
390abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &bytecode;
391abfd1a8bSRiver Riddle 
392abfd1a8bSRiver Riddle   /// The main generator producing PDL.
393abfd1a8bSRiver Riddle   Generator &generator;
394abfd1a8bSRiver Riddle };
395*85ab413bSRiver Riddle 
396*85ab413bSRiver Riddle /// This class represents a live range of PDL Interpreter values, containing
397*85ab413bSRiver Riddle /// information about when values are live within a match/rewrite.
398*85ab413bSRiver Riddle struct ByteCodeLiveRange {
399*85ab413bSRiver Riddle   using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
400*85ab413bSRiver Riddle   using Allocator = Set::Allocator;
401*85ab413bSRiver Riddle 
402*85ab413bSRiver Riddle   ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
403*85ab413bSRiver Riddle 
404*85ab413bSRiver Riddle   /// Union this live range with the one provided.
405*85ab413bSRiver Riddle   void unionWith(const ByteCodeLiveRange &rhs) {
406*85ab413bSRiver Riddle     for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
407*85ab413bSRiver Riddle       liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
408*85ab413bSRiver Riddle   }
409*85ab413bSRiver Riddle 
410*85ab413bSRiver Riddle   /// Returns true if this range overlaps with the one provided.
411*85ab413bSRiver Riddle   bool overlaps(const ByteCodeLiveRange &rhs) const {
412*85ab413bSRiver Riddle     return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
413*85ab413bSRiver Riddle   }
414*85ab413bSRiver Riddle 
415*85ab413bSRiver Riddle   /// A map representing the ranges of the match/rewrite that a value is live in
416*85ab413bSRiver Riddle   /// the interpreter.
417*85ab413bSRiver Riddle   llvm::IntervalMap<ByteCodeField, char, 16> liveness;
418*85ab413bSRiver Riddle 
419*85ab413bSRiver Riddle   /// The type range storage index for this range.
420*85ab413bSRiver Riddle   Optional<unsigned> typeRangeIndex;
421*85ab413bSRiver Riddle 
422*85ab413bSRiver Riddle   /// The value range storage index for this range.
423*85ab413bSRiver Riddle   Optional<unsigned> valueRangeIndex;
424*85ab413bSRiver Riddle };
425abfd1a8bSRiver Riddle } // end anonymous namespace
426abfd1a8bSRiver Riddle 
427abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) {
428abfd1a8bSRiver Riddle   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
429abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
430abfd1a8bSRiver Riddle   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
431abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getRewriterModuleName());
432abfd1a8bSRiver Riddle   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
433abfd1a8bSRiver Riddle 
434abfd1a8bSRiver Riddle   // Allocate memory indices for the results of operations within the matcher
435abfd1a8bSRiver Riddle   // and rewriters.
436abfd1a8bSRiver Riddle   allocateMemoryIndices(matcherFunc, rewriterModule);
437abfd1a8bSRiver Riddle 
438abfd1a8bSRiver Riddle   // Generate code for the rewriter functions.
439abfd1a8bSRiver Riddle   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
440abfd1a8bSRiver Riddle   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
441abfd1a8bSRiver Riddle     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
442abfd1a8bSRiver Riddle     for (Operation &op : rewriterFunc.getOps())
443abfd1a8bSRiver Riddle       generate(&op, rewriterByteCodeWriter);
444abfd1a8bSRiver Riddle   }
445abfd1a8bSRiver Riddle   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
446abfd1a8bSRiver Riddle          "unexpected branches in rewriter function");
447abfd1a8bSRiver Riddle 
448abfd1a8bSRiver Riddle   // Generate code for the matcher function.
449abfd1a8bSRiver Riddle   DenseMap<Block *, ByteCodeAddr> blockToAddr;
450abfd1a8bSRiver Riddle   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
451abfd1a8bSRiver Riddle   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
452abfd1a8bSRiver Riddle   for (Block *block : rpot) {
453abfd1a8bSRiver Riddle     // Keep track of where this block begins within the matcher function.
454abfd1a8bSRiver Riddle     blockToAddr.try_emplace(block, matcherByteCode.size());
455abfd1a8bSRiver Riddle     for (Operation &op : *block)
456abfd1a8bSRiver Riddle       generate(&op, matcherByteCodeWriter);
457abfd1a8bSRiver Riddle   }
458abfd1a8bSRiver Riddle 
459abfd1a8bSRiver Riddle   // Resolve successor references in the matcher.
460abfd1a8bSRiver Riddle   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
461abfd1a8bSRiver Riddle     ByteCodeAddr addr = blockToAddr[it.first];
462abfd1a8bSRiver Riddle     for (unsigned offsetToFix : it.second)
463abfd1a8bSRiver Riddle       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
464abfd1a8bSRiver Riddle   }
465abfd1a8bSRiver Riddle }
466abfd1a8bSRiver Riddle 
467abfd1a8bSRiver Riddle void Generator::allocateMemoryIndices(FuncOp matcherFunc,
468abfd1a8bSRiver Riddle                                       ModuleOp rewriterModule) {
469abfd1a8bSRiver Riddle   // Rewriters use simplistic allocation scheme that simply assigns an index to
470abfd1a8bSRiver Riddle   // each result.
471abfd1a8bSRiver Riddle   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
472*85ab413bSRiver Riddle     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
473*85ab413bSRiver Riddle     auto processRewriterValue = [&](Value val) {
474*85ab413bSRiver Riddle       valueToMemIndex.try_emplace(val, index++);
475*85ab413bSRiver Riddle       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
476*85ab413bSRiver Riddle         Type elementTy = rangeType.getElementType();
477*85ab413bSRiver Riddle         if (elementTy.isa<pdl::TypeType>())
478*85ab413bSRiver Riddle           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
479*85ab413bSRiver Riddle         else if (elementTy.isa<pdl::ValueType>())
480*85ab413bSRiver Riddle           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
481*85ab413bSRiver Riddle       }
482*85ab413bSRiver Riddle     };
483*85ab413bSRiver Riddle 
484abfd1a8bSRiver Riddle     for (BlockArgument arg : rewriterFunc.getArguments())
485*85ab413bSRiver Riddle       processRewriterValue(arg);
486abfd1a8bSRiver Riddle     rewriterFunc.getBody().walk([&](Operation *op) {
487abfd1a8bSRiver Riddle       for (Value result : op->getResults())
488*85ab413bSRiver Riddle         processRewriterValue(result);
489abfd1a8bSRiver Riddle     });
490abfd1a8bSRiver Riddle     if (index > maxValueMemoryIndex)
491abfd1a8bSRiver Riddle       maxValueMemoryIndex = index;
492*85ab413bSRiver Riddle     if (typeRangeIndex > maxTypeRangeMemoryIndex)
493*85ab413bSRiver Riddle       maxTypeRangeMemoryIndex = typeRangeIndex;
494*85ab413bSRiver Riddle     if (valueRangeIndex > maxValueRangeMemoryIndex)
495*85ab413bSRiver Riddle       maxValueRangeMemoryIndex = valueRangeIndex;
496abfd1a8bSRiver Riddle   }
497abfd1a8bSRiver Riddle 
498abfd1a8bSRiver Riddle   // The matcher function uses a more sophisticated numbering that tries to
499abfd1a8bSRiver Riddle   // minimize the number of memory indices assigned. This is done by determining
500abfd1a8bSRiver Riddle   // a live range of the values within the matcher, then the allocation is just
501abfd1a8bSRiver Riddle   // finding the minimal number of overlapping live ranges. This is essentially
502abfd1a8bSRiver Riddle   // a simplified form of register allocation where we don't necessarily have a
503abfd1a8bSRiver Riddle   // limited number of registers, but we still want to minimize the number used.
504abfd1a8bSRiver Riddle   DenseMap<Operation *, ByteCodeField> opToIndex;
505abfd1a8bSRiver Riddle   matcherFunc.getBody().walk([&](Operation *op) {
506abfd1a8bSRiver Riddle     opToIndex.insert(std::make_pair(op, opToIndex.size()));
507abfd1a8bSRiver Riddle   });
508abfd1a8bSRiver Riddle 
509abfd1a8bSRiver Riddle   // Liveness info for each of the defs within the matcher.
510*85ab413bSRiver Riddle   ByteCodeLiveRange::Allocator allocator;
511*85ab413bSRiver Riddle   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
512abfd1a8bSRiver Riddle 
513abfd1a8bSRiver Riddle   // Assign the root operation being matched to slot 0.
514abfd1a8bSRiver Riddle   BlockArgument rootOpArg = matcherFunc.getArgument(0);
515abfd1a8bSRiver Riddle   valueToMemIndex[rootOpArg] = 0;
516abfd1a8bSRiver Riddle 
517abfd1a8bSRiver Riddle   // Walk each of the blocks, computing the def interval that the value is used.
518abfd1a8bSRiver Riddle   Liveness matcherLiveness(matcherFunc);
519abfd1a8bSRiver Riddle   for (Block &block : matcherFunc.getBody()) {
520abfd1a8bSRiver Riddle     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
521abfd1a8bSRiver Riddle     assert(info && "expected liveness info for block");
522abfd1a8bSRiver Riddle     auto processValue = [&](Value value, Operation *firstUseOrDef) {
523abfd1a8bSRiver Riddle       // We don't need to process the root op argument, this value is always
524abfd1a8bSRiver Riddle       // assigned to the first memory slot.
525abfd1a8bSRiver Riddle       if (value == rootOpArg)
526abfd1a8bSRiver Riddle         return;
527abfd1a8bSRiver Riddle 
528abfd1a8bSRiver Riddle       // Set indices for the range of this block that the value is used.
529abfd1a8bSRiver Riddle       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
530*85ab413bSRiver Riddle       defRangeIt->second.liveness.insert(
531abfd1a8bSRiver Riddle           opToIndex[firstUseOrDef],
532abfd1a8bSRiver Riddle           opToIndex[info->getEndOperation(value, firstUseOrDef)],
533abfd1a8bSRiver Riddle           /*dummyValue*/ 0);
534*85ab413bSRiver Riddle 
535*85ab413bSRiver Riddle       // Check to see if this value is a range type.
536*85ab413bSRiver Riddle       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
537*85ab413bSRiver Riddle         Type eleType = rangeTy.getElementType();
538*85ab413bSRiver Riddle         if (eleType.isa<pdl::TypeType>())
539*85ab413bSRiver Riddle           defRangeIt->second.typeRangeIndex = 0;
540*85ab413bSRiver Riddle         else if (eleType.isa<pdl::ValueType>())
541*85ab413bSRiver Riddle           defRangeIt->second.valueRangeIndex = 0;
542*85ab413bSRiver Riddle       }
543abfd1a8bSRiver Riddle     };
544abfd1a8bSRiver Riddle 
545abfd1a8bSRiver Riddle     // Process the live-ins of this block.
546abfd1a8bSRiver Riddle     for (Value liveIn : info->in())
547abfd1a8bSRiver Riddle       processValue(liveIn, &block.front());
548abfd1a8bSRiver Riddle 
549abfd1a8bSRiver Riddle     // Process any new defs within this block.
550abfd1a8bSRiver Riddle     for (Operation &op : block)
551abfd1a8bSRiver Riddle       for (Value result : op.getResults())
552abfd1a8bSRiver Riddle         processValue(result, &op);
553abfd1a8bSRiver Riddle   }
554abfd1a8bSRiver Riddle 
555abfd1a8bSRiver Riddle   // Greedily allocate memory slots using the computed def live ranges.
556*85ab413bSRiver Riddle   std::vector<ByteCodeLiveRange> allocatedIndices;
557*85ab413bSRiver Riddle   ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
558abfd1a8bSRiver Riddle   for (auto &defIt : valueDefRanges) {
559abfd1a8bSRiver Riddle     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
560*85ab413bSRiver Riddle     ByteCodeLiveRange &defRange = defIt.second;
561abfd1a8bSRiver Riddle 
562abfd1a8bSRiver Riddle     // Try to allocate to an existing index.
563abfd1a8bSRiver Riddle     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
564*85ab413bSRiver Riddle       ByteCodeLiveRange &existingRange = existingIndexIt.value();
565*85ab413bSRiver Riddle       if (!defRange.overlaps(existingRange)) {
566*85ab413bSRiver Riddle         existingRange.unionWith(defRange);
567abfd1a8bSRiver Riddle         memIndex = existingIndexIt.index() + 1;
568*85ab413bSRiver Riddle 
569*85ab413bSRiver Riddle         if (defRange.typeRangeIndex) {
570*85ab413bSRiver Riddle           if (!existingRange.typeRangeIndex)
571*85ab413bSRiver Riddle             existingRange.typeRangeIndex = numTypeRanges++;
572*85ab413bSRiver Riddle           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
573*85ab413bSRiver Riddle         } else if (defRange.valueRangeIndex) {
574*85ab413bSRiver Riddle           if (!existingRange.valueRangeIndex)
575*85ab413bSRiver Riddle             existingRange.valueRangeIndex = numValueRanges++;
576*85ab413bSRiver Riddle           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
577*85ab413bSRiver Riddle         }
578*85ab413bSRiver Riddle         break;
579*85ab413bSRiver Riddle       }
580abfd1a8bSRiver Riddle     }
581abfd1a8bSRiver Riddle 
582abfd1a8bSRiver Riddle     // If no existing index could be used, add a new one.
583abfd1a8bSRiver Riddle     if (memIndex == 0) {
584abfd1a8bSRiver Riddle       allocatedIndices.emplace_back(allocator);
585*85ab413bSRiver Riddle       ByteCodeLiveRange &newRange = allocatedIndices.back();
586*85ab413bSRiver Riddle       newRange.unionWith(defRange);
587*85ab413bSRiver Riddle 
588*85ab413bSRiver Riddle       // Allocate an index for type/value ranges.
589*85ab413bSRiver Riddle       if (defRange.typeRangeIndex) {
590*85ab413bSRiver Riddle         newRange.typeRangeIndex = numTypeRanges;
591*85ab413bSRiver Riddle         valueToRangeIndex[defIt.first] = numTypeRanges++;
592*85ab413bSRiver Riddle       } else if (defRange.valueRangeIndex) {
593*85ab413bSRiver Riddle         newRange.valueRangeIndex = numValueRanges;
594*85ab413bSRiver Riddle         valueToRangeIndex[defIt.first] = numValueRanges++;
595*85ab413bSRiver Riddle       }
596*85ab413bSRiver Riddle 
597abfd1a8bSRiver Riddle       memIndex = allocatedIndices.size();
598*85ab413bSRiver Riddle       ++numIndices;
599abfd1a8bSRiver Riddle     }
600abfd1a8bSRiver Riddle   }
601abfd1a8bSRiver Riddle 
602abfd1a8bSRiver Riddle   // Update the max number of indices.
603*85ab413bSRiver Riddle   if (numIndices > maxValueMemoryIndex)
604*85ab413bSRiver Riddle     maxValueMemoryIndex = numIndices;
605*85ab413bSRiver Riddle   if (numTypeRanges > maxTypeRangeMemoryIndex)
606*85ab413bSRiver Riddle     maxTypeRangeMemoryIndex = numTypeRanges;
607*85ab413bSRiver Riddle   if (numValueRanges > maxValueRangeMemoryIndex)
608*85ab413bSRiver Riddle     maxValueRangeMemoryIndex = numValueRanges;
609abfd1a8bSRiver Riddle }
610abfd1a8bSRiver Riddle 
611abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) {
612abfd1a8bSRiver Riddle   TypeSwitch<Operation *>(op)
613abfd1a8bSRiver Riddle       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
614abfd1a8bSRiver Riddle             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
615abfd1a8bSRiver Riddle             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
616abfd1a8bSRiver Riddle             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
617*85ab413bSRiver Riddle             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
618*85ab413bSRiver Riddle             pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
619*85ab413bSRiver Riddle             pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
62002c4c0d5SRiver Riddle             pdl_interp::EraseOp, pdl_interp::FinalizeOp,
62102c4c0d5SRiver Riddle             pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
62202c4c0d5SRiver Riddle             pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
623*85ab413bSRiver Riddle             pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
624*85ab413bSRiver Riddle             pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
6253a833a0eSRiver Riddle             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
62602c4c0d5SRiver Riddle             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
62702c4c0d5SRiver Riddle             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
628*85ab413bSRiver Riddle             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
629*85ab413bSRiver Riddle             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
630abfd1a8bSRiver Riddle           [&](auto interpOp) { this->generate(interpOp, writer); })
631abfd1a8bSRiver Riddle       .Default([](Operation *) {
632abfd1a8bSRiver Riddle         llvm_unreachable("unknown `pdl_interp` operation");
633abfd1a8bSRiver Riddle       });
634abfd1a8bSRiver Riddle }
635abfd1a8bSRiver Riddle 
636abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op,
637abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
638abfd1a8bSRiver Riddle   assert(constraintToMemIndex.count(op.name()) &&
639abfd1a8bSRiver Riddle          "expected index for constraint function");
640abfd1a8bSRiver Riddle   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
641abfd1a8bSRiver Riddle                 op.constParamsAttr());
642abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
643abfd1a8bSRiver Riddle   writer.append(op.getSuccessors());
644abfd1a8bSRiver Riddle }
645abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op,
646abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
647abfd1a8bSRiver Riddle   assert(externalRewriterToMemIndex.count(op.name()) &&
648abfd1a8bSRiver Riddle          "expected index for rewrite function");
649abfd1a8bSRiver Riddle   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
65002c4c0d5SRiver Riddle                 op.constParamsAttr());
651abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
65202c4c0d5SRiver Riddle 
653*85ab413bSRiver Riddle   ResultRange results = op.results();
654*85ab413bSRiver Riddle   writer.append(ByteCodeField(results.size()));
655*85ab413bSRiver Riddle   for (Value result : results) {
656*85ab413bSRiver Riddle     // In debug mode we also record the expected kind of the result, so that we
657*85ab413bSRiver Riddle     // can provide extra verification of the native rewrite function.
65802c4c0d5SRiver Riddle #ifndef NDEBUG
659*85ab413bSRiver Riddle     writer.appendPDLValueKind(result);
66002c4c0d5SRiver Riddle #endif
661*85ab413bSRiver Riddle 
662*85ab413bSRiver Riddle     // Range results also need to append the range storage index.
663*85ab413bSRiver Riddle     if (result.getType().isa<pdl::RangeType>())
664*85ab413bSRiver Riddle       writer.append(getRangeStorageIndex(result));
66502c4c0d5SRiver Riddle     writer.append(result);
666abfd1a8bSRiver Riddle   }
667*85ab413bSRiver Riddle }
668abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
669*85ab413bSRiver Riddle   Value lhs = op.lhs();
670*85ab413bSRiver Riddle   if (lhs.getType().isa<pdl::RangeType>()) {
671*85ab413bSRiver Riddle     writer.append(OpCode::AreRangesEqual);
672*85ab413bSRiver Riddle     writer.appendPDLValueKind(lhs);
673*85ab413bSRiver Riddle     writer.append(op.lhs(), op.rhs(), op.getSuccessors());
674*85ab413bSRiver Riddle     return;
675*85ab413bSRiver Riddle   }
676*85ab413bSRiver Riddle 
677*85ab413bSRiver Riddle   writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
678abfd1a8bSRiver Riddle }
679abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
6808affe881SRiver Riddle   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
681abfd1a8bSRiver Riddle }
682abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op,
683abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
684abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
685abfd1a8bSRiver Riddle                 op.getSuccessors());
686abfd1a8bSRiver Riddle }
687abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op,
688abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
689abfd1a8bSRiver Riddle   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
690*85ab413bSRiver Riddle                 static_cast<ByteCodeField>(op.compareAtLeast()),
691abfd1a8bSRiver Riddle                 op.getSuccessors());
692abfd1a8bSRiver Riddle }
693abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op,
694abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
695abfd1a8bSRiver Riddle   writer.append(OpCode::CheckOperationName, op.operation(),
696abfd1a8bSRiver Riddle                 OperationName(op.name(), ctx), op.getSuccessors());
697abfd1a8bSRiver Riddle }
698abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op,
699abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
700abfd1a8bSRiver Riddle   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
701*85ab413bSRiver Riddle                 static_cast<ByteCodeField>(op.compareAtLeast()),
702abfd1a8bSRiver Riddle                 op.getSuccessors());
703abfd1a8bSRiver Riddle }
704abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
705abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
706abfd1a8bSRiver Riddle }
707*85ab413bSRiver Riddle void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
708*85ab413bSRiver Riddle   writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
709*85ab413bSRiver Riddle }
710abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op,
711abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
712abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
713abfd1a8bSRiver Riddle   getMemIndex(op.attribute()) = getMemIndex(op.value());
714abfd1a8bSRiver Riddle }
715abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op,
716abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
717abfd1a8bSRiver Riddle   writer.append(OpCode::CreateOperation, op.operation(),
718*85ab413bSRiver Riddle                 OperationName(op.name(), ctx));
719*85ab413bSRiver Riddle   writer.appendPDLValueList(op.operands());
720abfd1a8bSRiver Riddle 
721abfd1a8bSRiver Riddle   // Add the attributes.
722abfd1a8bSRiver Riddle   OperandRange attributes = op.attributes();
723abfd1a8bSRiver Riddle   writer.append(static_cast<ByteCodeField>(attributes.size()));
724abfd1a8bSRiver Riddle   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
725abfd1a8bSRiver Riddle     writer.append(
726abfd1a8bSRiver Riddle         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
727abfd1a8bSRiver Riddle         std::get<1>(it));
728abfd1a8bSRiver Riddle   }
729*85ab413bSRiver Riddle   writer.appendPDLValueList(op.types());
730abfd1a8bSRiver Riddle }
731abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
732abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
733abfd1a8bSRiver Riddle   getMemIndex(op.result()) = getMemIndex(op.value());
734abfd1a8bSRiver Riddle }
735*85ab413bSRiver Riddle void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
736*85ab413bSRiver Riddle   writer.append(OpCode::CreateTypes, op.result(),
737*85ab413bSRiver Riddle                 getRangeStorageIndex(op.result()), op.value());
738*85ab413bSRiver Riddle }
739abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
740abfd1a8bSRiver Riddle   writer.append(OpCode::EraseOp, op.operation());
741abfd1a8bSRiver Riddle }
742abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
743abfd1a8bSRiver Riddle   writer.append(OpCode::Finalize);
744abfd1a8bSRiver Riddle }
745abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op,
746abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
747abfd1a8bSRiver Riddle   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
748abfd1a8bSRiver Riddle                 Identifier::get(op.name(), ctx));
749abfd1a8bSRiver Riddle }
750abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op,
751abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
752abfd1a8bSRiver Riddle   writer.append(OpCode::GetAttributeType, op.result(), op.value());
753abfd1a8bSRiver Riddle }
754abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op,
755abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
756*85ab413bSRiver Riddle   writer.append(OpCode::GetDefiningOp, op.operation());
757*85ab413bSRiver Riddle   writer.appendPDLValue(op.value());
758abfd1a8bSRiver Riddle }
759abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
760abfd1a8bSRiver Riddle   uint32_t index = op.index();
761abfd1a8bSRiver Riddle   if (index < 4)
762abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
763abfd1a8bSRiver Riddle   else
764abfd1a8bSRiver Riddle     writer.append(OpCode::GetOperandN, index);
765abfd1a8bSRiver Riddle   writer.append(op.operation(), op.value());
766abfd1a8bSRiver Riddle }
767*85ab413bSRiver Riddle void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
768*85ab413bSRiver Riddle   Value result = op.value();
769*85ab413bSRiver Riddle   Optional<uint32_t> index = op.index();
770*85ab413bSRiver Riddle   writer.append(OpCode::GetOperands,
771*85ab413bSRiver Riddle                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
772*85ab413bSRiver Riddle                 op.operation());
773*85ab413bSRiver Riddle   if (result.getType().isa<pdl::RangeType>())
774*85ab413bSRiver Riddle     writer.append(getRangeStorageIndex(result));
775*85ab413bSRiver Riddle   else
776*85ab413bSRiver Riddle     writer.append(std::numeric_limits<ByteCodeField>::max());
777*85ab413bSRiver Riddle   writer.append(result);
778*85ab413bSRiver Riddle }
779abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
780abfd1a8bSRiver Riddle   uint32_t index = op.index();
781abfd1a8bSRiver Riddle   if (index < 4)
782abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
783abfd1a8bSRiver Riddle   else
784abfd1a8bSRiver Riddle     writer.append(OpCode::GetResultN, index);
785abfd1a8bSRiver Riddle   writer.append(op.operation(), op.value());
786abfd1a8bSRiver Riddle }
787*85ab413bSRiver Riddle void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
788*85ab413bSRiver Riddle   Value result = op.value();
789*85ab413bSRiver Riddle   Optional<uint32_t> index = op.index();
790*85ab413bSRiver Riddle   writer.append(OpCode::GetResults,
791*85ab413bSRiver Riddle                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
792*85ab413bSRiver Riddle                 op.operation());
793*85ab413bSRiver Riddle   if (result.getType().isa<pdl::RangeType>())
794*85ab413bSRiver Riddle     writer.append(getRangeStorageIndex(result));
795*85ab413bSRiver Riddle   else
796*85ab413bSRiver Riddle     writer.append(std::numeric_limits<ByteCodeField>::max());
797*85ab413bSRiver Riddle   writer.append(result);
798*85ab413bSRiver Riddle }
799abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op,
800abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
801*85ab413bSRiver Riddle   if (op.getType().isa<pdl::RangeType>()) {
802*85ab413bSRiver Riddle     Value result = op.result();
803*85ab413bSRiver Riddle     writer.append(OpCode::GetValueRangeTypes, result,
804*85ab413bSRiver Riddle                   getRangeStorageIndex(result), op.value());
805*85ab413bSRiver Riddle   } else {
806abfd1a8bSRiver Riddle     writer.append(OpCode::GetValueType, op.result(), op.value());
807abfd1a8bSRiver Riddle   }
808*85ab413bSRiver Riddle }
809*85ab413bSRiver Riddle 
8103a833a0eSRiver Riddle void Generator::generate(pdl_interp::InferredTypesOp op,
811abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
8123a833a0eSRiver Riddle   // InferType maps to a null type as a marker for inferring result types.
813abfd1a8bSRiver Riddle   getMemIndex(op.type()) = getMemIndex(Type());
814abfd1a8bSRiver Riddle }
815abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
816abfd1a8bSRiver Riddle   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
817abfd1a8bSRiver Riddle }
818abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
819abfd1a8bSRiver Riddle   ByteCodeField patternIndex = patterns.size();
820abfd1a8bSRiver Riddle   patterns.emplace_back(PDLByteCodePattern::create(
821abfd1a8bSRiver Riddle       op, rewriterToAddr[op.rewriter().getLeafReference()]));
8228affe881SRiver Riddle   writer.append(OpCode::RecordMatch, patternIndex,
823*85ab413bSRiver Riddle                 SuccessorRange(op.getOperation()), op.matchedOps());
824*85ab413bSRiver Riddle   writer.appendPDLValueList(op.inputs());
825abfd1a8bSRiver Riddle }
826abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
827*85ab413bSRiver Riddle   writer.append(OpCode::ReplaceOp, op.operation());
828*85ab413bSRiver Riddle   writer.appendPDLValueList(op.replValues());
829abfd1a8bSRiver Riddle }
830abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op,
831abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
832abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
833abfd1a8bSRiver Riddle                 op.getSuccessors());
834abfd1a8bSRiver Riddle }
835abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op,
836abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
837abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
838abfd1a8bSRiver Riddle                 op.getSuccessors());
839abfd1a8bSRiver Riddle }
840abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op,
841abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
842abfd1a8bSRiver Riddle   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
843abfd1a8bSRiver Riddle     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
844abfd1a8bSRiver Riddle   });
845abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
846abfd1a8bSRiver Riddle                 op.getSuccessors());
847abfd1a8bSRiver Riddle }
848abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op,
849abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
850abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
851abfd1a8bSRiver Riddle                 op.getSuccessors());
852abfd1a8bSRiver Riddle }
853abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
854abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
855abfd1a8bSRiver Riddle                 op.getSuccessors());
856abfd1a8bSRiver Riddle }
857*85ab413bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
858*85ab413bSRiver Riddle   writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
859*85ab413bSRiver Riddle                 op.getSuccessors());
860*85ab413bSRiver Riddle }
861abfd1a8bSRiver Riddle 
862abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
863abfd1a8bSRiver Riddle // PDLByteCode
864abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
865abfd1a8bSRiver Riddle 
866abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module,
867abfd1a8bSRiver Riddle                          llvm::StringMap<PDLConstraintFunction> constraintFns,
868abfd1a8bSRiver Riddle                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
869abfd1a8bSRiver Riddle   Generator generator(module.getContext(), uniquedData, matcherByteCode,
870abfd1a8bSRiver Riddle                       rewriterByteCode, patterns, maxValueMemoryIndex,
871*85ab413bSRiver Riddle                       maxTypeRangeCount, maxValueRangeCount, constraintFns,
872*85ab413bSRiver Riddle                       rewriteFns);
873abfd1a8bSRiver Riddle   generator.generate(module);
874abfd1a8bSRiver Riddle 
875abfd1a8bSRiver Riddle   // Initialize the external functions.
876abfd1a8bSRiver Riddle   for (auto &it : constraintFns)
877abfd1a8bSRiver Riddle     constraintFunctions.push_back(std::move(it.second));
878abfd1a8bSRiver Riddle   for (auto &it : rewriteFns)
879abfd1a8bSRiver Riddle     rewriteFunctions.push_back(std::move(it.second));
880abfd1a8bSRiver Riddle }
881abfd1a8bSRiver Riddle 
882abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current
883abfd1a8bSRiver Riddle /// bytecode.
884abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
885abfd1a8bSRiver Riddle   state.memory.resize(maxValueMemoryIndex, nullptr);
886*85ab413bSRiver Riddle   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
887*85ab413bSRiver Riddle   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
888abfd1a8bSRiver Riddle   state.currentPatternBenefits.reserve(patterns.size());
889abfd1a8bSRiver Riddle   for (const PDLByteCodePattern &pattern : patterns)
890abfd1a8bSRiver Riddle     state.currentPatternBenefits.push_back(pattern.getBenefit());
891abfd1a8bSRiver Riddle }
892abfd1a8bSRiver Riddle 
893abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
894abfd1a8bSRiver Riddle // ByteCode Execution
895abfd1a8bSRiver Riddle 
896abfd1a8bSRiver Riddle namespace {
897abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream.
898abfd1a8bSRiver Riddle class ByteCodeExecutor {
899abfd1a8bSRiver Riddle public:
900*85ab413bSRiver Riddle   ByteCodeExecutor(
901*85ab413bSRiver Riddle       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
902*85ab413bSRiver Riddle       MutableArrayRef<TypeRange> typeRangeMemory,
903*85ab413bSRiver Riddle       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
904*85ab413bSRiver Riddle       MutableArrayRef<ValueRange> valueRangeMemory,
905*85ab413bSRiver Riddle       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
906*85ab413bSRiver Riddle       ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
907abfd1a8bSRiver Riddle       ArrayRef<PatternBenefit> currentPatternBenefits,
908abfd1a8bSRiver Riddle       ArrayRef<PDLByteCodePattern> patterns,
909abfd1a8bSRiver Riddle       ArrayRef<PDLConstraintFunction> constraintFunctions,
910abfd1a8bSRiver Riddle       ArrayRef<PDLRewriteFunction> rewriteFunctions)
911*85ab413bSRiver Riddle       : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
912*85ab413bSRiver Riddle         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
913*85ab413bSRiver Riddle         valueRangeMemory(valueRangeMemory),
914*85ab413bSRiver Riddle         allocatedValueRangeMemory(allocatedValueRangeMemory),
915*85ab413bSRiver Riddle         uniquedMemory(uniquedMemory), code(code),
916*85ab413bSRiver Riddle         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
917*85ab413bSRiver Riddle         constraintFunctions(constraintFunctions),
91802c4c0d5SRiver Riddle         rewriteFunctions(rewriteFunctions) {}
919abfd1a8bSRiver Riddle 
920abfd1a8bSRiver Riddle   /// Start executing the code at the current bytecode index. `matches` is an
921abfd1a8bSRiver Riddle   /// optional field provided when this function is executed in a matching
922abfd1a8bSRiver Riddle   /// context.
923abfd1a8bSRiver Riddle   void execute(PatternRewriter &rewriter,
924abfd1a8bSRiver Riddle                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
925abfd1a8bSRiver Riddle                Optional<Location> mainRewriteLoc = {});
926abfd1a8bSRiver Riddle 
927abfd1a8bSRiver Riddle private:
928154cabe7SRiver Riddle   /// Internal implementation of executing each of the bytecode commands.
929154cabe7SRiver Riddle   void executeApplyConstraint(PatternRewriter &rewriter);
930154cabe7SRiver Riddle   void executeApplyRewrite(PatternRewriter &rewriter);
931154cabe7SRiver Riddle   void executeAreEqual();
932*85ab413bSRiver Riddle   void executeAreRangesEqual();
933154cabe7SRiver Riddle   void executeBranch();
934154cabe7SRiver Riddle   void executeCheckOperandCount();
935154cabe7SRiver Riddle   void executeCheckOperationName();
936154cabe7SRiver Riddle   void executeCheckResultCount();
937*85ab413bSRiver Riddle   void executeCheckTypes();
938154cabe7SRiver Riddle   void executeCreateOperation(PatternRewriter &rewriter,
939154cabe7SRiver Riddle                               Location mainRewriteLoc);
940*85ab413bSRiver Riddle   void executeCreateTypes();
941154cabe7SRiver Riddle   void executeEraseOp(PatternRewriter &rewriter);
942154cabe7SRiver Riddle   void executeGetAttribute();
943154cabe7SRiver Riddle   void executeGetAttributeType();
944154cabe7SRiver Riddle   void executeGetDefiningOp();
945154cabe7SRiver Riddle   void executeGetOperand(unsigned index);
946*85ab413bSRiver Riddle   void executeGetOperands();
947154cabe7SRiver Riddle   void executeGetResult(unsigned index);
948*85ab413bSRiver Riddle   void executeGetResults();
949154cabe7SRiver Riddle   void executeGetValueType();
950*85ab413bSRiver Riddle   void executeGetValueRangeTypes();
951154cabe7SRiver Riddle   void executeIsNotNull();
952154cabe7SRiver Riddle   void executeRecordMatch(PatternRewriter &rewriter,
953154cabe7SRiver Riddle                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
954154cabe7SRiver Riddle   void executeReplaceOp(PatternRewriter &rewriter);
955154cabe7SRiver Riddle   void executeSwitchAttribute();
956154cabe7SRiver Riddle   void executeSwitchOperandCount();
957154cabe7SRiver Riddle   void executeSwitchOperationName();
958154cabe7SRiver Riddle   void executeSwitchResultCount();
959154cabe7SRiver Riddle   void executeSwitchType();
960*85ab413bSRiver Riddle   void executeSwitchTypes();
961154cabe7SRiver Riddle 
962abfd1a8bSRiver Riddle   /// Read a value from the bytecode buffer, optionally skipping a certain
963abfd1a8bSRiver Riddle   /// number of prefix values. These methods always update the buffer to point
964abfd1a8bSRiver Riddle   /// to the next field after the read data.
965abfd1a8bSRiver Riddle   template <typename T = ByteCodeField>
966abfd1a8bSRiver Riddle   T read(size_t skipN = 0) {
967abfd1a8bSRiver Riddle     curCodeIt += skipN;
968abfd1a8bSRiver Riddle     return readImpl<T>();
969abfd1a8bSRiver Riddle   }
970abfd1a8bSRiver Riddle   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
971abfd1a8bSRiver Riddle 
972abfd1a8bSRiver Riddle   /// Read a list of values from the bytecode buffer.
973abfd1a8bSRiver Riddle   template <typename ValueT, typename T>
974abfd1a8bSRiver Riddle   void readList(SmallVectorImpl<T> &list) {
975abfd1a8bSRiver Riddle     list.clear();
976abfd1a8bSRiver Riddle     for (unsigned i = 0, e = read(); i != e; ++i)
977abfd1a8bSRiver Riddle       list.push_back(read<ValueT>());
978abfd1a8bSRiver Riddle   }
979abfd1a8bSRiver Riddle 
980*85ab413bSRiver Riddle   /// Read a list of values from the bytecode buffer. The values may be encoded
981*85ab413bSRiver Riddle   /// as either Value or ValueRange elements.
982*85ab413bSRiver Riddle   void readValueList(SmallVectorImpl<Value> &list) {
983*85ab413bSRiver Riddle     for (unsigned i = 0, e = read(); i != e; ++i) {
984*85ab413bSRiver Riddle       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
985*85ab413bSRiver Riddle         list.push_back(read<Value>());
986*85ab413bSRiver Riddle       } else {
987*85ab413bSRiver Riddle         ValueRange *values = read<ValueRange *>();
988*85ab413bSRiver Riddle         list.append(values->begin(), values->end());
989*85ab413bSRiver Riddle       }
990*85ab413bSRiver Riddle     }
991*85ab413bSRiver Riddle   }
992*85ab413bSRiver Riddle 
993abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a predicate value.
994abfd1a8bSRiver Riddle   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
995abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a destination index.
996abfd1a8bSRiver Riddle   void selectJump(size_t destIndex) {
997abfd1a8bSRiver Riddle     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
998abfd1a8bSRiver Riddle   }
999abfd1a8bSRiver Riddle 
1000abfd1a8bSRiver Riddle   /// Handle a switch operation with the provided value and cases.
1001*85ab413bSRiver Riddle   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1002*85ab413bSRiver Riddle   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1003abfd1a8bSRiver Riddle     LLVM_DEBUG({
1004abfd1a8bSRiver Riddle       llvm::dbgs() << "  * Value: " << value << "\n"
1005abfd1a8bSRiver Riddle                    << "  * Cases: ";
1006abfd1a8bSRiver Riddle       llvm::interleaveComma(cases, llvm::dbgs());
1007154cabe7SRiver Riddle       llvm::dbgs() << "\n";
1008abfd1a8bSRiver Riddle     });
1009abfd1a8bSRiver Riddle 
1010abfd1a8bSRiver Riddle     // Check to see if the attribute value is within the case list. Jump to
1011abfd1a8bSRiver Riddle     // the correct successor index based on the result.
1012f80b6304SRiver Riddle     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1013*85ab413bSRiver Riddle       if (cmp(*it, value))
1014f80b6304SRiver Riddle         return selectJump(size_t((it - cases.begin()) + 1));
1015f80b6304SRiver Riddle     selectJump(size_t(0));
1016abfd1a8bSRiver Riddle   }
1017abfd1a8bSRiver Riddle 
1018abfd1a8bSRiver Riddle   /// Internal implementation of reading various data types from the bytecode
1019abfd1a8bSRiver Riddle   /// stream.
1020abfd1a8bSRiver Riddle   template <typename T>
1021abfd1a8bSRiver Riddle   const void *readFromMemory() {
1022abfd1a8bSRiver Riddle     size_t index = *curCodeIt++;
1023abfd1a8bSRiver Riddle 
1024abfd1a8bSRiver Riddle     // If this type is an SSA value, it can only be stored in non-const memory.
1025*85ab413bSRiver Riddle     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1026*85ab413bSRiver Riddle                         Value>::value ||
1027*85ab413bSRiver Riddle         index < memory.size())
1028abfd1a8bSRiver Riddle       return memory[index];
1029abfd1a8bSRiver Riddle 
1030abfd1a8bSRiver Riddle     // Otherwise, if this index is not inbounds it is uniqued.
1031abfd1a8bSRiver Riddle     return uniquedMemory[index - memory.size()];
1032abfd1a8bSRiver Riddle   }
1033abfd1a8bSRiver Riddle   template <typename T>
1034abfd1a8bSRiver Riddle   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1035abfd1a8bSRiver Riddle     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1036abfd1a8bSRiver Riddle   }
1037abfd1a8bSRiver Riddle   template <typename T>
1038abfd1a8bSRiver Riddle   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1039abfd1a8bSRiver Riddle                    T>
1040abfd1a8bSRiver Riddle   readImpl() {
1041abfd1a8bSRiver Riddle     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1042abfd1a8bSRiver Riddle   }
1043abfd1a8bSRiver Riddle   template <typename T>
1044abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1045*85ab413bSRiver Riddle     switch (read<PDLValue::Kind>()) {
1046*85ab413bSRiver Riddle     case PDLValue::Kind::Attribute:
1047abfd1a8bSRiver Riddle       return read<Attribute>();
1048*85ab413bSRiver Riddle     case PDLValue::Kind::Operation:
1049abfd1a8bSRiver Riddle       return read<Operation *>();
1050*85ab413bSRiver Riddle     case PDLValue::Kind::Type:
1051abfd1a8bSRiver Riddle       return read<Type>();
1052*85ab413bSRiver Riddle     case PDLValue::Kind::Value:
1053abfd1a8bSRiver Riddle       return read<Value>();
1054*85ab413bSRiver Riddle     case PDLValue::Kind::TypeRange:
1055*85ab413bSRiver Riddle       return read<TypeRange *>();
1056*85ab413bSRiver Riddle     case PDLValue::Kind::ValueRange:
1057*85ab413bSRiver Riddle       return read<ValueRange *>();
1058abfd1a8bSRiver Riddle     }
1059*85ab413bSRiver Riddle     llvm_unreachable("unhandled PDLValue::Kind");
1060abfd1a8bSRiver Riddle   }
1061abfd1a8bSRiver Riddle   template <typename T>
1062abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1063abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1064abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
1065abfd1a8bSRiver Riddle     ByteCodeAddr result;
1066abfd1a8bSRiver Riddle     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1067abfd1a8bSRiver Riddle     curCodeIt += 2;
1068abfd1a8bSRiver Riddle     return result;
1069abfd1a8bSRiver Riddle   }
1070abfd1a8bSRiver Riddle   template <typename T>
1071abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1072abfd1a8bSRiver Riddle     return *curCodeIt++;
1073abfd1a8bSRiver Riddle   }
1074*85ab413bSRiver Riddle   template <typename T>
1075*85ab413bSRiver Riddle   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1076*85ab413bSRiver Riddle     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1077*85ab413bSRiver Riddle   }
1078abfd1a8bSRiver Riddle 
1079abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
1080abfd1a8bSRiver Riddle   const ByteCodeField *curCodeIt;
1081abfd1a8bSRiver Riddle 
1082abfd1a8bSRiver Riddle   /// The current execution memory.
1083abfd1a8bSRiver Riddle   MutableArrayRef<const void *> memory;
1084*85ab413bSRiver Riddle   MutableArrayRef<TypeRange> typeRangeMemory;
1085*85ab413bSRiver Riddle   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1086*85ab413bSRiver Riddle   MutableArrayRef<ValueRange> valueRangeMemory;
1087*85ab413bSRiver Riddle   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1088abfd1a8bSRiver Riddle 
1089abfd1a8bSRiver Riddle   /// References to ByteCode data necessary for execution.
1090abfd1a8bSRiver Riddle   ArrayRef<const void *> uniquedMemory;
1091abfd1a8bSRiver Riddle   ArrayRef<ByteCodeField> code;
1092abfd1a8bSRiver Riddle   ArrayRef<PatternBenefit> currentPatternBenefits;
1093abfd1a8bSRiver Riddle   ArrayRef<PDLByteCodePattern> patterns;
1094abfd1a8bSRiver Riddle   ArrayRef<PDLConstraintFunction> constraintFunctions;
1095abfd1a8bSRiver Riddle   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1096abfd1a8bSRiver Riddle };
109702c4c0d5SRiver Riddle 
109802c4c0d5SRiver Riddle /// This class is an instantiation of the PDLResultList that provides access to
109902c4c0d5SRiver Riddle /// the returned results. This API is not on `PDLResultList` to avoid
110002c4c0d5SRiver Riddle /// overexposing access to information specific solely to the ByteCode.
110102c4c0d5SRiver Riddle class ByteCodeRewriteResultList : public PDLResultList {
110202c4c0d5SRiver Riddle public:
1103*85ab413bSRiver Riddle   ByteCodeRewriteResultList(unsigned maxNumResults)
1104*85ab413bSRiver Riddle       : PDLResultList(maxNumResults) {}
1105*85ab413bSRiver Riddle 
110602c4c0d5SRiver Riddle   /// Return the list of PDL results.
110702c4c0d5SRiver Riddle   MutableArrayRef<PDLValue> getResults() { return results; }
1108*85ab413bSRiver Riddle 
1109*85ab413bSRiver Riddle   /// Return the type ranges allocated by this list.
1110*85ab413bSRiver Riddle   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1111*85ab413bSRiver Riddle     return allocatedTypeRanges;
1112*85ab413bSRiver Riddle   }
1113*85ab413bSRiver Riddle 
1114*85ab413bSRiver Riddle   /// Return the value ranges allocated by this list.
1115*85ab413bSRiver Riddle   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1116*85ab413bSRiver Riddle     return allocatedValueRanges;
1117*85ab413bSRiver Riddle   }
111802c4c0d5SRiver Riddle };
1119abfd1a8bSRiver Riddle } // end anonymous namespace
1120abfd1a8bSRiver Riddle 
1121154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1122abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1123abfd1a8bSRiver Riddle   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1124abfd1a8bSRiver Riddle   ArrayAttr constParams = read<ArrayAttr>();
1125abfd1a8bSRiver Riddle   SmallVector<PDLValue, 16> args;
1126abfd1a8bSRiver Riddle   readList<PDLValue>(args);
1127154cabe7SRiver Riddle 
1128abfd1a8bSRiver Riddle   LLVM_DEBUG({
1129abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Arguments: ";
1130abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1131154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1132abfd1a8bSRiver Riddle   });
1133abfd1a8bSRiver Riddle 
1134abfd1a8bSRiver Riddle   // Invoke the constraint and jump to the proper destination.
1135abfd1a8bSRiver Riddle   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1136abfd1a8bSRiver Riddle }
1137154cabe7SRiver Riddle 
1138154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1139abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1140abfd1a8bSRiver Riddle   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1141abfd1a8bSRiver Riddle   ArrayAttr constParams = read<ArrayAttr>();
1142abfd1a8bSRiver Riddle   SmallVector<PDLValue, 16> args;
1143abfd1a8bSRiver Riddle   readList<PDLValue>(args);
1144abfd1a8bSRiver Riddle 
1145abfd1a8bSRiver Riddle   LLVM_DEBUG({
114602c4c0d5SRiver Riddle     llvm::dbgs() << "  * Arguments: ";
1147abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1148154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1149abfd1a8bSRiver Riddle   });
1150*85ab413bSRiver Riddle 
1151*85ab413bSRiver Riddle   // Execute the rewrite function.
1152*85ab413bSRiver Riddle   ByteCodeField numResults = read();
1153*85ab413bSRiver Riddle   ByteCodeRewriteResultList results(numResults);
115402c4c0d5SRiver Riddle   rewriteFn(args, constParams, rewriter, results);
1155154cabe7SRiver Riddle 
1156*85ab413bSRiver Riddle   assert(results.getResults().size() == numResults &&
115702c4c0d5SRiver Riddle          "native PDL rewrite function returned unexpected number of results");
115802c4c0d5SRiver Riddle 
115902c4c0d5SRiver Riddle   // Store the results in the bytecode memory.
116002c4c0d5SRiver Riddle   for (PDLValue &result : results.getResults()) {
116102c4c0d5SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1162*85ab413bSRiver Riddle 
1163*85ab413bSRiver Riddle // In debug mode we also verify the expected kind of the result.
1164*85ab413bSRiver Riddle #ifndef NDEBUG
1165*85ab413bSRiver Riddle     assert(result.getKind() == read<PDLValue::Kind>() &&
1166*85ab413bSRiver Riddle            "native PDL rewrite function returned an unexpected type of result");
1167*85ab413bSRiver Riddle #endif
1168*85ab413bSRiver Riddle 
1169*85ab413bSRiver Riddle     // If the result is a range, we need to copy it over to the bytecodes
1170*85ab413bSRiver Riddle     // range memory.
1171*85ab413bSRiver Riddle     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1172*85ab413bSRiver Riddle       unsigned rangeIndex = read();
1173*85ab413bSRiver Riddle       typeRangeMemory[rangeIndex] = *typeRange;
1174*85ab413bSRiver Riddle       memory[read()] = &typeRangeMemory[rangeIndex];
1175*85ab413bSRiver Riddle     } else if (Optional<ValueRange> valueRange =
1176*85ab413bSRiver Riddle                    result.dyn_cast<ValueRange>()) {
1177*85ab413bSRiver Riddle       unsigned rangeIndex = read();
1178*85ab413bSRiver Riddle       valueRangeMemory[rangeIndex] = *valueRange;
1179*85ab413bSRiver Riddle       memory[read()] = &valueRangeMemory[rangeIndex];
1180*85ab413bSRiver Riddle     } else {
118102c4c0d5SRiver Riddle       memory[read()] = result.getAsOpaquePointer();
118202c4c0d5SRiver Riddle     }
1183abfd1a8bSRiver Riddle   }
1184154cabe7SRiver Riddle 
1185*85ab413bSRiver Riddle   // Copy over any underlying storage allocated for result ranges.
1186*85ab413bSRiver Riddle   for (auto &it : results.getAllocatedTypeRanges())
1187*85ab413bSRiver Riddle     allocatedTypeRangeMemory.push_back(std::move(it));
1188*85ab413bSRiver Riddle   for (auto &it : results.getAllocatedValueRanges())
1189*85ab413bSRiver Riddle     allocatedValueRangeMemory.push_back(std::move(it));
1190*85ab413bSRiver Riddle }
1191*85ab413bSRiver Riddle 
1192154cabe7SRiver Riddle void ByteCodeExecutor::executeAreEqual() {
1193abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1194abfd1a8bSRiver Riddle   const void *lhs = read<const void *>();
1195abfd1a8bSRiver Riddle   const void *rhs = read<const void *>();
1196abfd1a8bSRiver Riddle 
1197154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1198abfd1a8bSRiver Riddle   selectJump(lhs == rhs);
1199abfd1a8bSRiver Riddle }
1200154cabe7SRiver Riddle 
1201*85ab413bSRiver Riddle void ByteCodeExecutor::executeAreRangesEqual() {
1202*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1203*85ab413bSRiver Riddle   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1204*85ab413bSRiver Riddle   const void *lhs = read<const void *>();
1205*85ab413bSRiver Riddle   const void *rhs = read<const void *>();
1206*85ab413bSRiver Riddle 
1207*85ab413bSRiver Riddle   switch (valueKind) {
1208*85ab413bSRiver Riddle   case PDLValue::Kind::TypeRange: {
1209*85ab413bSRiver Riddle     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1210*85ab413bSRiver Riddle     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1211*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1212*85ab413bSRiver Riddle     selectJump(*lhsRange == *rhsRange);
1213*85ab413bSRiver Riddle     break;
1214*85ab413bSRiver Riddle   }
1215*85ab413bSRiver Riddle   case PDLValue::Kind::ValueRange: {
1216*85ab413bSRiver Riddle     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1217*85ab413bSRiver Riddle     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1218*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1219*85ab413bSRiver Riddle     selectJump(*lhsRange == *rhsRange);
1220*85ab413bSRiver Riddle     break;
1221*85ab413bSRiver Riddle   }
1222*85ab413bSRiver Riddle   default:
1223*85ab413bSRiver Riddle     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1224*85ab413bSRiver Riddle   }
1225*85ab413bSRiver Riddle }
1226*85ab413bSRiver Riddle 
1227154cabe7SRiver Riddle void ByteCodeExecutor::executeBranch() {
1228154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1229abfd1a8bSRiver Riddle   curCodeIt = &code[read<ByteCodeAddr>()];
1230abfd1a8bSRiver Riddle }
1231154cabe7SRiver Riddle 
1232154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperandCount() {
1233abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1234abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1235abfd1a8bSRiver Riddle   uint32_t expectedCount = read<uint32_t>();
1236*85ab413bSRiver Riddle   bool compareAtLeast = read();
1237abfd1a8bSRiver Riddle 
1238abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1239*85ab413bSRiver Riddle                           << "  * Expected: " << expectedCount << "\n"
1240*85ab413bSRiver Riddle                           << "  * Comparator: "
1241*85ab413bSRiver Riddle                           << (compareAtLeast ? ">=" : "==") << "\n");
1242*85ab413bSRiver Riddle   if (compareAtLeast)
1243*85ab413bSRiver Riddle     selectJump(op->getNumOperands() >= expectedCount);
1244*85ab413bSRiver Riddle   else
1245abfd1a8bSRiver Riddle     selectJump(op->getNumOperands() == expectedCount);
1246abfd1a8bSRiver Riddle }
1247154cabe7SRiver Riddle 
1248154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperationName() {
1249abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1250abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1251abfd1a8bSRiver Riddle   OperationName expectedName = read<OperationName>();
1252abfd1a8bSRiver Riddle 
1253154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1254154cabe7SRiver Riddle                           << "  * Expected: \"" << expectedName << "\"\n");
1255abfd1a8bSRiver Riddle   selectJump(op->getName() == expectedName);
1256abfd1a8bSRiver Riddle }
1257154cabe7SRiver Riddle 
1258154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckResultCount() {
1259abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1260abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1261abfd1a8bSRiver Riddle   uint32_t expectedCount = read<uint32_t>();
1262*85ab413bSRiver Riddle   bool compareAtLeast = read();
1263abfd1a8bSRiver Riddle 
1264abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1265*85ab413bSRiver Riddle                           << "  * Expected: " << expectedCount << "\n"
1266*85ab413bSRiver Riddle                           << "  * Comparator: "
1267*85ab413bSRiver Riddle                           << (compareAtLeast ? ">=" : "==") << "\n");
1268*85ab413bSRiver Riddle   if (compareAtLeast)
1269*85ab413bSRiver Riddle     selectJump(op->getNumResults() >= expectedCount);
1270*85ab413bSRiver Riddle   else
1271abfd1a8bSRiver Riddle     selectJump(op->getNumResults() == expectedCount);
1272abfd1a8bSRiver Riddle }
1273154cabe7SRiver Riddle 
1274*85ab413bSRiver Riddle void ByteCodeExecutor::executeCheckTypes() {
1275*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1276*85ab413bSRiver Riddle   TypeRange *lhs = read<TypeRange *>();
1277*85ab413bSRiver Riddle   Attribute rhs = read<Attribute>();
1278*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1279*85ab413bSRiver Riddle 
1280*85ab413bSRiver Riddle   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1281*85ab413bSRiver Riddle }
1282*85ab413bSRiver Riddle 
1283*85ab413bSRiver Riddle void ByteCodeExecutor::executeCreateTypes() {
1284*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1285*85ab413bSRiver Riddle   unsigned memIndex = read();
1286*85ab413bSRiver Riddle   unsigned rangeIndex = read();
1287*85ab413bSRiver Riddle   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1288*85ab413bSRiver Riddle 
1289*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1290*85ab413bSRiver Riddle 
1291*85ab413bSRiver Riddle   // Allocate a buffer for this type range.
1292*85ab413bSRiver Riddle   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1293*85ab413bSRiver Riddle   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1294*85ab413bSRiver Riddle   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1295*85ab413bSRiver Riddle 
1296*85ab413bSRiver Riddle   // Assign this to the range slot and use the range as the value for the
1297*85ab413bSRiver Riddle   // memory index.
1298*85ab413bSRiver Riddle   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1299*85ab413bSRiver Riddle   memory[memIndex] = &typeRangeMemory[rangeIndex];
1300*85ab413bSRiver Riddle }
1301*85ab413bSRiver Riddle 
1302154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1303154cabe7SRiver Riddle                                               Location mainRewriteLoc) {
1304abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1305abfd1a8bSRiver Riddle 
1306abfd1a8bSRiver Riddle   unsigned memIndex = read();
1307154cabe7SRiver Riddle   OperationState state(mainRewriteLoc, read<OperationName>());
1308*85ab413bSRiver Riddle   readValueList(state.operands);
1309abfd1a8bSRiver Riddle   for (unsigned i = 0, e = read(); i != e; ++i) {
1310abfd1a8bSRiver Riddle     Identifier name = read<Identifier>();
1311abfd1a8bSRiver Riddle     if (Attribute attr = read<Attribute>())
1312abfd1a8bSRiver Riddle       state.addAttribute(name, attr);
1313abfd1a8bSRiver Riddle   }
1314abfd1a8bSRiver Riddle 
1315abfd1a8bSRiver Riddle   for (unsigned i = 0, e = read(); i != e; ++i) {
1316*85ab413bSRiver Riddle     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1317*85ab413bSRiver Riddle       state.types.push_back(read<Type>());
1318*85ab413bSRiver Riddle       continue;
1319*85ab413bSRiver Riddle     }
1320*85ab413bSRiver Riddle 
1321*85ab413bSRiver Riddle     // If we find a null range, this signals that the types are infered.
1322*85ab413bSRiver Riddle     if (TypeRange *resultTypes = read<TypeRange *>()) {
1323*85ab413bSRiver Riddle       state.types.append(resultTypes->begin(), resultTypes->end());
1324*85ab413bSRiver Riddle       continue;
1325abfd1a8bSRiver Riddle     }
1326abfd1a8bSRiver Riddle 
1327abfd1a8bSRiver Riddle     // Handle the case where the operation has inferred types.
1328abfd1a8bSRiver Riddle     InferTypeOpInterface::Concept *concept =
1329154cabe7SRiver Riddle         state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
1330abfd1a8bSRiver Riddle 
1331abfd1a8bSRiver Riddle     // TODO: Handle failure.
13323a833a0eSRiver Riddle     state.types.clear();
1333abfd1a8bSRiver Riddle     if (failed(concept->inferReturnTypes(
1334abfd1a8bSRiver Riddle             state.getContext(), state.location, state.operands,
1335154cabe7SRiver Riddle             state.attributes.getDictionary(state.getContext()), state.regions,
13363a833a0eSRiver Riddle             state.types)))
1337abfd1a8bSRiver Riddle       return;
1338*85ab413bSRiver Riddle     break;
1339abfd1a8bSRiver Riddle   }
1340*85ab413bSRiver Riddle 
1341abfd1a8bSRiver Riddle   Operation *resultOp = rewriter.createOperation(state);
1342abfd1a8bSRiver Riddle   memory[memIndex] = resultOp;
1343abfd1a8bSRiver Riddle 
1344abfd1a8bSRiver Riddle   LLVM_DEBUG({
1345abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Attributes: "
1346abfd1a8bSRiver Riddle                  << state.attributes.getDictionary(state.getContext())
1347abfd1a8bSRiver Riddle                  << "\n  * Operands: ";
1348abfd1a8bSRiver Riddle     llvm::interleaveComma(state.operands, llvm::dbgs());
1349abfd1a8bSRiver Riddle     llvm::dbgs() << "\n  * Result Types: ";
1350abfd1a8bSRiver Riddle     llvm::interleaveComma(state.types, llvm::dbgs());
1351154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1352abfd1a8bSRiver Riddle   });
1353abfd1a8bSRiver Riddle }
1354154cabe7SRiver Riddle 
1355154cabe7SRiver Riddle void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1356abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1357abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1358abfd1a8bSRiver Riddle 
1359154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1360abfd1a8bSRiver Riddle   rewriter.eraseOp(op);
1361abfd1a8bSRiver Riddle }
1362154cabe7SRiver Riddle 
1363154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttribute() {
1364abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1365abfd1a8bSRiver Riddle   unsigned memIndex = read();
1366abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1367abfd1a8bSRiver Riddle   Identifier attrName = read<Identifier>();
1368abfd1a8bSRiver Riddle   Attribute attr = op->getAttr(attrName);
1369abfd1a8bSRiver Riddle 
1370abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1371abfd1a8bSRiver Riddle                           << "  * Attribute: " << attrName << "\n"
1372154cabe7SRiver Riddle                           << "  * Result: " << attr << "\n");
1373abfd1a8bSRiver Riddle   memory[memIndex] = attr.getAsOpaquePointer();
1374abfd1a8bSRiver Riddle }
1375154cabe7SRiver Riddle 
1376154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttributeType() {
1377abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1378abfd1a8bSRiver Riddle   unsigned memIndex = read();
1379abfd1a8bSRiver Riddle   Attribute attr = read<Attribute>();
1380154cabe7SRiver Riddle   Type type = attr ? attr.getType() : Type();
1381abfd1a8bSRiver Riddle 
1382abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1383154cabe7SRiver Riddle                           << "  * Result: " << type << "\n");
1384154cabe7SRiver Riddle   memory[memIndex] = type.getAsOpaquePointer();
1385abfd1a8bSRiver Riddle }
1386154cabe7SRiver Riddle 
1387154cabe7SRiver Riddle void ByteCodeExecutor::executeGetDefiningOp() {
1388abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1389abfd1a8bSRiver Riddle   unsigned memIndex = read();
1390*85ab413bSRiver Riddle   Operation *op = nullptr;
1391*85ab413bSRiver Riddle   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1392abfd1a8bSRiver Riddle     Value value = read<Value>();
1393*85ab413bSRiver Riddle     if (value)
1394*85ab413bSRiver Riddle       op = value.getDefiningOp();
1395*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1396*85ab413bSRiver Riddle   } else {
1397*85ab413bSRiver Riddle     ValueRange *values = read<ValueRange *>();
1398*85ab413bSRiver Riddle     if (values && !values->empty()) {
1399*85ab413bSRiver Riddle       op = values->front().getDefiningOp();
1400*85ab413bSRiver Riddle     }
1401*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1402*85ab413bSRiver Riddle   }
1403abfd1a8bSRiver Riddle 
1404*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1405abfd1a8bSRiver Riddle   memory[memIndex] = op;
1406abfd1a8bSRiver Riddle }
1407154cabe7SRiver Riddle 
1408154cabe7SRiver Riddle void ByteCodeExecutor::executeGetOperand(unsigned index) {
1409abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1410abfd1a8bSRiver Riddle   unsigned memIndex = read();
1411abfd1a8bSRiver Riddle   Value operand =
1412abfd1a8bSRiver Riddle       index < op->getNumOperands() ? op->getOperand(index) : Value();
1413abfd1a8bSRiver Riddle 
1414abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1415abfd1a8bSRiver Riddle                           << "  * Index: " << index << "\n"
1416154cabe7SRiver Riddle                           << "  * Result: " << operand << "\n");
1417abfd1a8bSRiver Riddle   memory[memIndex] = operand.getAsOpaquePointer();
1418abfd1a8bSRiver Riddle }
1419154cabe7SRiver Riddle 
1420*85ab413bSRiver Riddle /// This function is the internal implementation of `GetResults` and
1421*85ab413bSRiver Riddle /// `GetOperands` that provides support for extracting a value range from the
1422*85ab413bSRiver Riddle /// given operation.
1423*85ab413bSRiver Riddle template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1424*85ab413bSRiver Riddle static void *
1425*85ab413bSRiver Riddle executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1426*85ab413bSRiver Riddle                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1427*85ab413bSRiver Riddle                           MutableArrayRef<ValueRange> &valueRangeMemory) {
1428*85ab413bSRiver Riddle   // Check for the sentinel index that signals that all values should be
1429*85ab413bSRiver Riddle   // returned.
1430*85ab413bSRiver Riddle   if (index == std::numeric_limits<uint32_t>::max()) {
1431*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1432*85ab413bSRiver Riddle     // `values` is already the full value range.
1433*85ab413bSRiver Riddle 
1434*85ab413bSRiver Riddle     // Otherwise, check to see if this operation uses AttrSizedSegments.
1435*85ab413bSRiver Riddle   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1436*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs()
1437*85ab413bSRiver Riddle                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1438*85ab413bSRiver Riddle 
1439*85ab413bSRiver Riddle     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1440*85ab413bSRiver Riddle     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1441*85ab413bSRiver Riddle       return nullptr;
1442*85ab413bSRiver Riddle 
1443*85ab413bSRiver Riddle     auto segments = segmentAttr.getValues<int32_t>();
1444*85ab413bSRiver Riddle     unsigned startIndex =
1445*85ab413bSRiver Riddle         std::accumulate(segments.begin(), segments.begin() + index, 0);
1446*85ab413bSRiver Riddle     values = values.slice(startIndex, *std::next(segments.begin(), index));
1447*85ab413bSRiver Riddle 
1448*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1449*85ab413bSRiver Riddle                             << *std::next(segments.begin(), index) << "]\n");
1450*85ab413bSRiver Riddle 
1451*85ab413bSRiver Riddle     // Otherwise, assume this is the last operand group of the operation.
1452*85ab413bSRiver Riddle     // FIXME: We currently don't support operations with
1453*85ab413bSRiver Riddle     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1454*85ab413bSRiver Riddle     // have a way to detect it's presence.
1455*85ab413bSRiver Riddle   } else if (values.size() >= index) {
1456*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs()
1457*85ab413bSRiver Riddle                << "  * Treating values as trailing variadic range\n");
1458*85ab413bSRiver Riddle     values = values.drop_front(index);
1459*85ab413bSRiver Riddle 
1460*85ab413bSRiver Riddle     // If we couldn't detect a way to compute the values, bail out.
1461*85ab413bSRiver Riddle   } else {
1462*85ab413bSRiver Riddle     return nullptr;
1463*85ab413bSRiver Riddle   }
1464*85ab413bSRiver Riddle 
1465*85ab413bSRiver Riddle   // If the range index is valid, we are returning a range.
1466*85ab413bSRiver Riddle   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1467*85ab413bSRiver Riddle     valueRangeMemory[rangeIndex] = values;
1468*85ab413bSRiver Riddle     return &valueRangeMemory[rangeIndex];
1469*85ab413bSRiver Riddle   }
1470*85ab413bSRiver Riddle 
1471*85ab413bSRiver Riddle   // If a range index wasn't provided, the range is required to be non-variadic.
1472*85ab413bSRiver Riddle   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1473*85ab413bSRiver Riddle }
1474*85ab413bSRiver Riddle 
1475*85ab413bSRiver Riddle void ByteCodeExecutor::executeGetOperands() {
1476*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1477*85ab413bSRiver Riddle   unsigned index = read<uint32_t>();
1478*85ab413bSRiver Riddle   Operation *op = read<Operation *>();
1479*85ab413bSRiver Riddle   ByteCodeField rangeIndex = read();
1480*85ab413bSRiver Riddle 
1481*85ab413bSRiver Riddle   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1482*85ab413bSRiver Riddle       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1483*85ab413bSRiver Riddle       valueRangeMemory);
1484*85ab413bSRiver Riddle   if (!result)
1485*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1486*85ab413bSRiver Riddle   memory[read()] = result;
1487*85ab413bSRiver Riddle }
1488*85ab413bSRiver Riddle 
1489154cabe7SRiver Riddle void ByteCodeExecutor::executeGetResult(unsigned index) {
1490abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1491abfd1a8bSRiver Riddle   unsigned memIndex = read();
1492abfd1a8bSRiver Riddle   OpResult result =
1493abfd1a8bSRiver Riddle       index < op->getNumResults() ? op->getResult(index) : OpResult();
1494abfd1a8bSRiver Riddle 
1495abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1496abfd1a8bSRiver Riddle                           << "  * Index: " << index << "\n"
1497154cabe7SRiver Riddle                           << "  * Result: " << result << "\n");
1498abfd1a8bSRiver Riddle   memory[memIndex] = result.getAsOpaquePointer();
1499abfd1a8bSRiver Riddle }
1500154cabe7SRiver Riddle 
1501*85ab413bSRiver Riddle void ByteCodeExecutor::executeGetResults() {
1502*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1503*85ab413bSRiver Riddle   unsigned index = read<uint32_t>();
1504*85ab413bSRiver Riddle   Operation *op = read<Operation *>();
1505*85ab413bSRiver Riddle   ByteCodeField rangeIndex = read();
1506*85ab413bSRiver Riddle 
1507*85ab413bSRiver Riddle   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1508*85ab413bSRiver Riddle       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1509*85ab413bSRiver Riddle       valueRangeMemory);
1510*85ab413bSRiver Riddle   if (!result)
1511*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1512*85ab413bSRiver Riddle   memory[read()] = result;
1513*85ab413bSRiver Riddle }
1514*85ab413bSRiver Riddle 
1515154cabe7SRiver Riddle void ByteCodeExecutor::executeGetValueType() {
1516abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1517abfd1a8bSRiver Riddle   unsigned memIndex = read();
1518abfd1a8bSRiver Riddle   Value value = read<Value>();
1519154cabe7SRiver Riddle   Type type = value ? value.getType() : Type();
1520abfd1a8bSRiver Riddle 
1521abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1522154cabe7SRiver Riddle                           << "  * Result: " << type << "\n");
1523154cabe7SRiver Riddle   memory[memIndex] = type.getAsOpaquePointer();
1524abfd1a8bSRiver Riddle }
1525154cabe7SRiver Riddle 
1526*85ab413bSRiver Riddle void ByteCodeExecutor::executeGetValueRangeTypes() {
1527*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1528*85ab413bSRiver Riddle   unsigned memIndex = read();
1529*85ab413bSRiver Riddle   unsigned rangeIndex = read();
1530*85ab413bSRiver Riddle   ValueRange *values = read<ValueRange *>();
1531*85ab413bSRiver Riddle   if (!values) {
1532*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1533*85ab413bSRiver Riddle     memory[memIndex] = nullptr;
1534*85ab413bSRiver Riddle     return;
1535*85ab413bSRiver Riddle   }
1536*85ab413bSRiver Riddle 
1537*85ab413bSRiver Riddle   LLVM_DEBUG({
1538*85ab413bSRiver Riddle     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1539*85ab413bSRiver Riddle     llvm::interleaveComma(*values, llvm::dbgs());
1540*85ab413bSRiver Riddle     llvm::dbgs() << "\n  * Result: ";
1541*85ab413bSRiver Riddle     llvm::interleaveComma(values->getType(), llvm::dbgs());
1542*85ab413bSRiver Riddle     llvm::dbgs() << "\n";
1543*85ab413bSRiver Riddle   });
1544*85ab413bSRiver Riddle   typeRangeMemory[rangeIndex] = values->getType();
1545*85ab413bSRiver Riddle   memory[memIndex] = &typeRangeMemory[rangeIndex];
1546*85ab413bSRiver Riddle }
1547*85ab413bSRiver Riddle 
1548154cabe7SRiver Riddle void ByteCodeExecutor::executeIsNotNull() {
1549abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1550abfd1a8bSRiver Riddle   const void *value = read<const void *>();
1551abfd1a8bSRiver Riddle 
1552154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1553abfd1a8bSRiver Riddle   selectJump(value != nullptr);
1554abfd1a8bSRiver Riddle }
1555154cabe7SRiver Riddle 
1556154cabe7SRiver Riddle void ByteCodeExecutor::executeRecordMatch(
1557154cabe7SRiver Riddle     PatternRewriter &rewriter,
1558154cabe7SRiver Riddle     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1559abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1560abfd1a8bSRiver Riddle   unsigned patternIndex = read();
1561abfd1a8bSRiver Riddle   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1562abfd1a8bSRiver Riddle   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1563abfd1a8bSRiver Riddle 
1564abfd1a8bSRiver Riddle   // If the benefit of the pattern is impossible, skip the processing of the
1565abfd1a8bSRiver Riddle   // rest of the pattern.
1566abfd1a8bSRiver Riddle   if (benefit.isImpossibleToMatch()) {
1567154cabe7SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1568abfd1a8bSRiver Riddle     curCodeIt = dest;
1569154cabe7SRiver Riddle     return;
1570abfd1a8bSRiver Riddle   }
1571abfd1a8bSRiver Riddle 
1572abfd1a8bSRiver Riddle   // Create a fused location containing the locations of each of the
1573abfd1a8bSRiver Riddle   // operations used in the match. This will be used as the location for
1574abfd1a8bSRiver Riddle   // created operations during the rewrite that don't already have an
1575abfd1a8bSRiver Riddle   // explicit location set.
1576abfd1a8bSRiver Riddle   unsigned numMatchLocs = read();
1577abfd1a8bSRiver Riddle   SmallVector<Location, 4> matchLocs;
1578abfd1a8bSRiver Riddle   matchLocs.reserve(numMatchLocs);
1579abfd1a8bSRiver Riddle   for (unsigned i = 0; i != numMatchLocs; ++i)
1580abfd1a8bSRiver Riddle     matchLocs.push_back(read<Operation *>()->getLoc());
1581abfd1a8bSRiver Riddle   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1582abfd1a8bSRiver Riddle 
1583abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1584154cabe7SRiver Riddle                           << "  * Location: " << matchLoc << "\n");
1585154cabe7SRiver Riddle   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1586*85ab413bSRiver Riddle   PDLByteCode::MatchResult &match = matches.back();
1587*85ab413bSRiver Riddle 
1588*85ab413bSRiver Riddle   // Record all of the inputs to the match. If any of the inputs are ranges, we
1589*85ab413bSRiver Riddle   // will also need to remap the range pointer to memory stored in the match
1590*85ab413bSRiver Riddle   // state.
1591*85ab413bSRiver Riddle   unsigned numInputs = read();
1592*85ab413bSRiver Riddle   match.values.reserve(numInputs);
1593*85ab413bSRiver Riddle   match.typeRangeValues.reserve(numInputs);
1594*85ab413bSRiver Riddle   match.valueRangeValues.reserve(numInputs);
1595*85ab413bSRiver Riddle   for (unsigned i = 0; i < numInputs; ++i) {
1596*85ab413bSRiver Riddle     switch (read<PDLValue::Kind>()) {
1597*85ab413bSRiver Riddle     case PDLValue::Kind::TypeRange:
1598*85ab413bSRiver Riddle       match.typeRangeValues.push_back(*read<TypeRange *>());
1599*85ab413bSRiver Riddle       match.values.push_back(&match.typeRangeValues.back());
1600*85ab413bSRiver Riddle       break;
1601*85ab413bSRiver Riddle     case PDLValue::Kind::ValueRange:
1602*85ab413bSRiver Riddle       match.valueRangeValues.push_back(*read<ValueRange *>());
1603*85ab413bSRiver Riddle       match.values.push_back(&match.valueRangeValues.back());
1604*85ab413bSRiver Riddle       break;
1605*85ab413bSRiver Riddle     default:
1606*85ab413bSRiver Riddle       match.values.push_back(read<const void *>());
1607*85ab413bSRiver Riddle       break;
1608*85ab413bSRiver Riddle     }
1609*85ab413bSRiver Riddle   }
1610abfd1a8bSRiver Riddle   curCodeIt = dest;
1611abfd1a8bSRiver Riddle }
1612154cabe7SRiver Riddle 
1613154cabe7SRiver Riddle void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1614abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1615abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1616abfd1a8bSRiver Riddle   SmallVector<Value, 16> args;
1617*85ab413bSRiver Riddle   readValueList(args);
1618abfd1a8bSRiver Riddle 
1619abfd1a8bSRiver Riddle   LLVM_DEBUG({
1620abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Operation: " << *op << "\n"
1621abfd1a8bSRiver Riddle                  << "  * Values: ";
1622abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1623154cabe7SRiver Riddle     llvm::dbgs() << "\n";
1624abfd1a8bSRiver Riddle   });
1625abfd1a8bSRiver Riddle   rewriter.replaceOp(op, args);
1626abfd1a8bSRiver Riddle }
1627154cabe7SRiver Riddle 
1628154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchAttribute() {
1629abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1630abfd1a8bSRiver Riddle   Attribute value = read<Attribute>();
1631abfd1a8bSRiver Riddle   ArrayAttr cases = read<ArrayAttr>();
1632abfd1a8bSRiver Riddle   handleSwitch(value, cases);
1633abfd1a8bSRiver Riddle }
1634154cabe7SRiver Riddle 
1635154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperandCount() {
1636abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1637abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1638abfd1a8bSRiver Riddle   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1639abfd1a8bSRiver Riddle 
1640abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1641abfd1a8bSRiver Riddle   handleSwitch(op->getNumOperands(), cases);
1642abfd1a8bSRiver Riddle }
1643154cabe7SRiver Riddle 
1644154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperationName() {
1645abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1646abfd1a8bSRiver Riddle   OperationName value = read<Operation *>()->getName();
1647abfd1a8bSRiver Riddle   size_t caseCount = read();
1648abfd1a8bSRiver Riddle 
1649abfd1a8bSRiver Riddle   // The operation names are stored in-line, so to print them out for
1650abfd1a8bSRiver Riddle   // debugging purposes we need to read the array before executing the
1651abfd1a8bSRiver Riddle   // switch so that we can display all of the possible values.
1652abfd1a8bSRiver Riddle   LLVM_DEBUG({
1653abfd1a8bSRiver Riddle     const ByteCodeField *prevCodeIt = curCodeIt;
1654abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Value: " << value << "\n"
1655abfd1a8bSRiver Riddle                  << "  * Cases: ";
1656abfd1a8bSRiver Riddle     llvm::interleaveComma(
1657abfd1a8bSRiver Riddle         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1658154cabe7SRiver Riddle                         [&](size_t) { return read<OperationName>(); }),
1659abfd1a8bSRiver Riddle         llvm::dbgs());
1660154cabe7SRiver Riddle     llvm::dbgs() << "\n";
1661abfd1a8bSRiver Riddle     curCodeIt = prevCodeIt;
1662abfd1a8bSRiver Riddle   });
1663abfd1a8bSRiver Riddle 
1664abfd1a8bSRiver Riddle   // Try to find the switch value within any of the cases.
1665abfd1a8bSRiver Riddle   for (size_t i = 0; i != caseCount; ++i) {
1666abfd1a8bSRiver Riddle     if (read<OperationName>() == value) {
1667abfd1a8bSRiver Riddle       curCodeIt += (caseCount - i - 1);
1668154cabe7SRiver Riddle       return selectJump(i + 1);
1669abfd1a8bSRiver Riddle     }
1670abfd1a8bSRiver Riddle   }
1671154cabe7SRiver Riddle   selectJump(size_t(0));
1672abfd1a8bSRiver Riddle }
1673154cabe7SRiver Riddle 
1674154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchResultCount() {
1675abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1676abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1677abfd1a8bSRiver Riddle   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1678abfd1a8bSRiver Riddle 
1679abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1680abfd1a8bSRiver Riddle   handleSwitch(op->getNumResults(), cases);
1681abfd1a8bSRiver Riddle }
1682154cabe7SRiver Riddle 
1683154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchType() {
1684abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1685abfd1a8bSRiver Riddle   Type value = read<Type>();
1686abfd1a8bSRiver Riddle   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1687abfd1a8bSRiver Riddle   handleSwitch(value, cases);
1688154cabe7SRiver Riddle }
1689154cabe7SRiver Riddle 
1690*85ab413bSRiver Riddle void ByteCodeExecutor::executeSwitchTypes() {
1691*85ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
1692*85ab413bSRiver Riddle   TypeRange *value = read<TypeRange *>();
1693*85ab413bSRiver Riddle   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
1694*85ab413bSRiver Riddle   if (!value) {
1695*85ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
1696*85ab413bSRiver Riddle     return selectJump(size_t(0));
1697*85ab413bSRiver Riddle   }
1698*85ab413bSRiver Riddle   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
1699*85ab413bSRiver Riddle     return value == caseValue.getAsValueRange<TypeAttr>();
1700*85ab413bSRiver Riddle   });
1701*85ab413bSRiver Riddle }
1702*85ab413bSRiver Riddle 
1703154cabe7SRiver Riddle void ByteCodeExecutor::execute(
1704154cabe7SRiver Riddle     PatternRewriter &rewriter,
1705154cabe7SRiver Riddle     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1706154cabe7SRiver Riddle     Optional<Location> mainRewriteLoc) {
1707154cabe7SRiver Riddle   while (true) {
1708154cabe7SRiver Riddle     OpCode opCode = static_cast<OpCode>(read());
1709154cabe7SRiver Riddle     switch (opCode) {
1710154cabe7SRiver Riddle     case ApplyConstraint:
1711154cabe7SRiver Riddle       executeApplyConstraint(rewriter);
1712154cabe7SRiver Riddle       break;
1713154cabe7SRiver Riddle     case ApplyRewrite:
1714154cabe7SRiver Riddle       executeApplyRewrite(rewriter);
1715154cabe7SRiver Riddle       break;
1716154cabe7SRiver Riddle     case AreEqual:
1717154cabe7SRiver Riddle       executeAreEqual();
1718154cabe7SRiver Riddle       break;
1719*85ab413bSRiver Riddle     case AreRangesEqual:
1720*85ab413bSRiver Riddle       executeAreRangesEqual();
1721*85ab413bSRiver Riddle       break;
1722154cabe7SRiver Riddle     case Branch:
1723154cabe7SRiver Riddle       executeBranch();
1724154cabe7SRiver Riddle       break;
1725154cabe7SRiver Riddle     case CheckOperandCount:
1726154cabe7SRiver Riddle       executeCheckOperandCount();
1727154cabe7SRiver Riddle       break;
1728154cabe7SRiver Riddle     case CheckOperationName:
1729154cabe7SRiver Riddle       executeCheckOperationName();
1730154cabe7SRiver Riddle       break;
1731154cabe7SRiver Riddle     case CheckResultCount:
1732154cabe7SRiver Riddle       executeCheckResultCount();
1733154cabe7SRiver Riddle       break;
1734*85ab413bSRiver Riddle     case CheckTypes:
1735*85ab413bSRiver Riddle       executeCheckTypes();
1736*85ab413bSRiver Riddle       break;
1737154cabe7SRiver Riddle     case CreateOperation:
1738154cabe7SRiver Riddle       executeCreateOperation(rewriter, *mainRewriteLoc);
1739154cabe7SRiver Riddle       break;
1740*85ab413bSRiver Riddle     case CreateTypes:
1741*85ab413bSRiver Riddle       executeCreateTypes();
1742*85ab413bSRiver Riddle       break;
1743154cabe7SRiver Riddle     case EraseOp:
1744154cabe7SRiver Riddle       executeEraseOp(rewriter);
1745154cabe7SRiver Riddle       break;
1746154cabe7SRiver Riddle     case Finalize:
1747154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1748154cabe7SRiver Riddle       return;
1749154cabe7SRiver Riddle     case GetAttribute:
1750154cabe7SRiver Riddle       executeGetAttribute();
1751154cabe7SRiver Riddle       break;
1752154cabe7SRiver Riddle     case GetAttributeType:
1753154cabe7SRiver Riddle       executeGetAttributeType();
1754154cabe7SRiver Riddle       break;
1755154cabe7SRiver Riddle     case GetDefiningOp:
1756154cabe7SRiver Riddle       executeGetDefiningOp();
1757154cabe7SRiver Riddle       break;
1758154cabe7SRiver Riddle     case GetOperand0:
1759154cabe7SRiver Riddle     case GetOperand1:
1760154cabe7SRiver Riddle     case GetOperand2:
1761154cabe7SRiver Riddle     case GetOperand3: {
1762154cabe7SRiver Riddle       unsigned index = opCode - GetOperand0;
1763154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
17641fff7c89SFrederik Gossen       executeGetOperand(index);
1765abfd1a8bSRiver Riddle       break;
1766abfd1a8bSRiver Riddle     }
1767154cabe7SRiver Riddle     case GetOperandN:
1768154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
1769154cabe7SRiver Riddle       executeGetOperand(read<uint32_t>());
1770154cabe7SRiver Riddle       break;
1771*85ab413bSRiver Riddle     case GetOperands:
1772*85ab413bSRiver Riddle       executeGetOperands();
1773*85ab413bSRiver Riddle       break;
1774154cabe7SRiver Riddle     case GetResult0:
1775154cabe7SRiver Riddle     case GetResult1:
1776154cabe7SRiver Riddle     case GetResult2:
1777154cabe7SRiver Riddle     case GetResult3: {
1778154cabe7SRiver Riddle       unsigned index = opCode - GetResult0;
1779154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
17801fff7c89SFrederik Gossen       executeGetResult(index);
1781154cabe7SRiver Riddle       break;
1782abfd1a8bSRiver Riddle     }
1783154cabe7SRiver Riddle     case GetResultN:
1784154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
1785154cabe7SRiver Riddle       executeGetResult(read<uint32_t>());
1786154cabe7SRiver Riddle       break;
1787*85ab413bSRiver Riddle     case GetResults:
1788*85ab413bSRiver Riddle       executeGetResults();
1789*85ab413bSRiver Riddle       break;
1790154cabe7SRiver Riddle     case GetValueType:
1791154cabe7SRiver Riddle       executeGetValueType();
1792154cabe7SRiver Riddle       break;
1793*85ab413bSRiver Riddle     case GetValueRangeTypes:
1794*85ab413bSRiver Riddle       executeGetValueRangeTypes();
1795*85ab413bSRiver Riddle       break;
1796154cabe7SRiver Riddle     case IsNotNull:
1797154cabe7SRiver Riddle       executeIsNotNull();
1798154cabe7SRiver Riddle       break;
1799154cabe7SRiver Riddle     case RecordMatch:
1800154cabe7SRiver Riddle       assert(matches &&
1801154cabe7SRiver Riddle              "expected matches to be provided when executing the matcher");
1802154cabe7SRiver Riddle       executeRecordMatch(rewriter, *matches);
1803154cabe7SRiver Riddle       break;
1804154cabe7SRiver Riddle     case ReplaceOp:
1805154cabe7SRiver Riddle       executeReplaceOp(rewriter);
1806154cabe7SRiver Riddle       break;
1807154cabe7SRiver Riddle     case SwitchAttribute:
1808154cabe7SRiver Riddle       executeSwitchAttribute();
1809154cabe7SRiver Riddle       break;
1810154cabe7SRiver Riddle     case SwitchOperandCount:
1811154cabe7SRiver Riddle       executeSwitchOperandCount();
1812154cabe7SRiver Riddle       break;
1813154cabe7SRiver Riddle     case SwitchOperationName:
1814154cabe7SRiver Riddle       executeSwitchOperationName();
1815154cabe7SRiver Riddle       break;
1816154cabe7SRiver Riddle     case SwitchResultCount:
1817154cabe7SRiver Riddle       executeSwitchResultCount();
1818154cabe7SRiver Riddle       break;
1819154cabe7SRiver Riddle     case SwitchType:
1820154cabe7SRiver Riddle       executeSwitchType();
1821154cabe7SRiver Riddle       break;
1822*85ab413bSRiver Riddle     case SwitchTypes:
1823*85ab413bSRiver Riddle       executeSwitchTypes();
1824*85ab413bSRiver Riddle       break;
1825154cabe7SRiver Riddle     }
1826154cabe7SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "\n");
1827abfd1a8bSRiver Riddle   }
1828abfd1a8bSRiver Riddle }
1829abfd1a8bSRiver Riddle 
1830abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched
1831abfd1a8bSRiver Riddle /// patterns in `matches`.
1832abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1833abfd1a8bSRiver Riddle                         SmallVectorImpl<MatchResult> &matches,
1834abfd1a8bSRiver Riddle                         PDLByteCodeMutableState &state) const {
1835abfd1a8bSRiver Riddle   // The first memory slot is always the root operation.
1836abfd1a8bSRiver Riddle   state.memory[0] = op;
1837abfd1a8bSRiver Riddle 
1838abfd1a8bSRiver Riddle   // The matcher function always starts at code address 0.
1839*85ab413bSRiver Riddle   ByteCodeExecutor executor(
1840*85ab413bSRiver Riddle       matcherByteCode.data(), state.memory, state.typeRangeMemory,
1841*85ab413bSRiver Riddle       state.allocatedTypeRangeMemory, state.valueRangeMemory,
1842*85ab413bSRiver Riddle       state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
1843*85ab413bSRiver Riddle       state.currentPatternBenefits, patterns, constraintFunctions,
1844*85ab413bSRiver Riddle       rewriteFunctions);
1845abfd1a8bSRiver Riddle   executor.execute(rewriter, &matches);
1846abfd1a8bSRiver Riddle 
1847abfd1a8bSRiver Riddle   // Order the found matches by benefit.
1848abfd1a8bSRiver Riddle   std::stable_sort(matches.begin(), matches.end(),
1849abfd1a8bSRiver Riddle                    [](const MatchResult &lhs, const MatchResult &rhs) {
1850abfd1a8bSRiver Riddle                      return lhs.benefit > rhs.benefit;
1851abfd1a8bSRiver Riddle                    });
1852abfd1a8bSRiver Riddle }
1853abfd1a8bSRiver Riddle 
1854abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`.
1855abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1856abfd1a8bSRiver Riddle                           PDLByteCodeMutableState &state) const {
1857abfd1a8bSRiver Riddle   // The arguments of the rewrite function are stored at the start of the
1858abfd1a8bSRiver Riddle   // memory buffer.
1859abfd1a8bSRiver Riddle   llvm::copy(match.values, state.memory.begin());
1860abfd1a8bSRiver Riddle 
1861*85ab413bSRiver Riddle   ByteCodeExecutor executor(
1862*85ab413bSRiver Riddle       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1863*85ab413bSRiver Riddle       state.typeRangeMemory, state.allocatedTypeRangeMemory,
1864*85ab413bSRiver Riddle       state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
1865*85ab413bSRiver Riddle       rewriterByteCode, state.currentPatternBenefits, patterns,
186602c4c0d5SRiver Riddle       constraintFunctions, rewriteFunctions);
1867abfd1a8bSRiver Riddle   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1868abfd1a8bSRiver Riddle }
1869