1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.benefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.rootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Continue to the next iteration of a loop.
99   Continue,
100   /// Create an operation.
101   CreateOperation,
102   /// Create a range of types.
103   CreateTypes,
104   /// Erase an operation.
105   EraseOp,
106   /// Extract the op from a range at the specified index.
107   ExtractOp,
108   /// Extract the type from a range at the specified index.
109   ExtractType,
110   /// Extract the value from a range at the specified index.
111   ExtractValue,
112   /// Terminate a matcher or rewrite sequence.
113   Finalize,
114   /// Iterate over a range of values.
115   ForEach,
116   /// Get a specific attribute of an operation.
117   GetAttribute,
118   /// Get the type of an attribute.
119   GetAttributeType,
120   /// Get the defining operation of a value.
121   GetDefiningOp,
122   /// Get a specific operand of an operation.
123   GetOperand0,
124   GetOperand1,
125   GetOperand2,
126   GetOperand3,
127   GetOperandN,
128   /// Get a specific operand group of an operation.
129   GetOperands,
130   /// Get a specific result of an operation.
131   GetResult0,
132   GetResult1,
133   GetResult2,
134   GetResult3,
135   GetResultN,
136   /// Get a specific result group of an operation.
137   GetResults,
138   /// Get the users of a value or a range of values.
139   GetUsers,
140   /// Get the type of a value.
141   GetValueType,
142   /// Get the types of a value range.
143   GetValueRangeTypes,
144   /// Check if a generic value is not null.
145   IsNotNull,
146   /// Record a successful pattern match.
147   RecordMatch,
148   /// Replace an operation.
149   ReplaceOp,
150   /// Compare an attribute with a set of constants.
151   SwitchAttribute,
152   /// Compare the operand count of an operation with a set of constants.
153   SwitchOperandCount,
154   /// Compare the name of an operation with a set of constants.
155   SwitchOperationName,
156   /// Compare the result count of an operation with a set of constants.
157   SwitchResultCount,
158   /// Compare a type with a set of constants.
159   SwitchType,
160   /// Compare a range of types with a set of constants.
161   SwitchTypes,
162 };
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // ByteCode Generation
167 //===----------------------------------------------------------------------===//
168 
169 //===----------------------------------------------------------------------===//
170 // Generator
171 
172 namespace {
173 struct ByteCodeLiveRange;
174 struct ByteCodeWriter;
175 
176 /// Check if the given class `T` can be converted to an opaque pointer.
177 template <typename T, typename... Args>
178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
179 
180 /// This class represents the main generator for the pattern bytecode.
181 class Generator {
182 public:
183   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
184             SmallVectorImpl<ByteCodeField> &matcherByteCode,
185             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
186             SmallVectorImpl<PDLByteCodePattern> &patterns,
187             ByteCodeField &maxValueMemoryIndex,
188             ByteCodeField &maxOpRangeMemoryIndex,
189             ByteCodeField &maxTypeRangeMemoryIndex,
190             ByteCodeField &maxValueRangeMemoryIndex,
191             ByteCodeField &maxLoopLevel,
192             llvm::StringMap<PDLConstraintFunction> &constraintFns,
193             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
194       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
195         rewriterByteCode(rewriterByteCode), patterns(patterns),
196         maxValueMemoryIndex(maxValueMemoryIndex),
197         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
198         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
199         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
200         maxLoopLevel(maxLoopLevel) {
201     for (const auto &it : llvm::enumerate(constraintFns))
202       constraintToMemIndex.try_emplace(it.value().first(), it.index());
203     for (const auto &it : llvm::enumerate(rewriteFns))
204       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
205   }
206 
207   /// Generate the bytecode for the given PDL interpreter module.
208   void generate(ModuleOp module);
209 
210   /// Return the memory index to use for the given value.
211   ByteCodeField &getMemIndex(Value value) {
212     assert(valueToMemIndex.count(value) &&
213            "expected memory index to be assigned");
214     return valueToMemIndex[value];
215   }
216 
217   /// Return the range memory index used to store the given range value.
218   ByteCodeField &getRangeStorageIndex(Value value) {
219     assert(valueToRangeIndex.count(value) &&
220            "expected range index to be assigned");
221     return valueToRangeIndex[value];
222   }
223 
224   /// Return an index to use when referring to the given data that is uniqued in
225   /// the MLIR context.
226   template <typename T>
227   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
228   getMemIndex(T val) {
229     const void *opaqueVal = val.getAsOpaquePointer();
230 
231     // Get or insert a reference to this value.
232     auto it = uniquedDataToMemIndex.try_emplace(
233         opaqueVal, maxValueMemoryIndex + uniquedData.size());
234     if (it.second)
235       uniquedData.push_back(opaqueVal);
236     return it.first->second;
237   }
238 
239 private:
240   /// Allocate memory indices for the results of operations within the matcher
241   /// and rewriters.
242   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
243 
244   /// Generate the bytecode for the given operation.
245   void generate(Region *region, ByteCodeWriter &writer);
246   void generate(Operation *op, ByteCodeWriter &writer);
247   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
248   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
249   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
250   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
251   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
252   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
259   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
285 
286   /// Mapping from value to its corresponding memory index.
287   DenseMap<Value, ByteCodeField> valueToMemIndex;
288 
289   /// Mapping from a range value to its corresponding range storage index.
290   DenseMap<Value, ByteCodeField> valueToRangeIndex;
291 
292   /// Mapping from the name of an externally registered rewrite to its index in
293   /// the bytecode registry.
294   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
295 
296   /// Mapping from the name of an externally registered constraint to its index
297   /// in the bytecode registry.
298   llvm::StringMap<ByteCodeField> constraintToMemIndex;
299 
300   /// Mapping from rewriter function name to the bytecode address of the
301   /// rewriter function in byte.
302   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
303 
304   /// Mapping from a uniqued storage object to its memory index within
305   /// `uniquedData`.
306   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
307 
308   /// The current level of the foreach loop.
309   ByteCodeField curLoopLevel = 0;
310 
311   /// The current MLIR context.
312   MLIRContext *ctx;
313 
314   /// Mapping from block to its address.
315   DenseMap<Block *, ByteCodeAddr> blockToAddr;
316 
317   /// Data of the ByteCode class to be populated.
318   std::vector<const void *> &uniquedData;
319   SmallVectorImpl<ByteCodeField> &matcherByteCode;
320   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
321   SmallVectorImpl<PDLByteCodePattern> &patterns;
322   ByteCodeField &maxValueMemoryIndex;
323   ByteCodeField &maxOpRangeMemoryIndex;
324   ByteCodeField &maxTypeRangeMemoryIndex;
325   ByteCodeField &maxValueRangeMemoryIndex;
326   ByteCodeField &maxLoopLevel;
327 };
328 
329 /// This class provides utilities for writing a bytecode stream.
330 struct ByteCodeWriter {
331   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
332       : bytecode(bytecode), generator(generator) {}
333 
334   /// Append a field to the bytecode.
335   void append(ByteCodeField field) { bytecode.push_back(field); }
336   void append(OpCode opCode) { bytecode.push_back(opCode); }
337 
338   /// Append an address to the bytecode.
339   void append(ByteCodeAddr field) {
340     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
341                   "unexpected ByteCode address size");
342 
343     ByteCodeField fieldParts[2];
344     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
345     bytecode.append({fieldParts[0], fieldParts[1]});
346   }
347 
348   /// Append a single successor to the bytecode, the exact address will need to
349   /// be resolved later.
350   void append(Block *successor) {
351     // Add back a reference to the successor so that the address can be resolved
352     // later.
353     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
354     append(ByteCodeAddr(0));
355   }
356 
357   /// Append a successor range to the bytecode, the exact address will need to
358   /// be resolved later.
359   void append(SuccessorRange successors) {
360     for (Block *successor : successors)
361       append(successor);
362   }
363 
364   /// Append a range of values that will be read as generic PDLValues.
365   void appendPDLValueList(OperandRange values) {
366     bytecode.push_back(values.size());
367     for (Value value : values)
368       appendPDLValue(value);
369   }
370 
371   /// Append a value as a PDLValue.
372   void appendPDLValue(Value value) {
373     appendPDLValueKind(value);
374     append(value);
375   }
376 
377   /// Append the PDLValue::Kind of the given value.
378   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
379 
380   /// Append the PDLValue::Kind of the given type.
381   void appendPDLValueKind(Type type) {
382     PDLValue::Kind kind =
383         TypeSwitch<Type, PDLValue::Kind>(type)
384             .Case<pdl::AttributeType>(
385                 [](Type) { return PDLValue::Kind::Attribute; })
386             .Case<pdl::OperationType>(
387                 [](Type) { return PDLValue::Kind::Operation; })
388             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
389               if (rangeTy.getElementType().isa<pdl::TypeType>())
390                 return PDLValue::Kind::TypeRange;
391               return PDLValue::Kind::ValueRange;
392             })
393             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
394             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
395     bytecode.push_back(static_cast<ByteCodeField>(kind));
396   }
397 
398   /// Append a value that will be stored in a memory slot and not inline within
399   /// the bytecode.
400   template <typename T>
401   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
402                    std::is_pointer<T>::value>
403   append(T value) {
404     bytecode.push_back(generator.getMemIndex(value));
405   }
406 
407   /// Append a range of values.
408   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
409   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
410   append(T range) {
411     bytecode.push_back(llvm::size(range));
412     for (auto it : range)
413       append(it);
414   }
415 
416   /// Append a variadic number of fields to the bytecode.
417   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
418   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
419     append(field);
420     append(field2, fields...);
421   }
422 
423   /// Appends a value as a pointer, stored inline within the bytecode.
424   template <typename T>
425   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
426   appendInline(T value) {
427     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
428     const void *pointer = value.getAsOpaquePointer();
429     ByteCodeField fieldParts[numParts];
430     std::memcpy(fieldParts, &pointer, sizeof(const void *));
431     bytecode.append(fieldParts, fieldParts + numParts);
432   }
433 
434   /// Successor references in the bytecode that have yet to be resolved.
435   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
436 
437   /// The underlying bytecode buffer.
438   SmallVectorImpl<ByteCodeField> &bytecode;
439 
440   /// The main generator producing PDL.
441   Generator &generator;
442 };
443 
444 /// This class represents a live range of PDL Interpreter values, containing
445 /// information about when values are live within a match/rewrite.
446 struct ByteCodeLiveRange {
447   using Set = llvm::IntervalMap<uint64_t, char, 16>;
448   using Allocator = Set::Allocator;
449 
450   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
451 
452   /// Union this live range with the one provided.
453   void unionWith(const ByteCodeLiveRange &rhs) {
454     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
455          ++it)
456       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
457   }
458 
459   /// Returns true if this range overlaps with the one provided.
460   bool overlaps(const ByteCodeLiveRange &rhs) const {
461     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
462         .valid();
463   }
464 
465   /// A map representing the ranges of the match/rewrite that a value is live in
466   /// the interpreter.
467   ///
468   /// We use std::unique_ptr here, because IntervalMap does not provide a
469   /// correct copy or move constructor. We can eliminate the pointer once
470   /// https://reviews.llvm.org/D113240 lands.
471   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
472 
473   /// The operation range storage index for this range.
474   Optional<unsigned> opRangeIndex;
475 
476   /// The type range storage index for this range.
477   Optional<unsigned> typeRangeIndex;
478 
479   /// The value range storage index for this range.
480   Optional<unsigned> valueRangeIndex;
481 };
482 } // namespace
483 
484 void Generator::generate(ModuleOp module) {
485   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
486       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
487   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
488       pdl_interp::PDLInterpDialect::getRewriterModuleName());
489   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
490 
491   // Allocate memory indices for the results of operations within the matcher
492   // and rewriters.
493   allocateMemoryIndices(matcherFunc, rewriterModule);
494 
495   // Generate code for the rewriter functions.
496   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
497   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
498     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
499     for (Operation &op : rewriterFunc.getOps())
500       generate(&op, rewriterByteCodeWriter);
501   }
502   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
503          "unexpected branches in rewriter function");
504 
505   // Generate code for the matcher function.
506   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
507   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
508 
509   // Resolve successor references in the matcher.
510   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
511     ByteCodeAddr addr = blockToAddr[it.first];
512     for (unsigned offsetToFix : it.second)
513       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
514   }
515 }
516 
517 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
518                                       ModuleOp rewriterModule) {
519   // Rewriters use simplistic allocation scheme that simply assigns an index to
520   // each result.
521   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
522     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
523     auto processRewriterValue = [&](Value val) {
524       valueToMemIndex.try_emplace(val, index++);
525       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
526         Type elementTy = rangeType.getElementType();
527         if (elementTy.isa<pdl::TypeType>())
528           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
529         else if (elementTy.isa<pdl::ValueType>())
530           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
531       }
532     };
533 
534     for (BlockArgument arg : rewriterFunc.getArguments())
535       processRewriterValue(arg);
536     rewriterFunc.getBody().walk([&](Operation *op) {
537       for (Value result : op->getResults())
538         processRewriterValue(result);
539     });
540     if (index > maxValueMemoryIndex)
541       maxValueMemoryIndex = index;
542     if (typeRangeIndex > maxTypeRangeMemoryIndex)
543       maxTypeRangeMemoryIndex = typeRangeIndex;
544     if (valueRangeIndex > maxValueRangeMemoryIndex)
545       maxValueRangeMemoryIndex = valueRangeIndex;
546   }
547 
548   // The matcher function uses a more sophisticated numbering that tries to
549   // minimize the number of memory indices assigned. This is done by determining
550   // a live range of the values within the matcher, then the allocation is just
551   // finding the minimal number of overlapping live ranges. This is essentially
552   // a simplified form of register allocation where we don't necessarily have a
553   // limited number of registers, but we still want to minimize the number used.
554   DenseMap<Operation *, unsigned> opToFirstIndex;
555   DenseMap<Operation *, unsigned> opToLastIndex;
556 
557   // A custom walk that marks the first and the last index of each operation.
558   // The entry marks the beginning of the liveness range for this operation,
559   // followed by nested operations, followed by the end of the liveness range.
560   unsigned index = 0;
561   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
562     opToFirstIndex.try_emplace(op, index++);
563     for (Region &region : op->getRegions())
564       for (Block &block : region.getBlocks())
565         for (Operation &nested : block)
566           walk(&nested);
567     opToLastIndex.try_emplace(op, index++);
568   };
569   walk(matcherFunc);
570 
571   // Liveness info for each of the defs within the matcher.
572   ByteCodeLiveRange::Allocator allocator;
573   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
574 
575   // Assign the root operation being matched to slot 0.
576   BlockArgument rootOpArg = matcherFunc.getArgument(0);
577   valueToMemIndex[rootOpArg] = 0;
578 
579   // Walk each of the blocks, computing the def interval that the value is used.
580   Liveness matcherLiveness(matcherFunc);
581   matcherFunc->walk([&](Block *block) {
582     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
583     assert(info && "expected liveness info for block");
584     auto processValue = [&](Value value, Operation *firstUseOrDef) {
585       // We don't need to process the root op argument, this value is always
586       // assigned to the first memory slot.
587       if (value == rootOpArg)
588         return;
589 
590       // Set indices for the range of this block that the value is used.
591       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
592       defRangeIt->second.liveness->insert(
593           opToFirstIndex[firstUseOrDef],
594           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
595           /*dummyValue*/ 0);
596 
597       // Check to see if this value is a range type.
598       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
599         Type eleType = rangeTy.getElementType();
600         if (eleType.isa<pdl::OperationType>())
601           defRangeIt->second.opRangeIndex = 0;
602         else if (eleType.isa<pdl::TypeType>())
603           defRangeIt->second.typeRangeIndex = 0;
604         else if (eleType.isa<pdl::ValueType>())
605           defRangeIt->second.valueRangeIndex = 0;
606       }
607     };
608 
609     // Process the live-ins of this block.
610     for (Value liveIn : info->in()) {
611       // Only process the value if it has been defined in the current region.
612       // Other values that span across pdl_interp.foreach will be added higher
613       // up. This ensures that the we keep them alive for the entire duration
614       // of the loop.
615       if (liveIn.getParentRegion() == block->getParent())
616         processValue(liveIn, &block->front());
617     }
618 
619     // Process the block arguments for the entry block (those are not live-in).
620     if (block->isEntryBlock()) {
621       for (Value argument : block->getArguments())
622         processValue(argument, &block->front());
623     }
624 
625     // Process any new defs within this block.
626     for (Operation &op : *block)
627       for (Value result : op.getResults())
628         processValue(result, &op);
629   });
630 
631   // Greedily allocate memory slots using the computed def live ranges.
632   std::vector<ByteCodeLiveRange> allocatedIndices;
633 
634   // The number of memory indices currently allocated (and its next value).
635   // Recall that the root gets allocated memory index 0.
636   ByteCodeField numIndices = 1;
637 
638   // The number of memory ranges of various types (and their next values).
639   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
640 
641   for (auto &defIt : valueDefRanges) {
642     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
643     ByteCodeLiveRange &defRange = defIt.second;
644 
645     // Try to allocate to an existing index.
646     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
647       ByteCodeLiveRange &existingRange = existingIndexIt.value();
648       if (!defRange.overlaps(existingRange)) {
649         existingRange.unionWith(defRange);
650         memIndex = existingIndexIt.index() + 1;
651 
652         if (defRange.opRangeIndex) {
653           if (!existingRange.opRangeIndex)
654             existingRange.opRangeIndex = numOpRanges++;
655           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
656         } else if (defRange.typeRangeIndex) {
657           if (!existingRange.typeRangeIndex)
658             existingRange.typeRangeIndex = numTypeRanges++;
659           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
660         } else if (defRange.valueRangeIndex) {
661           if (!existingRange.valueRangeIndex)
662             existingRange.valueRangeIndex = numValueRanges++;
663           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
664         }
665         break;
666       }
667     }
668 
669     // If no existing index could be used, add a new one.
670     if (memIndex == 0) {
671       allocatedIndices.emplace_back(allocator);
672       ByteCodeLiveRange &newRange = allocatedIndices.back();
673       newRange.unionWith(defRange);
674 
675       // Allocate an index for op/type/value ranges.
676       if (defRange.opRangeIndex) {
677         newRange.opRangeIndex = numOpRanges;
678         valueToRangeIndex[defIt.first] = numOpRanges++;
679       } else if (defRange.typeRangeIndex) {
680         newRange.typeRangeIndex = numTypeRanges;
681         valueToRangeIndex[defIt.first] = numTypeRanges++;
682       } else if (defRange.valueRangeIndex) {
683         newRange.valueRangeIndex = numValueRanges;
684         valueToRangeIndex[defIt.first] = numValueRanges++;
685       }
686 
687       memIndex = allocatedIndices.size();
688       ++numIndices;
689     }
690   }
691 
692   // Print the index usage and ensure that we did not run out of index space.
693   LLVM_DEBUG({
694     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
695                  << "(down from initial " << valueDefRanges.size() << ").\n";
696   });
697   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
698          "Ran out of memory for allocated indices");
699 
700   // Update the max number of indices.
701   if (numIndices > maxValueMemoryIndex)
702     maxValueMemoryIndex = numIndices;
703   if (numOpRanges > maxOpRangeMemoryIndex)
704     maxOpRangeMemoryIndex = numOpRanges;
705   if (numTypeRanges > maxTypeRangeMemoryIndex)
706     maxTypeRangeMemoryIndex = numTypeRanges;
707   if (numValueRanges > maxValueRangeMemoryIndex)
708     maxValueRangeMemoryIndex = numValueRanges;
709 }
710 
711 void Generator::generate(Region *region, ByteCodeWriter &writer) {
712   llvm::ReversePostOrderTraversal<Region *> rpot(region);
713   for (Block *block : rpot) {
714     // Keep track of where this block begins within the matcher function.
715     blockToAddr.try_emplace(block, matcherByteCode.size());
716     for (Operation &op : *block)
717       generate(&op, writer);
718   }
719 }
720 
721 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
722   LLVM_DEBUG({
723     // The following list must contain all the operations that do not
724     // produce any bytecode.
725     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
726              pdl_interp::InferredTypesOp>(op))
727       writer.appendInline(op->getLoc());
728   });
729   TypeSwitch<Operation *>(op)
730       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
731             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
732             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
733             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
734             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
735             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
736             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
737             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
738             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
739             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
740             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
741             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
742             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
743             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
744             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
745             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
746             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
747             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
748             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
749           [&](auto interpOp) { this->generate(interpOp, writer); })
750       .Default([](Operation *) {
751         llvm_unreachable("unknown `pdl_interp` operation");
752       });
753 }
754 
755 void Generator::generate(pdl_interp::ApplyConstraintOp op,
756                          ByteCodeWriter &writer) {
757   assert(constraintToMemIndex.count(op.name()) &&
758          "expected index for constraint function");
759   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
760                 op.constParamsAttr());
761   writer.appendPDLValueList(op.args());
762   writer.append(op.getSuccessors());
763 }
764 void Generator::generate(pdl_interp::ApplyRewriteOp op,
765                          ByteCodeWriter &writer) {
766   assert(externalRewriterToMemIndex.count(op.name()) &&
767          "expected index for rewrite function");
768   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
769                 op.constParamsAttr());
770   writer.appendPDLValueList(op.args());
771 
772   ResultRange results = op.results();
773   writer.append(ByteCodeField(results.size()));
774   for (Value result : results) {
775     // In debug mode we also record the expected kind of the result, so that we
776     // can provide extra verification of the native rewrite function.
777 #ifndef NDEBUG
778     writer.appendPDLValueKind(result);
779 #endif
780 
781     // Range results also need to append the range storage index.
782     if (result.getType().isa<pdl::RangeType>())
783       writer.append(getRangeStorageIndex(result));
784     writer.append(result);
785   }
786 }
787 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
788   Value lhs = op.lhs();
789   if (lhs.getType().isa<pdl::RangeType>()) {
790     writer.append(OpCode::AreRangesEqual);
791     writer.appendPDLValueKind(lhs);
792     writer.append(op.lhs(), op.rhs(), op.getSuccessors());
793     return;
794   }
795 
796   writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
797 }
798 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
799   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
800 }
801 void Generator::generate(pdl_interp::CheckAttributeOp op,
802                          ByteCodeWriter &writer) {
803   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
804                 op.getSuccessors());
805 }
806 void Generator::generate(pdl_interp::CheckOperandCountOp op,
807                          ByteCodeWriter &writer) {
808   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
809                 static_cast<ByteCodeField>(op.compareAtLeast()),
810                 op.getSuccessors());
811 }
812 void Generator::generate(pdl_interp::CheckOperationNameOp op,
813                          ByteCodeWriter &writer) {
814   writer.append(OpCode::CheckOperationName, op.operation(),
815                 OperationName(op.name(), ctx), op.getSuccessors());
816 }
817 void Generator::generate(pdl_interp::CheckResultCountOp op,
818                          ByteCodeWriter &writer) {
819   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
820                 static_cast<ByteCodeField>(op.compareAtLeast()),
821                 op.getSuccessors());
822 }
823 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
824   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
825 }
826 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
827   writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
828 }
829 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
830   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
831   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
832 }
833 void Generator::generate(pdl_interp::CreateAttributeOp op,
834                          ByteCodeWriter &writer) {
835   // Simply repoint the memory index of the result to the constant.
836   getMemIndex(op.attribute()) = getMemIndex(op.value());
837 }
838 void Generator::generate(pdl_interp::CreateOperationOp op,
839                          ByteCodeWriter &writer) {
840   writer.append(OpCode::CreateOperation, op.operation(),
841                 OperationName(op.name(), ctx));
842   writer.appendPDLValueList(op.operands());
843 
844   // Add the attributes.
845   OperandRange attributes = op.attributes();
846   writer.append(static_cast<ByteCodeField>(attributes.size()));
847   for (auto it : llvm::zip(op.attributeNames(), op.attributes()))
848     writer.append(std::get<0>(it), std::get<1>(it));
849   writer.appendPDLValueList(op.types());
850 }
851 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
852   // Simply repoint the memory index of the result to the constant.
853   getMemIndex(op.result()) = getMemIndex(op.value());
854 }
855 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
856   writer.append(OpCode::CreateTypes, op.result(),
857                 getRangeStorageIndex(op.result()), op.value());
858 }
859 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
860   writer.append(OpCode::EraseOp, op.operation());
861 }
862 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
863   OpCode opCode =
864       TypeSwitch<Type, OpCode>(op.result().getType())
865           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
866           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
867           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
868           .Default([](Type) -> OpCode {
869             llvm_unreachable("unsupported element type");
870           });
871   writer.append(opCode, op.range(), op.index(), op.result());
872 }
873 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
874   writer.append(OpCode::Finalize);
875 }
876 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
877   BlockArgument arg = op.getLoopVariable();
878   writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg);
879   writer.appendPDLValueKind(arg.getType());
880   writer.append(curLoopLevel, op.successor());
881   ++curLoopLevel;
882   if (curLoopLevel > maxLoopLevel)
883     maxLoopLevel = curLoopLevel;
884   generate(&op.region(), writer);
885   --curLoopLevel;
886 }
887 void Generator::generate(pdl_interp::GetAttributeOp op,
888                          ByteCodeWriter &writer) {
889   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
890                 op.nameAttr());
891 }
892 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
893                          ByteCodeWriter &writer) {
894   writer.append(OpCode::GetAttributeType, op.result(), op.value());
895 }
896 void Generator::generate(pdl_interp::GetDefiningOpOp op,
897                          ByteCodeWriter &writer) {
898   writer.append(OpCode::GetDefiningOp, op.operation());
899   writer.appendPDLValue(op.value());
900 }
901 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
902   uint32_t index = op.index();
903   if (index < 4)
904     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
905   else
906     writer.append(OpCode::GetOperandN, index);
907   writer.append(op.operation(), op.value());
908 }
909 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
910   Value result = op.value();
911   Optional<uint32_t> index = op.index();
912   writer.append(OpCode::GetOperands,
913                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
914                 op.operation());
915   if (result.getType().isa<pdl::RangeType>())
916     writer.append(getRangeStorageIndex(result));
917   else
918     writer.append(std::numeric_limits<ByteCodeField>::max());
919   writer.append(result);
920 }
921 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
922   uint32_t index = op.index();
923   if (index < 4)
924     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
925   else
926     writer.append(OpCode::GetResultN, index);
927   writer.append(op.operation(), op.value());
928 }
929 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
930   Value result = op.value();
931   Optional<uint32_t> index = op.index();
932   writer.append(OpCode::GetResults,
933                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
934                 op.operation());
935   if (result.getType().isa<pdl::RangeType>())
936     writer.append(getRangeStorageIndex(result));
937   else
938     writer.append(std::numeric_limits<ByteCodeField>::max());
939   writer.append(result);
940 }
941 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
942   Value operations = op.operations();
943   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
944   writer.append(OpCode::GetUsers, operations, rangeIndex);
945   writer.appendPDLValue(op.value());
946 }
947 void Generator::generate(pdl_interp::GetValueTypeOp op,
948                          ByteCodeWriter &writer) {
949   if (op.getType().isa<pdl::RangeType>()) {
950     Value result = op.result();
951     writer.append(OpCode::GetValueRangeTypes, result,
952                   getRangeStorageIndex(result), op.value());
953   } else {
954     writer.append(OpCode::GetValueType, op.result(), op.value());
955   }
956 }
957 
958 void Generator::generate(pdl_interp::InferredTypesOp op,
959                          ByteCodeWriter &writer) {
960   // InferType maps to a null type as a marker for inferring result types.
961   getMemIndex(op.type()) = getMemIndex(Type());
962 }
963 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
964   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
965 }
966 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
967   ByteCodeField patternIndex = patterns.size();
968   patterns.emplace_back(PDLByteCodePattern::create(
969       op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
970   writer.append(OpCode::RecordMatch, patternIndex,
971                 SuccessorRange(op.getOperation()), op.matchedOps());
972   writer.appendPDLValueList(op.inputs());
973 }
974 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
975   writer.append(OpCode::ReplaceOp, op.operation());
976   writer.appendPDLValueList(op.replValues());
977 }
978 void Generator::generate(pdl_interp::SwitchAttributeOp op,
979                          ByteCodeWriter &writer) {
980   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
981                 op.getSuccessors());
982 }
983 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
984                          ByteCodeWriter &writer) {
985   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
986                 op.getSuccessors());
987 }
988 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
989                          ByteCodeWriter &writer) {
990   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
991     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
992   });
993   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
994                 op.getSuccessors());
995 }
996 void Generator::generate(pdl_interp::SwitchResultCountOp op,
997                          ByteCodeWriter &writer) {
998   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
999                 op.getSuccessors());
1000 }
1001 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1002   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
1003                 op.getSuccessors());
1004 }
1005 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1006   writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
1007                 op.getSuccessors());
1008 }
1009 
1010 //===----------------------------------------------------------------------===//
1011 // PDLByteCode
1012 //===----------------------------------------------------------------------===//
1013 
1014 PDLByteCode::PDLByteCode(ModuleOp module,
1015                          llvm::StringMap<PDLConstraintFunction> constraintFns,
1016                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1017   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1018                       rewriterByteCode, patterns, maxValueMemoryIndex,
1019                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1020                       maxLoopLevel, constraintFns, rewriteFns);
1021   generator.generate(module);
1022 
1023   // Initialize the external functions.
1024   for (auto &it : constraintFns)
1025     constraintFunctions.push_back(std::move(it.second));
1026   for (auto &it : rewriteFns)
1027     rewriteFunctions.push_back(std::move(it.second));
1028 }
1029 
1030 /// Initialize the given state such that it can be used to execute the current
1031 /// bytecode.
1032 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1033   state.memory.resize(maxValueMemoryIndex, nullptr);
1034   state.opRangeMemory.resize(maxOpRangeCount);
1035   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1036   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1037   state.loopIndex.resize(maxLoopLevel, 0);
1038   state.currentPatternBenefits.reserve(patterns.size());
1039   for (const PDLByteCodePattern &pattern : patterns)
1040     state.currentPatternBenefits.push_back(pattern.getBenefit());
1041 }
1042 
1043 //===----------------------------------------------------------------------===//
1044 // ByteCode Execution
1045 
1046 namespace {
1047 /// This class provides support for executing a bytecode stream.
1048 class ByteCodeExecutor {
1049 public:
1050   ByteCodeExecutor(
1051       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1052       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1053       MutableArrayRef<TypeRange> typeRangeMemory,
1054       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1055       MutableArrayRef<ValueRange> valueRangeMemory,
1056       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1057       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1058       ArrayRef<ByteCodeField> code,
1059       ArrayRef<PatternBenefit> currentPatternBenefits,
1060       ArrayRef<PDLByteCodePattern> patterns,
1061       ArrayRef<PDLConstraintFunction> constraintFunctions,
1062       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1063       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1064         typeRangeMemory(typeRangeMemory),
1065         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1066         valueRangeMemory(valueRangeMemory),
1067         allocatedValueRangeMemory(allocatedValueRangeMemory),
1068         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1069         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1070         constraintFunctions(constraintFunctions),
1071         rewriteFunctions(rewriteFunctions) {}
1072 
1073   /// Start executing the code at the current bytecode index. `matches` is an
1074   /// optional field provided when this function is executed in a matching
1075   /// context.
1076   void execute(PatternRewriter &rewriter,
1077                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1078                Optional<Location> mainRewriteLoc = {});
1079 
1080 private:
1081   /// Internal implementation of executing each of the bytecode commands.
1082   void executeApplyConstraint(PatternRewriter &rewriter);
1083   void executeApplyRewrite(PatternRewriter &rewriter);
1084   void executeAreEqual();
1085   void executeAreRangesEqual();
1086   void executeBranch();
1087   void executeCheckOperandCount();
1088   void executeCheckOperationName();
1089   void executeCheckResultCount();
1090   void executeCheckTypes();
1091   void executeContinue();
1092   void executeCreateOperation(PatternRewriter &rewriter,
1093                               Location mainRewriteLoc);
1094   void executeCreateTypes();
1095   void executeEraseOp(PatternRewriter &rewriter);
1096   template <typename T, typename Range, PDLValue::Kind kind>
1097   void executeExtract();
1098   void executeFinalize();
1099   void executeForEach();
1100   void executeGetAttribute();
1101   void executeGetAttributeType();
1102   void executeGetDefiningOp();
1103   void executeGetOperand(unsigned index);
1104   void executeGetOperands();
1105   void executeGetResult(unsigned index);
1106   void executeGetResults();
1107   void executeGetUsers();
1108   void executeGetValueType();
1109   void executeGetValueRangeTypes();
1110   void executeIsNotNull();
1111   void executeRecordMatch(PatternRewriter &rewriter,
1112                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1113   void executeReplaceOp(PatternRewriter &rewriter);
1114   void executeSwitchAttribute();
1115   void executeSwitchOperandCount();
1116   void executeSwitchOperationName();
1117   void executeSwitchResultCount();
1118   void executeSwitchType();
1119   void executeSwitchTypes();
1120 
1121   /// Pushes a code iterator to the stack.
1122   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1123 
1124   /// Pops a code iterator from the stack, returning true on success.
1125   void popCodeIt() {
1126     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1127     curCodeIt = resumeCodeIt.back();
1128     resumeCodeIt.pop_back();
1129   }
1130 
1131   /// Return the bytecode iterator at the start of the current op code.
1132   const ByteCodeField *getPrevCodeIt() const {
1133     LLVM_DEBUG({
1134       // Account for the op code and the Location stored inline.
1135       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1136     });
1137 
1138     // Account for the op code only.
1139     return curCodeIt - 1;
1140   }
1141 
1142   /// Read a value from the bytecode buffer, optionally skipping a certain
1143   /// number of prefix values. These methods always update the buffer to point
1144   /// to the next field after the read data.
1145   template <typename T = ByteCodeField>
1146   T read(size_t skipN = 0) {
1147     curCodeIt += skipN;
1148     return readImpl<T>();
1149   }
1150   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1151 
1152   /// Read a list of values from the bytecode buffer.
1153   template <typename ValueT, typename T>
1154   void readList(SmallVectorImpl<T> &list) {
1155     list.clear();
1156     for (unsigned i = 0, e = read(); i != e; ++i)
1157       list.push_back(read<ValueT>());
1158   }
1159 
1160   /// Read a list of values from the bytecode buffer. The values may be encoded
1161   /// as either Value or ValueRange elements.
1162   void readValueList(SmallVectorImpl<Value> &list) {
1163     for (unsigned i = 0, e = read(); i != e; ++i) {
1164       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1165         list.push_back(read<Value>());
1166       } else {
1167         ValueRange *values = read<ValueRange *>();
1168         list.append(values->begin(), values->end());
1169       }
1170     }
1171   }
1172 
1173   /// Read a value stored inline as a pointer.
1174   template <typename T>
1175   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1176   readInline() {
1177     const void *pointer;
1178     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1179     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1180     return T::getFromOpaquePointer(pointer);
1181   }
1182 
1183   /// Jump to a specific successor based on a predicate value.
1184   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1185   /// Jump to a specific successor based on a destination index.
1186   void selectJump(size_t destIndex) {
1187     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1188   }
1189 
1190   /// Handle a switch operation with the provided value and cases.
1191   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1192   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1193     LLVM_DEBUG({
1194       llvm::dbgs() << "  * Value: " << value << "\n"
1195                    << "  * Cases: ";
1196       llvm::interleaveComma(cases, llvm::dbgs());
1197       llvm::dbgs() << "\n";
1198     });
1199 
1200     // Check to see if the attribute value is within the case list. Jump to
1201     // the correct successor index based on the result.
1202     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1203       if (cmp(*it, value))
1204         return selectJump(size_t((it - cases.begin()) + 1));
1205     selectJump(size_t(0));
1206   }
1207 
1208   /// Store a pointer to memory.
1209   void storeToMemory(unsigned index, const void *value) {
1210     memory[index] = value;
1211   }
1212 
1213   /// Store a value to memory as an opaque pointer.
1214   template <typename T>
1215   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1216   storeToMemory(unsigned index, T value) {
1217     memory[index] = value.getAsOpaquePointer();
1218   }
1219 
1220   /// Internal implementation of reading various data types from the bytecode
1221   /// stream.
1222   template <typename T>
1223   const void *readFromMemory() {
1224     size_t index = *curCodeIt++;
1225 
1226     // If this type is an SSA value, it can only be stored in non-const memory.
1227     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1228                         Value>::value ||
1229         index < memory.size())
1230       return memory[index];
1231 
1232     // Otherwise, if this index is not inbounds it is uniqued.
1233     return uniquedMemory[index - memory.size()];
1234   }
1235   template <typename T>
1236   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1237     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1238   }
1239   template <typename T>
1240   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1241                    T>
1242   readImpl() {
1243     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1244   }
1245   template <typename T>
1246   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1247     switch (read<PDLValue::Kind>()) {
1248     case PDLValue::Kind::Attribute:
1249       return read<Attribute>();
1250     case PDLValue::Kind::Operation:
1251       return read<Operation *>();
1252     case PDLValue::Kind::Type:
1253       return read<Type>();
1254     case PDLValue::Kind::Value:
1255       return read<Value>();
1256     case PDLValue::Kind::TypeRange:
1257       return read<TypeRange *>();
1258     case PDLValue::Kind::ValueRange:
1259       return read<ValueRange *>();
1260     }
1261     llvm_unreachable("unhandled PDLValue::Kind");
1262   }
1263   template <typename T>
1264   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1265     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1266                   "unexpected ByteCode address size");
1267     ByteCodeAddr result;
1268     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1269     curCodeIt += 2;
1270     return result;
1271   }
1272   template <typename T>
1273   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1274     return *curCodeIt++;
1275   }
1276   template <typename T>
1277   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1278     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1279   }
1280 
1281   /// The underlying bytecode buffer.
1282   const ByteCodeField *curCodeIt;
1283 
1284   /// The stack of bytecode positions at which to resume operation.
1285   SmallVector<const ByteCodeField *> resumeCodeIt;
1286 
1287   /// The current execution memory.
1288   MutableArrayRef<const void *> memory;
1289   MutableArrayRef<OwningOpRange> opRangeMemory;
1290   MutableArrayRef<TypeRange> typeRangeMemory;
1291   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1292   MutableArrayRef<ValueRange> valueRangeMemory;
1293   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1294 
1295   /// The current loop indices.
1296   MutableArrayRef<unsigned> loopIndex;
1297 
1298   /// References to ByteCode data necessary for execution.
1299   ArrayRef<const void *> uniquedMemory;
1300   ArrayRef<ByteCodeField> code;
1301   ArrayRef<PatternBenefit> currentPatternBenefits;
1302   ArrayRef<PDLByteCodePattern> patterns;
1303   ArrayRef<PDLConstraintFunction> constraintFunctions;
1304   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1305 };
1306 
1307 /// This class is an instantiation of the PDLResultList that provides access to
1308 /// the returned results. This API is not on `PDLResultList` to avoid
1309 /// overexposing access to information specific solely to the ByteCode.
1310 class ByteCodeRewriteResultList : public PDLResultList {
1311 public:
1312   ByteCodeRewriteResultList(unsigned maxNumResults)
1313       : PDLResultList(maxNumResults) {}
1314 
1315   /// Return the list of PDL results.
1316   MutableArrayRef<PDLValue> getResults() { return results; }
1317 
1318   /// Return the type ranges allocated by this list.
1319   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1320     return allocatedTypeRanges;
1321   }
1322 
1323   /// Return the value ranges allocated by this list.
1324   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1325     return allocatedValueRanges;
1326   }
1327 };
1328 } // namespace
1329 
1330 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1331   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1332   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1333   ArrayAttr constParams = read<ArrayAttr>();
1334   SmallVector<PDLValue, 16> args;
1335   readList<PDLValue>(args);
1336 
1337   LLVM_DEBUG({
1338     llvm::dbgs() << "  * Arguments: ";
1339     llvm::interleaveComma(args, llvm::dbgs());
1340     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1341   });
1342 
1343   // Invoke the constraint and jump to the proper destination.
1344   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1345 }
1346 
1347 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1348   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1349   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1350   ArrayAttr constParams = read<ArrayAttr>();
1351   SmallVector<PDLValue, 16> args;
1352   readList<PDLValue>(args);
1353 
1354   LLVM_DEBUG({
1355     llvm::dbgs() << "  * Arguments: ";
1356     llvm::interleaveComma(args, llvm::dbgs());
1357     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1358   });
1359 
1360   // Execute the rewrite function.
1361   ByteCodeField numResults = read();
1362   ByteCodeRewriteResultList results(numResults);
1363   rewriteFn(args, constParams, rewriter, results);
1364 
1365   assert(results.getResults().size() == numResults &&
1366          "native PDL rewrite function returned unexpected number of results");
1367 
1368   // Store the results in the bytecode memory.
1369   for (PDLValue &result : results.getResults()) {
1370     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1371 
1372 // In debug mode we also verify the expected kind of the result.
1373 #ifndef NDEBUG
1374     assert(result.getKind() == read<PDLValue::Kind>() &&
1375            "native PDL rewrite function returned an unexpected type of result");
1376 #endif
1377 
1378     // If the result is a range, we need to copy it over to the bytecodes
1379     // range memory.
1380     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1381       unsigned rangeIndex = read();
1382       typeRangeMemory[rangeIndex] = *typeRange;
1383       memory[read()] = &typeRangeMemory[rangeIndex];
1384     } else if (Optional<ValueRange> valueRange =
1385                    result.dyn_cast<ValueRange>()) {
1386       unsigned rangeIndex = read();
1387       valueRangeMemory[rangeIndex] = *valueRange;
1388       memory[read()] = &valueRangeMemory[rangeIndex];
1389     } else {
1390       memory[read()] = result.getAsOpaquePointer();
1391     }
1392   }
1393 
1394   // Copy over any underlying storage allocated for result ranges.
1395   for (auto &it : results.getAllocatedTypeRanges())
1396     allocatedTypeRangeMemory.push_back(std::move(it));
1397   for (auto &it : results.getAllocatedValueRanges())
1398     allocatedValueRangeMemory.push_back(std::move(it));
1399 }
1400 
1401 void ByteCodeExecutor::executeAreEqual() {
1402   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1403   const void *lhs = read<const void *>();
1404   const void *rhs = read<const void *>();
1405 
1406   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1407   selectJump(lhs == rhs);
1408 }
1409 
1410 void ByteCodeExecutor::executeAreRangesEqual() {
1411   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1412   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1413   const void *lhs = read<const void *>();
1414   const void *rhs = read<const void *>();
1415 
1416   switch (valueKind) {
1417   case PDLValue::Kind::TypeRange: {
1418     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1419     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1420     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1421     selectJump(*lhsRange == *rhsRange);
1422     break;
1423   }
1424   case PDLValue::Kind::ValueRange: {
1425     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1426     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1427     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1428     selectJump(*lhsRange == *rhsRange);
1429     break;
1430   }
1431   default:
1432     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1433   }
1434 }
1435 
1436 void ByteCodeExecutor::executeBranch() {
1437   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1438   curCodeIt = &code[read<ByteCodeAddr>()];
1439 }
1440 
1441 void ByteCodeExecutor::executeCheckOperandCount() {
1442   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1443   Operation *op = read<Operation *>();
1444   uint32_t expectedCount = read<uint32_t>();
1445   bool compareAtLeast = read();
1446 
1447   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1448                           << "  * Expected: " << expectedCount << "\n"
1449                           << "  * Comparator: "
1450                           << (compareAtLeast ? ">=" : "==") << "\n");
1451   if (compareAtLeast)
1452     selectJump(op->getNumOperands() >= expectedCount);
1453   else
1454     selectJump(op->getNumOperands() == expectedCount);
1455 }
1456 
1457 void ByteCodeExecutor::executeCheckOperationName() {
1458   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1459   Operation *op = read<Operation *>();
1460   OperationName expectedName = read<OperationName>();
1461 
1462   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1463                           << "  * Expected: \"" << expectedName << "\"\n");
1464   selectJump(op->getName() == expectedName);
1465 }
1466 
1467 void ByteCodeExecutor::executeCheckResultCount() {
1468   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1469   Operation *op = read<Operation *>();
1470   uint32_t expectedCount = read<uint32_t>();
1471   bool compareAtLeast = read();
1472 
1473   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1474                           << "  * Expected: " << expectedCount << "\n"
1475                           << "  * Comparator: "
1476                           << (compareAtLeast ? ">=" : "==") << "\n");
1477   if (compareAtLeast)
1478     selectJump(op->getNumResults() >= expectedCount);
1479   else
1480     selectJump(op->getNumResults() == expectedCount);
1481 }
1482 
1483 void ByteCodeExecutor::executeCheckTypes() {
1484   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1485   TypeRange *lhs = read<TypeRange *>();
1486   Attribute rhs = read<Attribute>();
1487   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1488 
1489   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1490 }
1491 
1492 void ByteCodeExecutor::executeContinue() {
1493   ByteCodeField level = read();
1494   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1495                           << "  * Level: " << level << "\n");
1496   ++loopIndex[level];
1497   popCodeIt();
1498 }
1499 
1500 void ByteCodeExecutor::executeCreateTypes() {
1501   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1502   unsigned memIndex = read();
1503   unsigned rangeIndex = read();
1504   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1505 
1506   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1507 
1508   // Allocate a buffer for this type range.
1509   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1510   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1511   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1512 
1513   // Assign this to the range slot and use the range as the value for the
1514   // memory index.
1515   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1516   memory[memIndex] = &typeRangeMemory[rangeIndex];
1517 }
1518 
1519 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1520                                               Location mainRewriteLoc) {
1521   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1522 
1523   unsigned memIndex = read();
1524   OperationState state(mainRewriteLoc, read<OperationName>());
1525   readValueList(state.operands);
1526   for (unsigned i = 0, e = read(); i != e; ++i) {
1527     StringAttr name = read<StringAttr>();
1528     if (Attribute attr = read<Attribute>())
1529       state.addAttribute(name, attr);
1530   }
1531 
1532   for (unsigned i = 0, e = read(); i != e; ++i) {
1533     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1534       state.types.push_back(read<Type>());
1535       continue;
1536     }
1537 
1538     // If we find a null range, this signals that the types are infered.
1539     if (TypeRange *resultTypes = read<TypeRange *>()) {
1540       state.types.append(resultTypes->begin(), resultTypes->end());
1541       continue;
1542     }
1543 
1544     // Handle the case where the operation has inferred types.
1545     InferTypeOpInterface::Concept *concept =
1546         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1547 
1548     // TODO: Handle failure.
1549     state.types.clear();
1550     if (failed(concept->inferReturnTypes(
1551             state.getContext(), state.location, state.operands,
1552             state.attributes.getDictionary(state.getContext()), state.regions,
1553             state.types)))
1554       return;
1555     break;
1556   }
1557 
1558   Operation *resultOp = rewriter.createOperation(state);
1559   memory[memIndex] = resultOp;
1560 
1561   LLVM_DEBUG({
1562     llvm::dbgs() << "  * Attributes: "
1563                  << state.attributes.getDictionary(state.getContext())
1564                  << "\n  * Operands: ";
1565     llvm::interleaveComma(state.operands, llvm::dbgs());
1566     llvm::dbgs() << "\n  * Result Types: ";
1567     llvm::interleaveComma(state.types, llvm::dbgs());
1568     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1569   });
1570 }
1571 
1572 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1573   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1574   Operation *op = read<Operation *>();
1575 
1576   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1577   rewriter.eraseOp(op);
1578 }
1579 
1580 template <typename T, typename Range, PDLValue::Kind kind>
1581 void ByteCodeExecutor::executeExtract() {
1582   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1583   Range *range = read<Range *>();
1584   unsigned index = read<uint32_t>();
1585   unsigned memIndex = read();
1586 
1587   if (!range) {
1588     memory[memIndex] = nullptr;
1589     return;
1590   }
1591 
1592   T result = index < range->size() ? (*range)[index] : T();
1593   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1594                           << "  * Index: " << index << "\n"
1595                           << "  * Result: " << result << "\n");
1596   storeToMemory(memIndex, result);
1597 }
1598 
1599 void ByteCodeExecutor::executeFinalize() {
1600   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1601 }
1602 
1603 void ByteCodeExecutor::executeForEach() {
1604   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1605   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1606   unsigned rangeIndex = read();
1607   unsigned memIndex = read();
1608   const void *value = nullptr;
1609 
1610   switch (read<PDLValue::Kind>()) {
1611   case PDLValue::Kind::Operation: {
1612     unsigned &index = loopIndex[read()];
1613     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1614     assert(index <= array.size() && "iterated past the end");
1615     if (index < array.size()) {
1616       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1617       value = array[index];
1618       break;
1619     }
1620 
1621     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1622     index = 0;
1623     selectJump(size_t(0));
1624     return;
1625   }
1626   default:
1627     llvm_unreachable("unexpected `ForEach` value kind");
1628   }
1629 
1630   // Store the iterate value and the stack address.
1631   memory[memIndex] = value;
1632   pushCodeIt(prevCodeIt);
1633 
1634   // Skip over the successor (we will enter the body of the loop).
1635   read<ByteCodeAddr>();
1636 }
1637 
1638 void ByteCodeExecutor::executeGetAttribute() {
1639   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1640   unsigned memIndex = read();
1641   Operation *op = read<Operation *>();
1642   StringAttr attrName = read<StringAttr>();
1643   Attribute attr = op->getAttr(attrName);
1644 
1645   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1646                           << "  * Attribute: " << attrName << "\n"
1647                           << "  * Result: " << attr << "\n");
1648   memory[memIndex] = attr.getAsOpaquePointer();
1649 }
1650 
1651 void ByteCodeExecutor::executeGetAttributeType() {
1652   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1653   unsigned memIndex = read();
1654   Attribute attr = read<Attribute>();
1655   Type type = attr ? attr.getType() : Type();
1656 
1657   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1658                           << "  * Result: " << type << "\n");
1659   memory[memIndex] = type.getAsOpaquePointer();
1660 }
1661 
1662 void ByteCodeExecutor::executeGetDefiningOp() {
1663   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1664   unsigned memIndex = read();
1665   Operation *op = nullptr;
1666   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1667     Value value = read<Value>();
1668     if (value)
1669       op = value.getDefiningOp();
1670     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1671   } else {
1672     ValueRange *values = read<ValueRange *>();
1673     if (values && !values->empty()) {
1674       op = values->front().getDefiningOp();
1675     }
1676     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1677   }
1678 
1679   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1680   memory[memIndex] = op;
1681 }
1682 
1683 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1684   Operation *op = read<Operation *>();
1685   unsigned memIndex = read();
1686   Value operand =
1687       index < op->getNumOperands() ? op->getOperand(index) : Value();
1688 
1689   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1690                           << "  * Index: " << index << "\n"
1691                           << "  * Result: " << operand << "\n");
1692   memory[memIndex] = operand.getAsOpaquePointer();
1693 }
1694 
1695 /// This function is the internal implementation of `GetResults` and
1696 /// `GetOperands` that provides support for extracting a value range from the
1697 /// given operation.
1698 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1699 static void *
1700 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1701                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1702                           MutableArrayRef<ValueRange> valueRangeMemory) {
1703   // Check for the sentinel index that signals that all values should be
1704   // returned.
1705   if (index == std::numeric_limits<uint32_t>::max()) {
1706     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1707     // `values` is already the full value range.
1708 
1709     // Otherwise, check to see if this operation uses AttrSizedSegments.
1710   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1711     LLVM_DEBUG(llvm::dbgs()
1712                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1713 
1714     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1715     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1716       return nullptr;
1717 
1718     auto segments = segmentAttr.getValues<int32_t>();
1719     unsigned startIndex =
1720         std::accumulate(segments.begin(), segments.begin() + index, 0);
1721     values = values.slice(startIndex, *std::next(segments.begin(), index));
1722 
1723     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1724                             << *std::next(segments.begin(), index) << "]\n");
1725 
1726     // Otherwise, assume this is the last operand group of the operation.
1727     // FIXME: We currently don't support operations with
1728     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1729     // have a way to detect it's presence.
1730   } else if (values.size() >= index) {
1731     LLVM_DEBUG(llvm::dbgs()
1732                << "  * Treating values as trailing variadic range\n");
1733     values = values.drop_front(index);
1734 
1735     // If we couldn't detect a way to compute the values, bail out.
1736   } else {
1737     return nullptr;
1738   }
1739 
1740   // If the range index is valid, we are returning a range.
1741   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1742     valueRangeMemory[rangeIndex] = values;
1743     return &valueRangeMemory[rangeIndex];
1744   }
1745 
1746   // If a range index wasn't provided, the range is required to be non-variadic.
1747   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1748 }
1749 
1750 void ByteCodeExecutor::executeGetOperands() {
1751   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1752   unsigned index = read<uint32_t>();
1753   Operation *op = read<Operation *>();
1754   ByteCodeField rangeIndex = read();
1755 
1756   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1757       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1758       valueRangeMemory);
1759   if (!result)
1760     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1761   memory[read()] = result;
1762 }
1763 
1764 void ByteCodeExecutor::executeGetResult(unsigned index) {
1765   Operation *op = read<Operation *>();
1766   unsigned memIndex = read();
1767   OpResult result =
1768       index < op->getNumResults() ? op->getResult(index) : OpResult();
1769 
1770   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1771                           << "  * Index: " << index << "\n"
1772                           << "  * Result: " << result << "\n");
1773   memory[memIndex] = result.getAsOpaquePointer();
1774 }
1775 
1776 void ByteCodeExecutor::executeGetResults() {
1777   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1778   unsigned index = read<uint32_t>();
1779   Operation *op = read<Operation *>();
1780   ByteCodeField rangeIndex = read();
1781 
1782   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1783       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1784       valueRangeMemory);
1785   if (!result)
1786     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1787   memory[read()] = result;
1788 }
1789 
1790 void ByteCodeExecutor::executeGetUsers() {
1791   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1792   unsigned memIndex = read();
1793   unsigned rangeIndex = read();
1794   OwningOpRange &range = opRangeMemory[rangeIndex];
1795   memory[memIndex] = &range;
1796 
1797   range = OwningOpRange();
1798   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1799     // Read the value.
1800     Value value = read<Value>();
1801     if (!value)
1802       return;
1803     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1804 
1805     // Extract the users of a single value.
1806     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1807     llvm::copy(value.getUsers(), range.begin());
1808   } else {
1809     // Read a range of values.
1810     ValueRange *values = read<ValueRange *>();
1811     if (!values)
1812       return;
1813     LLVM_DEBUG({
1814       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1815       llvm::interleaveComma(*values, llvm::dbgs());
1816       llvm::dbgs() << "\n";
1817     });
1818 
1819     // Extract all the users of a range of values.
1820     SmallVector<Operation *> users;
1821     for (Value value : *values)
1822       users.append(value.user_begin(), value.user_end());
1823     range = OwningOpRange(users.size());
1824     llvm::copy(users, range.begin());
1825   }
1826 
1827   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1828 }
1829 
1830 void ByteCodeExecutor::executeGetValueType() {
1831   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1832   unsigned memIndex = read();
1833   Value value = read<Value>();
1834   Type type = value ? value.getType() : Type();
1835 
1836   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1837                           << "  * Result: " << type << "\n");
1838   memory[memIndex] = type.getAsOpaquePointer();
1839 }
1840 
1841 void ByteCodeExecutor::executeGetValueRangeTypes() {
1842   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1843   unsigned memIndex = read();
1844   unsigned rangeIndex = read();
1845   ValueRange *values = read<ValueRange *>();
1846   if (!values) {
1847     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1848     memory[memIndex] = nullptr;
1849     return;
1850   }
1851 
1852   LLVM_DEBUG({
1853     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1854     llvm::interleaveComma(*values, llvm::dbgs());
1855     llvm::dbgs() << "\n  * Result: ";
1856     llvm::interleaveComma(values->getType(), llvm::dbgs());
1857     llvm::dbgs() << "\n";
1858   });
1859   typeRangeMemory[rangeIndex] = values->getType();
1860   memory[memIndex] = &typeRangeMemory[rangeIndex];
1861 }
1862 
1863 void ByteCodeExecutor::executeIsNotNull() {
1864   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1865   const void *value = read<const void *>();
1866 
1867   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1868   selectJump(value != nullptr);
1869 }
1870 
1871 void ByteCodeExecutor::executeRecordMatch(
1872     PatternRewriter &rewriter,
1873     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1874   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1875   unsigned patternIndex = read();
1876   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1877   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1878 
1879   // If the benefit of the pattern is impossible, skip the processing of the
1880   // rest of the pattern.
1881   if (benefit.isImpossibleToMatch()) {
1882     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1883     curCodeIt = dest;
1884     return;
1885   }
1886 
1887   // Create a fused location containing the locations of each of the
1888   // operations used in the match. This will be used as the location for
1889   // created operations during the rewrite that don't already have an
1890   // explicit location set.
1891   unsigned numMatchLocs = read();
1892   SmallVector<Location, 4> matchLocs;
1893   matchLocs.reserve(numMatchLocs);
1894   for (unsigned i = 0; i != numMatchLocs; ++i)
1895     matchLocs.push_back(read<Operation *>()->getLoc());
1896   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1897 
1898   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1899                           << "  * Location: " << matchLoc << "\n");
1900   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1901   PDLByteCode::MatchResult &match = matches.back();
1902 
1903   // Record all of the inputs to the match. If any of the inputs are ranges, we
1904   // will also need to remap the range pointer to memory stored in the match
1905   // state.
1906   unsigned numInputs = read();
1907   match.values.reserve(numInputs);
1908   match.typeRangeValues.reserve(numInputs);
1909   match.valueRangeValues.reserve(numInputs);
1910   for (unsigned i = 0; i < numInputs; ++i) {
1911     switch (read<PDLValue::Kind>()) {
1912     case PDLValue::Kind::TypeRange:
1913       match.typeRangeValues.push_back(*read<TypeRange *>());
1914       match.values.push_back(&match.typeRangeValues.back());
1915       break;
1916     case PDLValue::Kind::ValueRange:
1917       match.valueRangeValues.push_back(*read<ValueRange *>());
1918       match.values.push_back(&match.valueRangeValues.back());
1919       break;
1920     default:
1921       match.values.push_back(read<const void *>());
1922       break;
1923     }
1924   }
1925   curCodeIt = dest;
1926 }
1927 
1928 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1929   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1930   Operation *op = read<Operation *>();
1931   SmallVector<Value, 16> args;
1932   readValueList(args);
1933 
1934   LLVM_DEBUG({
1935     llvm::dbgs() << "  * Operation: " << *op << "\n"
1936                  << "  * Values: ";
1937     llvm::interleaveComma(args, llvm::dbgs());
1938     llvm::dbgs() << "\n";
1939   });
1940   rewriter.replaceOp(op, args);
1941 }
1942 
1943 void ByteCodeExecutor::executeSwitchAttribute() {
1944   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1945   Attribute value = read<Attribute>();
1946   ArrayAttr cases = read<ArrayAttr>();
1947   handleSwitch(value, cases);
1948 }
1949 
1950 void ByteCodeExecutor::executeSwitchOperandCount() {
1951   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1952   Operation *op = read<Operation *>();
1953   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1954 
1955   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1956   handleSwitch(op->getNumOperands(), cases);
1957 }
1958 
1959 void ByteCodeExecutor::executeSwitchOperationName() {
1960   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1961   OperationName value = read<Operation *>()->getName();
1962   size_t caseCount = read();
1963 
1964   // The operation names are stored in-line, so to print them out for
1965   // debugging purposes we need to read the array before executing the
1966   // switch so that we can display all of the possible values.
1967   LLVM_DEBUG({
1968     const ByteCodeField *prevCodeIt = curCodeIt;
1969     llvm::dbgs() << "  * Value: " << value << "\n"
1970                  << "  * Cases: ";
1971     llvm::interleaveComma(
1972         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1973                         [&](size_t) { return read<OperationName>(); }),
1974         llvm::dbgs());
1975     llvm::dbgs() << "\n";
1976     curCodeIt = prevCodeIt;
1977   });
1978 
1979   // Try to find the switch value within any of the cases.
1980   for (size_t i = 0; i != caseCount; ++i) {
1981     if (read<OperationName>() == value) {
1982       curCodeIt += (caseCount - i - 1);
1983       return selectJump(i + 1);
1984     }
1985   }
1986   selectJump(size_t(0));
1987 }
1988 
1989 void ByteCodeExecutor::executeSwitchResultCount() {
1990   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1991   Operation *op = read<Operation *>();
1992   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1993 
1994   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1995   handleSwitch(op->getNumResults(), cases);
1996 }
1997 
1998 void ByteCodeExecutor::executeSwitchType() {
1999   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2000   Type value = read<Type>();
2001   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2002   handleSwitch(value, cases);
2003 }
2004 
2005 void ByteCodeExecutor::executeSwitchTypes() {
2006   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2007   TypeRange *value = read<TypeRange *>();
2008   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2009   if (!value) {
2010     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2011     return selectJump(size_t(0));
2012   }
2013   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2014     return value == caseValue.getAsValueRange<TypeAttr>();
2015   });
2016 }
2017 
2018 void ByteCodeExecutor::execute(
2019     PatternRewriter &rewriter,
2020     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2021     Optional<Location> mainRewriteLoc) {
2022   while (true) {
2023     // Print the location of the operation being executed.
2024     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2025 
2026     OpCode opCode = static_cast<OpCode>(read());
2027     switch (opCode) {
2028     case ApplyConstraint:
2029       executeApplyConstraint(rewriter);
2030       break;
2031     case ApplyRewrite:
2032       executeApplyRewrite(rewriter);
2033       break;
2034     case AreEqual:
2035       executeAreEqual();
2036       break;
2037     case AreRangesEqual:
2038       executeAreRangesEqual();
2039       break;
2040     case Branch:
2041       executeBranch();
2042       break;
2043     case CheckOperandCount:
2044       executeCheckOperandCount();
2045       break;
2046     case CheckOperationName:
2047       executeCheckOperationName();
2048       break;
2049     case CheckResultCount:
2050       executeCheckResultCount();
2051       break;
2052     case CheckTypes:
2053       executeCheckTypes();
2054       break;
2055     case Continue:
2056       executeContinue();
2057       break;
2058     case CreateOperation:
2059       executeCreateOperation(rewriter, *mainRewriteLoc);
2060       break;
2061     case CreateTypes:
2062       executeCreateTypes();
2063       break;
2064     case EraseOp:
2065       executeEraseOp(rewriter);
2066       break;
2067     case ExtractOp:
2068       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2069       break;
2070     case ExtractType:
2071       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2072       break;
2073     case ExtractValue:
2074       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2075       break;
2076     case Finalize:
2077       executeFinalize();
2078       LLVM_DEBUG(llvm::dbgs() << "\n");
2079       return;
2080     case ForEach:
2081       executeForEach();
2082       break;
2083     case GetAttribute:
2084       executeGetAttribute();
2085       break;
2086     case GetAttributeType:
2087       executeGetAttributeType();
2088       break;
2089     case GetDefiningOp:
2090       executeGetDefiningOp();
2091       break;
2092     case GetOperand0:
2093     case GetOperand1:
2094     case GetOperand2:
2095     case GetOperand3: {
2096       unsigned index = opCode - GetOperand0;
2097       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2098       executeGetOperand(index);
2099       break;
2100     }
2101     case GetOperandN:
2102       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2103       executeGetOperand(read<uint32_t>());
2104       break;
2105     case GetOperands:
2106       executeGetOperands();
2107       break;
2108     case GetResult0:
2109     case GetResult1:
2110     case GetResult2:
2111     case GetResult3: {
2112       unsigned index = opCode - GetResult0;
2113       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2114       executeGetResult(index);
2115       break;
2116     }
2117     case GetResultN:
2118       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2119       executeGetResult(read<uint32_t>());
2120       break;
2121     case GetResults:
2122       executeGetResults();
2123       break;
2124     case GetUsers:
2125       executeGetUsers();
2126       break;
2127     case GetValueType:
2128       executeGetValueType();
2129       break;
2130     case GetValueRangeTypes:
2131       executeGetValueRangeTypes();
2132       break;
2133     case IsNotNull:
2134       executeIsNotNull();
2135       break;
2136     case RecordMatch:
2137       assert(matches &&
2138              "expected matches to be provided when executing the matcher");
2139       executeRecordMatch(rewriter, *matches);
2140       break;
2141     case ReplaceOp:
2142       executeReplaceOp(rewriter);
2143       break;
2144     case SwitchAttribute:
2145       executeSwitchAttribute();
2146       break;
2147     case SwitchOperandCount:
2148       executeSwitchOperandCount();
2149       break;
2150     case SwitchOperationName:
2151       executeSwitchOperationName();
2152       break;
2153     case SwitchResultCount:
2154       executeSwitchResultCount();
2155       break;
2156     case SwitchType:
2157       executeSwitchType();
2158       break;
2159     case SwitchTypes:
2160       executeSwitchTypes();
2161       break;
2162     }
2163     LLVM_DEBUG(llvm::dbgs() << "\n");
2164   }
2165 }
2166 
2167 /// Run the pattern matcher on the given root operation, collecting the matched
2168 /// patterns in `matches`.
2169 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2170                         SmallVectorImpl<MatchResult> &matches,
2171                         PDLByteCodeMutableState &state) const {
2172   // The first memory slot is always the root operation.
2173   state.memory[0] = op;
2174 
2175   // The matcher function always starts at code address 0.
2176   ByteCodeExecutor executor(
2177       matcherByteCode.data(), state.memory, state.opRangeMemory,
2178       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2179       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2180       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2181       constraintFunctions, rewriteFunctions);
2182   executor.execute(rewriter, &matches);
2183 
2184   // Order the found matches by benefit.
2185   std::stable_sort(matches.begin(), matches.end(),
2186                    [](const MatchResult &lhs, const MatchResult &rhs) {
2187                      return lhs.benefit > rhs.benefit;
2188                    });
2189 }
2190 
2191 /// Run the rewriter of the given pattern on the root operation `op`.
2192 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2193                           PDLByteCodeMutableState &state) const {
2194   // The arguments of the rewrite function are stored at the start of the
2195   // memory buffer.
2196   llvm::copy(match.values, state.memory.begin());
2197 
2198   ByteCodeExecutor executor(
2199       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2200       state.opRangeMemory, state.typeRangeMemory,
2201       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2202       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2203       rewriterByteCode, state.currentPatternBenefits, patterns,
2204       constraintFunctions, rewriteFunctions);
2205   executor.execute(rewriter, /*matches=*/nullptr, match.location);
2206 }
2207