1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.benefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.rootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Continue to the next iteration of a loop.
99   Continue,
100   /// Create an operation.
101   CreateOperation,
102   /// Create a range of types.
103   CreateTypes,
104   /// Erase an operation.
105   EraseOp,
106   /// Extract the op from a range at the specified index.
107   ExtractOp,
108   /// Extract the type from a range at the specified index.
109   ExtractType,
110   /// Extract the value from a range at the specified index.
111   ExtractValue,
112   /// Terminate a matcher or rewrite sequence.
113   Finalize,
114   /// Iterate over a range of values.
115   ForEach,
116   /// Get a specific attribute of an operation.
117   GetAttribute,
118   /// Get the type of an attribute.
119   GetAttributeType,
120   /// Get the defining operation of a value.
121   GetDefiningOp,
122   /// Get a specific operand of an operation.
123   GetOperand0,
124   GetOperand1,
125   GetOperand2,
126   GetOperand3,
127   GetOperandN,
128   /// Get a specific operand group of an operation.
129   GetOperands,
130   /// Get a specific result of an operation.
131   GetResult0,
132   GetResult1,
133   GetResult2,
134   GetResult3,
135   GetResultN,
136   /// Get a specific result group of an operation.
137   GetResults,
138   /// Get the users of a value or a range of values.
139   GetUsers,
140   /// Get the type of a value.
141   GetValueType,
142   /// Get the types of a value range.
143   GetValueRangeTypes,
144   /// Check if a generic value is not null.
145   IsNotNull,
146   /// Record a successful pattern match.
147   RecordMatch,
148   /// Replace an operation.
149   ReplaceOp,
150   /// Compare an attribute with a set of constants.
151   SwitchAttribute,
152   /// Compare the operand count of an operation with a set of constants.
153   SwitchOperandCount,
154   /// Compare the name of an operation with a set of constants.
155   SwitchOperationName,
156   /// Compare the result count of an operation with a set of constants.
157   SwitchResultCount,
158   /// Compare a type with a set of constants.
159   SwitchType,
160   /// Compare a range of types with a set of constants.
161   SwitchTypes,
162 };
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // ByteCode Generation
167 //===----------------------------------------------------------------------===//
168 
169 //===----------------------------------------------------------------------===//
170 // Generator
171 
172 namespace {
173 struct ByteCodeLiveRange;
174 struct ByteCodeWriter;
175 
176 /// Check if the given class `T` can be converted to an opaque pointer.
177 template <typename T, typename... Args>
178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
179 
180 /// This class represents the main generator for the pattern bytecode.
181 class Generator {
182 public:
183   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
184             SmallVectorImpl<ByteCodeField> &matcherByteCode,
185             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
186             SmallVectorImpl<PDLByteCodePattern> &patterns,
187             ByteCodeField &maxValueMemoryIndex,
188             ByteCodeField &maxOpRangeMemoryIndex,
189             ByteCodeField &maxTypeRangeMemoryIndex,
190             ByteCodeField &maxValueRangeMemoryIndex,
191             ByteCodeField &maxLoopLevel,
192             llvm::StringMap<PDLConstraintFunction> &constraintFns,
193             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
194       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
195         rewriterByteCode(rewriterByteCode), patterns(patterns),
196         maxValueMemoryIndex(maxValueMemoryIndex),
197         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
198         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
199         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
200         maxLoopLevel(maxLoopLevel) {
201     for (auto it : llvm::enumerate(constraintFns))
202       constraintToMemIndex.try_emplace(it.value().first(), it.index());
203     for (auto it : llvm::enumerate(rewriteFns))
204       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
205   }
206 
207   /// Generate the bytecode for the given PDL interpreter module.
208   void generate(ModuleOp module);
209 
210   /// Return the memory index to use for the given value.
211   ByteCodeField &getMemIndex(Value value) {
212     assert(valueToMemIndex.count(value) &&
213            "expected memory index to be assigned");
214     return valueToMemIndex[value];
215   }
216 
217   /// Return the range memory index used to store the given range value.
218   ByteCodeField &getRangeStorageIndex(Value value) {
219     assert(valueToRangeIndex.count(value) &&
220            "expected range index to be assigned");
221     return valueToRangeIndex[value];
222   }
223 
224   /// Return an index to use when referring to the given data that is uniqued in
225   /// the MLIR context.
226   template <typename T>
227   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
228   getMemIndex(T val) {
229     const void *opaqueVal = val.getAsOpaquePointer();
230 
231     // Get or insert a reference to this value.
232     auto it = uniquedDataToMemIndex.try_emplace(
233         opaqueVal, maxValueMemoryIndex + uniquedData.size());
234     if (it.second)
235       uniquedData.push_back(opaqueVal);
236     return it.first->second;
237   }
238 
239 private:
240   /// Allocate memory indices for the results of operations within the matcher
241   /// and rewriters.
242   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
243 
244   /// Generate the bytecode for the given operation.
245   void generate(Region *region, ByteCodeWriter &writer);
246   void generate(Operation *op, ByteCodeWriter &writer);
247   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
248   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
249   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
250   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
251   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
252   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
259   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
285 
286   /// Mapping from value to its corresponding memory index.
287   DenseMap<Value, ByteCodeField> valueToMemIndex;
288 
289   /// Mapping from a range value to its corresponding range storage index.
290   DenseMap<Value, ByteCodeField> valueToRangeIndex;
291 
292   /// Mapping from the name of an externally registered rewrite to its index in
293   /// the bytecode registry.
294   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
295 
296   /// Mapping from the name of an externally registered constraint to its index
297   /// in the bytecode registry.
298   llvm::StringMap<ByteCodeField> constraintToMemIndex;
299 
300   /// Mapping from rewriter function name to the bytecode address of the
301   /// rewriter function in byte.
302   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
303 
304   /// Mapping from a uniqued storage object to its memory index within
305   /// `uniquedData`.
306   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
307 
308   /// The current level of the foreach loop.
309   ByteCodeField curLoopLevel = 0;
310 
311   /// The current MLIR context.
312   MLIRContext *ctx;
313 
314   /// Mapping from block to its address.
315   DenseMap<Block *, ByteCodeAddr> blockToAddr;
316 
317   /// Data of the ByteCode class to be populated.
318   std::vector<const void *> &uniquedData;
319   SmallVectorImpl<ByteCodeField> &matcherByteCode;
320   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
321   SmallVectorImpl<PDLByteCodePattern> &patterns;
322   ByteCodeField &maxValueMemoryIndex;
323   ByteCodeField &maxOpRangeMemoryIndex;
324   ByteCodeField &maxTypeRangeMemoryIndex;
325   ByteCodeField &maxValueRangeMemoryIndex;
326   ByteCodeField &maxLoopLevel;
327 };
328 
329 /// This class provides utilities for writing a bytecode stream.
330 struct ByteCodeWriter {
331   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
332       : bytecode(bytecode), generator(generator) {}
333 
334   /// Append a field to the bytecode.
335   void append(ByteCodeField field) { bytecode.push_back(field); }
336   void append(OpCode opCode) { bytecode.push_back(opCode); }
337 
338   /// Append an address to the bytecode.
339   void append(ByteCodeAddr field) {
340     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
341                   "unexpected ByteCode address size");
342 
343     ByteCodeField fieldParts[2];
344     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
345     bytecode.append({fieldParts[0], fieldParts[1]});
346   }
347 
348   /// Append a single successor to the bytecode, the exact address will need to
349   /// be resolved later.
350   void append(Block *successor) {
351     // Add back a reference to the successor so that the address can be resolved
352     // later.
353     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
354     append(ByteCodeAddr(0));
355   }
356 
357   /// Append a successor range to the bytecode, the exact address will need to
358   /// be resolved later.
359   void append(SuccessorRange successors) {
360     for (Block *successor : successors)
361       append(successor);
362   }
363 
364   /// Append a range of values that will be read as generic PDLValues.
365   void appendPDLValueList(OperandRange values) {
366     bytecode.push_back(values.size());
367     for (Value value : values)
368       appendPDLValue(value);
369   }
370 
371   /// Append a value as a PDLValue.
372   void appendPDLValue(Value value) {
373     appendPDLValueKind(value);
374     append(value);
375   }
376 
377   /// Append the PDLValue::Kind of the given value.
378   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
379 
380   /// Append the PDLValue::Kind of the given type.
381   void appendPDLValueKind(Type type) {
382     PDLValue::Kind kind =
383         TypeSwitch<Type, PDLValue::Kind>(type)
384             .Case<pdl::AttributeType>(
385                 [](Type) { return PDLValue::Kind::Attribute; })
386             .Case<pdl::OperationType>(
387                 [](Type) { return PDLValue::Kind::Operation; })
388             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
389               if (rangeTy.getElementType().isa<pdl::TypeType>())
390                 return PDLValue::Kind::TypeRange;
391               return PDLValue::Kind::ValueRange;
392             })
393             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
394             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
395     bytecode.push_back(static_cast<ByteCodeField>(kind));
396   }
397 
398   /// Append a value that will be stored in a memory slot and not inline within
399   /// the bytecode.
400   template <typename T>
401   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
402                    std::is_pointer<T>::value>
403   append(T value) {
404     bytecode.push_back(generator.getMemIndex(value));
405   }
406 
407   /// Append a range of values.
408   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
409   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
410   append(T range) {
411     bytecode.push_back(llvm::size(range));
412     for (auto it : range)
413       append(it);
414   }
415 
416   /// Append a variadic number of fields to the bytecode.
417   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
418   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
419     append(field);
420     append(field2, fields...);
421   }
422 
423   /// Appends a value as a pointer, stored inline within the bytecode.
424   template <typename T>
425   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
426   appendInline(T value) {
427     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
428     const void *pointer = value.getAsOpaquePointer();
429     ByteCodeField fieldParts[numParts];
430     std::memcpy(fieldParts, &pointer, sizeof(const void *));
431     bytecode.append(fieldParts, fieldParts + numParts);
432   }
433 
434   /// Successor references in the bytecode that have yet to be resolved.
435   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
436 
437   /// The underlying bytecode buffer.
438   SmallVectorImpl<ByteCodeField> &bytecode;
439 
440   /// The main generator producing PDL.
441   Generator &generator;
442 };
443 
444 /// This class represents a live range of PDL Interpreter values, containing
445 /// information about when values are live within a match/rewrite.
446 struct ByteCodeLiveRange {
447   using Set = llvm::IntervalMap<uint64_t, char, 16>;
448   using Allocator = Set::Allocator;
449 
450   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
451 
452   /// Union this live range with the one provided.
453   void unionWith(const ByteCodeLiveRange &rhs) {
454     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
455          ++it)
456       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
457   }
458 
459   /// Returns true if this range overlaps with the one provided.
460   bool overlaps(const ByteCodeLiveRange &rhs) const {
461     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
462         .valid();
463   }
464 
465   /// A map representing the ranges of the match/rewrite that a value is live in
466   /// the interpreter.
467   ///
468   /// We use std::unique_ptr here, because IntervalMap does not provide a
469   /// correct copy or move constructor. We can eliminate the pointer once
470   /// https://reviews.llvm.org/D113240 lands.
471   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
472 
473   /// The operation range storage index for this range.
474   Optional<unsigned> opRangeIndex;
475 
476   /// The type range storage index for this range.
477   Optional<unsigned> typeRangeIndex;
478 
479   /// The value range storage index for this range.
480   Optional<unsigned> valueRangeIndex;
481 };
482 } // namespace
483 
484 void Generator::generate(ModuleOp module) {
485   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
486       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
487   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
488       pdl_interp::PDLInterpDialect::getRewriterModuleName());
489   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
490 
491   // Allocate memory indices for the results of operations within the matcher
492   // and rewriters.
493   allocateMemoryIndices(matcherFunc, rewriterModule);
494 
495   // Generate code for the rewriter functions.
496   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
497   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
498     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
499     for (Operation &op : rewriterFunc.getOps())
500       generate(&op, rewriterByteCodeWriter);
501   }
502   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
503          "unexpected branches in rewriter function");
504 
505   // Generate code for the matcher function.
506   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
507   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
508 
509   // Resolve successor references in the matcher.
510   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
511     ByteCodeAddr addr = blockToAddr[it.first];
512     for (unsigned offsetToFix : it.second)
513       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
514   }
515 }
516 
517 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
518                                       ModuleOp rewriterModule) {
519   // Rewriters use simplistic allocation scheme that simply assigns an index to
520   // each result.
521   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
522     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
523     auto processRewriterValue = [&](Value val) {
524       valueToMemIndex.try_emplace(val, index++);
525       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
526         Type elementTy = rangeType.getElementType();
527         if (elementTy.isa<pdl::TypeType>())
528           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
529         else if (elementTy.isa<pdl::ValueType>())
530           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
531       }
532     };
533 
534     for (BlockArgument arg : rewriterFunc.getArguments())
535       processRewriterValue(arg);
536     rewriterFunc.getBody().walk([&](Operation *op) {
537       for (Value result : op->getResults())
538         processRewriterValue(result);
539     });
540     if (index > maxValueMemoryIndex)
541       maxValueMemoryIndex = index;
542     if (typeRangeIndex > maxTypeRangeMemoryIndex)
543       maxTypeRangeMemoryIndex = typeRangeIndex;
544     if (valueRangeIndex > maxValueRangeMemoryIndex)
545       maxValueRangeMemoryIndex = valueRangeIndex;
546   }
547 
548   // The matcher function uses a more sophisticated numbering that tries to
549   // minimize the number of memory indices assigned. This is done by determining
550   // a live range of the values within the matcher, then the allocation is just
551   // finding the minimal number of overlapping live ranges. This is essentially
552   // a simplified form of register allocation where we don't necessarily have a
553   // limited number of registers, but we still want to minimize the number used.
554   DenseMap<Operation *, unsigned> opToIndex;
555   matcherFunc.getBody().walk([&](Operation *op) {
556     opToIndex.insert(std::make_pair(op, opToIndex.size()));
557   });
558 
559   // Liveness info for each of the defs within the matcher.
560   ByteCodeLiveRange::Allocator allocator;
561   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
562 
563   // Assign the root operation being matched to slot 0.
564   BlockArgument rootOpArg = matcherFunc.getArgument(0);
565   valueToMemIndex[rootOpArg] = 0;
566 
567   // Walk each of the blocks, computing the def interval that the value is used.
568   Liveness matcherLiveness(matcherFunc);
569   matcherFunc->walk([&](Block *block) {
570     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
571     assert(info && "expected liveness info for block");
572     auto processValue = [&](Value value, Operation *firstUseOrDef) {
573       // We don't need to process the root op argument, this value is always
574       // assigned to the first memory slot.
575       if (value == rootOpArg)
576         return;
577 
578       // Set indices for the range of this block that the value is used.
579       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
580       defRangeIt->second.liveness->insert(
581           opToIndex[firstUseOrDef],
582           opToIndex[info->getEndOperation(value, firstUseOrDef)],
583           /*dummyValue*/ 0);
584 
585       // Check to see if this value is a range type.
586       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
587         Type eleType = rangeTy.getElementType();
588         if (eleType.isa<pdl::OperationType>())
589           defRangeIt->second.opRangeIndex = 0;
590         else if (eleType.isa<pdl::TypeType>())
591           defRangeIt->second.typeRangeIndex = 0;
592         else if (eleType.isa<pdl::ValueType>())
593           defRangeIt->second.valueRangeIndex = 0;
594       }
595     };
596 
597     // Process the live-ins of this block.
598     for (Value liveIn : info->in()) {
599       // Only process the value if it has been defined in the current region.
600       // Other values that span across pdl_interp.foreach will be added higher
601       // up. This ensures that the we keep them alive for the entire duration
602       // of the loop.
603       if (liveIn.getParentRegion() == block->getParent())
604         processValue(liveIn, &block->front());
605     }
606 
607     // Process the block arguments for the entry block (those are not live-in).
608     if (block->isEntryBlock()) {
609       for (Value argument : block->getArguments())
610         processValue(argument, &block->front());
611     }
612 
613     // Process any new defs within this block.
614     for (Operation &op : *block)
615       for (Value result : op.getResults())
616         processValue(result, &op);
617   });
618 
619   // Greedily allocate memory slots using the computed def live ranges.
620   std::vector<ByteCodeLiveRange> allocatedIndices;
621 
622   // The number of memory indices currently allocated (and its next value).
623   // Recall that the root gets allocated memory index 0.
624   ByteCodeField numIndices = 1;
625 
626   // The number of memory ranges of various types (and their next values).
627   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
628 
629   for (auto &defIt : valueDefRanges) {
630     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
631     ByteCodeLiveRange &defRange = defIt.second;
632 
633     // Try to allocate to an existing index.
634     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
635       ByteCodeLiveRange &existingRange = existingIndexIt.value();
636       if (!defRange.overlaps(existingRange)) {
637         existingRange.unionWith(defRange);
638         memIndex = existingIndexIt.index() + 1;
639 
640         if (defRange.opRangeIndex) {
641           if (!existingRange.opRangeIndex)
642             existingRange.opRangeIndex = numOpRanges++;
643           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
644         } else if (defRange.typeRangeIndex) {
645           if (!existingRange.typeRangeIndex)
646             existingRange.typeRangeIndex = numTypeRanges++;
647           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
648         } else if (defRange.valueRangeIndex) {
649           if (!existingRange.valueRangeIndex)
650             existingRange.valueRangeIndex = numValueRanges++;
651           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
652         }
653         break;
654       }
655     }
656 
657     // If no existing index could be used, add a new one.
658     if (memIndex == 0) {
659       allocatedIndices.emplace_back(allocator);
660       ByteCodeLiveRange &newRange = allocatedIndices.back();
661       newRange.unionWith(defRange);
662 
663       // Allocate an index for op/type/value ranges.
664       if (defRange.opRangeIndex) {
665         newRange.opRangeIndex = numOpRanges;
666         valueToRangeIndex[defIt.first] = numOpRanges++;
667       } else if (defRange.typeRangeIndex) {
668         newRange.typeRangeIndex = numTypeRanges;
669         valueToRangeIndex[defIt.first] = numTypeRanges++;
670       } else if (defRange.valueRangeIndex) {
671         newRange.valueRangeIndex = numValueRanges;
672         valueToRangeIndex[defIt.first] = numValueRanges++;
673       }
674 
675       memIndex = allocatedIndices.size();
676       ++numIndices;
677     }
678   }
679 
680   // Print the index usage and ensure that we did not run out of index space.
681   LLVM_DEBUG({
682     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
683                  << "(down from initial " << valueDefRanges.size() << ").\n";
684   });
685   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
686          "Ran out of memory for allocated indices");
687 
688   // Update the max number of indices.
689   if (numIndices > maxValueMemoryIndex)
690     maxValueMemoryIndex = numIndices;
691   if (numOpRanges > maxOpRangeMemoryIndex)
692     maxOpRangeMemoryIndex = numOpRanges;
693   if (numTypeRanges > maxTypeRangeMemoryIndex)
694     maxTypeRangeMemoryIndex = numTypeRanges;
695   if (numValueRanges > maxValueRangeMemoryIndex)
696     maxValueRangeMemoryIndex = numValueRanges;
697 }
698 
699 void Generator::generate(Region *region, ByteCodeWriter &writer) {
700   llvm::ReversePostOrderTraversal<Region *> rpot(region);
701   for (Block *block : rpot) {
702     // Keep track of where this block begins within the matcher function.
703     blockToAddr.try_emplace(block, matcherByteCode.size());
704     for (Operation &op : *block)
705       generate(&op, writer);
706   }
707 }
708 
709 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
710   LLVM_DEBUG({
711     // The following list must contain all the operations that do not
712     // produce any bytecode.
713     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
714              pdl_interp::InferredTypesOp>(op))
715       writer.appendInline(op->getLoc());
716   });
717   TypeSwitch<Operation *>(op)
718       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
719             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
720             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
721             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
722             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
723             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
724             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
725             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
726             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
727             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
728             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
729             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
730             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
731             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
732             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
733             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
734             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
735             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
736             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
737           [&](auto interpOp) { this->generate(interpOp, writer); })
738       .Default([](Operation *) {
739         llvm_unreachable("unknown `pdl_interp` operation");
740       });
741 }
742 
743 void Generator::generate(pdl_interp::ApplyConstraintOp op,
744                          ByteCodeWriter &writer) {
745   assert(constraintToMemIndex.count(op.name()) &&
746          "expected index for constraint function");
747   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
748                 op.constParamsAttr());
749   writer.appendPDLValueList(op.args());
750   writer.append(op.getSuccessors());
751 }
752 void Generator::generate(pdl_interp::ApplyRewriteOp op,
753                          ByteCodeWriter &writer) {
754   assert(externalRewriterToMemIndex.count(op.name()) &&
755          "expected index for rewrite function");
756   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
757                 op.constParamsAttr());
758   writer.appendPDLValueList(op.args());
759 
760   ResultRange results = op.results();
761   writer.append(ByteCodeField(results.size()));
762   for (Value result : results) {
763     // In debug mode we also record the expected kind of the result, so that we
764     // can provide extra verification of the native rewrite function.
765 #ifndef NDEBUG
766     writer.appendPDLValueKind(result);
767 #endif
768 
769     // Range results also need to append the range storage index.
770     if (result.getType().isa<pdl::RangeType>())
771       writer.append(getRangeStorageIndex(result));
772     writer.append(result);
773   }
774 }
775 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
776   Value lhs = op.lhs();
777   if (lhs.getType().isa<pdl::RangeType>()) {
778     writer.append(OpCode::AreRangesEqual);
779     writer.appendPDLValueKind(lhs);
780     writer.append(op.lhs(), op.rhs(), op.getSuccessors());
781     return;
782   }
783 
784   writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
785 }
786 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
787   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
788 }
789 void Generator::generate(pdl_interp::CheckAttributeOp op,
790                          ByteCodeWriter &writer) {
791   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
792                 op.getSuccessors());
793 }
794 void Generator::generate(pdl_interp::CheckOperandCountOp op,
795                          ByteCodeWriter &writer) {
796   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
797                 static_cast<ByteCodeField>(op.compareAtLeast()),
798                 op.getSuccessors());
799 }
800 void Generator::generate(pdl_interp::CheckOperationNameOp op,
801                          ByteCodeWriter &writer) {
802   writer.append(OpCode::CheckOperationName, op.operation(),
803                 OperationName(op.name(), ctx), op.getSuccessors());
804 }
805 void Generator::generate(pdl_interp::CheckResultCountOp op,
806                          ByteCodeWriter &writer) {
807   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
808                 static_cast<ByteCodeField>(op.compareAtLeast()),
809                 op.getSuccessors());
810 }
811 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
812   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
813 }
814 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
815   writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
816 }
817 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
818   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
819   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
820 }
821 void Generator::generate(pdl_interp::CreateAttributeOp op,
822                          ByteCodeWriter &writer) {
823   // Simply repoint the memory index of the result to the constant.
824   getMemIndex(op.attribute()) = getMemIndex(op.value());
825 }
826 void Generator::generate(pdl_interp::CreateOperationOp op,
827                          ByteCodeWriter &writer) {
828   writer.append(OpCode::CreateOperation, op.operation(),
829                 OperationName(op.name(), ctx));
830   writer.appendPDLValueList(op.operands());
831 
832   // Add the attributes.
833   OperandRange attributes = op.attributes();
834   writer.append(static_cast<ByteCodeField>(attributes.size()));
835   for (auto it : llvm::zip(op.attributeNames(), op.attributes()))
836     writer.append(std::get<0>(it), std::get<1>(it));
837   writer.appendPDLValueList(op.types());
838 }
839 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
840   // Simply repoint the memory index of the result to the constant.
841   getMemIndex(op.result()) = getMemIndex(op.value());
842 }
843 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
844   writer.append(OpCode::CreateTypes, op.result(),
845                 getRangeStorageIndex(op.result()), op.value());
846 }
847 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
848   writer.append(OpCode::EraseOp, op.operation());
849 }
850 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
851   OpCode opCode =
852       TypeSwitch<Type, OpCode>(op.result().getType())
853           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
854           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
855           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
856           .Default([](Type) -> OpCode {
857             llvm_unreachable("unsupported element type");
858           });
859   writer.append(opCode, op.range(), op.index(), op.result());
860 }
861 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
862   writer.append(OpCode::Finalize);
863 }
864 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
865   BlockArgument arg = op.getLoopVariable();
866   writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg);
867   writer.appendPDLValueKind(arg.getType());
868   writer.append(curLoopLevel, op.successor());
869   ++curLoopLevel;
870   if (curLoopLevel > maxLoopLevel)
871     maxLoopLevel = curLoopLevel;
872   generate(&op.region(), writer);
873   --curLoopLevel;
874 }
875 void Generator::generate(pdl_interp::GetAttributeOp op,
876                          ByteCodeWriter &writer) {
877   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
878                 op.nameAttr());
879 }
880 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
881                          ByteCodeWriter &writer) {
882   writer.append(OpCode::GetAttributeType, op.result(), op.value());
883 }
884 void Generator::generate(pdl_interp::GetDefiningOpOp op,
885                          ByteCodeWriter &writer) {
886   writer.append(OpCode::GetDefiningOp, op.operation());
887   writer.appendPDLValue(op.value());
888 }
889 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
890   uint32_t index = op.index();
891   if (index < 4)
892     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
893   else
894     writer.append(OpCode::GetOperandN, index);
895   writer.append(op.operation(), op.value());
896 }
897 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
898   Value result = op.value();
899   Optional<uint32_t> index = op.index();
900   writer.append(OpCode::GetOperands,
901                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
902                 op.operation());
903   if (result.getType().isa<pdl::RangeType>())
904     writer.append(getRangeStorageIndex(result));
905   else
906     writer.append(std::numeric_limits<ByteCodeField>::max());
907   writer.append(result);
908 }
909 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
910   uint32_t index = op.index();
911   if (index < 4)
912     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
913   else
914     writer.append(OpCode::GetResultN, index);
915   writer.append(op.operation(), op.value());
916 }
917 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
918   Value result = op.value();
919   Optional<uint32_t> index = op.index();
920   writer.append(OpCode::GetResults,
921                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
922                 op.operation());
923   if (result.getType().isa<pdl::RangeType>())
924     writer.append(getRangeStorageIndex(result));
925   else
926     writer.append(std::numeric_limits<ByteCodeField>::max());
927   writer.append(result);
928 }
929 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
930   Value operations = op.operations();
931   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
932   writer.append(OpCode::GetUsers, operations, rangeIndex);
933   writer.appendPDLValue(op.value());
934 }
935 void Generator::generate(pdl_interp::GetValueTypeOp op,
936                          ByteCodeWriter &writer) {
937   if (op.getType().isa<pdl::RangeType>()) {
938     Value result = op.result();
939     writer.append(OpCode::GetValueRangeTypes, result,
940                   getRangeStorageIndex(result), op.value());
941   } else {
942     writer.append(OpCode::GetValueType, op.result(), op.value());
943   }
944 }
945 
946 void Generator::generate(pdl_interp::InferredTypesOp op,
947                          ByteCodeWriter &writer) {
948   // InferType maps to a null type as a marker for inferring result types.
949   getMemIndex(op.type()) = getMemIndex(Type());
950 }
951 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
952   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
953 }
954 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
955   ByteCodeField patternIndex = patterns.size();
956   patterns.emplace_back(PDLByteCodePattern::create(
957       op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
958   writer.append(OpCode::RecordMatch, patternIndex,
959                 SuccessorRange(op.getOperation()), op.matchedOps());
960   writer.appendPDLValueList(op.inputs());
961 }
962 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
963   writer.append(OpCode::ReplaceOp, op.operation());
964   writer.appendPDLValueList(op.replValues());
965 }
966 void Generator::generate(pdl_interp::SwitchAttributeOp op,
967                          ByteCodeWriter &writer) {
968   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
969                 op.getSuccessors());
970 }
971 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
972                          ByteCodeWriter &writer) {
973   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
974                 op.getSuccessors());
975 }
976 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
977                          ByteCodeWriter &writer) {
978   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
979     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
980   });
981   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
982                 op.getSuccessors());
983 }
984 void Generator::generate(pdl_interp::SwitchResultCountOp op,
985                          ByteCodeWriter &writer) {
986   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
987                 op.getSuccessors());
988 }
989 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
990   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
991                 op.getSuccessors());
992 }
993 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
994   writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
995                 op.getSuccessors());
996 }
997 
998 //===----------------------------------------------------------------------===//
999 // PDLByteCode
1000 //===----------------------------------------------------------------------===//
1001 
1002 PDLByteCode::PDLByteCode(ModuleOp module,
1003                          llvm::StringMap<PDLConstraintFunction> constraintFns,
1004                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1005   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1006                       rewriterByteCode, patterns, maxValueMemoryIndex,
1007                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1008                       maxLoopLevel, constraintFns, rewriteFns);
1009   generator.generate(module);
1010 
1011   // Initialize the external functions.
1012   for (auto &it : constraintFns)
1013     constraintFunctions.push_back(std::move(it.second));
1014   for (auto &it : rewriteFns)
1015     rewriteFunctions.push_back(std::move(it.second));
1016 }
1017 
1018 /// Initialize the given state such that it can be used to execute the current
1019 /// bytecode.
1020 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1021   state.memory.resize(maxValueMemoryIndex, nullptr);
1022   state.opRangeMemory.resize(maxOpRangeCount);
1023   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1024   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1025   state.loopIndex.resize(maxLoopLevel, 0);
1026   state.currentPatternBenefits.reserve(patterns.size());
1027   for (const PDLByteCodePattern &pattern : patterns)
1028     state.currentPatternBenefits.push_back(pattern.getBenefit());
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // ByteCode Execution
1033 
1034 namespace {
1035 /// This class provides support for executing a bytecode stream.
1036 class ByteCodeExecutor {
1037 public:
1038   ByteCodeExecutor(
1039       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1040       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1041       MutableArrayRef<TypeRange> typeRangeMemory,
1042       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1043       MutableArrayRef<ValueRange> valueRangeMemory,
1044       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1045       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1046       ArrayRef<ByteCodeField> code,
1047       ArrayRef<PatternBenefit> currentPatternBenefits,
1048       ArrayRef<PDLByteCodePattern> patterns,
1049       ArrayRef<PDLConstraintFunction> constraintFunctions,
1050       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1051       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1052         typeRangeMemory(typeRangeMemory),
1053         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1054         valueRangeMemory(valueRangeMemory),
1055         allocatedValueRangeMemory(allocatedValueRangeMemory),
1056         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1057         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1058         constraintFunctions(constraintFunctions),
1059         rewriteFunctions(rewriteFunctions) {}
1060 
1061   /// Start executing the code at the current bytecode index. `matches` is an
1062   /// optional field provided when this function is executed in a matching
1063   /// context.
1064   void execute(PatternRewriter &rewriter,
1065                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1066                Optional<Location> mainRewriteLoc = {});
1067 
1068 private:
1069   /// Internal implementation of executing each of the bytecode commands.
1070   void executeApplyConstraint(PatternRewriter &rewriter);
1071   void executeApplyRewrite(PatternRewriter &rewriter);
1072   void executeAreEqual();
1073   void executeAreRangesEqual();
1074   void executeBranch();
1075   void executeCheckOperandCount();
1076   void executeCheckOperationName();
1077   void executeCheckResultCount();
1078   void executeCheckTypes();
1079   void executeContinue();
1080   void executeCreateOperation(PatternRewriter &rewriter,
1081                               Location mainRewriteLoc);
1082   void executeCreateTypes();
1083   void executeEraseOp(PatternRewriter &rewriter);
1084   template <typename T, typename Range, PDLValue::Kind kind>
1085   void executeExtract();
1086   void executeFinalize();
1087   void executeForEach();
1088   void executeGetAttribute();
1089   void executeGetAttributeType();
1090   void executeGetDefiningOp();
1091   void executeGetOperand(unsigned index);
1092   void executeGetOperands();
1093   void executeGetResult(unsigned index);
1094   void executeGetResults();
1095   void executeGetUsers();
1096   void executeGetValueType();
1097   void executeGetValueRangeTypes();
1098   void executeIsNotNull();
1099   void executeRecordMatch(PatternRewriter &rewriter,
1100                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1101   void executeReplaceOp(PatternRewriter &rewriter);
1102   void executeSwitchAttribute();
1103   void executeSwitchOperandCount();
1104   void executeSwitchOperationName();
1105   void executeSwitchResultCount();
1106   void executeSwitchType();
1107   void executeSwitchTypes();
1108 
1109   /// Pushes a code iterator to the stack.
1110   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1111 
1112   /// Pops a code iterator from the stack, returning true on success.
1113   void popCodeIt() {
1114     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1115     curCodeIt = resumeCodeIt.back();
1116     resumeCodeIt.pop_back();
1117   }
1118 
1119   /// Return the bytecode iterator at the start of the current op code.
1120   const ByteCodeField *getPrevCodeIt() const {
1121     LLVM_DEBUG({
1122       // Account for the op code and the Location stored inline.
1123       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1124     });
1125 
1126     // Account for the op code only.
1127     return curCodeIt - 1;
1128   }
1129 
1130   /// Read a value from the bytecode buffer, optionally skipping a certain
1131   /// number of prefix values. These methods always update the buffer to point
1132   /// to the next field after the read data.
1133   template <typename T = ByteCodeField>
1134   T read(size_t skipN = 0) {
1135     curCodeIt += skipN;
1136     return readImpl<T>();
1137   }
1138   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1139 
1140   /// Read a list of values from the bytecode buffer.
1141   template <typename ValueT, typename T>
1142   void readList(SmallVectorImpl<T> &list) {
1143     list.clear();
1144     for (unsigned i = 0, e = read(); i != e; ++i)
1145       list.push_back(read<ValueT>());
1146   }
1147 
1148   /// Read a list of values from the bytecode buffer. The values may be encoded
1149   /// as either Value or ValueRange elements.
1150   void readValueList(SmallVectorImpl<Value> &list) {
1151     for (unsigned i = 0, e = read(); i != e; ++i) {
1152       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1153         list.push_back(read<Value>());
1154       } else {
1155         ValueRange *values = read<ValueRange *>();
1156         list.append(values->begin(), values->end());
1157       }
1158     }
1159   }
1160 
1161   /// Read a value stored inline as a pointer.
1162   template <typename T>
1163   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1164   readInline() {
1165     const void *pointer;
1166     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1167     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1168     return T::getFromOpaquePointer(pointer);
1169   }
1170 
1171   /// Jump to a specific successor based on a predicate value.
1172   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1173   /// Jump to a specific successor based on a destination index.
1174   void selectJump(size_t destIndex) {
1175     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1176   }
1177 
1178   /// Handle a switch operation with the provided value and cases.
1179   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1180   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1181     LLVM_DEBUG({
1182       llvm::dbgs() << "  * Value: " << value << "\n"
1183                    << "  * Cases: ";
1184       llvm::interleaveComma(cases, llvm::dbgs());
1185       llvm::dbgs() << "\n";
1186     });
1187 
1188     // Check to see if the attribute value is within the case list. Jump to
1189     // the correct successor index based on the result.
1190     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1191       if (cmp(*it, value))
1192         return selectJump(size_t((it - cases.begin()) + 1));
1193     selectJump(size_t(0));
1194   }
1195 
1196   /// Store a pointer to memory.
1197   void storeToMemory(unsigned index, const void *value) {
1198     memory[index] = value;
1199   }
1200 
1201   /// Store a value to memory as an opaque pointer.
1202   template <typename T>
1203   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1204   storeToMemory(unsigned index, T value) {
1205     memory[index] = value.getAsOpaquePointer();
1206   }
1207 
1208   /// Internal implementation of reading various data types from the bytecode
1209   /// stream.
1210   template <typename T>
1211   const void *readFromMemory() {
1212     size_t index = *curCodeIt++;
1213 
1214     // If this type is an SSA value, it can only be stored in non-const memory.
1215     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1216                         Value>::value ||
1217         index < memory.size())
1218       return memory[index];
1219 
1220     // Otherwise, if this index is not inbounds it is uniqued.
1221     return uniquedMemory[index - memory.size()];
1222   }
1223   template <typename T>
1224   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1225     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1226   }
1227   template <typename T>
1228   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1229                    T>
1230   readImpl() {
1231     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1232   }
1233   template <typename T>
1234   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1235     switch (read<PDLValue::Kind>()) {
1236     case PDLValue::Kind::Attribute:
1237       return read<Attribute>();
1238     case PDLValue::Kind::Operation:
1239       return read<Operation *>();
1240     case PDLValue::Kind::Type:
1241       return read<Type>();
1242     case PDLValue::Kind::Value:
1243       return read<Value>();
1244     case PDLValue::Kind::TypeRange:
1245       return read<TypeRange *>();
1246     case PDLValue::Kind::ValueRange:
1247       return read<ValueRange *>();
1248     }
1249     llvm_unreachable("unhandled PDLValue::Kind");
1250   }
1251   template <typename T>
1252   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1253     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1254                   "unexpected ByteCode address size");
1255     ByteCodeAddr result;
1256     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1257     curCodeIt += 2;
1258     return result;
1259   }
1260   template <typename T>
1261   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1262     return *curCodeIt++;
1263   }
1264   template <typename T>
1265   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1266     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1267   }
1268 
1269   /// The underlying bytecode buffer.
1270   const ByteCodeField *curCodeIt;
1271 
1272   /// The stack of bytecode positions at which to resume operation.
1273   SmallVector<const ByteCodeField *> resumeCodeIt;
1274 
1275   /// The current execution memory.
1276   MutableArrayRef<const void *> memory;
1277   MutableArrayRef<OwningOpRange> opRangeMemory;
1278   MutableArrayRef<TypeRange> typeRangeMemory;
1279   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1280   MutableArrayRef<ValueRange> valueRangeMemory;
1281   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1282 
1283   /// The current loop indices.
1284   MutableArrayRef<unsigned> loopIndex;
1285 
1286   /// References to ByteCode data necessary for execution.
1287   ArrayRef<const void *> uniquedMemory;
1288   ArrayRef<ByteCodeField> code;
1289   ArrayRef<PatternBenefit> currentPatternBenefits;
1290   ArrayRef<PDLByteCodePattern> patterns;
1291   ArrayRef<PDLConstraintFunction> constraintFunctions;
1292   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1293 };
1294 
1295 /// This class is an instantiation of the PDLResultList that provides access to
1296 /// the returned results. This API is not on `PDLResultList` to avoid
1297 /// overexposing access to information specific solely to the ByteCode.
1298 class ByteCodeRewriteResultList : public PDLResultList {
1299 public:
1300   ByteCodeRewriteResultList(unsigned maxNumResults)
1301       : PDLResultList(maxNumResults) {}
1302 
1303   /// Return the list of PDL results.
1304   MutableArrayRef<PDLValue> getResults() { return results; }
1305 
1306   /// Return the type ranges allocated by this list.
1307   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1308     return allocatedTypeRanges;
1309   }
1310 
1311   /// Return the value ranges allocated by this list.
1312   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1313     return allocatedValueRanges;
1314   }
1315 };
1316 } // namespace
1317 
1318 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1319   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1320   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1321   ArrayAttr constParams = read<ArrayAttr>();
1322   SmallVector<PDLValue, 16> args;
1323   readList<PDLValue>(args);
1324 
1325   LLVM_DEBUG({
1326     llvm::dbgs() << "  * Arguments: ";
1327     llvm::interleaveComma(args, llvm::dbgs());
1328     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1329   });
1330 
1331   // Invoke the constraint and jump to the proper destination.
1332   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1333 }
1334 
1335 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1336   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1337   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1338   ArrayAttr constParams = read<ArrayAttr>();
1339   SmallVector<PDLValue, 16> args;
1340   readList<PDLValue>(args);
1341 
1342   LLVM_DEBUG({
1343     llvm::dbgs() << "  * Arguments: ";
1344     llvm::interleaveComma(args, llvm::dbgs());
1345     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1346   });
1347 
1348   // Execute the rewrite function.
1349   ByteCodeField numResults = read();
1350   ByteCodeRewriteResultList results(numResults);
1351   rewriteFn(args, constParams, rewriter, results);
1352 
1353   assert(results.getResults().size() == numResults &&
1354          "native PDL rewrite function returned unexpected number of results");
1355 
1356   // Store the results in the bytecode memory.
1357   for (PDLValue &result : results.getResults()) {
1358     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1359 
1360 // In debug mode we also verify the expected kind of the result.
1361 #ifndef NDEBUG
1362     assert(result.getKind() == read<PDLValue::Kind>() &&
1363            "native PDL rewrite function returned an unexpected type of result");
1364 #endif
1365 
1366     // If the result is a range, we need to copy it over to the bytecodes
1367     // range memory.
1368     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1369       unsigned rangeIndex = read();
1370       typeRangeMemory[rangeIndex] = *typeRange;
1371       memory[read()] = &typeRangeMemory[rangeIndex];
1372     } else if (Optional<ValueRange> valueRange =
1373                    result.dyn_cast<ValueRange>()) {
1374       unsigned rangeIndex = read();
1375       valueRangeMemory[rangeIndex] = *valueRange;
1376       memory[read()] = &valueRangeMemory[rangeIndex];
1377     } else {
1378       memory[read()] = result.getAsOpaquePointer();
1379     }
1380   }
1381 
1382   // Copy over any underlying storage allocated for result ranges.
1383   for (auto &it : results.getAllocatedTypeRanges())
1384     allocatedTypeRangeMemory.push_back(std::move(it));
1385   for (auto &it : results.getAllocatedValueRanges())
1386     allocatedValueRangeMemory.push_back(std::move(it));
1387 }
1388 
1389 void ByteCodeExecutor::executeAreEqual() {
1390   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1391   const void *lhs = read<const void *>();
1392   const void *rhs = read<const void *>();
1393 
1394   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1395   selectJump(lhs == rhs);
1396 }
1397 
1398 void ByteCodeExecutor::executeAreRangesEqual() {
1399   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1400   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1401   const void *lhs = read<const void *>();
1402   const void *rhs = read<const void *>();
1403 
1404   switch (valueKind) {
1405   case PDLValue::Kind::TypeRange: {
1406     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1407     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1408     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1409     selectJump(*lhsRange == *rhsRange);
1410     break;
1411   }
1412   case PDLValue::Kind::ValueRange: {
1413     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1414     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1415     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1416     selectJump(*lhsRange == *rhsRange);
1417     break;
1418   }
1419   default:
1420     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1421   }
1422 }
1423 
1424 void ByteCodeExecutor::executeBranch() {
1425   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1426   curCodeIt = &code[read<ByteCodeAddr>()];
1427 }
1428 
1429 void ByteCodeExecutor::executeCheckOperandCount() {
1430   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1431   Operation *op = read<Operation *>();
1432   uint32_t expectedCount = read<uint32_t>();
1433   bool compareAtLeast = read();
1434 
1435   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1436                           << "  * Expected: " << expectedCount << "\n"
1437                           << "  * Comparator: "
1438                           << (compareAtLeast ? ">=" : "==") << "\n");
1439   if (compareAtLeast)
1440     selectJump(op->getNumOperands() >= expectedCount);
1441   else
1442     selectJump(op->getNumOperands() == expectedCount);
1443 }
1444 
1445 void ByteCodeExecutor::executeCheckOperationName() {
1446   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1447   Operation *op = read<Operation *>();
1448   OperationName expectedName = read<OperationName>();
1449 
1450   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1451                           << "  * Expected: \"" << expectedName << "\"\n");
1452   selectJump(op->getName() == expectedName);
1453 }
1454 
1455 void ByteCodeExecutor::executeCheckResultCount() {
1456   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1457   Operation *op = read<Operation *>();
1458   uint32_t expectedCount = read<uint32_t>();
1459   bool compareAtLeast = read();
1460 
1461   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1462                           << "  * Expected: " << expectedCount << "\n"
1463                           << "  * Comparator: "
1464                           << (compareAtLeast ? ">=" : "==") << "\n");
1465   if (compareAtLeast)
1466     selectJump(op->getNumResults() >= expectedCount);
1467   else
1468     selectJump(op->getNumResults() == expectedCount);
1469 }
1470 
1471 void ByteCodeExecutor::executeCheckTypes() {
1472   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1473   TypeRange *lhs = read<TypeRange *>();
1474   Attribute rhs = read<Attribute>();
1475   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1476 
1477   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1478 }
1479 
1480 void ByteCodeExecutor::executeContinue() {
1481   ByteCodeField level = read();
1482   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1483                           << "  * Level: " << level << "\n");
1484   ++loopIndex[level];
1485   popCodeIt();
1486 }
1487 
1488 void ByteCodeExecutor::executeCreateTypes() {
1489   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1490   unsigned memIndex = read();
1491   unsigned rangeIndex = read();
1492   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1493 
1494   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1495 
1496   // Allocate a buffer for this type range.
1497   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1498   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1499   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1500 
1501   // Assign this to the range slot and use the range as the value for the
1502   // memory index.
1503   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1504   memory[memIndex] = &typeRangeMemory[rangeIndex];
1505 }
1506 
1507 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1508                                               Location mainRewriteLoc) {
1509   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1510 
1511   unsigned memIndex = read();
1512   OperationState state(mainRewriteLoc, read<OperationName>());
1513   readValueList(state.operands);
1514   for (unsigned i = 0, e = read(); i != e; ++i) {
1515     StringAttr name = read<StringAttr>();
1516     if (Attribute attr = read<Attribute>())
1517       state.addAttribute(name, attr);
1518   }
1519 
1520   for (unsigned i = 0, e = read(); i != e; ++i) {
1521     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1522       state.types.push_back(read<Type>());
1523       continue;
1524     }
1525 
1526     // If we find a null range, this signals that the types are infered.
1527     if (TypeRange *resultTypes = read<TypeRange *>()) {
1528       state.types.append(resultTypes->begin(), resultTypes->end());
1529       continue;
1530     }
1531 
1532     // Handle the case where the operation has inferred types.
1533     InferTypeOpInterface::Concept *concept =
1534         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1535 
1536     // TODO: Handle failure.
1537     state.types.clear();
1538     if (failed(concept->inferReturnTypes(
1539             state.getContext(), state.location, state.operands,
1540             state.attributes.getDictionary(state.getContext()), state.regions,
1541             state.types)))
1542       return;
1543     break;
1544   }
1545 
1546   Operation *resultOp = rewriter.createOperation(state);
1547   memory[memIndex] = resultOp;
1548 
1549   LLVM_DEBUG({
1550     llvm::dbgs() << "  * Attributes: "
1551                  << state.attributes.getDictionary(state.getContext())
1552                  << "\n  * Operands: ";
1553     llvm::interleaveComma(state.operands, llvm::dbgs());
1554     llvm::dbgs() << "\n  * Result Types: ";
1555     llvm::interleaveComma(state.types, llvm::dbgs());
1556     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1557   });
1558 }
1559 
1560 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1561   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1562   Operation *op = read<Operation *>();
1563 
1564   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1565   rewriter.eraseOp(op);
1566 }
1567 
1568 template <typename T, typename Range, PDLValue::Kind kind>
1569 void ByteCodeExecutor::executeExtract() {
1570   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1571   Range *range = read<Range *>();
1572   unsigned index = read<uint32_t>();
1573   unsigned memIndex = read();
1574 
1575   if (!range) {
1576     memory[memIndex] = nullptr;
1577     return;
1578   }
1579 
1580   T result = index < range->size() ? (*range)[index] : T();
1581   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1582                           << "  * Index: " << index << "\n"
1583                           << "  * Result: " << result << "\n");
1584   storeToMemory(memIndex, result);
1585 }
1586 
1587 void ByteCodeExecutor::executeFinalize() {
1588   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1589 }
1590 
1591 void ByteCodeExecutor::executeForEach() {
1592   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1593   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1594   unsigned rangeIndex = read();
1595   unsigned memIndex = read();
1596   const void *value = nullptr;
1597 
1598   switch (read<PDLValue::Kind>()) {
1599   case PDLValue::Kind::Operation: {
1600     unsigned &index = loopIndex[read()];
1601     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1602     assert(index <= array.size() && "iterated past the end");
1603     if (index < array.size()) {
1604       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1605       value = array[index];
1606       break;
1607     }
1608 
1609     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1610     index = 0;
1611     selectJump(size_t(0));
1612     return;
1613   }
1614   default:
1615     llvm_unreachable("unexpected `ForEach` value kind");
1616   }
1617 
1618   // Store the iterate value and the stack address.
1619   memory[memIndex] = value;
1620   pushCodeIt(prevCodeIt);
1621 
1622   // Skip over the successor (we will enter the body of the loop).
1623   read<ByteCodeAddr>();
1624 }
1625 
1626 void ByteCodeExecutor::executeGetAttribute() {
1627   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1628   unsigned memIndex = read();
1629   Operation *op = read<Operation *>();
1630   StringAttr attrName = read<StringAttr>();
1631   Attribute attr = op->getAttr(attrName);
1632 
1633   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1634                           << "  * Attribute: " << attrName << "\n"
1635                           << "  * Result: " << attr << "\n");
1636   memory[memIndex] = attr.getAsOpaquePointer();
1637 }
1638 
1639 void ByteCodeExecutor::executeGetAttributeType() {
1640   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1641   unsigned memIndex = read();
1642   Attribute attr = read<Attribute>();
1643   Type type = attr ? attr.getType() : Type();
1644 
1645   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1646                           << "  * Result: " << type << "\n");
1647   memory[memIndex] = type.getAsOpaquePointer();
1648 }
1649 
1650 void ByteCodeExecutor::executeGetDefiningOp() {
1651   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1652   unsigned memIndex = read();
1653   Operation *op = nullptr;
1654   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1655     Value value = read<Value>();
1656     if (value)
1657       op = value.getDefiningOp();
1658     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1659   } else {
1660     ValueRange *values = read<ValueRange *>();
1661     if (values && !values->empty()) {
1662       op = values->front().getDefiningOp();
1663     }
1664     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1665   }
1666 
1667   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1668   memory[memIndex] = op;
1669 }
1670 
1671 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1672   Operation *op = read<Operation *>();
1673   unsigned memIndex = read();
1674   Value operand =
1675       index < op->getNumOperands() ? op->getOperand(index) : Value();
1676 
1677   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1678                           << "  * Index: " << index << "\n"
1679                           << "  * Result: " << operand << "\n");
1680   memory[memIndex] = operand.getAsOpaquePointer();
1681 }
1682 
1683 /// This function is the internal implementation of `GetResults` and
1684 /// `GetOperands` that provides support for extracting a value range from the
1685 /// given operation.
1686 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1687 static void *
1688 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1689                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1690                           MutableArrayRef<ValueRange> valueRangeMemory) {
1691   // Check for the sentinel index that signals that all values should be
1692   // returned.
1693   if (index == std::numeric_limits<uint32_t>::max()) {
1694     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1695     // `values` is already the full value range.
1696 
1697     // Otherwise, check to see if this operation uses AttrSizedSegments.
1698   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1699     LLVM_DEBUG(llvm::dbgs()
1700                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1701 
1702     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1703     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1704       return nullptr;
1705 
1706     auto segments = segmentAttr.getValues<int32_t>();
1707     unsigned startIndex =
1708         std::accumulate(segments.begin(), segments.begin() + index, 0);
1709     values = values.slice(startIndex, *std::next(segments.begin(), index));
1710 
1711     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1712                             << *std::next(segments.begin(), index) << "]\n");
1713 
1714     // Otherwise, assume this is the last operand group of the operation.
1715     // FIXME: We currently don't support operations with
1716     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1717     // have a way to detect it's presence.
1718   } else if (values.size() >= index) {
1719     LLVM_DEBUG(llvm::dbgs()
1720                << "  * Treating values as trailing variadic range\n");
1721     values = values.drop_front(index);
1722 
1723     // If we couldn't detect a way to compute the values, bail out.
1724   } else {
1725     return nullptr;
1726   }
1727 
1728   // If the range index is valid, we are returning a range.
1729   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1730     valueRangeMemory[rangeIndex] = values;
1731     return &valueRangeMemory[rangeIndex];
1732   }
1733 
1734   // If a range index wasn't provided, the range is required to be non-variadic.
1735   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1736 }
1737 
1738 void ByteCodeExecutor::executeGetOperands() {
1739   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1740   unsigned index = read<uint32_t>();
1741   Operation *op = read<Operation *>();
1742   ByteCodeField rangeIndex = read();
1743 
1744   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1745       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1746       valueRangeMemory);
1747   if (!result)
1748     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1749   memory[read()] = result;
1750 }
1751 
1752 void ByteCodeExecutor::executeGetResult(unsigned index) {
1753   Operation *op = read<Operation *>();
1754   unsigned memIndex = read();
1755   OpResult result =
1756       index < op->getNumResults() ? op->getResult(index) : OpResult();
1757 
1758   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1759                           << "  * Index: " << index << "\n"
1760                           << "  * Result: " << result << "\n");
1761   memory[memIndex] = result.getAsOpaquePointer();
1762 }
1763 
1764 void ByteCodeExecutor::executeGetResults() {
1765   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1766   unsigned index = read<uint32_t>();
1767   Operation *op = read<Operation *>();
1768   ByteCodeField rangeIndex = read();
1769 
1770   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1771       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1772       valueRangeMemory);
1773   if (!result)
1774     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1775   memory[read()] = result;
1776 }
1777 
1778 void ByteCodeExecutor::executeGetUsers() {
1779   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1780   unsigned memIndex = read();
1781   unsigned rangeIndex = read();
1782   OwningOpRange &range = opRangeMemory[rangeIndex];
1783   memory[memIndex] = &range;
1784 
1785   range = OwningOpRange();
1786   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1787     // Read the value.
1788     Value value = read<Value>();
1789     if (!value)
1790       return;
1791     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1792 
1793     // Extract the users of a single value.
1794     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1795     llvm::copy(value.getUsers(), range.begin());
1796   } else {
1797     // Read a range of values.
1798     ValueRange *values = read<ValueRange *>();
1799     if (!values)
1800       return;
1801     LLVM_DEBUG({
1802       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1803       llvm::interleaveComma(*values, llvm::dbgs());
1804       llvm::dbgs() << "\n";
1805     });
1806 
1807     // Extract all the users of a range of values.
1808     SmallVector<Operation *> users;
1809     for (Value value : *values)
1810       users.append(value.user_begin(), value.user_end());
1811     range = OwningOpRange(users.size());
1812     llvm::copy(users, range.begin());
1813   }
1814 
1815   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1816 }
1817 
1818 void ByteCodeExecutor::executeGetValueType() {
1819   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1820   unsigned memIndex = read();
1821   Value value = read<Value>();
1822   Type type = value ? value.getType() : Type();
1823 
1824   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1825                           << "  * Result: " << type << "\n");
1826   memory[memIndex] = type.getAsOpaquePointer();
1827 }
1828 
1829 void ByteCodeExecutor::executeGetValueRangeTypes() {
1830   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1831   unsigned memIndex = read();
1832   unsigned rangeIndex = read();
1833   ValueRange *values = read<ValueRange *>();
1834   if (!values) {
1835     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1836     memory[memIndex] = nullptr;
1837     return;
1838   }
1839 
1840   LLVM_DEBUG({
1841     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1842     llvm::interleaveComma(*values, llvm::dbgs());
1843     llvm::dbgs() << "\n  * Result: ";
1844     llvm::interleaveComma(values->getType(), llvm::dbgs());
1845     llvm::dbgs() << "\n";
1846   });
1847   typeRangeMemory[rangeIndex] = values->getType();
1848   memory[memIndex] = &typeRangeMemory[rangeIndex];
1849 }
1850 
1851 void ByteCodeExecutor::executeIsNotNull() {
1852   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1853   const void *value = read<const void *>();
1854 
1855   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1856   selectJump(value != nullptr);
1857 }
1858 
1859 void ByteCodeExecutor::executeRecordMatch(
1860     PatternRewriter &rewriter,
1861     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1862   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1863   unsigned patternIndex = read();
1864   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1865   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1866 
1867   // If the benefit of the pattern is impossible, skip the processing of the
1868   // rest of the pattern.
1869   if (benefit.isImpossibleToMatch()) {
1870     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1871     curCodeIt = dest;
1872     return;
1873   }
1874 
1875   // Create a fused location containing the locations of each of the
1876   // operations used in the match. This will be used as the location for
1877   // created operations during the rewrite that don't already have an
1878   // explicit location set.
1879   unsigned numMatchLocs = read();
1880   SmallVector<Location, 4> matchLocs;
1881   matchLocs.reserve(numMatchLocs);
1882   for (unsigned i = 0; i != numMatchLocs; ++i)
1883     matchLocs.push_back(read<Operation *>()->getLoc());
1884   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1885 
1886   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1887                           << "  * Location: " << matchLoc << "\n");
1888   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1889   PDLByteCode::MatchResult &match = matches.back();
1890 
1891   // Record all of the inputs to the match. If any of the inputs are ranges, we
1892   // will also need to remap the range pointer to memory stored in the match
1893   // state.
1894   unsigned numInputs = read();
1895   match.values.reserve(numInputs);
1896   match.typeRangeValues.reserve(numInputs);
1897   match.valueRangeValues.reserve(numInputs);
1898   for (unsigned i = 0; i < numInputs; ++i) {
1899     switch (read<PDLValue::Kind>()) {
1900     case PDLValue::Kind::TypeRange:
1901       match.typeRangeValues.push_back(*read<TypeRange *>());
1902       match.values.push_back(&match.typeRangeValues.back());
1903       break;
1904     case PDLValue::Kind::ValueRange:
1905       match.valueRangeValues.push_back(*read<ValueRange *>());
1906       match.values.push_back(&match.valueRangeValues.back());
1907       break;
1908     default:
1909       match.values.push_back(read<const void *>());
1910       break;
1911     }
1912   }
1913   curCodeIt = dest;
1914 }
1915 
1916 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1917   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1918   Operation *op = read<Operation *>();
1919   SmallVector<Value, 16> args;
1920   readValueList(args);
1921 
1922   LLVM_DEBUG({
1923     llvm::dbgs() << "  * Operation: " << *op << "\n"
1924                  << "  * Values: ";
1925     llvm::interleaveComma(args, llvm::dbgs());
1926     llvm::dbgs() << "\n";
1927   });
1928   rewriter.replaceOp(op, args);
1929 }
1930 
1931 void ByteCodeExecutor::executeSwitchAttribute() {
1932   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1933   Attribute value = read<Attribute>();
1934   ArrayAttr cases = read<ArrayAttr>();
1935   handleSwitch(value, cases);
1936 }
1937 
1938 void ByteCodeExecutor::executeSwitchOperandCount() {
1939   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1940   Operation *op = read<Operation *>();
1941   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1942 
1943   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1944   handleSwitch(op->getNumOperands(), cases);
1945 }
1946 
1947 void ByteCodeExecutor::executeSwitchOperationName() {
1948   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1949   OperationName value = read<Operation *>()->getName();
1950   size_t caseCount = read();
1951 
1952   // The operation names are stored in-line, so to print them out for
1953   // debugging purposes we need to read the array before executing the
1954   // switch so that we can display all of the possible values.
1955   LLVM_DEBUG({
1956     const ByteCodeField *prevCodeIt = curCodeIt;
1957     llvm::dbgs() << "  * Value: " << value << "\n"
1958                  << "  * Cases: ";
1959     llvm::interleaveComma(
1960         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1961                         [&](size_t) { return read<OperationName>(); }),
1962         llvm::dbgs());
1963     llvm::dbgs() << "\n";
1964     curCodeIt = prevCodeIt;
1965   });
1966 
1967   // Try to find the switch value within any of the cases.
1968   for (size_t i = 0; i != caseCount; ++i) {
1969     if (read<OperationName>() == value) {
1970       curCodeIt += (caseCount - i - 1);
1971       return selectJump(i + 1);
1972     }
1973   }
1974   selectJump(size_t(0));
1975 }
1976 
1977 void ByteCodeExecutor::executeSwitchResultCount() {
1978   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1979   Operation *op = read<Operation *>();
1980   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1981 
1982   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1983   handleSwitch(op->getNumResults(), cases);
1984 }
1985 
1986 void ByteCodeExecutor::executeSwitchType() {
1987   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1988   Type value = read<Type>();
1989   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1990   handleSwitch(value, cases);
1991 }
1992 
1993 void ByteCodeExecutor::executeSwitchTypes() {
1994   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
1995   TypeRange *value = read<TypeRange *>();
1996   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
1997   if (!value) {
1998     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
1999     return selectJump(size_t(0));
2000   }
2001   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2002     return value == caseValue.getAsValueRange<TypeAttr>();
2003   });
2004 }
2005 
2006 void ByteCodeExecutor::execute(
2007     PatternRewriter &rewriter,
2008     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2009     Optional<Location> mainRewriteLoc) {
2010   while (true) {
2011     // Print the location of the operation being executed.
2012     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2013 
2014     OpCode opCode = static_cast<OpCode>(read());
2015     switch (opCode) {
2016     case ApplyConstraint:
2017       executeApplyConstraint(rewriter);
2018       break;
2019     case ApplyRewrite:
2020       executeApplyRewrite(rewriter);
2021       break;
2022     case AreEqual:
2023       executeAreEqual();
2024       break;
2025     case AreRangesEqual:
2026       executeAreRangesEqual();
2027       break;
2028     case Branch:
2029       executeBranch();
2030       break;
2031     case CheckOperandCount:
2032       executeCheckOperandCount();
2033       break;
2034     case CheckOperationName:
2035       executeCheckOperationName();
2036       break;
2037     case CheckResultCount:
2038       executeCheckResultCount();
2039       break;
2040     case CheckTypes:
2041       executeCheckTypes();
2042       break;
2043     case Continue:
2044       executeContinue();
2045       break;
2046     case CreateOperation:
2047       executeCreateOperation(rewriter, *mainRewriteLoc);
2048       break;
2049     case CreateTypes:
2050       executeCreateTypes();
2051       break;
2052     case EraseOp:
2053       executeEraseOp(rewriter);
2054       break;
2055     case ExtractOp:
2056       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2057       break;
2058     case ExtractType:
2059       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2060       break;
2061     case ExtractValue:
2062       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2063       break;
2064     case Finalize:
2065       executeFinalize();
2066       LLVM_DEBUG(llvm::dbgs() << "\n");
2067       return;
2068     case ForEach:
2069       executeForEach();
2070       break;
2071     case GetAttribute:
2072       executeGetAttribute();
2073       break;
2074     case GetAttributeType:
2075       executeGetAttributeType();
2076       break;
2077     case GetDefiningOp:
2078       executeGetDefiningOp();
2079       break;
2080     case GetOperand0:
2081     case GetOperand1:
2082     case GetOperand2:
2083     case GetOperand3: {
2084       unsigned index = opCode - GetOperand0;
2085       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2086       executeGetOperand(index);
2087       break;
2088     }
2089     case GetOperandN:
2090       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2091       executeGetOperand(read<uint32_t>());
2092       break;
2093     case GetOperands:
2094       executeGetOperands();
2095       break;
2096     case GetResult0:
2097     case GetResult1:
2098     case GetResult2:
2099     case GetResult3: {
2100       unsigned index = opCode - GetResult0;
2101       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2102       executeGetResult(index);
2103       break;
2104     }
2105     case GetResultN:
2106       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2107       executeGetResult(read<uint32_t>());
2108       break;
2109     case GetResults:
2110       executeGetResults();
2111       break;
2112     case GetUsers:
2113       executeGetUsers();
2114       break;
2115     case GetValueType:
2116       executeGetValueType();
2117       break;
2118     case GetValueRangeTypes:
2119       executeGetValueRangeTypes();
2120       break;
2121     case IsNotNull:
2122       executeIsNotNull();
2123       break;
2124     case RecordMatch:
2125       assert(matches &&
2126              "expected matches to be provided when executing the matcher");
2127       executeRecordMatch(rewriter, *matches);
2128       break;
2129     case ReplaceOp:
2130       executeReplaceOp(rewriter);
2131       break;
2132     case SwitchAttribute:
2133       executeSwitchAttribute();
2134       break;
2135     case SwitchOperandCount:
2136       executeSwitchOperandCount();
2137       break;
2138     case SwitchOperationName:
2139       executeSwitchOperationName();
2140       break;
2141     case SwitchResultCount:
2142       executeSwitchResultCount();
2143       break;
2144     case SwitchType:
2145       executeSwitchType();
2146       break;
2147     case SwitchTypes:
2148       executeSwitchTypes();
2149       break;
2150     }
2151     LLVM_DEBUG(llvm::dbgs() << "\n");
2152   }
2153 }
2154 
2155 /// Run the pattern matcher on the given root operation, collecting the matched
2156 /// patterns in `matches`.
2157 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2158                         SmallVectorImpl<MatchResult> &matches,
2159                         PDLByteCodeMutableState &state) const {
2160   // The first memory slot is always the root operation.
2161   state.memory[0] = op;
2162 
2163   // The matcher function always starts at code address 0.
2164   ByteCodeExecutor executor(
2165       matcherByteCode.data(), state.memory, state.opRangeMemory,
2166       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2167       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2168       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2169       constraintFunctions, rewriteFunctions);
2170   executor.execute(rewriter, &matches);
2171 
2172   // Order the found matches by benefit.
2173   std::stable_sort(matches.begin(), matches.end(),
2174                    [](const MatchResult &lhs, const MatchResult &rhs) {
2175                      return lhs.benefit > rhs.benefit;
2176                    });
2177 }
2178 
2179 /// Run the rewriter of the given pattern on the root operation `op`.
2180 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2181                           PDLByteCodeMutableState &state) const {
2182   // The arguments of the rewrite function are stored at the start of the
2183   // memory buffer.
2184   llvm::copy(match.values, state.memory.begin());
2185 
2186   ByteCodeExecutor executor(
2187       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2188       state.opRangeMemory, state.typeRangeMemory,
2189       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2190       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2191       rewriterByteCode, state.currentPatternBenefits, patterns,
2192       constraintFunctions, rewriteFunctions);
2193   executor.execute(rewriter, /*matches=*/nullptr, match.location);
2194 }
2195