1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.getBenefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.getRootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Continue to the next iteration of a loop.
99   Continue,
100   /// Create an operation.
101   CreateOperation,
102   /// Create a range of types.
103   CreateTypes,
104   /// Erase an operation.
105   EraseOp,
106   /// Extract the op from a range at the specified index.
107   ExtractOp,
108   /// Extract the type from a range at the specified index.
109   ExtractType,
110   /// Extract the value from a range at the specified index.
111   ExtractValue,
112   /// Terminate a matcher or rewrite sequence.
113   Finalize,
114   /// Iterate over a range of values.
115   ForEach,
116   /// Get a specific attribute of an operation.
117   GetAttribute,
118   /// Get the type of an attribute.
119   GetAttributeType,
120   /// Get the defining operation of a value.
121   GetDefiningOp,
122   /// Get a specific operand of an operation.
123   GetOperand0,
124   GetOperand1,
125   GetOperand2,
126   GetOperand3,
127   GetOperandN,
128   /// Get a specific operand group of an operation.
129   GetOperands,
130   /// Get a specific result of an operation.
131   GetResult0,
132   GetResult1,
133   GetResult2,
134   GetResult3,
135   GetResultN,
136   /// Get a specific result group of an operation.
137   GetResults,
138   /// Get the users of a value or a range of values.
139   GetUsers,
140   /// Get the type of a value.
141   GetValueType,
142   /// Get the types of a value range.
143   GetValueRangeTypes,
144   /// Check if a generic value is not null.
145   IsNotNull,
146   /// Record a successful pattern match.
147   RecordMatch,
148   /// Replace an operation.
149   ReplaceOp,
150   /// Compare an attribute with a set of constants.
151   SwitchAttribute,
152   /// Compare the operand count of an operation with a set of constants.
153   SwitchOperandCount,
154   /// Compare the name of an operation with a set of constants.
155   SwitchOperationName,
156   /// Compare the result count of an operation with a set of constants.
157   SwitchResultCount,
158   /// Compare a type with a set of constants.
159   SwitchType,
160   /// Compare a range of types with a set of constants.
161   SwitchTypes,
162 };
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // ByteCode Generation
167 //===----------------------------------------------------------------------===//
168 
169 //===----------------------------------------------------------------------===//
170 // Generator
171 
172 namespace {
173 struct ByteCodeLiveRange;
174 struct ByteCodeWriter;
175 
176 /// Check if the given class `T` can be converted to an opaque pointer.
177 template <typename T, typename... Args>
178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
179 
180 /// This class represents the main generator for the pattern bytecode.
181 class Generator {
182 public:
183   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
184             SmallVectorImpl<ByteCodeField> &matcherByteCode,
185             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
186             SmallVectorImpl<PDLByteCodePattern> &patterns,
187             ByteCodeField &maxValueMemoryIndex,
188             ByteCodeField &maxOpRangeMemoryIndex,
189             ByteCodeField &maxTypeRangeMemoryIndex,
190             ByteCodeField &maxValueRangeMemoryIndex,
191             ByteCodeField &maxLoopLevel,
192             llvm::StringMap<PDLConstraintFunction> &constraintFns,
193             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
194       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
195         rewriterByteCode(rewriterByteCode), patterns(patterns),
196         maxValueMemoryIndex(maxValueMemoryIndex),
197         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
198         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
199         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
200         maxLoopLevel(maxLoopLevel) {
201     for (const auto &it : llvm::enumerate(constraintFns))
202       constraintToMemIndex.try_emplace(it.value().first(), it.index());
203     for (const auto &it : llvm::enumerate(rewriteFns))
204       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
205   }
206 
207   /// Generate the bytecode for the given PDL interpreter module.
208   void generate(ModuleOp module);
209 
210   /// Return the memory index to use for the given value.
211   ByteCodeField &getMemIndex(Value value) {
212     assert(valueToMemIndex.count(value) &&
213            "expected memory index to be assigned");
214     return valueToMemIndex[value];
215   }
216 
217   /// Return the range memory index used to store the given range value.
218   ByteCodeField &getRangeStorageIndex(Value value) {
219     assert(valueToRangeIndex.count(value) &&
220            "expected range index to be assigned");
221     return valueToRangeIndex[value];
222   }
223 
224   /// Return an index to use when referring to the given data that is uniqued in
225   /// the MLIR context.
226   template <typename T>
227   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
228   getMemIndex(T val) {
229     const void *opaqueVal = val.getAsOpaquePointer();
230 
231     // Get or insert a reference to this value.
232     auto it = uniquedDataToMemIndex.try_emplace(
233         opaqueVal, maxValueMemoryIndex + uniquedData.size());
234     if (it.second)
235       uniquedData.push_back(opaqueVal);
236     return it.first->second;
237   }
238 
239 private:
240   /// Allocate memory indices for the results of operations within the matcher
241   /// and rewriters.
242   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
243                              ModuleOp rewriterModule);
244 
245   /// Generate the bytecode for the given operation.
246   void generate(Region *region, ByteCodeWriter &writer);
247   void generate(Operation *op, ByteCodeWriter &writer);
248   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
249   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
250   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
251   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
252   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
259   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
286 
287   /// Mapping from value to its corresponding memory index.
288   DenseMap<Value, ByteCodeField> valueToMemIndex;
289 
290   /// Mapping from a range value to its corresponding range storage index.
291   DenseMap<Value, ByteCodeField> valueToRangeIndex;
292 
293   /// Mapping from the name of an externally registered rewrite to its index in
294   /// the bytecode registry.
295   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
296 
297   /// Mapping from the name of an externally registered constraint to its index
298   /// in the bytecode registry.
299   llvm::StringMap<ByteCodeField> constraintToMemIndex;
300 
301   /// Mapping from rewriter function name to the bytecode address of the
302   /// rewriter function in byte.
303   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
304 
305   /// Mapping from a uniqued storage object to its memory index within
306   /// `uniquedData`.
307   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
308 
309   /// The current level of the foreach loop.
310   ByteCodeField curLoopLevel = 0;
311 
312   /// The current MLIR context.
313   MLIRContext *ctx;
314 
315   /// Mapping from block to its address.
316   DenseMap<Block *, ByteCodeAddr> blockToAddr;
317 
318   /// Data of the ByteCode class to be populated.
319   std::vector<const void *> &uniquedData;
320   SmallVectorImpl<ByteCodeField> &matcherByteCode;
321   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
322   SmallVectorImpl<PDLByteCodePattern> &patterns;
323   ByteCodeField &maxValueMemoryIndex;
324   ByteCodeField &maxOpRangeMemoryIndex;
325   ByteCodeField &maxTypeRangeMemoryIndex;
326   ByteCodeField &maxValueRangeMemoryIndex;
327   ByteCodeField &maxLoopLevel;
328 };
329 
330 /// This class provides utilities for writing a bytecode stream.
331 struct ByteCodeWriter {
332   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
333       : bytecode(bytecode), generator(generator) {}
334 
335   /// Append a field to the bytecode.
336   void append(ByteCodeField field) { bytecode.push_back(field); }
337   void append(OpCode opCode) { bytecode.push_back(opCode); }
338 
339   /// Append an address to the bytecode.
340   void append(ByteCodeAddr field) {
341     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
342                   "unexpected ByteCode address size");
343 
344     ByteCodeField fieldParts[2];
345     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
346     bytecode.append({fieldParts[0], fieldParts[1]});
347   }
348 
349   /// Append a single successor to the bytecode, the exact address will need to
350   /// be resolved later.
351   void append(Block *successor) {
352     // Add back a reference to the successor so that the address can be resolved
353     // later.
354     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
355     append(ByteCodeAddr(0));
356   }
357 
358   /// Append a successor range to the bytecode, the exact address will need to
359   /// be resolved later.
360   void append(SuccessorRange successors) {
361     for (Block *successor : successors)
362       append(successor);
363   }
364 
365   /// Append a range of values that will be read as generic PDLValues.
366   void appendPDLValueList(OperandRange values) {
367     bytecode.push_back(values.size());
368     for (Value value : values)
369       appendPDLValue(value);
370   }
371 
372   /// Append a value as a PDLValue.
373   void appendPDLValue(Value value) {
374     appendPDLValueKind(value);
375     append(value);
376   }
377 
378   /// Append the PDLValue::Kind of the given value.
379   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
380 
381   /// Append the PDLValue::Kind of the given type.
382   void appendPDLValueKind(Type type) {
383     PDLValue::Kind kind =
384         TypeSwitch<Type, PDLValue::Kind>(type)
385             .Case<pdl::AttributeType>(
386                 [](Type) { return PDLValue::Kind::Attribute; })
387             .Case<pdl::OperationType>(
388                 [](Type) { return PDLValue::Kind::Operation; })
389             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
390               if (rangeTy.getElementType().isa<pdl::TypeType>())
391                 return PDLValue::Kind::TypeRange;
392               return PDLValue::Kind::ValueRange;
393             })
394             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
395             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
396     bytecode.push_back(static_cast<ByteCodeField>(kind));
397   }
398 
399   /// Append a value that will be stored in a memory slot and not inline within
400   /// the bytecode.
401   template <typename T>
402   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
403                    std::is_pointer<T>::value>
404   append(T value) {
405     bytecode.push_back(generator.getMemIndex(value));
406   }
407 
408   /// Append a range of values.
409   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
410   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
411   append(T range) {
412     bytecode.push_back(llvm::size(range));
413     for (auto it : range)
414       append(it);
415   }
416 
417   /// Append a variadic number of fields to the bytecode.
418   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
419   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
420     append(field);
421     append(field2, fields...);
422   }
423 
424   /// Appends a value as a pointer, stored inline within the bytecode.
425   template <typename T>
426   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
427   appendInline(T value) {
428     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
429     const void *pointer = value.getAsOpaquePointer();
430     ByteCodeField fieldParts[numParts];
431     std::memcpy(fieldParts, &pointer, sizeof(const void *));
432     bytecode.append(fieldParts, fieldParts + numParts);
433   }
434 
435   /// Successor references in the bytecode that have yet to be resolved.
436   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
437 
438   /// The underlying bytecode buffer.
439   SmallVectorImpl<ByteCodeField> &bytecode;
440 
441   /// The main generator producing PDL.
442   Generator &generator;
443 };
444 
445 /// This class represents a live range of PDL Interpreter values, containing
446 /// information about when values are live within a match/rewrite.
447 struct ByteCodeLiveRange {
448   using Set = llvm::IntervalMap<uint64_t, char, 16>;
449   using Allocator = Set::Allocator;
450 
451   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
452 
453   /// Union this live range with the one provided.
454   void unionWith(const ByteCodeLiveRange &rhs) {
455     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
456          ++it)
457       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
458   }
459 
460   /// Returns true if this range overlaps with the one provided.
461   bool overlaps(const ByteCodeLiveRange &rhs) const {
462     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
463         .valid();
464   }
465 
466   /// A map representing the ranges of the match/rewrite that a value is live in
467   /// the interpreter.
468   ///
469   /// We use std::unique_ptr here, because IntervalMap does not provide a
470   /// correct copy or move constructor. We can eliminate the pointer once
471   /// https://reviews.llvm.org/D113240 lands.
472   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
473 
474   /// The operation range storage index for this range.
475   Optional<unsigned> opRangeIndex;
476 
477   /// The type range storage index for this range.
478   Optional<unsigned> typeRangeIndex;
479 
480   /// The value range storage index for this range.
481   Optional<unsigned> valueRangeIndex;
482 };
483 } // namespace
484 
485 void Generator::generate(ModuleOp module) {
486   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
487       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
488   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
489       pdl_interp::PDLInterpDialect::getRewriterModuleName());
490   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
491 
492   // Allocate memory indices for the results of operations within the matcher
493   // and rewriters.
494   allocateMemoryIndices(matcherFunc, rewriterModule);
495 
496   // Generate code for the rewriter functions.
497   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
498   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
499     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
500     for (Operation &op : rewriterFunc.getOps())
501       generate(&op, rewriterByteCodeWriter);
502   }
503   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
504          "unexpected branches in rewriter function");
505 
506   // Generate code for the matcher function.
507   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
508   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
509 
510   // Resolve successor references in the matcher.
511   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
512     ByteCodeAddr addr = blockToAddr[it.first];
513     for (unsigned offsetToFix : it.second)
514       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
515   }
516 }
517 
518 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
519                                       ModuleOp rewriterModule) {
520   // Rewriters use simplistic allocation scheme that simply assigns an index to
521   // each result.
522   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
523     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
524     auto processRewriterValue = [&](Value val) {
525       valueToMemIndex.try_emplace(val, index++);
526       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
527         Type elementTy = rangeType.getElementType();
528         if (elementTy.isa<pdl::TypeType>())
529           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
530         else if (elementTy.isa<pdl::ValueType>())
531           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
532       }
533     };
534 
535     for (BlockArgument arg : rewriterFunc.getArguments())
536       processRewriterValue(arg);
537     rewriterFunc.getBody().walk([&](Operation *op) {
538       for (Value result : op->getResults())
539         processRewriterValue(result);
540     });
541     if (index > maxValueMemoryIndex)
542       maxValueMemoryIndex = index;
543     if (typeRangeIndex > maxTypeRangeMemoryIndex)
544       maxTypeRangeMemoryIndex = typeRangeIndex;
545     if (valueRangeIndex > maxValueRangeMemoryIndex)
546       maxValueRangeMemoryIndex = valueRangeIndex;
547   }
548 
549   // The matcher function uses a more sophisticated numbering that tries to
550   // minimize the number of memory indices assigned. This is done by determining
551   // a live range of the values within the matcher, then the allocation is just
552   // finding the minimal number of overlapping live ranges. This is essentially
553   // a simplified form of register allocation where we don't necessarily have a
554   // limited number of registers, but we still want to minimize the number used.
555   DenseMap<Operation *, unsigned> opToFirstIndex;
556   DenseMap<Operation *, unsigned> opToLastIndex;
557 
558   // A custom walk that marks the first and the last index of each operation.
559   // The entry marks the beginning of the liveness range for this operation,
560   // followed by nested operations, followed by the end of the liveness range.
561   unsigned index = 0;
562   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
563     opToFirstIndex.try_emplace(op, index++);
564     for (Region &region : op->getRegions())
565       for (Block &block : region.getBlocks())
566         for (Operation &nested : block)
567           walk(&nested);
568     opToLastIndex.try_emplace(op, index++);
569   };
570   walk(matcherFunc);
571 
572   // Liveness info for each of the defs within the matcher.
573   ByteCodeLiveRange::Allocator allocator;
574   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
575 
576   // Assign the root operation being matched to slot 0.
577   BlockArgument rootOpArg = matcherFunc.getArgument(0);
578   valueToMemIndex[rootOpArg] = 0;
579 
580   // Walk each of the blocks, computing the def interval that the value is used.
581   Liveness matcherLiveness(matcherFunc);
582   matcherFunc->walk([&](Block *block) {
583     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
584     assert(info && "expected liveness info for block");
585     auto processValue = [&](Value value, Operation *firstUseOrDef) {
586       // We don't need to process the root op argument, this value is always
587       // assigned to the first memory slot.
588       if (value == rootOpArg)
589         return;
590 
591       // Set indices for the range of this block that the value is used.
592       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
593       defRangeIt->second.liveness->insert(
594           opToFirstIndex[firstUseOrDef],
595           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
596           /*dummyValue*/ 0);
597 
598       // Check to see if this value is a range type.
599       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
600         Type eleType = rangeTy.getElementType();
601         if (eleType.isa<pdl::OperationType>())
602           defRangeIt->second.opRangeIndex = 0;
603         else if (eleType.isa<pdl::TypeType>())
604           defRangeIt->second.typeRangeIndex = 0;
605         else if (eleType.isa<pdl::ValueType>())
606           defRangeIt->second.valueRangeIndex = 0;
607       }
608     };
609 
610     // Process the live-ins of this block.
611     for (Value liveIn : info->in()) {
612       // Only process the value if it has been defined in the current region.
613       // Other values that span across pdl_interp.foreach will be added higher
614       // up. This ensures that the we keep them alive for the entire duration
615       // of the loop.
616       if (liveIn.getParentRegion() == block->getParent())
617         processValue(liveIn, &block->front());
618     }
619 
620     // Process the block arguments for the entry block (those are not live-in).
621     if (block->isEntryBlock()) {
622       for (Value argument : block->getArguments())
623         processValue(argument, &block->front());
624     }
625 
626     // Process any new defs within this block.
627     for (Operation &op : *block)
628       for (Value result : op.getResults())
629         processValue(result, &op);
630   });
631 
632   // Greedily allocate memory slots using the computed def live ranges.
633   std::vector<ByteCodeLiveRange> allocatedIndices;
634 
635   // The number of memory indices currently allocated (and its next value).
636   // Recall that the root gets allocated memory index 0.
637   ByteCodeField numIndices = 1;
638 
639   // The number of memory ranges of various types (and their next values).
640   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
641 
642   for (auto &defIt : valueDefRanges) {
643     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
644     ByteCodeLiveRange &defRange = defIt.second;
645 
646     // Try to allocate to an existing index.
647     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
648       ByteCodeLiveRange &existingRange = existingIndexIt.value();
649       if (!defRange.overlaps(existingRange)) {
650         existingRange.unionWith(defRange);
651         memIndex = existingIndexIt.index() + 1;
652 
653         if (defRange.opRangeIndex) {
654           if (!existingRange.opRangeIndex)
655             existingRange.opRangeIndex = numOpRanges++;
656           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
657         } else if (defRange.typeRangeIndex) {
658           if (!existingRange.typeRangeIndex)
659             existingRange.typeRangeIndex = numTypeRanges++;
660           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
661         } else if (defRange.valueRangeIndex) {
662           if (!existingRange.valueRangeIndex)
663             existingRange.valueRangeIndex = numValueRanges++;
664           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
665         }
666         break;
667       }
668     }
669 
670     // If no existing index could be used, add a new one.
671     if (memIndex == 0) {
672       allocatedIndices.emplace_back(allocator);
673       ByteCodeLiveRange &newRange = allocatedIndices.back();
674       newRange.unionWith(defRange);
675 
676       // Allocate an index for op/type/value ranges.
677       if (defRange.opRangeIndex) {
678         newRange.opRangeIndex = numOpRanges;
679         valueToRangeIndex[defIt.first] = numOpRanges++;
680       } else if (defRange.typeRangeIndex) {
681         newRange.typeRangeIndex = numTypeRanges;
682         valueToRangeIndex[defIt.first] = numTypeRanges++;
683       } else if (defRange.valueRangeIndex) {
684         newRange.valueRangeIndex = numValueRanges;
685         valueToRangeIndex[defIt.first] = numValueRanges++;
686       }
687 
688       memIndex = allocatedIndices.size();
689       ++numIndices;
690     }
691   }
692 
693   // Print the index usage and ensure that we did not run out of index space.
694   LLVM_DEBUG({
695     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
696                  << "(down from initial " << valueDefRanges.size() << ").\n";
697   });
698   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
699          "Ran out of memory for allocated indices");
700 
701   // Update the max number of indices.
702   if (numIndices > maxValueMemoryIndex)
703     maxValueMemoryIndex = numIndices;
704   if (numOpRanges > maxOpRangeMemoryIndex)
705     maxOpRangeMemoryIndex = numOpRanges;
706   if (numTypeRanges > maxTypeRangeMemoryIndex)
707     maxTypeRangeMemoryIndex = numTypeRanges;
708   if (numValueRanges > maxValueRangeMemoryIndex)
709     maxValueRangeMemoryIndex = numValueRanges;
710 }
711 
712 void Generator::generate(Region *region, ByteCodeWriter &writer) {
713   llvm::ReversePostOrderTraversal<Region *> rpot(region);
714   for (Block *block : rpot) {
715     // Keep track of where this block begins within the matcher function.
716     blockToAddr.try_emplace(block, matcherByteCode.size());
717     for (Operation &op : *block)
718       generate(&op, writer);
719   }
720 }
721 
722 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
723   LLVM_DEBUG({
724     // The following list must contain all the operations that do not
725     // produce any bytecode.
726     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
727              pdl_interp::InferredTypesOp>(op))
728       writer.appendInline(op->getLoc());
729   });
730   TypeSwitch<Operation *>(op)
731       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
732             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
733             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
734             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
735             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
736             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
737             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
738             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
739             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
740             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
741             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
742             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
743             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
744             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
745             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
746             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
747             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
748             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
749             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
750           [&](auto interpOp) { this->generate(interpOp, writer); })
751       .Default([](Operation *) {
752         llvm_unreachable("unknown `pdl_interp` operation");
753       });
754 }
755 
756 void Generator::generate(pdl_interp::ApplyConstraintOp op,
757                          ByteCodeWriter &writer) {
758   assert(constraintToMemIndex.count(op.getName()) &&
759          "expected index for constraint function");
760   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
761   writer.appendPDLValueList(op.getArgs());
762   writer.append(op.getSuccessors());
763 }
764 void Generator::generate(pdl_interp::ApplyRewriteOp op,
765                          ByteCodeWriter &writer) {
766   assert(externalRewriterToMemIndex.count(op.getName()) &&
767          "expected index for rewrite function");
768   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
769   writer.appendPDLValueList(op.getArgs());
770 
771   ResultRange results = op.getResults();
772   writer.append(ByteCodeField(results.size()));
773   for (Value result : results) {
774     // In debug mode we also record the expected kind of the result, so that we
775     // can provide extra verification of the native rewrite function.
776 #ifndef NDEBUG
777     writer.appendPDLValueKind(result);
778 #endif
779 
780     // Range results also need to append the range storage index.
781     if (result.getType().isa<pdl::RangeType>())
782       writer.append(getRangeStorageIndex(result));
783     writer.append(result);
784   }
785 }
786 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
787   Value lhs = op.getLhs();
788   if (lhs.getType().isa<pdl::RangeType>()) {
789     writer.append(OpCode::AreRangesEqual);
790     writer.appendPDLValueKind(lhs);
791     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
792     return;
793   }
794 
795   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
796 }
797 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
798   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
799 }
800 void Generator::generate(pdl_interp::CheckAttributeOp op,
801                          ByteCodeWriter &writer) {
802   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
803                 op.getSuccessors());
804 }
805 void Generator::generate(pdl_interp::CheckOperandCountOp op,
806                          ByteCodeWriter &writer) {
807   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
808                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
809                 op.getSuccessors());
810 }
811 void Generator::generate(pdl_interp::CheckOperationNameOp op,
812                          ByteCodeWriter &writer) {
813   writer.append(OpCode::CheckOperationName, op.getInputOp(),
814                 OperationName(op.getName(), ctx), op.getSuccessors());
815 }
816 void Generator::generate(pdl_interp::CheckResultCountOp op,
817                          ByteCodeWriter &writer) {
818   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
819                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
820                 op.getSuccessors());
821 }
822 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
823   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
824                 op.getSuccessors());
825 }
826 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
827   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
828                 op.getSuccessors());
829 }
830 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
831   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
832   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
833 }
834 void Generator::generate(pdl_interp::CreateAttributeOp op,
835                          ByteCodeWriter &writer) {
836   // Simply repoint the memory index of the result to the constant.
837   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
838 }
839 void Generator::generate(pdl_interp::CreateOperationOp op,
840                          ByteCodeWriter &writer) {
841   writer.append(OpCode::CreateOperation, op.getResultOp(),
842                 OperationName(op.getName(), ctx));
843   writer.appendPDLValueList(op.getInputOperands());
844 
845   // Add the attributes.
846   OperandRange attributes = op.getInputAttributes();
847   writer.append(static_cast<ByteCodeField>(attributes.size()));
848   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
849     writer.append(std::get<0>(it), std::get<1>(it));
850   writer.appendPDLValueList(op.getInputResultTypes());
851 }
852 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
853   // Simply repoint the memory index of the result to the constant.
854   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
855 }
856 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
857   writer.append(OpCode::CreateTypes, op.getResult(),
858                 getRangeStorageIndex(op.getResult()), op.getValue());
859 }
860 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
861   writer.append(OpCode::EraseOp, op.getInputOp());
862 }
863 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
864   OpCode opCode =
865       TypeSwitch<Type, OpCode>(op.getResult().getType())
866           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
867           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
868           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
869           .Default([](Type) -> OpCode {
870             llvm_unreachable("unsupported element type");
871           });
872   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
873 }
874 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
875   writer.append(OpCode::Finalize);
876 }
877 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
878   BlockArgument arg = op.getLoopVariable();
879   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
880   writer.appendPDLValueKind(arg.getType());
881   writer.append(curLoopLevel, op.getSuccessor());
882   ++curLoopLevel;
883   if (curLoopLevel > maxLoopLevel)
884     maxLoopLevel = curLoopLevel;
885   generate(&op.getRegion(), writer);
886   --curLoopLevel;
887 }
888 void Generator::generate(pdl_interp::GetAttributeOp op,
889                          ByteCodeWriter &writer) {
890   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
891                 op.getNameAttr());
892 }
893 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
894                          ByteCodeWriter &writer) {
895   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
896 }
897 void Generator::generate(pdl_interp::GetDefiningOpOp op,
898                          ByteCodeWriter &writer) {
899   writer.append(OpCode::GetDefiningOp, op.getInputOp());
900   writer.appendPDLValue(op.getValue());
901 }
902 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
903   uint32_t index = op.getIndex();
904   if (index < 4)
905     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
906   else
907     writer.append(OpCode::GetOperandN, index);
908   writer.append(op.getInputOp(), op.getValue());
909 }
910 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
911   Value result = op.getValue();
912   Optional<uint32_t> index = op.getIndex();
913   writer.append(OpCode::GetOperands,
914                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
915                 op.getInputOp());
916   if (result.getType().isa<pdl::RangeType>())
917     writer.append(getRangeStorageIndex(result));
918   else
919     writer.append(std::numeric_limits<ByteCodeField>::max());
920   writer.append(result);
921 }
922 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
923   uint32_t index = op.getIndex();
924   if (index < 4)
925     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
926   else
927     writer.append(OpCode::GetResultN, index);
928   writer.append(op.getInputOp(), op.getValue());
929 }
930 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
931   Value result = op.getValue();
932   Optional<uint32_t> index = op.getIndex();
933   writer.append(OpCode::GetResults,
934                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
935                 op.getInputOp());
936   if (result.getType().isa<pdl::RangeType>())
937     writer.append(getRangeStorageIndex(result));
938   else
939     writer.append(std::numeric_limits<ByteCodeField>::max());
940   writer.append(result);
941 }
942 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
943   Value operations = op.getOperations();
944   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
945   writer.append(OpCode::GetUsers, operations, rangeIndex);
946   writer.appendPDLValue(op.getValue());
947 }
948 void Generator::generate(pdl_interp::GetValueTypeOp op,
949                          ByteCodeWriter &writer) {
950   if (op.getType().isa<pdl::RangeType>()) {
951     Value result = op.getResult();
952     writer.append(OpCode::GetValueRangeTypes, result,
953                   getRangeStorageIndex(result), op.getValue());
954   } else {
955     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
956   }
957 }
958 
959 void Generator::generate(pdl_interp::InferredTypesOp op,
960                          ByteCodeWriter &writer) {
961   // InferType maps to a null type as a marker for inferring result types.
962   getMemIndex(op.getResult()) = getMemIndex(Type());
963 }
964 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
965   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
966 }
967 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
968   ByteCodeField patternIndex = patterns.size();
969   patterns.emplace_back(PDLByteCodePattern::create(
970       op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
971   writer.append(OpCode::RecordMatch, patternIndex,
972                 SuccessorRange(op.getOperation()), op.getMatchedOps());
973   writer.appendPDLValueList(op.getInputs());
974 }
975 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
976   writer.append(OpCode::ReplaceOp, op.getInputOp());
977   writer.appendPDLValueList(op.getReplValues());
978 }
979 void Generator::generate(pdl_interp::SwitchAttributeOp op,
980                          ByteCodeWriter &writer) {
981   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
982                 op.getCaseValuesAttr(), op.getSuccessors());
983 }
984 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
985                          ByteCodeWriter &writer) {
986   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
987                 op.getCaseValuesAttr(), op.getSuccessors());
988 }
989 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
990                          ByteCodeWriter &writer) {
991   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
992     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
993   });
994   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
995                 op.getSuccessors());
996 }
997 void Generator::generate(pdl_interp::SwitchResultCountOp op,
998                          ByteCodeWriter &writer) {
999   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1000                 op.getCaseValuesAttr(), op.getSuccessors());
1001 }
1002 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1003   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1004                 op.getSuccessors());
1005 }
1006 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1007   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1008                 op.getSuccessors());
1009 }
1010 
1011 //===----------------------------------------------------------------------===//
1012 // PDLByteCode
1013 //===----------------------------------------------------------------------===//
1014 
1015 PDLByteCode::PDLByteCode(ModuleOp module,
1016                          llvm::StringMap<PDLConstraintFunction> constraintFns,
1017                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1018   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1019                       rewriterByteCode, patterns, maxValueMemoryIndex,
1020                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1021                       maxLoopLevel, constraintFns, rewriteFns);
1022   generator.generate(module);
1023 
1024   // Initialize the external functions.
1025   for (auto &it : constraintFns)
1026     constraintFunctions.push_back(std::move(it.second));
1027   for (auto &it : rewriteFns)
1028     rewriteFunctions.push_back(std::move(it.second));
1029 }
1030 
1031 /// Initialize the given state such that it can be used to execute the current
1032 /// bytecode.
1033 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1034   state.memory.resize(maxValueMemoryIndex, nullptr);
1035   state.opRangeMemory.resize(maxOpRangeCount);
1036   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1037   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1038   state.loopIndex.resize(maxLoopLevel, 0);
1039   state.currentPatternBenefits.reserve(patterns.size());
1040   for (const PDLByteCodePattern &pattern : patterns)
1041     state.currentPatternBenefits.push_back(pattern.getBenefit());
1042 }
1043 
1044 //===----------------------------------------------------------------------===//
1045 // ByteCode Execution
1046 
1047 namespace {
1048 /// This class provides support for executing a bytecode stream.
1049 class ByteCodeExecutor {
1050 public:
1051   ByteCodeExecutor(
1052       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1053       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1054       MutableArrayRef<TypeRange> typeRangeMemory,
1055       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1056       MutableArrayRef<ValueRange> valueRangeMemory,
1057       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1058       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1059       ArrayRef<ByteCodeField> code,
1060       ArrayRef<PatternBenefit> currentPatternBenefits,
1061       ArrayRef<PDLByteCodePattern> patterns,
1062       ArrayRef<PDLConstraintFunction> constraintFunctions,
1063       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1064       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1065         typeRangeMemory(typeRangeMemory),
1066         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1067         valueRangeMemory(valueRangeMemory),
1068         allocatedValueRangeMemory(allocatedValueRangeMemory),
1069         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1070         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1071         constraintFunctions(constraintFunctions),
1072         rewriteFunctions(rewriteFunctions) {}
1073 
1074   /// Start executing the code at the current bytecode index. `matches` is an
1075   /// optional field provided when this function is executed in a matching
1076   /// context.
1077   void execute(PatternRewriter &rewriter,
1078                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1079                Optional<Location> mainRewriteLoc = {});
1080 
1081 private:
1082   /// Internal implementation of executing each of the bytecode commands.
1083   void executeApplyConstraint(PatternRewriter &rewriter);
1084   void executeApplyRewrite(PatternRewriter &rewriter);
1085   void executeAreEqual();
1086   void executeAreRangesEqual();
1087   void executeBranch();
1088   void executeCheckOperandCount();
1089   void executeCheckOperationName();
1090   void executeCheckResultCount();
1091   void executeCheckTypes();
1092   void executeContinue();
1093   void executeCreateOperation(PatternRewriter &rewriter,
1094                               Location mainRewriteLoc);
1095   void executeCreateTypes();
1096   void executeEraseOp(PatternRewriter &rewriter);
1097   template <typename T, typename Range, PDLValue::Kind kind>
1098   void executeExtract();
1099   void executeFinalize();
1100   void executeForEach();
1101   void executeGetAttribute();
1102   void executeGetAttributeType();
1103   void executeGetDefiningOp();
1104   void executeGetOperand(unsigned index);
1105   void executeGetOperands();
1106   void executeGetResult(unsigned index);
1107   void executeGetResults();
1108   void executeGetUsers();
1109   void executeGetValueType();
1110   void executeGetValueRangeTypes();
1111   void executeIsNotNull();
1112   void executeRecordMatch(PatternRewriter &rewriter,
1113                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1114   void executeReplaceOp(PatternRewriter &rewriter);
1115   void executeSwitchAttribute();
1116   void executeSwitchOperandCount();
1117   void executeSwitchOperationName();
1118   void executeSwitchResultCount();
1119   void executeSwitchType();
1120   void executeSwitchTypes();
1121 
1122   /// Pushes a code iterator to the stack.
1123   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1124 
1125   /// Pops a code iterator from the stack, returning true on success.
1126   void popCodeIt() {
1127     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1128     curCodeIt = resumeCodeIt.back();
1129     resumeCodeIt.pop_back();
1130   }
1131 
1132   /// Return the bytecode iterator at the start of the current op code.
1133   const ByteCodeField *getPrevCodeIt() const {
1134     LLVM_DEBUG({
1135       // Account for the op code and the Location stored inline.
1136       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1137     });
1138 
1139     // Account for the op code only.
1140     return curCodeIt - 1;
1141   }
1142 
1143   /// Read a value from the bytecode buffer, optionally skipping a certain
1144   /// number of prefix values. These methods always update the buffer to point
1145   /// to the next field after the read data.
1146   template <typename T = ByteCodeField>
1147   T read(size_t skipN = 0) {
1148     curCodeIt += skipN;
1149     return readImpl<T>();
1150   }
1151   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1152 
1153   /// Read a list of values from the bytecode buffer.
1154   template <typename ValueT, typename T>
1155   void readList(SmallVectorImpl<T> &list) {
1156     list.clear();
1157     for (unsigned i = 0, e = read(); i != e; ++i)
1158       list.push_back(read<ValueT>());
1159   }
1160 
1161   /// Read a list of values from the bytecode buffer. The values may be encoded
1162   /// as either Value or ValueRange elements.
1163   void readValueList(SmallVectorImpl<Value> &list) {
1164     for (unsigned i = 0, e = read(); i != e; ++i) {
1165       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1166         list.push_back(read<Value>());
1167       } else {
1168         ValueRange *values = read<ValueRange *>();
1169         list.append(values->begin(), values->end());
1170       }
1171     }
1172   }
1173 
1174   /// Read a value stored inline as a pointer.
1175   template <typename T>
1176   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1177   readInline() {
1178     const void *pointer;
1179     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1180     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1181     return T::getFromOpaquePointer(pointer);
1182   }
1183 
1184   /// Jump to a specific successor based on a predicate value.
1185   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1186   /// Jump to a specific successor based on a destination index.
1187   void selectJump(size_t destIndex) {
1188     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1189   }
1190 
1191   /// Handle a switch operation with the provided value and cases.
1192   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1193   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1194     LLVM_DEBUG({
1195       llvm::dbgs() << "  * Value: " << value << "\n"
1196                    << "  * Cases: ";
1197       llvm::interleaveComma(cases, llvm::dbgs());
1198       llvm::dbgs() << "\n";
1199     });
1200 
1201     // Check to see if the attribute value is within the case list. Jump to
1202     // the correct successor index based on the result.
1203     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1204       if (cmp(*it, value))
1205         return selectJump(size_t((it - cases.begin()) + 1));
1206     selectJump(size_t(0));
1207   }
1208 
1209   /// Store a pointer to memory.
1210   void storeToMemory(unsigned index, const void *value) {
1211     memory[index] = value;
1212   }
1213 
1214   /// Store a value to memory as an opaque pointer.
1215   template <typename T>
1216   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1217   storeToMemory(unsigned index, T value) {
1218     memory[index] = value.getAsOpaquePointer();
1219   }
1220 
1221   /// Internal implementation of reading various data types from the bytecode
1222   /// stream.
1223   template <typename T>
1224   const void *readFromMemory() {
1225     size_t index = *curCodeIt++;
1226 
1227     // If this type is an SSA value, it can only be stored in non-const memory.
1228     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1229                         Value>::value ||
1230         index < memory.size())
1231       return memory[index];
1232 
1233     // Otherwise, if this index is not inbounds it is uniqued.
1234     return uniquedMemory[index - memory.size()];
1235   }
1236   template <typename T>
1237   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1238     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1239   }
1240   template <typename T>
1241   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1242                    T>
1243   readImpl() {
1244     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1245   }
1246   template <typename T>
1247   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1248     switch (read<PDLValue::Kind>()) {
1249     case PDLValue::Kind::Attribute:
1250       return read<Attribute>();
1251     case PDLValue::Kind::Operation:
1252       return read<Operation *>();
1253     case PDLValue::Kind::Type:
1254       return read<Type>();
1255     case PDLValue::Kind::Value:
1256       return read<Value>();
1257     case PDLValue::Kind::TypeRange:
1258       return read<TypeRange *>();
1259     case PDLValue::Kind::ValueRange:
1260       return read<ValueRange *>();
1261     }
1262     llvm_unreachable("unhandled PDLValue::Kind");
1263   }
1264   template <typename T>
1265   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1266     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1267                   "unexpected ByteCode address size");
1268     ByteCodeAddr result;
1269     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1270     curCodeIt += 2;
1271     return result;
1272   }
1273   template <typename T>
1274   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1275     return *curCodeIt++;
1276   }
1277   template <typename T>
1278   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1279     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1280   }
1281 
1282   /// The underlying bytecode buffer.
1283   const ByteCodeField *curCodeIt;
1284 
1285   /// The stack of bytecode positions at which to resume operation.
1286   SmallVector<const ByteCodeField *> resumeCodeIt;
1287 
1288   /// The current execution memory.
1289   MutableArrayRef<const void *> memory;
1290   MutableArrayRef<OwningOpRange> opRangeMemory;
1291   MutableArrayRef<TypeRange> typeRangeMemory;
1292   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1293   MutableArrayRef<ValueRange> valueRangeMemory;
1294   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1295 
1296   /// The current loop indices.
1297   MutableArrayRef<unsigned> loopIndex;
1298 
1299   /// References to ByteCode data necessary for execution.
1300   ArrayRef<const void *> uniquedMemory;
1301   ArrayRef<ByteCodeField> code;
1302   ArrayRef<PatternBenefit> currentPatternBenefits;
1303   ArrayRef<PDLByteCodePattern> patterns;
1304   ArrayRef<PDLConstraintFunction> constraintFunctions;
1305   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1306 };
1307 
1308 /// This class is an instantiation of the PDLResultList that provides access to
1309 /// the returned results. This API is not on `PDLResultList` to avoid
1310 /// overexposing access to information specific solely to the ByteCode.
1311 class ByteCodeRewriteResultList : public PDLResultList {
1312 public:
1313   ByteCodeRewriteResultList(unsigned maxNumResults)
1314       : PDLResultList(maxNumResults) {}
1315 
1316   /// Return the list of PDL results.
1317   MutableArrayRef<PDLValue> getResults() { return results; }
1318 
1319   /// Return the type ranges allocated by this list.
1320   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1321     return allocatedTypeRanges;
1322   }
1323 
1324   /// Return the value ranges allocated by this list.
1325   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1326     return allocatedValueRanges;
1327   }
1328 };
1329 } // namespace
1330 
1331 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1332   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1333   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1334   SmallVector<PDLValue, 16> args;
1335   readList<PDLValue>(args);
1336 
1337   LLVM_DEBUG({
1338     llvm::dbgs() << "  * Arguments: ";
1339     llvm::interleaveComma(args, llvm::dbgs());
1340   });
1341 
1342   // Invoke the constraint and jump to the proper destination.
1343   selectJump(succeeded(constraintFn(args, rewriter)));
1344 }
1345 
1346 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1347   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1348   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1349   SmallVector<PDLValue, 16> args;
1350   readList<PDLValue>(args);
1351 
1352   LLVM_DEBUG({
1353     llvm::dbgs() << "  * Arguments: ";
1354     llvm::interleaveComma(args, llvm::dbgs());
1355   });
1356 
1357   // Execute the rewrite function.
1358   ByteCodeField numResults = read();
1359   ByteCodeRewriteResultList results(numResults);
1360   rewriteFn(args, rewriter, results);
1361 
1362   assert(results.getResults().size() == numResults &&
1363          "native PDL rewrite function returned unexpected number of results");
1364 
1365   // Store the results in the bytecode memory.
1366   for (PDLValue &result : results.getResults()) {
1367     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1368 
1369 // In debug mode we also verify the expected kind of the result.
1370 #ifndef NDEBUG
1371     assert(result.getKind() == read<PDLValue::Kind>() &&
1372            "native PDL rewrite function returned an unexpected type of result");
1373 #endif
1374 
1375     // If the result is a range, we need to copy it over to the bytecodes
1376     // range memory.
1377     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1378       unsigned rangeIndex = read();
1379       typeRangeMemory[rangeIndex] = *typeRange;
1380       memory[read()] = &typeRangeMemory[rangeIndex];
1381     } else if (Optional<ValueRange> valueRange =
1382                    result.dyn_cast<ValueRange>()) {
1383       unsigned rangeIndex = read();
1384       valueRangeMemory[rangeIndex] = *valueRange;
1385       memory[read()] = &valueRangeMemory[rangeIndex];
1386     } else {
1387       memory[read()] = result.getAsOpaquePointer();
1388     }
1389   }
1390 
1391   // Copy over any underlying storage allocated for result ranges.
1392   for (auto &it : results.getAllocatedTypeRanges())
1393     allocatedTypeRangeMemory.push_back(std::move(it));
1394   for (auto &it : results.getAllocatedValueRanges())
1395     allocatedValueRangeMemory.push_back(std::move(it));
1396 }
1397 
1398 void ByteCodeExecutor::executeAreEqual() {
1399   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1400   const void *lhs = read<const void *>();
1401   const void *rhs = read<const void *>();
1402 
1403   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1404   selectJump(lhs == rhs);
1405 }
1406 
1407 void ByteCodeExecutor::executeAreRangesEqual() {
1408   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1409   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1410   const void *lhs = read<const void *>();
1411   const void *rhs = read<const void *>();
1412 
1413   switch (valueKind) {
1414   case PDLValue::Kind::TypeRange: {
1415     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1416     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1417     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1418     selectJump(*lhsRange == *rhsRange);
1419     break;
1420   }
1421   case PDLValue::Kind::ValueRange: {
1422     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1423     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1424     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1425     selectJump(*lhsRange == *rhsRange);
1426     break;
1427   }
1428   default:
1429     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1430   }
1431 }
1432 
1433 void ByteCodeExecutor::executeBranch() {
1434   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1435   curCodeIt = &code[read<ByteCodeAddr>()];
1436 }
1437 
1438 void ByteCodeExecutor::executeCheckOperandCount() {
1439   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1440   Operation *op = read<Operation *>();
1441   uint32_t expectedCount = read<uint32_t>();
1442   bool compareAtLeast = read();
1443 
1444   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1445                           << "  * Expected: " << expectedCount << "\n"
1446                           << "  * Comparator: "
1447                           << (compareAtLeast ? ">=" : "==") << "\n");
1448   if (compareAtLeast)
1449     selectJump(op->getNumOperands() >= expectedCount);
1450   else
1451     selectJump(op->getNumOperands() == expectedCount);
1452 }
1453 
1454 void ByteCodeExecutor::executeCheckOperationName() {
1455   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1456   Operation *op = read<Operation *>();
1457   OperationName expectedName = read<OperationName>();
1458 
1459   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1460                           << "  * Expected: \"" << expectedName << "\"\n");
1461   selectJump(op->getName() == expectedName);
1462 }
1463 
1464 void ByteCodeExecutor::executeCheckResultCount() {
1465   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1466   Operation *op = read<Operation *>();
1467   uint32_t expectedCount = read<uint32_t>();
1468   bool compareAtLeast = read();
1469 
1470   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1471                           << "  * Expected: " << expectedCount << "\n"
1472                           << "  * Comparator: "
1473                           << (compareAtLeast ? ">=" : "==") << "\n");
1474   if (compareAtLeast)
1475     selectJump(op->getNumResults() >= expectedCount);
1476   else
1477     selectJump(op->getNumResults() == expectedCount);
1478 }
1479 
1480 void ByteCodeExecutor::executeCheckTypes() {
1481   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1482   TypeRange *lhs = read<TypeRange *>();
1483   Attribute rhs = read<Attribute>();
1484   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1485 
1486   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1487 }
1488 
1489 void ByteCodeExecutor::executeContinue() {
1490   ByteCodeField level = read();
1491   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1492                           << "  * Level: " << level << "\n");
1493   ++loopIndex[level];
1494   popCodeIt();
1495 }
1496 
1497 void ByteCodeExecutor::executeCreateTypes() {
1498   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1499   unsigned memIndex = read();
1500   unsigned rangeIndex = read();
1501   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1502 
1503   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1504 
1505   // Allocate a buffer for this type range.
1506   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1507   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1508   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1509 
1510   // Assign this to the range slot and use the range as the value for the
1511   // memory index.
1512   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1513   memory[memIndex] = &typeRangeMemory[rangeIndex];
1514 }
1515 
1516 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1517                                               Location mainRewriteLoc) {
1518   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1519 
1520   unsigned memIndex = read();
1521   OperationState state(mainRewriteLoc, read<OperationName>());
1522   readValueList(state.operands);
1523   for (unsigned i = 0, e = read(); i != e; ++i) {
1524     StringAttr name = read<StringAttr>();
1525     if (Attribute attr = read<Attribute>())
1526       state.addAttribute(name, attr);
1527   }
1528 
1529   for (unsigned i = 0, e = read(); i != e; ++i) {
1530     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1531       state.types.push_back(read<Type>());
1532       continue;
1533     }
1534 
1535     // If we find a null range, this signals that the types are infered.
1536     if (TypeRange *resultTypes = read<TypeRange *>()) {
1537       state.types.append(resultTypes->begin(), resultTypes->end());
1538       continue;
1539     }
1540 
1541     // Handle the case where the operation has inferred types.
1542     InferTypeOpInterface::Concept *inferInterface =
1543         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1544 
1545     // TODO: Handle failure.
1546     state.types.clear();
1547     if (failed(inferInterface->inferReturnTypes(
1548             state.getContext(), state.location, state.operands,
1549             state.attributes.getDictionary(state.getContext()), state.regions,
1550             state.types)))
1551       return;
1552     break;
1553   }
1554 
1555   Operation *resultOp = rewriter.createOperation(state);
1556   memory[memIndex] = resultOp;
1557 
1558   LLVM_DEBUG({
1559     llvm::dbgs() << "  * Attributes: "
1560                  << state.attributes.getDictionary(state.getContext())
1561                  << "\n  * Operands: ";
1562     llvm::interleaveComma(state.operands, llvm::dbgs());
1563     llvm::dbgs() << "\n  * Result Types: ";
1564     llvm::interleaveComma(state.types, llvm::dbgs());
1565     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1566   });
1567 }
1568 
1569 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1570   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1571   Operation *op = read<Operation *>();
1572 
1573   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1574   rewriter.eraseOp(op);
1575 }
1576 
1577 template <typename T, typename Range, PDLValue::Kind kind>
1578 void ByteCodeExecutor::executeExtract() {
1579   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1580   Range *range = read<Range *>();
1581   unsigned index = read<uint32_t>();
1582   unsigned memIndex = read();
1583 
1584   if (!range) {
1585     memory[memIndex] = nullptr;
1586     return;
1587   }
1588 
1589   T result = index < range->size() ? (*range)[index] : T();
1590   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1591                           << "  * Index: " << index << "\n"
1592                           << "  * Result: " << result << "\n");
1593   storeToMemory(memIndex, result);
1594 }
1595 
1596 void ByteCodeExecutor::executeFinalize() {
1597   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1598 }
1599 
1600 void ByteCodeExecutor::executeForEach() {
1601   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1602   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1603   unsigned rangeIndex = read();
1604   unsigned memIndex = read();
1605   const void *value = nullptr;
1606 
1607   switch (read<PDLValue::Kind>()) {
1608   case PDLValue::Kind::Operation: {
1609     unsigned &index = loopIndex[read()];
1610     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1611     assert(index <= array.size() && "iterated past the end");
1612     if (index < array.size()) {
1613       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1614       value = array[index];
1615       break;
1616     }
1617 
1618     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1619     index = 0;
1620     selectJump(size_t(0));
1621     return;
1622   }
1623   default:
1624     llvm_unreachable("unexpected `ForEach` value kind");
1625   }
1626 
1627   // Store the iterate value and the stack address.
1628   memory[memIndex] = value;
1629   pushCodeIt(prevCodeIt);
1630 
1631   // Skip over the successor (we will enter the body of the loop).
1632   read<ByteCodeAddr>();
1633 }
1634 
1635 void ByteCodeExecutor::executeGetAttribute() {
1636   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1637   unsigned memIndex = read();
1638   Operation *op = read<Operation *>();
1639   StringAttr attrName = read<StringAttr>();
1640   Attribute attr = op->getAttr(attrName);
1641 
1642   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1643                           << "  * Attribute: " << attrName << "\n"
1644                           << "  * Result: " << attr << "\n");
1645   memory[memIndex] = attr.getAsOpaquePointer();
1646 }
1647 
1648 void ByteCodeExecutor::executeGetAttributeType() {
1649   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1650   unsigned memIndex = read();
1651   Attribute attr = read<Attribute>();
1652   Type type = attr ? attr.getType() : Type();
1653 
1654   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1655                           << "  * Result: " << type << "\n");
1656   memory[memIndex] = type.getAsOpaquePointer();
1657 }
1658 
1659 void ByteCodeExecutor::executeGetDefiningOp() {
1660   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1661   unsigned memIndex = read();
1662   Operation *op = nullptr;
1663   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1664     Value value = read<Value>();
1665     if (value)
1666       op = value.getDefiningOp();
1667     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1668   } else {
1669     ValueRange *values = read<ValueRange *>();
1670     if (values && !values->empty()) {
1671       op = values->front().getDefiningOp();
1672     }
1673     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1674   }
1675 
1676   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1677   memory[memIndex] = op;
1678 }
1679 
1680 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1681   Operation *op = read<Operation *>();
1682   unsigned memIndex = read();
1683   Value operand =
1684       index < op->getNumOperands() ? op->getOperand(index) : Value();
1685 
1686   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1687                           << "  * Index: " << index << "\n"
1688                           << "  * Result: " << operand << "\n");
1689   memory[memIndex] = operand.getAsOpaquePointer();
1690 }
1691 
1692 /// This function is the internal implementation of `GetResults` and
1693 /// `GetOperands` that provides support for extracting a value range from the
1694 /// given operation.
1695 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1696 static void *
1697 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1698                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1699                           MutableArrayRef<ValueRange> valueRangeMemory) {
1700   // Check for the sentinel index that signals that all values should be
1701   // returned.
1702   if (index == std::numeric_limits<uint32_t>::max()) {
1703     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1704     // `values` is already the full value range.
1705 
1706     // Otherwise, check to see if this operation uses AttrSizedSegments.
1707   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1708     LLVM_DEBUG(llvm::dbgs()
1709                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1710 
1711     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1712     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1713       return nullptr;
1714 
1715     auto segments = segmentAttr.getValues<int32_t>();
1716     unsigned startIndex =
1717         std::accumulate(segments.begin(), segments.begin() + index, 0);
1718     values = values.slice(startIndex, *std::next(segments.begin(), index));
1719 
1720     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1721                             << *std::next(segments.begin(), index) << "]\n");
1722 
1723     // Otherwise, assume this is the last operand group of the operation.
1724     // FIXME: We currently don't support operations with
1725     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1726     // have a way to detect it's presence.
1727   } else if (values.size() >= index) {
1728     LLVM_DEBUG(llvm::dbgs()
1729                << "  * Treating values as trailing variadic range\n");
1730     values = values.drop_front(index);
1731 
1732     // If we couldn't detect a way to compute the values, bail out.
1733   } else {
1734     return nullptr;
1735   }
1736 
1737   // If the range index is valid, we are returning a range.
1738   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1739     valueRangeMemory[rangeIndex] = values;
1740     return &valueRangeMemory[rangeIndex];
1741   }
1742 
1743   // If a range index wasn't provided, the range is required to be non-variadic.
1744   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1745 }
1746 
1747 void ByteCodeExecutor::executeGetOperands() {
1748   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1749   unsigned index = read<uint32_t>();
1750   Operation *op = read<Operation *>();
1751   ByteCodeField rangeIndex = read();
1752 
1753   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1754       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1755       valueRangeMemory);
1756   if (!result)
1757     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1758   memory[read()] = result;
1759 }
1760 
1761 void ByteCodeExecutor::executeGetResult(unsigned index) {
1762   Operation *op = read<Operation *>();
1763   unsigned memIndex = read();
1764   OpResult result =
1765       index < op->getNumResults() ? op->getResult(index) : OpResult();
1766 
1767   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1768                           << "  * Index: " << index << "\n"
1769                           << "  * Result: " << result << "\n");
1770   memory[memIndex] = result.getAsOpaquePointer();
1771 }
1772 
1773 void ByteCodeExecutor::executeGetResults() {
1774   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1775   unsigned index = read<uint32_t>();
1776   Operation *op = read<Operation *>();
1777   ByteCodeField rangeIndex = read();
1778 
1779   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1780       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1781       valueRangeMemory);
1782   if (!result)
1783     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1784   memory[read()] = result;
1785 }
1786 
1787 void ByteCodeExecutor::executeGetUsers() {
1788   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1789   unsigned memIndex = read();
1790   unsigned rangeIndex = read();
1791   OwningOpRange &range = opRangeMemory[rangeIndex];
1792   memory[memIndex] = &range;
1793 
1794   range = OwningOpRange();
1795   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1796     // Read the value.
1797     Value value = read<Value>();
1798     if (!value)
1799       return;
1800     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1801 
1802     // Extract the users of a single value.
1803     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1804     llvm::copy(value.getUsers(), range.begin());
1805   } else {
1806     // Read a range of values.
1807     ValueRange *values = read<ValueRange *>();
1808     if (!values)
1809       return;
1810     LLVM_DEBUG({
1811       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1812       llvm::interleaveComma(*values, llvm::dbgs());
1813       llvm::dbgs() << "\n";
1814     });
1815 
1816     // Extract all the users of a range of values.
1817     SmallVector<Operation *> users;
1818     for (Value value : *values)
1819       users.append(value.user_begin(), value.user_end());
1820     range = OwningOpRange(users.size());
1821     llvm::copy(users, range.begin());
1822   }
1823 
1824   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1825 }
1826 
1827 void ByteCodeExecutor::executeGetValueType() {
1828   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1829   unsigned memIndex = read();
1830   Value value = read<Value>();
1831   Type type = value ? value.getType() : Type();
1832 
1833   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1834                           << "  * Result: " << type << "\n");
1835   memory[memIndex] = type.getAsOpaquePointer();
1836 }
1837 
1838 void ByteCodeExecutor::executeGetValueRangeTypes() {
1839   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1840   unsigned memIndex = read();
1841   unsigned rangeIndex = read();
1842   ValueRange *values = read<ValueRange *>();
1843   if (!values) {
1844     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1845     memory[memIndex] = nullptr;
1846     return;
1847   }
1848 
1849   LLVM_DEBUG({
1850     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1851     llvm::interleaveComma(*values, llvm::dbgs());
1852     llvm::dbgs() << "\n  * Result: ";
1853     llvm::interleaveComma(values->getType(), llvm::dbgs());
1854     llvm::dbgs() << "\n";
1855   });
1856   typeRangeMemory[rangeIndex] = values->getType();
1857   memory[memIndex] = &typeRangeMemory[rangeIndex];
1858 }
1859 
1860 void ByteCodeExecutor::executeIsNotNull() {
1861   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1862   const void *value = read<const void *>();
1863 
1864   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1865   selectJump(value != nullptr);
1866 }
1867 
1868 void ByteCodeExecutor::executeRecordMatch(
1869     PatternRewriter &rewriter,
1870     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1871   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1872   unsigned patternIndex = read();
1873   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1874   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1875 
1876   // If the benefit of the pattern is impossible, skip the processing of the
1877   // rest of the pattern.
1878   if (benefit.isImpossibleToMatch()) {
1879     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1880     curCodeIt = dest;
1881     return;
1882   }
1883 
1884   // Create a fused location containing the locations of each of the
1885   // operations used in the match. This will be used as the location for
1886   // created operations during the rewrite that don't already have an
1887   // explicit location set.
1888   unsigned numMatchLocs = read();
1889   SmallVector<Location, 4> matchLocs;
1890   matchLocs.reserve(numMatchLocs);
1891   for (unsigned i = 0; i != numMatchLocs; ++i)
1892     matchLocs.push_back(read<Operation *>()->getLoc());
1893   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1894 
1895   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1896                           << "  * Location: " << matchLoc << "\n");
1897   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1898   PDLByteCode::MatchResult &match = matches.back();
1899 
1900   // Record all of the inputs to the match. If any of the inputs are ranges, we
1901   // will also need to remap the range pointer to memory stored in the match
1902   // state.
1903   unsigned numInputs = read();
1904   match.values.reserve(numInputs);
1905   match.typeRangeValues.reserve(numInputs);
1906   match.valueRangeValues.reserve(numInputs);
1907   for (unsigned i = 0; i < numInputs; ++i) {
1908     switch (read<PDLValue::Kind>()) {
1909     case PDLValue::Kind::TypeRange:
1910       match.typeRangeValues.push_back(*read<TypeRange *>());
1911       match.values.push_back(&match.typeRangeValues.back());
1912       break;
1913     case PDLValue::Kind::ValueRange:
1914       match.valueRangeValues.push_back(*read<ValueRange *>());
1915       match.values.push_back(&match.valueRangeValues.back());
1916       break;
1917     default:
1918       match.values.push_back(read<const void *>());
1919       break;
1920     }
1921   }
1922   curCodeIt = dest;
1923 }
1924 
1925 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1926   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1927   Operation *op = read<Operation *>();
1928   SmallVector<Value, 16> args;
1929   readValueList(args);
1930 
1931   LLVM_DEBUG({
1932     llvm::dbgs() << "  * Operation: " << *op << "\n"
1933                  << "  * Values: ";
1934     llvm::interleaveComma(args, llvm::dbgs());
1935     llvm::dbgs() << "\n";
1936   });
1937   rewriter.replaceOp(op, args);
1938 }
1939 
1940 void ByteCodeExecutor::executeSwitchAttribute() {
1941   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1942   Attribute value = read<Attribute>();
1943   ArrayAttr cases = read<ArrayAttr>();
1944   handleSwitch(value, cases);
1945 }
1946 
1947 void ByteCodeExecutor::executeSwitchOperandCount() {
1948   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1949   Operation *op = read<Operation *>();
1950   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1951 
1952   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1953   handleSwitch(op->getNumOperands(), cases);
1954 }
1955 
1956 void ByteCodeExecutor::executeSwitchOperationName() {
1957   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1958   OperationName value = read<Operation *>()->getName();
1959   size_t caseCount = read();
1960 
1961   // The operation names are stored in-line, so to print them out for
1962   // debugging purposes we need to read the array before executing the
1963   // switch so that we can display all of the possible values.
1964   LLVM_DEBUG({
1965     const ByteCodeField *prevCodeIt = curCodeIt;
1966     llvm::dbgs() << "  * Value: " << value << "\n"
1967                  << "  * Cases: ";
1968     llvm::interleaveComma(
1969         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1970                         [&](size_t) { return read<OperationName>(); }),
1971         llvm::dbgs());
1972     llvm::dbgs() << "\n";
1973     curCodeIt = prevCodeIt;
1974   });
1975 
1976   // Try to find the switch value within any of the cases.
1977   for (size_t i = 0; i != caseCount; ++i) {
1978     if (read<OperationName>() == value) {
1979       curCodeIt += (caseCount - i - 1);
1980       return selectJump(i + 1);
1981     }
1982   }
1983   selectJump(size_t(0));
1984 }
1985 
1986 void ByteCodeExecutor::executeSwitchResultCount() {
1987   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1988   Operation *op = read<Operation *>();
1989   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1990 
1991   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1992   handleSwitch(op->getNumResults(), cases);
1993 }
1994 
1995 void ByteCodeExecutor::executeSwitchType() {
1996   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1997   Type value = read<Type>();
1998   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1999   handleSwitch(value, cases);
2000 }
2001 
2002 void ByteCodeExecutor::executeSwitchTypes() {
2003   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2004   TypeRange *value = read<TypeRange *>();
2005   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2006   if (!value) {
2007     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2008     return selectJump(size_t(0));
2009   }
2010   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2011     return value == caseValue.getAsValueRange<TypeAttr>();
2012   });
2013 }
2014 
2015 void ByteCodeExecutor::execute(
2016     PatternRewriter &rewriter,
2017     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2018     Optional<Location> mainRewriteLoc) {
2019   while (true) {
2020     // Print the location of the operation being executed.
2021     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2022 
2023     OpCode opCode = static_cast<OpCode>(read());
2024     switch (opCode) {
2025     case ApplyConstraint:
2026       executeApplyConstraint(rewriter);
2027       break;
2028     case ApplyRewrite:
2029       executeApplyRewrite(rewriter);
2030       break;
2031     case AreEqual:
2032       executeAreEqual();
2033       break;
2034     case AreRangesEqual:
2035       executeAreRangesEqual();
2036       break;
2037     case Branch:
2038       executeBranch();
2039       break;
2040     case CheckOperandCount:
2041       executeCheckOperandCount();
2042       break;
2043     case CheckOperationName:
2044       executeCheckOperationName();
2045       break;
2046     case CheckResultCount:
2047       executeCheckResultCount();
2048       break;
2049     case CheckTypes:
2050       executeCheckTypes();
2051       break;
2052     case Continue:
2053       executeContinue();
2054       break;
2055     case CreateOperation:
2056       executeCreateOperation(rewriter, *mainRewriteLoc);
2057       break;
2058     case CreateTypes:
2059       executeCreateTypes();
2060       break;
2061     case EraseOp:
2062       executeEraseOp(rewriter);
2063       break;
2064     case ExtractOp:
2065       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2066       break;
2067     case ExtractType:
2068       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2069       break;
2070     case ExtractValue:
2071       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2072       break;
2073     case Finalize:
2074       executeFinalize();
2075       LLVM_DEBUG(llvm::dbgs() << "\n");
2076       return;
2077     case ForEach:
2078       executeForEach();
2079       break;
2080     case GetAttribute:
2081       executeGetAttribute();
2082       break;
2083     case GetAttributeType:
2084       executeGetAttributeType();
2085       break;
2086     case GetDefiningOp:
2087       executeGetDefiningOp();
2088       break;
2089     case GetOperand0:
2090     case GetOperand1:
2091     case GetOperand2:
2092     case GetOperand3: {
2093       unsigned index = opCode - GetOperand0;
2094       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2095       executeGetOperand(index);
2096       break;
2097     }
2098     case GetOperandN:
2099       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2100       executeGetOperand(read<uint32_t>());
2101       break;
2102     case GetOperands:
2103       executeGetOperands();
2104       break;
2105     case GetResult0:
2106     case GetResult1:
2107     case GetResult2:
2108     case GetResult3: {
2109       unsigned index = opCode - GetResult0;
2110       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2111       executeGetResult(index);
2112       break;
2113     }
2114     case GetResultN:
2115       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2116       executeGetResult(read<uint32_t>());
2117       break;
2118     case GetResults:
2119       executeGetResults();
2120       break;
2121     case GetUsers:
2122       executeGetUsers();
2123       break;
2124     case GetValueType:
2125       executeGetValueType();
2126       break;
2127     case GetValueRangeTypes:
2128       executeGetValueRangeTypes();
2129       break;
2130     case IsNotNull:
2131       executeIsNotNull();
2132       break;
2133     case RecordMatch:
2134       assert(matches &&
2135              "expected matches to be provided when executing the matcher");
2136       executeRecordMatch(rewriter, *matches);
2137       break;
2138     case ReplaceOp:
2139       executeReplaceOp(rewriter);
2140       break;
2141     case SwitchAttribute:
2142       executeSwitchAttribute();
2143       break;
2144     case SwitchOperandCount:
2145       executeSwitchOperandCount();
2146       break;
2147     case SwitchOperationName:
2148       executeSwitchOperationName();
2149       break;
2150     case SwitchResultCount:
2151       executeSwitchResultCount();
2152       break;
2153     case SwitchType:
2154       executeSwitchType();
2155       break;
2156     case SwitchTypes:
2157       executeSwitchTypes();
2158       break;
2159     }
2160     LLVM_DEBUG(llvm::dbgs() << "\n");
2161   }
2162 }
2163 
2164 /// Run the pattern matcher on the given root operation, collecting the matched
2165 /// patterns in `matches`.
2166 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2167                         SmallVectorImpl<MatchResult> &matches,
2168                         PDLByteCodeMutableState &state) const {
2169   // The first memory slot is always the root operation.
2170   state.memory[0] = op;
2171 
2172   // The matcher function always starts at code address 0.
2173   ByteCodeExecutor executor(
2174       matcherByteCode.data(), state.memory, state.opRangeMemory,
2175       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2176       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2177       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2178       constraintFunctions, rewriteFunctions);
2179   executor.execute(rewriter, &matches);
2180 
2181   // Order the found matches by benefit.
2182   std::stable_sort(matches.begin(), matches.end(),
2183                    [](const MatchResult &lhs, const MatchResult &rhs) {
2184                      return lhs.benefit > rhs.benefit;
2185                    });
2186 }
2187 
2188 /// Run the rewriter of the given pattern on the root operation `op`.
2189 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2190                           PDLByteCodeMutableState &state) const {
2191   // The arguments of the rewrite function are stored at the start of the
2192   // memory buffer.
2193   llvm::copy(match.values, state.memory.begin());
2194 
2195   ByteCodeExecutor executor(
2196       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2197       state.opRangeMemory, state.typeRangeMemory,
2198       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2199       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2200       rewriterByteCode, state.currentPatternBenefits, patterns,
2201       constraintFunctions, rewriteFunctions);
2202   executor.execute(rewriter, /*matches=*/nullptr, match.location);
2203 }
2204