1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.benefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.rootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Continue to the next iteration of a loop.
99   Continue,
100   /// Create an operation.
101   CreateOperation,
102   /// Create a range of types.
103   CreateTypes,
104   /// Erase an operation.
105   EraseOp,
106   /// Extract the op from a range at the specified index.
107   ExtractOp,
108   /// Extract the type from a range at the specified index.
109   ExtractType,
110   /// Extract the value from a range at the specified index.
111   ExtractValue,
112   /// Terminate a matcher or rewrite sequence.
113   Finalize,
114   /// Iterate over a range of values.
115   ForEach,
116   /// Get a specific attribute of an operation.
117   GetAttribute,
118   /// Get the type of an attribute.
119   GetAttributeType,
120   /// Get the defining operation of a value.
121   GetDefiningOp,
122   /// Get a specific operand of an operation.
123   GetOperand0,
124   GetOperand1,
125   GetOperand2,
126   GetOperand3,
127   GetOperandN,
128   /// Get a specific operand group of an operation.
129   GetOperands,
130   /// Get a specific result of an operation.
131   GetResult0,
132   GetResult1,
133   GetResult2,
134   GetResult3,
135   GetResultN,
136   /// Get a specific result group of an operation.
137   GetResults,
138   /// Get the users of a value or a range of values.
139   GetUsers,
140   /// Get the type of a value.
141   GetValueType,
142   /// Get the types of a value range.
143   GetValueRangeTypes,
144   /// Check if a generic value is not null.
145   IsNotNull,
146   /// Record a successful pattern match.
147   RecordMatch,
148   /// Replace an operation.
149   ReplaceOp,
150   /// Compare an attribute with a set of constants.
151   SwitchAttribute,
152   /// Compare the operand count of an operation with a set of constants.
153   SwitchOperandCount,
154   /// Compare the name of an operation with a set of constants.
155   SwitchOperationName,
156   /// Compare the result count of an operation with a set of constants.
157   SwitchResultCount,
158   /// Compare a type with a set of constants.
159   SwitchType,
160   /// Compare a range of types with a set of constants.
161   SwitchTypes,
162 };
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // ByteCode Generation
167 //===----------------------------------------------------------------------===//
168 
169 //===----------------------------------------------------------------------===//
170 // Generator
171 
172 namespace {
173 struct ByteCodeLiveRange;
174 struct ByteCodeWriter;
175 
176 /// Check if the given class `T` can be converted to an opaque pointer.
177 template <typename T, typename... Args>
178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
179 
180 /// This class represents the main generator for the pattern bytecode.
181 class Generator {
182 public:
183   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
184             SmallVectorImpl<ByteCodeField> &matcherByteCode,
185             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
186             SmallVectorImpl<PDLByteCodePattern> &patterns,
187             ByteCodeField &maxValueMemoryIndex,
188             ByteCodeField &maxOpRangeMemoryIndex,
189             ByteCodeField &maxTypeRangeMemoryIndex,
190             ByteCodeField &maxValueRangeMemoryIndex,
191             ByteCodeField &maxLoopLevel,
192             llvm::StringMap<PDLConstraintFunction> &constraintFns,
193             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
194       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
195         rewriterByteCode(rewriterByteCode), patterns(patterns),
196         maxValueMemoryIndex(maxValueMemoryIndex),
197         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
198         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
199         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
200         maxLoopLevel(maxLoopLevel) {
201     for (const auto &it : llvm::enumerate(constraintFns))
202       constraintToMemIndex.try_emplace(it.value().first(), it.index());
203     for (const auto &it : llvm::enumerate(rewriteFns))
204       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
205   }
206 
207   /// Generate the bytecode for the given PDL interpreter module.
208   void generate(ModuleOp module);
209 
210   /// Return the memory index to use for the given value.
211   ByteCodeField &getMemIndex(Value value) {
212     assert(valueToMemIndex.count(value) &&
213            "expected memory index to be assigned");
214     return valueToMemIndex[value];
215   }
216 
217   /// Return the range memory index used to store the given range value.
218   ByteCodeField &getRangeStorageIndex(Value value) {
219     assert(valueToRangeIndex.count(value) &&
220            "expected range index to be assigned");
221     return valueToRangeIndex[value];
222   }
223 
224   /// Return an index to use when referring to the given data that is uniqued in
225   /// the MLIR context.
226   template <typename T>
227   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
228   getMemIndex(T val) {
229     const void *opaqueVal = val.getAsOpaquePointer();
230 
231     // Get or insert a reference to this value.
232     auto it = uniquedDataToMemIndex.try_emplace(
233         opaqueVal, maxValueMemoryIndex + uniquedData.size());
234     if (it.second)
235       uniquedData.push_back(opaqueVal);
236     return it.first->second;
237   }
238 
239 private:
240   /// Allocate memory indices for the results of operations within the matcher
241   /// and rewriters.
242   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
243                              ModuleOp rewriterModule);
244 
245   /// Generate the bytecode for the given operation.
246   void generate(Region *region, ByteCodeWriter &writer);
247   void generate(Operation *op, ByteCodeWriter &writer);
248   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
249   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
250   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
251   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
252   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
259   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
286 
287   /// Mapping from value to its corresponding memory index.
288   DenseMap<Value, ByteCodeField> valueToMemIndex;
289 
290   /// Mapping from a range value to its corresponding range storage index.
291   DenseMap<Value, ByteCodeField> valueToRangeIndex;
292 
293   /// Mapping from the name of an externally registered rewrite to its index in
294   /// the bytecode registry.
295   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
296 
297   /// Mapping from the name of an externally registered constraint to its index
298   /// in the bytecode registry.
299   llvm::StringMap<ByteCodeField> constraintToMemIndex;
300 
301   /// Mapping from rewriter function name to the bytecode address of the
302   /// rewriter function in byte.
303   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
304 
305   /// Mapping from a uniqued storage object to its memory index within
306   /// `uniquedData`.
307   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
308 
309   /// The current level of the foreach loop.
310   ByteCodeField curLoopLevel = 0;
311 
312   /// The current MLIR context.
313   MLIRContext *ctx;
314 
315   /// Mapping from block to its address.
316   DenseMap<Block *, ByteCodeAddr> blockToAddr;
317 
318   /// Data of the ByteCode class to be populated.
319   std::vector<const void *> &uniquedData;
320   SmallVectorImpl<ByteCodeField> &matcherByteCode;
321   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
322   SmallVectorImpl<PDLByteCodePattern> &patterns;
323   ByteCodeField &maxValueMemoryIndex;
324   ByteCodeField &maxOpRangeMemoryIndex;
325   ByteCodeField &maxTypeRangeMemoryIndex;
326   ByteCodeField &maxValueRangeMemoryIndex;
327   ByteCodeField &maxLoopLevel;
328 };
329 
330 /// This class provides utilities for writing a bytecode stream.
331 struct ByteCodeWriter {
332   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
333       : bytecode(bytecode), generator(generator) {}
334 
335   /// Append a field to the bytecode.
336   void append(ByteCodeField field) { bytecode.push_back(field); }
337   void append(OpCode opCode) { bytecode.push_back(opCode); }
338 
339   /// Append an address to the bytecode.
340   void append(ByteCodeAddr field) {
341     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
342                   "unexpected ByteCode address size");
343 
344     ByteCodeField fieldParts[2];
345     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
346     bytecode.append({fieldParts[0], fieldParts[1]});
347   }
348 
349   /// Append a single successor to the bytecode, the exact address will need to
350   /// be resolved later.
351   void append(Block *successor) {
352     // Add back a reference to the successor so that the address can be resolved
353     // later.
354     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
355     append(ByteCodeAddr(0));
356   }
357 
358   /// Append a successor range to the bytecode, the exact address will need to
359   /// be resolved later.
360   void append(SuccessorRange successors) {
361     for (Block *successor : successors)
362       append(successor);
363   }
364 
365   /// Append a range of values that will be read as generic PDLValues.
366   void appendPDLValueList(OperandRange values) {
367     bytecode.push_back(values.size());
368     for (Value value : values)
369       appendPDLValue(value);
370   }
371 
372   /// Append a value as a PDLValue.
373   void appendPDLValue(Value value) {
374     appendPDLValueKind(value);
375     append(value);
376   }
377 
378   /// Append the PDLValue::Kind of the given value.
379   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
380 
381   /// Append the PDLValue::Kind of the given type.
382   void appendPDLValueKind(Type type) {
383     PDLValue::Kind kind =
384         TypeSwitch<Type, PDLValue::Kind>(type)
385             .Case<pdl::AttributeType>(
386                 [](Type) { return PDLValue::Kind::Attribute; })
387             .Case<pdl::OperationType>(
388                 [](Type) { return PDLValue::Kind::Operation; })
389             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
390               if (rangeTy.getElementType().isa<pdl::TypeType>())
391                 return PDLValue::Kind::TypeRange;
392               return PDLValue::Kind::ValueRange;
393             })
394             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
395             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
396     bytecode.push_back(static_cast<ByteCodeField>(kind));
397   }
398 
399   /// Append a value that will be stored in a memory slot and not inline within
400   /// the bytecode.
401   template <typename T>
402   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
403                    std::is_pointer<T>::value>
404   append(T value) {
405     bytecode.push_back(generator.getMemIndex(value));
406   }
407 
408   /// Append a range of values.
409   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
410   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
411   append(T range) {
412     bytecode.push_back(llvm::size(range));
413     for (auto it : range)
414       append(it);
415   }
416 
417   /// Append a variadic number of fields to the bytecode.
418   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
419   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
420     append(field);
421     append(field2, fields...);
422   }
423 
424   /// Appends a value as a pointer, stored inline within the bytecode.
425   template <typename T>
426   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
427   appendInline(T value) {
428     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
429     const void *pointer = value.getAsOpaquePointer();
430     ByteCodeField fieldParts[numParts];
431     std::memcpy(fieldParts, &pointer, sizeof(const void *));
432     bytecode.append(fieldParts, fieldParts + numParts);
433   }
434 
435   /// Successor references in the bytecode that have yet to be resolved.
436   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
437 
438   /// The underlying bytecode buffer.
439   SmallVectorImpl<ByteCodeField> &bytecode;
440 
441   /// The main generator producing PDL.
442   Generator &generator;
443 };
444 
445 /// This class represents a live range of PDL Interpreter values, containing
446 /// information about when values are live within a match/rewrite.
447 struct ByteCodeLiveRange {
448   using Set = llvm::IntervalMap<uint64_t, char, 16>;
449   using Allocator = Set::Allocator;
450 
451   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
452 
453   /// Union this live range with the one provided.
454   void unionWith(const ByteCodeLiveRange &rhs) {
455     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
456          ++it)
457       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
458   }
459 
460   /// Returns true if this range overlaps with the one provided.
461   bool overlaps(const ByteCodeLiveRange &rhs) const {
462     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
463         .valid();
464   }
465 
466   /// A map representing the ranges of the match/rewrite that a value is live in
467   /// the interpreter.
468   ///
469   /// We use std::unique_ptr here, because IntervalMap does not provide a
470   /// correct copy or move constructor. We can eliminate the pointer once
471   /// https://reviews.llvm.org/D113240 lands.
472   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
473 
474   /// The operation range storage index for this range.
475   Optional<unsigned> opRangeIndex;
476 
477   /// The type range storage index for this range.
478   Optional<unsigned> typeRangeIndex;
479 
480   /// The value range storage index for this range.
481   Optional<unsigned> valueRangeIndex;
482 };
483 } // namespace
484 
485 void Generator::generate(ModuleOp module) {
486   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
487       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
488   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
489       pdl_interp::PDLInterpDialect::getRewriterModuleName());
490   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
491 
492   // Allocate memory indices for the results of operations within the matcher
493   // and rewriters.
494   allocateMemoryIndices(matcherFunc, rewriterModule);
495 
496   // Generate code for the rewriter functions.
497   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
498   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
499     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
500     for (Operation &op : rewriterFunc.getOps())
501       generate(&op, rewriterByteCodeWriter);
502   }
503   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
504          "unexpected branches in rewriter function");
505 
506   // Generate code for the matcher function.
507   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
508   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
509 
510   // Resolve successor references in the matcher.
511   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
512     ByteCodeAddr addr = blockToAddr[it.first];
513     for (unsigned offsetToFix : it.second)
514       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
515   }
516 }
517 
518 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
519                                       ModuleOp rewriterModule) {
520   // Rewriters use simplistic allocation scheme that simply assigns an index to
521   // each result.
522   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
523     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
524     auto processRewriterValue = [&](Value val) {
525       valueToMemIndex.try_emplace(val, index++);
526       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
527         Type elementTy = rangeType.getElementType();
528         if (elementTy.isa<pdl::TypeType>())
529           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
530         else if (elementTy.isa<pdl::ValueType>())
531           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
532       }
533     };
534 
535     for (BlockArgument arg : rewriterFunc.getArguments())
536       processRewriterValue(arg);
537     rewriterFunc.getBody().walk([&](Operation *op) {
538       for (Value result : op->getResults())
539         processRewriterValue(result);
540     });
541     if (index > maxValueMemoryIndex)
542       maxValueMemoryIndex = index;
543     if (typeRangeIndex > maxTypeRangeMemoryIndex)
544       maxTypeRangeMemoryIndex = typeRangeIndex;
545     if (valueRangeIndex > maxValueRangeMemoryIndex)
546       maxValueRangeMemoryIndex = valueRangeIndex;
547   }
548 
549   // The matcher function uses a more sophisticated numbering that tries to
550   // minimize the number of memory indices assigned. This is done by determining
551   // a live range of the values within the matcher, then the allocation is just
552   // finding the minimal number of overlapping live ranges. This is essentially
553   // a simplified form of register allocation where we don't necessarily have a
554   // limited number of registers, but we still want to minimize the number used.
555   DenseMap<Operation *, unsigned> opToFirstIndex;
556   DenseMap<Operation *, unsigned> opToLastIndex;
557 
558   // A custom walk that marks the first and the last index of each operation.
559   // The entry marks the beginning of the liveness range for this operation,
560   // followed by nested operations, followed by the end of the liveness range.
561   unsigned index = 0;
562   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
563     opToFirstIndex.try_emplace(op, index++);
564     for (Region &region : op->getRegions())
565       for (Block &block : region.getBlocks())
566         for (Operation &nested : block)
567           walk(&nested);
568     opToLastIndex.try_emplace(op, index++);
569   };
570   walk(matcherFunc);
571 
572   // Liveness info for each of the defs within the matcher.
573   ByteCodeLiveRange::Allocator allocator;
574   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
575 
576   // Assign the root operation being matched to slot 0.
577   BlockArgument rootOpArg = matcherFunc.getArgument(0);
578   valueToMemIndex[rootOpArg] = 0;
579 
580   // Walk each of the blocks, computing the def interval that the value is used.
581   Liveness matcherLiveness(matcherFunc);
582   matcherFunc->walk([&](Block *block) {
583     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
584     assert(info && "expected liveness info for block");
585     auto processValue = [&](Value value, Operation *firstUseOrDef) {
586       // We don't need to process the root op argument, this value is always
587       // assigned to the first memory slot.
588       if (value == rootOpArg)
589         return;
590 
591       // Set indices for the range of this block that the value is used.
592       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
593       defRangeIt->second.liveness->insert(
594           opToFirstIndex[firstUseOrDef],
595           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
596           /*dummyValue*/ 0);
597 
598       // Check to see if this value is a range type.
599       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
600         Type eleType = rangeTy.getElementType();
601         if (eleType.isa<pdl::OperationType>())
602           defRangeIt->second.opRangeIndex = 0;
603         else if (eleType.isa<pdl::TypeType>())
604           defRangeIt->second.typeRangeIndex = 0;
605         else if (eleType.isa<pdl::ValueType>())
606           defRangeIt->second.valueRangeIndex = 0;
607       }
608     };
609 
610     // Process the live-ins of this block.
611     for (Value liveIn : info->in()) {
612       // Only process the value if it has been defined in the current region.
613       // Other values that span across pdl_interp.foreach will be added higher
614       // up. This ensures that the we keep them alive for the entire duration
615       // of the loop.
616       if (liveIn.getParentRegion() == block->getParent())
617         processValue(liveIn, &block->front());
618     }
619 
620     // Process the block arguments for the entry block (those are not live-in).
621     if (block->isEntryBlock()) {
622       for (Value argument : block->getArguments())
623         processValue(argument, &block->front());
624     }
625 
626     // Process any new defs within this block.
627     for (Operation &op : *block)
628       for (Value result : op.getResults())
629         processValue(result, &op);
630   });
631 
632   // Greedily allocate memory slots using the computed def live ranges.
633   std::vector<ByteCodeLiveRange> allocatedIndices;
634 
635   // The number of memory indices currently allocated (and its next value).
636   // Recall that the root gets allocated memory index 0.
637   ByteCodeField numIndices = 1;
638 
639   // The number of memory ranges of various types (and their next values).
640   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
641 
642   for (auto &defIt : valueDefRanges) {
643     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
644     ByteCodeLiveRange &defRange = defIt.second;
645 
646     // Try to allocate to an existing index.
647     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
648       ByteCodeLiveRange &existingRange = existingIndexIt.value();
649       if (!defRange.overlaps(existingRange)) {
650         existingRange.unionWith(defRange);
651         memIndex = existingIndexIt.index() + 1;
652 
653         if (defRange.opRangeIndex) {
654           if (!existingRange.opRangeIndex)
655             existingRange.opRangeIndex = numOpRanges++;
656           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
657         } else if (defRange.typeRangeIndex) {
658           if (!existingRange.typeRangeIndex)
659             existingRange.typeRangeIndex = numTypeRanges++;
660           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
661         } else if (defRange.valueRangeIndex) {
662           if (!existingRange.valueRangeIndex)
663             existingRange.valueRangeIndex = numValueRanges++;
664           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
665         }
666         break;
667       }
668     }
669 
670     // If no existing index could be used, add a new one.
671     if (memIndex == 0) {
672       allocatedIndices.emplace_back(allocator);
673       ByteCodeLiveRange &newRange = allocatedIndices.back();
674       newRange.unionWith(defRange);
675 
676       // Allocate an index for op/type/value ranges.
677       if (defRange.opRangeIndex) {
678         newRange.opRangeIndex = numOpRanges;
679         valueToRangeIndex[defIt.first] = numOpRanges++;
680       } else if (defRange.typeRangeIndex) {
681         newRange.typeRangeIndex = numTypeRanges;
682         valueToRangeIndex[defIt.first] = numTypeRanges++;
683       } else if (defRange.valueRangeIndex) {
684         newRange.valueRangeIndex = numValueRanges;
685         valueToRangeIndex[defIt.first] = numValueRanges++;
686       }
687 
688       memIndex = allocatedIndices.size();
689       ++numIndices;
690     }
691   }
692 
693   // Print the index usage and ensure that we did not run out of index space.
694   LLVM_DEBUG({
695     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
696                  << "(down from initial " << valueDefRanges.size() << ").\n";
697   });
698   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
699          "Ran out of memory for allocated indices");
700 
701   // Update the max number of indices.
702   if (numIndices > maxValueMemoryIndex)
703     maxValueMemoryIndex = numIndices;
704   if (numOpRanges > maxOpRangeMemoryIndex)
705     maxOpRangeMemoryIndex = numOpRanges;
706   if (numTypeRanges > maxTypeRangeMemoryIndex)
707     maxTypeRangeMemoryIndex = numTypeRanges;
708   if (numValueRanges > maxValueRangeMemoryIndex)
709     maxValueRangeMemoryIndex = numValueRanges;
710 }
711 
712 void Generator::generate(Region *region, ByteCodeWriter &writer) {
713   llvm::ReversePostOrderTraversal<Region *> rpot(region);
714   for (Block *block : rpot) {
715     // Keep track of where this block begins within the matcher function.
716     blockToAddr.try_emplace(block, matcherByteCode.size());
717     for (Operation &op : *block)
718       generate(&op, writer);
719   }
720 }
721 
722 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
723   LLVM_DEBUG({
724     // The following list must contain all the operations that do not
725     // produce any bytecode.
726     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
727              pdl_interp::InferredTypesOp>(op))
728       writer.appendInline(op->getLoc());
729   });
730   TypeSwitch<Operation *>(op)
731       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
732             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
733             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
734             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
735             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
736             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
737             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
738             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
739             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
740             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
741             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
742             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
743             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
744             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
745             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
746             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
747             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
748             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
749             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
750           [&](auto interpOp) { this->generate(interpOp, writer); })
751       .Default([](Operation *) {
752         llvm_unreachable("unknown `pdl_interp` operation");
753       });
754 }
755 
756 void Generator::generate(pdl_interp::ApplyConstraintOp op,
757                          ByteCodeWriter &writer) {
758   assert(constraintToMemIndex.count(op.name()) &&
759          "expected index for constraint function");
760   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
761                 op.constParamsAttr());
762   writer.appendPDLValueList(op.args());
763   writer.append(op.getSuccessors());
764 }
765 void Generator::generate(pdl_interp::ApplyRewriteOp op,
766                          ByteCodeWriter &writer) {
767   assert(externalRewriterToMemIndex.count(op.name()) &&
768          "expected index for rewrite function");
769   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
770                 op.constParamsAttr());
771   writer.appendPDLValueList(op.args());
772 
773   ResultRange results = op.results();
774   writer.append(ByteCodeField(results.size()));
775   for (Value result : results) {
776     // In debug mode we also record the expected kind of the result, so that we
777     // can provide extra verification of the native rewrite function.
778 #ifndef NDEBUG
779     writer.appendPDLValueKind(result);
780 #endif
781 
782     // Range results also need to append the range storage index.
783     if (result.getType().isa<pdl::RangeType>())
784       writer.append(getRangeStorageIndex(result));
785     writer.append(result);
786   }
787 }
788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
789   Value lhs = op.lhs();
790   if (lhs.getType().isa<pdl::RangeType>()) {
791     writer.append(OpCode::AreRangesEqual);
792     writer.appendPDLValueKind(lhs);
793     writer.append(op.lhs(), op.rhs(), op.getSuccessors());
794     return;
795   }
796 
797   writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
798 }
799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
800   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
801 }
802 void Generator::generate(pdl_interp::CheckAttributeOp op,
803                          ByteCodeWriter &writer) {
804   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
805                 op.getSuccessors());
806 }
807 void Generator::generate(pdl_interp::CheckOperandCountOp op,
808                          ByteCodeWriter &writer) {
809   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
810                 static_cast<ByteCodeField>(op.compareAtLeast()),
811                 op.getSuccessors());
812 }
813 void Generator::generate(pdl_interp::CheckOperationNameOp op,
814                          ByteCodeWriter &writer) {
815   writer.append(OpCode::CheckOperationName, op.operation(),
816                 OperationName(op.name(), ctx), op.getSuccessors());
817 }
818 void Generator::generate(pdl_interp::CheckResultCountOp op,
819                          ByteCodeWriter &writer) {
820   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
821                 static_cast<ByteCodeField>(op.compareAtLeast()),
822                 op.getSuccessors());
823 }
824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
825   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
826 }
827 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
828   writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
829 }
830 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
831   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
832   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
833 }
834 void Generator::generate(pdl_interp::CreateAttributeOp op,
835                          ByteCodeWriter &writer) {
836   // Simply repoint the memory index of the result to the constant.
837   getMemIndex(op.attribute()) = getMemIndex(op.value());
838 }
839 void Generator::generate(pdl_interp::CreateOperationOp op,
840                          ByteCodeWriter &writer) {
841   writer.append(OpCode::CreateOperation, op.operation(),
842                 OperationName(op.name(), ctx));
843   writer.appendPDLValueList(op.operands());
844 
845   // Add the attributes.
846   OperandRange attributes = op.attributes();
847   writer.append(static_cast<ByteCodeField>(attributes.size()));
848   for (auto it : llvm::zip(op.attributeNames(), op.attributes()))
849     writer.append(std::get<0>(it), std::get<1>(it));
850   writer.appendPDLValueList(op.types());
851 }
852 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
853   // Simply repoint the memory index of the result to the constant.
854   getMemIndex(op.result()) = getMemIndex(op.value());
855 }
856 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
857   writer.append(OpCode::CreateTypes, op.result(),
858                 getRangeStorageIndex(op.result()), op.value());
859 }
860 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
861   writer.append(OpCode::EraseOp, op.operation());
862 }
863 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
864   OpCode opCode =
865       TypeSwitch<Type, OpCode>(op.result().getType())
866           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
867           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
868           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
869           .Default([](Type) -> OpCode {
870             llvm_unreachable("unsupported element type");
871           });
872   writer.append(opCode, op.range(), op.index(), op.result());
873 }
874 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
875   writer.append(OpCode::Finalize);
876 }
877 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
878   BlockArgument arg = op.getLoopVariable();
879   writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg);
880   writer.appendPDLValueKind(arg.getType());
881   writer.append(curLoopLevel, op.successor());
882   ++curLoopLevel;
883   if (curLoopLevel > maxLoopLevel)
884     maxLoopLevel = curLoopLevel;
885   generate(&op.region(), writer);
886   --curLoopLevel;
887 }
888 void Generator::generate(pdl_interp::GetAttributeOp op,
889                          ByteCodeWriter &writer) {
890   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
891                 op.nameAttr());
892 }
893 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
894                          ByteCodeWriter &writer) {
895   writer.append(OpCode::GetAttributeType, op.result(), op.value());
896 }
897 void Generator::generate(pdl_interp::GetDefiningOpOp op,
898                          ByteCodeWriter &writer) {
899   writer.append(OpCode::GetDefiningOp, op.operation());
900   writer.appendPDLValue(op.value());
901 }
902 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
903   uint32_t index = op.index();
904   if (index < 4)
905     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
906   else
907     writer.append(OpCode::GetOperandN, index);
908   writer.append(op.operation(), op.value());
909 }
910 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
911   Value result = op.value();
912   Optional<uint32_t> index = op.index();
913   writer.append(OpCode::GetOperands,
914                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
915                 op.operation());
916   if (result.getType().isa<pdl::RangeType>())
917     writer.append(getRangeStorageIndex(result));
918   else
919     writer.append(std::numeric_limits<ByteCodeField>::max());
920   writer.append(result);
921 }
922 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
923   uint32_t index = op.index();
924   if (index < 4)
925     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
926   else
927     writer.append(OpCode::GetResultN, index);
928   writer.append(op.operation(), op.value());
929 }
930 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
931   Value result = op.value();
932   Optional<uint32_t> index = op.index();
933   writer.append(OpCode::GetResults,
934                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
935                 op.operation());
936   if (result.getType().isa<pdl::RangeType>())
937     writer.append(getRangeStorageIndex(result));
938   else
939     writer.append(std::numeric_limits<ByteCodeField>::max());
940   writer.append(result);
941 }
942 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
943   Value operations = op.operations();
944   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
945   writer.append(OpCode::GetUsers, operations, rangeIndex);
946   writer.appendPDLValue(op.value());
947 }
948 void Generator::generate(pdl_interp::GetValueTypeOp op,
949                          ByteCodeWriter &writer) {
950   if (op.getType().isa<pdl::RangeType>()) {
951     Value result = op.result();
952     writer.append(OpCode::GetValueRangeTypes, result,
953                   getRangeStorageIndex(result), op.value());
954   } else {
955     writer.append(OpCode::GetValueType, op.result(), op.value());
956   }
957 }
958 
959 void Generator::generate(pdl_interp::InferredTypesOp op,
960                          ByteCodeWriter &writer) {
961   // InferType maps to a null type as a marker for inferring result types.
962   getMemIndex(op.type()) = getMemIndex(Type());
963 }
964 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
965   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
966 }
967 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
968   ByteCodeField patternIndex = patterns.size();
969   patterns.emplace_back(PDLByteCodePattern::create(
970       op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
971   writer.append(OpCode::RecordMatch, patternIndex,
972                 SuccessorRange(op.getOperation()), op.matchedOps());
973   writer.appendPDLValueList(op.inputs());
974 }
975 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
976   writer.append(OpCode::ReplaceOp, op.operation());
977   writer.appendPDLValueList(op.replValues());
978 }
979 void Generator::generate(pdl_interp::SwitchAttributeOp op,
980                          ByteCodeWriter &writer) {
981   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
982                 op.getSuccessors());
983 }
984 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
985                          ByteCodeWriter &writer) {
986   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
987                 op.getSuccessors());
988 }
989 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
990                          ByteCodeWriter &writer) {
991   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
992     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
993   });
994   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
995                 op.getSuccessors());
996 }
997 void Generator::generate(pdl_interp::SwitchResultCountOp op,
998                          ByteCodeWriter &writer) {
999   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
1000                 op.getSuccessors());
1001 }
1002 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1003   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
1004                 op.getSuccessors());
1005 }
1006 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1007   writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
1008                 op.getSuccessors());
1009 }
1010 
1011 //===----------------------------------------------------------------------===//
1012 // PDLByteCode
1013 //===----------------------------------------------------------------------===//
1014 
1015 PDLByteCode::PDLByteCode(ModuleOp module,
1016                          llvm::StringMap<PDLConstraintFunction> constraintFns,
1017                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1018   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1019                       rewriterByteCode, patterns, maxValueMemoryIndex,
1020                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1021                       maxLoopLevel, constraintFns, rewriteFns);
1022   generator.generate(module);
1023 
1024   // Initialize the external functions.
1025   for (auto &it : constraintFns)
1026     constraintFunctions.push_back(std::move(it.second));
1027   for (auto &it : rewriteFns)
1028     rewriteFunctions.push_back(std::move(it.second));
1029 }
1030 
1031 /// Initialize the given state such that it can be used to execute the current
1032 /// bytecode.
1033 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1034   state.memory.resize(maxValueMemoryIndex, nullptr);
1035   state.opRangeMemory.resize(maxOpRangeCount);
1036   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1037   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1038   state.loopIndex.resize(maxLoopLevel, 0);
1039   state.currentPatternBenefits.reserve(patterns.size());
1040   for (const PDLByteCodePattern &pattern : patterns)
1041     state.currentPatternBenefits.push_back(pattern.getBenefit());
1042 }
1043 
1044 //===----------------------------------------------------------------------===//
1045 // ByteCode Execution
1046 
1047 namespace {
1048 /// This class provides support for executing a bytecode stream.
1049 class ByteCodeExecutor {
1050 public:
1051   ByteCodeExecutor(
1052       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1053       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1054       MutableArrayRef<TypeRange> typeRangeMemory,
1055       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1056       MutableArrayRef<ValueRange> valueRangeMemory,
1057       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1058       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1059       ArrayRef<ByteCodeField> code,
1060       ArrayRef<PatternBenefit> currentPatternBenefits,
1061       ArrayRef<PDLByteCodePattern> patterns,
1062       ArrayRef<PDLConstraintFunction> constraintFunctions,
1063       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1064       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1065         typeRangeMemory(typeRangeMemory),
1066         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1067         valueRangeMemory(valueRangeMemory),
1068         allocatedValueRangeMemory(allocatedValueRangeMemory),
1069         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1070         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1071         constraintFunctions(constraintFunctions),
1072         rewriteFunctions(rewriteFunctions) {}
1073 
1074   /// Start executing the code at the current bytecode index. `matches` is an
1075   /// optional field provided when this function is executed in a matching
1076   /// context.
1077   void execute(PatternRewriter &rewriter,
1078                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1079                Optional<Location> mainRewriteLoc = {});
1080 
1081 private:
1082   /// Internal implementation of executing each of the bytecode commands.
1083   void executeApplyConstraint(PatternRewriter &rewriter);
1084   void executeApplyRewrite(PatternRewriter &rewriter);
1085   void executeAreEqual();
1086   void executeAreRangesEqual();
1087   void executeBranch();
1088   void executeCheckOperandCount();
1089   void executeCheckOperationName();
1090   void executeCheckResultCount();
1091   void executeCheckTypes();
1092   void executeContinue();
1093   void executeCreateOperation(PatternRewriter &rewriter,
1094                               Location mainRewriteLoc);
1095   void executeCreateTypes();
1096   void executeEraseOp(PatternRewriter &rewriter);
1097   template <typename T, typename Range, PDLValue::Kind kind>
1098   void executeExtract();
1099   void executeFinalize();
1100   void executeForEach();
1101   void executeGetAttribute();
1102   void executeGetAttributeType();
1103   void executeGetDefiningOp();
1104   void executeGetOperand(unsigned index);
1105   void executeGetOperands();
1106   void executeGetResult(unsigned index);
1107   void executeGetResults();
1108   void executeGetUsers();
1109   void executeGetValueType();
1110   void executeGetValueRangeTypes();
1111   void executeIsNotNull();
1112   void executeRecordMatch(PatternRewriter &rewriter,
1113                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1114   void executeReplaceOp(PatternRewriter &rewriter);
1115   void executeSwitchAttribute();
1116   void executeSwitchOperandCount();
1117   void executeSwitchOperationName();
1118   void executeSwitchResultCount();
1119   void executeSwitchType();
1120   void executeSwitchTypes();
1121 
1122   /// Pushes a code iterator to the stack.
1123   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1124 
1125   /// Pops a code iterator from the stack, returning true on success.
1126   void popCodeIt() {
1127     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1128     curCodeIt = resumeCodeIt.back();
1129     resumeCodeIt.pop_back();
1130   }
1131 
1132   /// Return the bytecode iterator at the start of the current op code.
1133   const ByteCodeField *getPrevCodeIt() const {
1134     LLVM_DEBUG({
1135       // Account for the op code and the Location stored inline.
1136       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1137     });
1138 
1139     // Account for the op code only.
1140     return curCodeIt - 1;
1141   }
1142 
1143   /// Read a value from the bytecode buffer, optionally skipping a certain
1144   /// number of prefix values. These methods always update the buffer to point
1145   /// to the next field after the read data.
1146   template <typename T = ByteCodeField>
1147   T read(size_t skipN = 0) {
1148     curCodeIt += skipN;
1149     return readImpl<T>();
1150   }
1151   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1152 
1153   /// Read a list of values from the bytecode buffer.
1154   template <typename ValueT, typename T>
1155   void readList(SmallVectorImpl<T> &list) {
1156     list.clear();
1157     for (unsigned i = 0, e = read(); i != e; ++i)
1158       list.push_back(read<ValueT>());
1159   }
1160 
1161   /// Read a list of values from the bytecode buffer. The values may be encoded
1162   /// as either Value or ValueRange elements.
1163   void readValueList(SmallVectorImpl<Value> &list) {
1164     for (unsigned i = 0, e = read(); i != e; ++i) {
1165       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1166         list.push_back(read<Value>());
1167       } else {
1168         ValueRange *values = read<ValueRange *>();
1169         list.append(values->begin(), values->end());
1170       }
1171     }
1172   }
1173 
1174   /// Read a value stored inline as a pointer.
1175   template <typename T>
1176   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1177   readInline() {
1178     const void *pointer;
1179     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1180     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1181     return T::getFromOpaquePointer(pointer);
1182   }
1183 
1184   /// Jump to a specific successor based on a predicate value.
1185   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1186   /// Jump to a specific successor based on a destination index.
1187   void selectJump(size_t destIndex) {
1188     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1189   }
1190 
1191   /// Handle a switch operation with the provided value and cases.
1192   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1193   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1194     LLVM_DEBUG({
1195       llvm::dbgs() << "  * Value: " << value << "\n"
1196                    << "  * Cases: ";
1197       llvm::interleaveComma(cases, llvm::dbgs());
1198       llvm::dbgs() << "\n";
1199     });
1200 
1201     // Check to see if the attribute value is within the case list. Jump to
1202     // the correct successor index based on the result.
1203     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1204       if (cmp(*it, value))
1205         return selectJump(size_t((it - cases.begin()) + 1));
1206     selectJump(size_t(0));
1207   }
1208 
1209   /// Store a pointer to memory.
1210   void storeToMemory(unsigned index, const void *value) {
1211     memory[index] = value;
1212   }
1213 
1214   /// Store a value to memory as an opaque pointer.
1215   template <typename T>
1216   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1217   storeToMemory(unsigned index, T value) {
1218     memory[index] = value.getAsOpaquePointer();
1219   }
1220 
1221   /// Internal implementation of reading various data types from the bytecode
1222   /// stream.
1223   template <typename T>
1224   const void *readFromMemory() {
1225     size_t index = *curCodeIt++;
1226 
1227     // If this type is an SSA value, it can only be stored in non-const memory.
1228     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1229                         Value>::value ||
1230         index < memory.size())
1231       return memory[index];
1232 
1233     // Otherwise, if this index is not inbounds it is uniqued.
1234     return uniquedMemory[index - memory.size()];
1235   }
1236   template <typename T>
1237   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1238     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1239   }
1240   template <typename T>
1241   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1242                    T>
1243   readImpl() {
1244     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1245   }
1246   template <typename T>
1247   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1248     switch (read<PDLValue::Kind>()) {
1249     case PDLValue::Kind::Attribute:
1250       return read<Attribute>();
1251     case PDLValue::Kind::Operation:
1252       return read<Operation *>();
1253     case PDLValue::Kind::Type:
1254       return read<Type>();
1255     case PDLValue::Kind::Value:
1256       return read<Value>();
1257     case PDLValue::Kind::TypeRange:
1258       return read<TypeRange *>();
1259     case PDLValue::Kind::ValueRange:
1260       return read<ValueRange *>();
1261     }
1262     llvm_unreachable("unhandled PDLValue::Kind");
1263   }
1264   template <typename T>
1265   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1266     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1267                   "unexpected ByteCode address size");
1268     ByteCodeAddr result;
1269     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1270     curCodeIt += 2;
1271     return result;
1272   }
1273   template <typename T>
1274   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1275     return *curCodeIt++;
1276   }
1277   template <typename T>
1278   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1279     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1280   }
1281 
1282   /// The underlying bytecode buffer.
1283   const ByteCodeField *curCodeIt;
1284 
1285   /// The stack of bytecode positions at which to resume operation.
1286   SmallVector<const ByteCodeField *> resumeCodeIt;
1287 
1288   /// The current execution memory.
1289   MutableArrayRef<const void *> memory;
1290   MutableArrayRef<OwningOpRange> opRangeMemory;
1291   MutableArrayRef<TypeRange> typeRangeMemory;
1292   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1293   MutableArrayRef<ValueRange> valueRangeMemory;
1294   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1295 
1296   /// The current loop indices.
1297   MutableArrayRef<unsigned> loopIndex;
1298 
1299   /// References to ByteCode data necessary for execution.
1300   ArrayRef<const void *> uniquedMemory;
1301   ArrayRef<ByteCodeField> code;
1302   ArrayRef<PatternBenefit> currentPatternBenefits;
1303   ArrayRef<PDLByteCodePattern> patterns;
1304   ArrayRef<PDLConstraintFunction> constraintFunctions;
1305   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1306 };
1307 
1308 /// This class is an instantiation of the PDLResultList that provides access to
1309 /// the returned results. This API is not on `PDLResultList` to avoid
1310 /// overexposing access to information specific solely to the ByteCode.
1311 class ByteCodeRewriteResultList : public PDLResultList {
1312 public:
1313   ByteCodeRewriteResultList(unsigned maxNumResults)
1314       : PDLResultList(maxNumResults) {}
1315 
1316   /// Return the list of PDL results.
1317   MutableArrayRef<PDLValue> getResults() { return results; }
1318 
1319   /// Return the type ranges allocated by this list.
1320   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1321     return allocatedTypeRanges;
1322   }
1323 
1324   /// Return the value ranges allocated by this list.
1325   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1326     return allocatedValueRanges;
1327   }
1328 };
1329 } // namespace
1330 
1331 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1332   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1333   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1334   ArrayAttr constParams = read<ArrayAttr>();
1335   SmallVector<PDLValue, 16> args;
1336   readList<PDLValue>(args);
1337 
1338   LLVM_DEBUG({
1339     llvm::dbgs() << "  * Arguments: ";
1340     llvm::interleaveComma(args, llvm::dbgs());
1341     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1342   });
1343 
1344   // Invoke the constraint and jump to the proper destination.
1345   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1346 }
1347 
1348 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1349   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1350   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1351   ArrayAttr constParams = read<ArrayAttr>();
1352   SmallVector<PDLValue, 16> args;
1353   readList<PDLValue>(args);
1354 
1355   LLVM_DEBUG({
1356     llvm::dbgs() << "  * Arguments: ";
1357     llvm::interleaveComma(args, llvm::dbgs());
1358     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1359   });
1360 
1361   // Execute the rewrite function.
1362   ByteCodeField numResults = read();
1363   ByteCodeRewriteResultList results(numResults);
1364   rewriteFn(args, constParams, rewriter, results);
1365 
1366   assert(results.getResults().size() == numResults &&
1367          "native PDL rewrite function returned unexpected number of results");
1368 
1369   // Store the results in the bytecode memory.
1370   for (PDLValue &result : results.getResults()) {
1371     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1372 
1373 // In debug mode we also verify the expected kind of the result.
1374 #ifndef NDEBUG
1375     assert(result.getKind() == read<PDLValue::Kind>() &&
1376            "native PDL rewrite function returned an unexpected type of result");
1377 #endif
1378 
1379     // If the result is a range, we need to copy it over to the bytecodes
1380     // range memory.
1381     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1382       unsigned rangeIndex = read();
1383       typeRangeMemory[rangeIndex] = *typeRange;
1384       memory[read()] = &typeRangeMemory[rangeIndex];
1385     } else if (Optional<ValueRange> valueRange =
1386                    result.dyn_cast<ValueRange>()) {
1387       unsigned rangeIndex = read();
1388       valueRangeMemory[rangeIndex] = *valueRange;
1389       memory[read()] = &valueRangeMemory[rangeIndex];
1390     } else {
1391       memory[read()] = result.getAsOpaquePointer();
1392     }
1393   }
1394 
1395   // Copy over any underlying storage allocated for result ranges.
1396   for (auto &it : results.getAllocatedTypeRanges())
1397     allocatedTypeRangeMemory.push_back(std::move(it));
1398   for (auto &it : results.getAllocatedValueRanges())
1399     allocatedValueRangeMemory.push_back(std::move(it));
1400 }
1401 
1402 void ByteCodeExecutor::executeAreEqual() {
1403   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1404   const void *lhs = read<const void *>();
1405   const void *rhs = read<const void *>();
1406 
1407   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1408   selectJump(lhs == rhs);
1409 }
1410 
1411 void ByteCodeExecutor::executeAreRangesEqual() {
1412   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1413   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1414   const void *lhs = read<const void *>();
1415   const void *rhs = read<const void *>();
1416 
1417   switch (valueKind) {
1418   case PDLValue::Kind::TypeRange: {
1419     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1420     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1421     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1422     selectJump(*lhsRange == *rhsRange);
1423     break;
1424   }
1425   case PDLValue::Kind::ValueRange: {
1426     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1427     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1428     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1429     selectJump(*lhsRange == *rhsRange);
1430     break;
1431   }
1432   default:
1433     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1434   }
1435 }
1436 
1437 void ByteCodeExecutor::executeBranch() {
1438   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1439   curCodeIt = &code[read<ByteCodeAddr>()];
1440 }
1441 
1442 void ByteCodeExecutor::executeCheckOperandCount() {
1443   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1444   Operation *op = read<Operation *>();
1445   uint32_t expectedCount = read<uint32_t>();
1446   bool compareAtLeast = read();
1447 
1448   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1449                           << "  * Expected: " << expectedCount << "\n"
1450                           << "  * Comparator: "
1451                           << (compareAtLeast ? ">=" : "==") << "\n");
1452   if (compareAtLeast)
1453     selectJump(op->getNumOperands() >= expectedCount);
1454   else
1455     selectJump(op->getNumOperands() == expectedCount);
1456 }
1457 
1458 void ByteCodeExecutor::executeCheckOperationName() {
1459   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1460   Operation *op = read<Operation *>();
1461   OperationName expectedName = read<OperationName>();
1462 
1463   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1464                           << "  * Expected: \"" << expectedName << "\"\n");
1465   selectJump(op->getName() == expectedName);
1466 }
1467 
1468 void ByteCodeExecutor::executeCheckResultCount() {
1469   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1470   Operation *op = read<Operation *>();
1471   uint32_t expectedCount = read<uint32_t>();
1472   bool compareAtLeast = read();
1473 
1474   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1475                           << "  * Expected: " << expectedCount << "\n"
1476                           << "  * Comparator: "
1477                           << (compareAtLeast ? ">=" : "==") << "\n");
1478   if (compareAtLeast)
1479     selectJump(op->getNumResults() >= expectedCount);
1480   else
1481     selectJump(op->getNumResults() == expectedCount);
1482 }
1483 
1484 void ByteCodeExecutor::executeCheckTypes() {
1485   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1486   TypeRange *lhs = read<TypeRange *>();
1487   Attribute rhs = read<Attribute>();
1488   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1489 
1490   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1491 }
1492 
1493 void ByteCodeExecutor::executeContinue() {
1494   ByteCodeField level = read();
1495   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1496                           << "  * Level: " << level << "\n");
1497   ++loopIndex[level];
1498   popCodeIt();
1499 }
1500 
1501 void ByteCodeExecutor::executeCreateTypes() {
1502   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1503   unsigned memIndex = read();
1504   unsigned rangeIndex = read();
1505   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1506 
1507   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1508 
1509   // Allocate a buffer for this type range.
1510   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1511   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1512   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1513 
1514   // Assign this to the range slot and use the range as the value for the
1515   // memory index.
1516   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1517   memory[memIndex] = &typeRangeMemory[rangeIndex];
1518 }
1519 
1520 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1521                                               Location mainRewriteLoc) {
1522   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1523 
1524   unsigned memIndex = read();
1525   OperationState state(mainRewriteLoc, read<OperationName>());
1526   readValueList(state.operands);
1527   for (unsigned i = 0, e = read(); i != e; ++i) {
1528     StringAttr name = read<StringAttr>();
1529     if (Attribute attr = read<Attribute>())
1530       state.addAttribute(name, attr);
1531   }
1532 
1533   for (unsigned i = 0, e = read(); i != e; ++i) {
1534     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1535       state.types.push_back(read<Type>());
1536       continue;
1537     }
1538 
1539     // If we find a null range, this signals that the types are infered.
1540     if (TypeRange *resultTypes = read<TypeRange *>()) {
1541       state.types.append(resultTypes->begin(), resultTypes->end());
1542       continue;
1543     }
1544 
1545     // Handle the case where the operation has inferred types.
1546     InferTypeOpInterface::Concept *inferInterface =
1547         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1548 
1549     // TODO: Handle failure.
1550     state.types.clear();
1551     if (failed(inferInterface->inferReturnTypes(
1552             state.getContext(), state.location, state.operands,
1553             state.attributes.getDictionary(state.getContext()), state.regions,
1554             state.types)))
1555       return;
1556     break;
1557   }
1558 
1559   Operation *resultOp = rewriter.createOperation(state);
1560   memory[memIndex] = resultOp;
1561 
1562   LLVM_DEBUG({
1563     llvm::dbgs() << "  * Attributes: "
1564                  << state.attributes.getDictionary(state.getContext())
1565                  << "\n  * Operands: ";
1566     llvm::interleaveComma(state.operands, llvm::dbgs());
1567     llvm::dbgs() << "\n  * Result Types: ";
1568     llvm::interleaveComma(state.types, llvm::dbgs());
1569     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1570   });
1571 }
1572 
1573 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1574   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1575   Operation *op = read<Operation *>();
1576 
1577   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1578   rewriter.eraseOp(op);
1579 }
1580 
1581 template <typename T, typename Range, PDLValue::Kind kind>
1582 void ByteCodeExecutor::executeExtract() {
1583   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1584   Range *range = read<Range *>();
1585   unsigned index = read<uint32_t>();
1586   unsigned memIndex = read();
1587 
1588   if (!range) {
1589     memory[memIndex] = nullptr;
1590     return;
1591   }
1592 
1593   T result = index < range->size() ? (*range)[index] : T();
1594   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1595                           << "  * Index: " << index << "\n"
1596                           << "  * Result: " << result << "\n");
1597   storeToMemory(memIndex, result);
1598 }
1599 
1600 void ByteCodeExecutor::executeFinalize() {
1601   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1602 }
1603 
1604 void ByteCodeExecutor::executeForEach() {
1605   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1606   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1607   unsigned rangeIndex = read();
1608   unsigned memIndex = read();
1609   const void *value = nullptr;
1610 
1611   switch (read<PDLValue::Kind>()) {
1612   case PDLValue::Kind::Operation: {
1613     unsigned &index = loopIndex[read()];
1614     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1615     assert(index <= array.size() && "iterated past the end");
1616     if (index < array.size()) {
1617       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1618       value = array[index];
1619       break;
1620     }
1621 
1622     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1623     index = 0;
1624     selectJump(size_t(0));
1625     return;
1626   }
1627   default:
1628     llvm_unreachable("unexpected `ForEach` value kind");
1629   }
1630 
1631   // Store the iterate value and the stack address.
1632   memory[memIndex] = value;
1633   pushCodeIt(prevCodeIt);
1634 
1635   // Skip over the successor (we will enter the body of the loop).
1636   read<ByteCodeAddr>();
1637 }
1638 
1639 void ByteCodeExecutor::executeGetAttribute() {
1640   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1641   unsigned memIndex = read();
1642   Operation *op = read<Operation *>();
1643   StringAttr attrName = read<StringAttr>();
1644   Attribute attr = op->getAttr(attrName);
1645 
1646   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1647                           << "  * Attribute: " << attrName << "\n"
1648                           << "  * Result: " << attr << "\n");
1649   memory[memIndex] = attr.getAsOpaquePointer();
1650 }
1651 
1652 void ByteCodeExecutor::executeGetAttributeType() {
1653   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1654   unsigned memIndex = read();
1655   Attribute attr = read<Attribute>();
1656   Type type = attr ? attr.getType() : Type();
1657 
1658   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1659                           << "  * Result: " << type << "\n");
1660   memory[memIndex] = type.getAsOpaquePointer();
1661 }
1662 
1663 void ByteCodeExecutor::executeGetDefiningOp() {
1664   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1665   unsigned memIndex = read();
1666   Operation *op = nullptr;
1667   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1668     Value value = read<Value>();
1669     if (value)
1670       op = value.getDefiningOp();
1671     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1672   } else {
1673     ValueRange *values = read<ValueRange *>();
1674     if (values && !values->empty()) {
1675       op = values->front().getDefiningOp();
1676     }
1677     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1678   }
1679 
1680   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1681   memory[memIndex] = op;
1682 }
1683 
1684 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1685   Operation *op = read<Operation *>();
1686   unsigned memIndex = read();
1687   Value operand =
1688       index < op->getNumOperands() ? op->getOperand(index) : Value();
1689 
1690   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1691                           << "  * Index: " << index << "\n"
1692                           << "  * Result: " << operand << "\n");
1693   memory[memIndex] = operand.getAsOpaquePointer();
1694 }
1695 
1696 /// This function is the internal implementation of `GetResults` and
1697 /// `GetOperands` that provides support for extracting a value range from the
1698 /// given operation.
1699 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1700 static void *
1701 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1702                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1703                           MutableArrayRef<ValueRange> valueRangeMemory) {
1704   // Check for the sentinel index that signals that all values should be
1705   // returned.
1706   if (index == std::numeric_limits<uint32_t>::max()) {
1707     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1708     // `values` is already the full value range.
1709 
1710     // Otherwise, check to see if this operation uses AttrSizedSegments.
1711   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1712     LLVM_DEBUG(llvm::dbgs()
1713                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1714 
1715     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1716     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1717       return nullptr;
1718 
1719     auto segments = segmentAttr.getValues<int32_t>();
1720     unsigned startIndex =
1721         std::accumulate(segments.begin(), segments.begin() + index, 0);
1722     values = values.slice(startIndex, *std::next(segments.begin(), index));
1723 
1724     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1725                             << *std::next(segments.begin(), index) << "]\n");
1726 
1727     // Otherwise, assume this is the last operand group of the operation.
1728     // FIXME: We currently don't support operations with
1729     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1730     // have a way to detect it's presence.
1731   } else if (values.size() >= index) {
1732     LLVM_DEBUG(llvm::dbgs()
1733                << "  * Treating values as trailing variadic range\n");
1734     values = values.drop_front(index);
1735 
1736     // If we couldn't detect a way to compute the values, bail out.
1737   } else {
1738     return nullptr;
1739   }
1740 
1741   // If the range index is valid, we are returning a range.
1742   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1743     valueRangeMemory[rangeIndex] = values;
1744     return &valueRangeMemory[rangeIndex];
1745   }
1746 
1747   // If a range index wasn't provided, the range is required to be non-variadic.
1748   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1749 }
1750 
1751 void ByteCodeExecutor::executeGetOperands() {
1752   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1753   unsigned index = read<uint32_t>();
1754   Operation *op = read<Operation *>();
1755   ByteCodeField rangeIndex = read();
1756 
1757   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1758       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1759       valueRangeMemory);
1760   if (!result)
1761     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1762   memory[read()] = result;
1763 }
1764 
1765 void ByteCodeExecutor::executeGetResult(unsigned index) {
1766   Operation *op = read<Operation *>();
1767   unsigned memIndex = read();
1768   OpResult result =
1769       index < op->getNumResults() ? op->getResult(index) : OpResult();
1770 
1771   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1772                           << "  * Index: " << index << "\n"
1773                           << "  * Result: " << result << "\n");
1774   memory[memIndex] = result.getAsOpaquePointer();
1775 }
1776 
1777 void ByteCodeExecutor::executeGetResults() {
1778   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1779   unsigned index = read<uint32_t>();
1780   Operation *op = read<Operation *>();
1781   ByteCodeField rangeIndex = read();
1782 
1783   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1784       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1785       valueRangeMemory);
1786   if (!result)
1787     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1788   memory[read()] = result;
1789 }
1790 
1791 void ByteCodeExecutor::executeGetUsers() {
1792   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1793   unsigned memIndex = read();
1794   unsigned rangeIndex = read();
1795   OwningOpRange &range = opRangeMemory[rangeIndex];
1796   memory[memIndex] = &range;
1797 
1798   range = OwningOpRange();
1799   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1800     // Read the value.
1801     Value value = read<Value>();
1802     if (!value)
1803       return;
1804     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1805 
1806     // Extract the users of a single value.
1807     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1808     llvm::copy(value.getUsers(), range.begin());
1809   } else {
1810     // Read a range of values.
1811     ValueRange *values = read<ValueRange *>();
1812     if (!values)
1813       return;
1814     LLVM_DEBUG({
1815       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1816       llvm::interleaveComma(*values, llvm::dbgs());
1817       llvm::dbgs() << "\n";
1818     });
1819 
1820     // Extract all the users of a range of values.
1821     SmallVector<Operation *> users;
1822     for (Value value : *values)
1823       users.append(value.user_begin(), value.user_end());
1824     range = OwningOpRange(users.size());
1825     llvm::copy(users, range.begin());
1826   }
1827 
1828   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1829 }
1830 
1831 void ByteCodeExecutor::executeGetValueType() {
1832   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1833   unsigned memIndex = read();
1834   Value value = read<Value>();
1835   Type type = value ? value.getType() : Type();
1836 
1837   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1838                           << "  * Result: " << type << "\n");
1839   memory[memIndex] = type.getAsOpaquePointer();
1840 }
1841 
1842 void ByteCodeExecutor::executeGetValueRangeTypes() {
1843   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1844   unsigned memIndex = read();
1845   unsigned rangeIndex = read();
1846   ValueRange *values = read<ValueRange *>();
1847   if (!values) {
1848     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1849     memory[memIndex] = nullptr;
1850     return;
1851   }
1852 
1853   LLVM_DEBUG({
1854     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1855     llvm::interleaveComma(*values, llvm::dbgs());
1856     llvm::dbgs() << "\n  * Result: ";
1857     llvm::interleaveComma(values->getType(), llvm::dbgs());
1858     llvm::dbgs() << "\n";
1859   });
1860   typeRangeMemory[rangeIndex] = values->getType();
1861   memory[memIndex] = &typeRangeMemory[rangeIndex];
1862 }
1863 
1864 void ByteCodeExecutor::executeIsNotNull() {
1865   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1866   const void *value = read<const void *>();
1867 
1868   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1869   selectJump(value != nullptr);
1870 }
1871 
1872 void ByteCodeExecutor::executeRecordMatch(
1873     PatternRewriter &rewriter,
1874     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1875   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1876   unsigned patternIndex = read();
1877   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1878   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1879 
1880   // If the benefit of the pattern is impossible, skip the processing of the
1881   // rest of the pattern.
1882   if (benefit.isImpossibleToMatch()) {
1883     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1884     curCodeIt = dest;
1885     return;
1886   }
1887 
1888   // Create a fused location containing the locations of each of the
1889   // operations used in the match. This will be used as the location for
1890   // created operations during the rewrite that don't already have an
1891   // explicit location set.
1892   unsigned numMatchLocs = read();
1893   SmallVector<Location, 4> matchLocs;
1894   matchLocs.reserve(numMatchLocs);
1895   for (unsigned i = 0; i != numMatchLocs; ++i)
1896     matchLocs.push_back(read<Operation *>()->getLoc());
1897   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1898 
1899   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1900                           << "  * Location: " << matchLoc << "\n");
1901   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1902   PDLByteCode::MatchResult &match = matches.back();
1903 
1904   // Record all of the inputs to the match. If any of the inputs are ranges, we
1905   // will also need to remap the range pointer to memory stored in the match
1906   // state.
1907   unsigned numInputs = read();
1908   match.values.reserve(numInputs);
1909   match.typeRangeValues.reserve(numInputs);
1910   match.valueRangeValues.reserve(numInputs);
1911   for (unsigned i = 0; i < numInputs; ++i) {
1912     switch (read<PDLValue::Kind>()) {
1913     case PDLValue::Kind::TypeRange:
1914       match.typeRangeValues.push_back(*read<TypeRange *>());
1915       match.values.push_back(&match.typeRangeValues.back());
1916       break;
1917     case PDLValue::Kind::ValueRange:
1918       match.valueRangeValues.push_back(*read<ValueRange *>());
1919       match.values.push_back(&match.valueRangeValues.back());
1920       break;
1921     default:
1922       match.values.push_back(read<const void *>());
1923       break;
1924     }
1925   }
1926   curCodeIt = dest;
1927 }
1928 
1929 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1930   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1931   Operation *op = read<Operation *>();
1932   SmallVector<Value, 16> args;
1933   readValueList(args);
1934 
1935   LLVM_DEBUG({
1936     llvm::dbgs() << "  * Operation: " << *op << "\n"
1937                  << "  * Values: ";
1938     llvm::interleaveComma(args, llvm::dbgs());
1939     llvm::dbgs() << "\n";
1940   });
1941   rewriter.replaceOp(op, args);
1942 }
1943 
1944 void ByteCodeExecutor::executeSwitchAttribute() {
1945   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1946   Attribute value = read<Attribute>();
1947   ArrayAttr cases = read<ArrayAttr>();
1948   handleSwitch(value, cases);
1949 }
1950 
1951 void ByteCodeExecutor::executeSwitchOperandCount() {
1952   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1953   Operation *op = read<Operation *>();
1954   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1955 
1956   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1957   handleSwitch(op->getNumOperands(), cases);
1958 }
1959 
1960 void ByteCodeExecutor::executeSwitchOperationName() {
1961   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1962   OperationName value = read<Operation *>()->getName();
1963   size_t caseCount = read();
1964 
1965   // The operation names are stored in-line, so to print them out for
1966   // debugging purposes we need to read the array before executing the
1967   // switch so that we can display all of the possible values.
1968   LLVM_DEBUG({
1969     const ByteCodeField *prevCodeIt = curCodeIt;
1970     llvm::dbgs() << "  * Value: " << value << "\n"
1971                  << "  * Cases: ";
1972     llvm::interleaveComma(
1973         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1974                         [&](size_t) { return read<OperationName>(); }),
1975         llvm::dbgs());
1976     llvm::dbgs() << "\n";
1977     curCodeIt = prevCodeIt;
1978   });
1979 
1980   // Try to find the switch value within any of the cases.
1981   for (size_t i = 0; i != caseCount; ++i) {
1982     if (read<OperationName>() == value) {
1983       curCodeIt += (caseCount - i - 1);
1984       return selectJump(i + 1);
1985     }
1986   }
1987   selectJump(size_t(0));
1988 }
1989 
1990 void ByteCodeExecutor::executeSwitchResultCount() {
1991   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1992   Operation *op = read<Operation *>();
1993   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1994 
1995   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1996   handleSwitch(op->getNumResults(), cases);
1997 }
1998 
1999 void ByteCodeExecutor::executeSwitchType() {
2000   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2001   Type value = read<Type>();
2002   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2003   handleSwitch(value, cases);
2004 }
2005 
2006 void ByteCodeExecutor::executeSwitchTypes() {
2007   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2008   TypeRange *value = read<TypeRange *>();
2009   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2010   if (!value) {
2011     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2012     return selectJump(size_t(0));
2013   }
2014   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2015     return value == caseValue.getAsValueRange<TypeAttr>();
2016   });
2017 }
2018 
2019 void ByteCodeExecutor::execute(
2020     PatternRewriter &rewriter,
2021     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2022     Optional<Location> mainRewriteLoc) {
2023   while (true) {
2024     // Print the location of the operation being executed.
2025     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2026 
2027     OpCode opCode = static_cast<OpCode>(read());
2028     switch (opCode) {
2029     case ApplyConstraint:
2030       executeApplyConstraint(rewriter);
2031       break;
2032     case ApplyRewrite:
2033       executeApplyRewrite(rewriter);
2034       break;
2035     case AreEqual:
2036       executeAreEqual();
2037       break;
2038     case AreRangesEqual:
2039       executeAreRangesEqual();
2040       break;
2041     case Branch:
2042       executeBranch();
2043       break;
2044     case CheckOperandCount:
2045       executeCheckOperandCount();
2046       break;
2047     case CheckOperationName:
2048       executeCheckOperationName();
2049       break;
2050     case CheckResultCount:
2051       executeCheckResultCount();
2052       break;
2053     case CheckTypes:
2054       executeCheckTypes();
2055       break;
2056     case Continue:
2057       executeContinue();
2058       break;
2059     case CreateOperation:
2060       executeCreateOperation(rewriter, *mainRewriteLoc);
2061       break;
2062     case CreateTypes:
2063       executeCreateTypes();
2064       break;
2065     case EraseOp:
2066       executeEraseOp(rewriter);
2067       break;
2068     case ExtractOp:
2069       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2070       break;
2071     case ExtractType:
2072       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2073       break;
2074     case ExtractValue:
2075       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2076       break;
2077     case Finalize:
2078       executeFinalize();
2079       LLVM_DEBUG(llvm::dbgs() << "\n");
2080       return;
2081     case ForEach:
2082       executeForEach();
2083       break;
2084     case GetAttribute:
2085       executeGetAttribute();
2086       break;
2087     case GetAttributeType:
2088       executeGetAttributeType();
2089       break;
2090     case GetDefiningOp:
2091       executeGetDefiningOp();
2092       break;
2093     case GetOperand0:
2094     case GetOperand1:
2095     case GetOperand2:
2096     case GetOperand3: {
2097       unsigned index = opCode - GetOperand0;
2098       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2099       executeGetOperand(index);
2100       break;
2101     }
2102     case GetOperandN:
2103       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2104       executeGetOperand(read<uint32_t>());
2105       break;
2106     case GetOperands:
2107       executeGetOperands();
2108       break;
2109     case GetResult0:
2110     case GetResult1:
2111     case GetResult2:
2112     case GetResult3: {
2113       unsigned index = opCode - GetResult0;
2114       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2115       executeGetResult(index);
2116       break;
2117     }
2118     case GetResultN:
2119       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2120       executeGetResult(read<uint32_t>());
2121       break;
2122     case GetResults:
2123       executeGetResults();
2124       break;
2125     case GetUsers:
2126       executeGetUsers();
2127       break;
2128     case GetValueType:
2129       executeGetValueType();
2130       break;
2131     case GetValueRangeTypes:
2132       executeGetValueRangeTypes();
2133       break;
2134     case IsNotNull:
2135       executeIsNotNull();
2136       break;
2137     case RecordMatch:
2138       assert(matches &&
2139              "expected matches to be provided when executing the matcher");
2140       executeRecordMatch(rewriter, *matches);
2141       break;
2142     case ReplaceOp:
2143       executeReplaceOp(rewriter);
2144       break;
2145     case SwitchAttribute:
2146       executeSwitchAttribute();
2147       break;
2148     case SwitchOperandCount:
2149       executeSwitchOperandCount();
2150       break;
2151     case SwitchOperationName:
2152       executeSwitchOperationName();
2153       break;
2154     case SwitchResultCount:
2155       executeSwitchResultCount();
2156       break;
2157     case SwitchType:
2158       executeSwitchType();
2159       break;
2160     case SwitchTypes:
2161       executeSwitchTypes();
2162       break;
2163     }
2164     LLVM_DEBUG(llvm::dbgs() << "\n");
2165   }
2166 }
2167 
2168 /// Run the pattern matcher on the given root operation, collecting the matched
2169 /// patterns in `matches`.
2170 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2171                         SmallVectorImpl<MatchResult> &matches,
2172                         PDLByteCodeMutableState &state) const {
2173   // The first memory slot is always the root operation.
2174   state.memory[0] = op;
2175 
2176   // The matcher function always starts at code address 0.
2177   ByteCodeExecutor executor(
2178       matcherByteCode.data(), state.memory, state.opRangeMemory,
2179       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2180       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2181       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2182       constraintFunctions, rewriteFunctions);
2183   executor.execute(rewriter, &matches);
2184 
2185   // Order the found matches by benefit.
2186   std::stable_sort(matches.begin(), matches.end(),
2187                    [](const MatchResult &lhs, const MatchResult &rhs) {
2188                      return lhs.benefit > rhs.benefit;
2189                    });
2190 }
2191 
2192 /// Run the rewriter of the given pattern on the root operation `op`.
2193 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2194                           PDLByteCodeMutableState &state) const {
2195   // The arguments of the rewrite function are stored at the start of the
2196   // memory buffer.
2197   llvm::copy(match.values, state.memory.begin());
2198 
2199   ByteCodeExecutor executor(
2200       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2201       state.opRangeMemory, state.typeRangeMemory,
2202       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2203       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2204       rewriterByteCode, state.currentPatternBenefits, patterns,
2205       constraintFunctions, rewriteFunctions);
2206   executor.execute(rewriter, /*matches=*/nullptr, match.location);
2207 }
2208