1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.getBenefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.getRootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Continue to the next iteration of a loop.
99   Continue,
100   /// Create an operation.
101   CreateOperation,
102   /// Create a range of types.
103   CreateTypes,
104   /// Erase an operation.
105   EraseOp,
106   /// Extract the op from a range at the specified index.
107   ExtractOp,
108   /// Extract the type from a range at the specified index.
109   ExtractType,
110   /// Extract the value from a range at the specified index.
111   ExtractValue,
112   /// Terminate a matcher or rewrite sequence.
113   Finalize,
114   /// Iterate over a range of values.
115   ForEach,
116   /// Get a specific attribute of an operation.
117   GetAttribute,
118   /// Get the type of an attribute.
119   GetAttributeType,
120   /// Get the defining operation of a value.
121   GetDefiningOp,
122   /// Get a specific operand of an operation.
123   GetOperand0,
124   GetOperand1,
125   GetOperand2,
126   GetOperand3,
127   GetOperandN,
128   /// Get a specific operand group of an operation.
129   GetOperands,
130   /// Get a specific result of an operation.
131   GetResult0,
132   GetResult1,
133   GetResult2,
134   GetResult3,
135   GetResultN,
136   /// Get a specific result group of an operation.
137   GetResults,
138   /// Get the users of a value or a range of values.
139   GetUsers,
140   /// Get the type of a value.
141   GetValueType,
142   /// Get the types of a value range.
143   GetValueRangeTypes,
144   /// Check if a generic value is not null.
145   IsNotNull,
146   /// Record a successful pattern match.
147   RecordMatch,
148   /// Replace an operation.
149   ReplaceOp,
150   /// Compare an attribute with a set of constants.
151   SwitchAttribute,
152   /// Compare the operand count of an operation with a set of constants.
153   SwitchOperandCount,
154   /// Compare the name of an operation with a set of constants.
155   SwitchOperationName,
156   /// Compare the result count of an operation with a set of constants.
157   SwitchResultCount,
158   /// Compare a type with a set of constants.
159   SwitchType,
160   /// Compare a range of types with a set of constants.
161   SwitchTypes,
162 };
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // ByteCode Generation
167 //===----------------------------------------------------------------------===//
168 
169 //===----------------------------------------------------------------------===//
170 // Generator
171 
172 namespace {
173 struct ByteCodeLiveRange;
174 struct ByteCodeWriter;
175 
176 /// Check if the given class `T` can be converted to an opaque pointer.
177 template <typename T, typename... Args>
178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
179 
180 /// This class represents the main generator for the pattern bytecode.
181 class Generator {
182 public:
183   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
184             SmallVectorImpl<ByteCodeField> &matcherByteCode,
185             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
186             SmallVectorImpl<PDLByteCodePattern> &patterns,
187             ByteCodeField &maxValueMemoryIndex,
188             ByteCodeField &maxOpRangeMemoryIndex,
189             ByteCodeField &maxTypeRangeMemoryIndex,
190             ByteCodeField &maxValueRangeMemoryIndex,
191             ByteCodeField &maxLoopLevel,
192             llvm::StringMap<PDLConstraintFunction> &constraintFns,
193             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
194       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
195         rewriterByteCode(rewriterByteCode), patterns(patterns),
196         maxValueMemoryIndex(maxValueMemoryIndex),
197         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
198         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
199         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
200         maxLoopLevel(maxLoopLevel) {
201     for (const auto &it : llvm::enumerate(constraintFns))
202       constraintToMemIndex.try_emplace(it.value().first(), it.index());
203     for (const auto &it : llvm::enumerate(rewriteFns))
204       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
205   }
206 
207   /// Generate the bytecode for the given PDL interpreter module.
208   void generate(ModuleOp module);
209 
210   /// Return the memory index to use for the given value.
211   ByteCodeField &getMemIndex(Value value) {
212     assert(valueToMemIndex.count(value) &&
213            "expected memory index to be assigned");
214     return valueToMemIndex[value];
215   }
216 
217   /// Return the range memory index used to store the given range value.
218   ByteCodeField &getRangeStorageIndex(Value value) {
219     assert(valueToRangeIndex.count(value) &&
220            "expected range index to be assigned");
221     return valueToRangeIndex[value];
222   }
223 
224   /// Return an index to use when referring to the given data that is uniqued in
225   /// the MLIR context.
226   template <typename T>
227   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
228   getMemIndex(T val) {
229     const void *opaqueVal = val.getAsOpaquePointer();
230 
231     // Get or insert a reference to this value.
232     auto it = uniquedDataToMemIndex.try_emplace(
233         opaqueVal, maxValueMemoryIndex + uniquedData.size());
234     if (it.second)
235       uniquedData.push_back(opaqueVal);
236     return it.first->second;
237   }
238 
239 private:
240   /// Allocate memory indices for the results of operations within the matcher
241   /// and rewriters.
242   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
243                              ModuleOp rewriterModule);
244 
245   /// Generate the bytecode for the given operation.
246   void generate(Region *region, ByteCodeWriter &writer);
247   void generate(Operation *op, ByteCodeWriter &writer);
248   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
249   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
250   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
251   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
252   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
259   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
286 
287   /// Mapping from value to its corresponding memory index.
288   DenseMap<Value, ByteCodeField> valueToMemIndex;
289 
290   /// Mapping from a range value to its corresponding range storage index.
291   DenseMap<Value, ByteCodeField> valueToRangeIndex;
292 
293   /// Mapping from the name of an externally registered rewrite to its index in
294   /// the bytecode registry.
295   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
296 
297   /// Mapping from the name of an externally registered constraint to its index
298   /// in the bytecode registry.
299   llvm::StringMap<ByteCodeField> constraintToMemIndex;
300 
301   /// Mapping from rewriter function name to the bytecode address of the
302   /// rewriter function in byte.
303   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
304 
305   /// Mapping from a uniqued storage object to its memory index within
306   /// `uniquedData`.
307   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
308 
309   /// The current level of the foreach loop.
310   ByteCodeField curLoopLevel = 0;
311 
312   /// The current MLIR context.
313   MLIRContext *ctx;
314 
315   /// Mapping from block to its address.
316   DenseMap<Block *, ByteCodeAddr> blockToAddr;
317 
318   /// Data of the ByteCode class to be populated.
319   std::vector<const void *> &uniquedData;
320   SmallVectorImpl<ByteCodeField> &matcherByteCode;
321   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
322   SmallVectorImpl<PDLByteCodePattern> &patterns;
323   ByteCodeField &maxValueMemoryIndex;
324   ByteCodeField &maxOpRangeMemoryIndex;
325   ByteCodeField &maxTypeRangeMemoryIndex;
326   ByteCodeField &maxValueRangeMemoryIndex;
327   ByteCodeField &maxLoopLevel;
328 };
329 
330 /// This class provides utilities for writing a bytecode stream.
331 struct ByteCodeWriter {
332   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
333       : bytecode(bytecode), generator(generator) {}
334 
335   /// Append a field to the bytecode.
336   void append(ByteCodeField field) { bytecode.push_back(field); }
337   void append(OpCode opCode) { bytecode.push_back(opCode); }
338 
339   /// Append an address to the bytecode.
340   void append(ByteCodeAddr field) {
341     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
342                   "unexpected ByteCode address size");
343 
344     ByteCodeField fieldParts[2];
345     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
346     bytecode.append({fieldParts[0], fieldParts[1]});
347   }
348 
349   /// Append a single successor to the bytecode, the exact address will need to
350   /// be resolved later.
351   void append(Block *successor) {
352     // Add back a reference to the successor so that the address can be resolved
353     // later.
354     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
355     append(ByteCodeAddr(0));
356   }
357 
358   /// Append a successor range to the bytecode, the exact address will need to
359   /// be resolved later.
360   void append(SuccessorRange successors) {
361     for (Block *successor : successors)
362       append(successor);
363   }
364 
365   /// Append a range of values that will be read as generic PDLValues.
366   void appendPDLValueList(OperandRange values) {
367     bytecode.push_back(values.size());
368     for (Value value : values)
369       appendPDLValue(value);
370   }
371 
372   /// Append a value as a PDLValue.
373   void appendPDLValue(Value value) {
374     appendPDLValueKind(value);
375     append(value);
376   }
377 
378   /// Append the PDLValue::Kind of the given value.
379   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
380 
381   /// Append the PDLValue::Kind of the given type.
382   void appendPDLValueKind(Type type) {
383     PDLValue::Kind kind =
384         TypeSwitch<Type, PDLValue::Kind>(type)
385             .Case<pdl::AttributeType>(
386                 [](Type) { return PDLValue::Kind::Attribute; })
387             .Case<pdl::OperationType>(
388                 [](Type) { return PDLValue::Kind::Operation; })
389             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
390               if (rangeTy.getElementType().isa<pdl::TypeType>())
391                 return PDLValue::Kind::TypeRange;
392               return PDLValue::Kind::ValueRange;
393             })
394             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
395             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
396     bytecode.push_back(static_cast<ByteCodeField>(kind));
397   }
398 
399   /// Append a value that will be stored in a memory slot and not inline within
400   /// the bytecode.
401   template <typename T>
402   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
403                    std::is_pointer<T>::value>
404   append(T value) {
405     bytecode.push_back(generator.getMemIndex(value));
406   }
407 
408   /// Append a range of values.
409   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
410   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
411   append(T range) {
412     bytecode.push_back(llvm::size(range));
413     for (auto it : range)
414       append(it);
415   }
416 
417   /// Append a variadic number of fields to the bytecode.
418   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
419   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
420     append(field);
421     append(field2, fields...);
422   }
423 
424   /// Appends a value as a pointer, stored inline within the bytecode.
425   template <typename T>
426   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
427   appendInline(T value) {
428     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
429     const void *pointer = value.getAsOpaquePointer();
430     ByteCodeField fieldParts[numParts];
431     std::memcpy(fieldParts, &pointer, sizeof(const void *));
432     bytecode.append(fieldParts, fieldParts + numParts);
433   }
434 
435   /// Successor references in the bytecode that have yet to be resolved.
436   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
437 
438   /// The underlying bytecode buffer.
439   SmallVectorImpl<ByteCodeField> &bytecode;
440 
441   /// The main generator producing PDL.
442   Generator &generator;
443 };
444 
445 /// This class represents a live range of PDL Interpreter values, containing
446 /// information about when values are live within a match/rewrite.
447 struct ByteCodeLiveRange {
448   using Set = llvm::IntervalMap<uint64_t, char, 16>;
449   using Allocator = Set::Allocator;
450 
451   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
452 
453   /// Union this live range with the one provided.
454   void unionWith(const ByteCodeLiveRange &rhs) {
455     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
456          ++it)
457       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
458   }
459 
460   /// Returns true if this range overlaps with the one provided.
461   bool overlaps(const ByteCodeLiveRange &rhs) const {
462     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
463         .valid();
464   }
465 
466   /// A map representing the ranges of the match/rewrite that a value is live in
467   /// the interpreter.
468   ///
469   /// We use std::unique_ptr here, because IntervalMap does not provide a
470   /// correct copy or move constructor. We can eliminate the pointer once
471   /// https://reviews.llvm.org/D113240 lands.
472   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
473 
474   /// The operation range storage index for this range.
475   Optional<unsigned> opRangeIndex;
476 
477   /// The type range storage index for this range.
478   Optional<unsigned> typeRangeIndex;
479 
480   /// The value range storage index for this range.
481   Optional<unsigned> valueRangeIndex;
482 };
483 } // namespace
484 
485 void Generator::generate(ModuleOp module) {
486   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
487       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
488   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
489       pdl_interp::PDLInterpDialect::getRewriterModuleName());
490   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
491 
492   // Allocate memory indices for the results of operations within the matcher
493   // and rewriters.
494   allocateMemoryIndices(matcherFunc, rewriterModule);
495 
496   // Generate code for the rewriter functions.
497   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
498   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
499     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
500     for (Operation &op : rewriterFunc.getOps())
501       generate(&op, rewriterByteCodeWriter);
502   }
503   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
504          "unexpected branches in rewriter function");
505 
506   // Generate code for the matcher function.
507   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
508   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
509 
510   // Resolve successor references in the matcher.
511   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
512     ByteCodeAddr addr = blockToAddr[it.first];
513     for (unsigned offsetToFix : it.second)
514       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
515   }
516 }
517 
518 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
519                                       ModuleOp rewriterModule) {
520   // Rewriters use simplistic allocation scheme that simply assigns an index to
521   // each result.
522   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
523     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
524     auto processRewriterValue = [&](Value val) {
525       valueToMemIndex.try_emplace(val, index++);
526       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
527         Type elementTy = rangeType.getElementType();
528         if (elementTy.isa<pdl::TypeType>())
529           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
530         else if (elementTy.isa<pdl::ValueType>())
531           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
532       }
533     };
534 
535     for (BlockArgument arg : rewriterFunc.getArguments())
536       processRewriterValue(arg);
537     rewriterFunc.getBody().walk([&](Operation *op) {
538       for (Value result : op->getResults())
539         processRewriterValue(result);
540     });
541     if (index > maxValueMemoryIndex)
542       maxValueMemoryIndex = index;
543     if (typeRangeIndex > maxTypeRangeMemoryIndex)
544       maxTypeRangeMemoryIndex = typeRangeIndex;
545     if (valueRangeIndex > maxValueRangeMemoryIndex)
546       maxValueRangeMemoryIndex = valueRangeIndex;
547   }
548 
549   // The matcher function uses a more sophisticated numbering that tries to
550   // minimize the number of memory indices assigned. This is done by determining
551   // a live range of the values within the matcher, then the allocation is just
552   // finding the minimal number of overlapping live ranges. This is essentially
553   // a simplified form of register allocation where we don't necessarily have a
554   // limited number of registers, but we still want to minimize the number used.
555   DenseMap<Operation *, unsigned> opToFirstIndex;
556   DenseMap<Operation *, unsigned> opToLastIndex;
557 
558   // A custom walk that marks the first and the last index of each operation.
559   // The entry marks the beginning of the liveness range for this operation,
560   // followed by nested operations, followed by the end of the liveness range.
561   unsigned index = 0;
562   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
563     opToFirstIndex.try_emplace(op, index++);
564     for (Region &region : op->getRegions())
565       for (Block &block : region.getBlocks())
566         for (Operation &nested : block)
567           walk(&nested);
568     opToLastIndex.try_emplace(op, index++);
569   };
570   walk(matcherFunc);
571 
572   // Liveness info for each of the defs within the matcher.
573   ByteCodeLiveRange::Allocator allocator;
574   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
575 
576   // Assign the root operation being matched to slot 0.
577   BlockArgument rootOpArg = matcherFunc.getArgument(0);
578   valueToMemIndex[rootOpArg] = 0;
579 
580   // Walk each of the blocks, computing the def interval that the value is used.
581   Liveness matcherLiveness(matcherFunc);
582   matcherFunc->walk([&](Block *block) {
583     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
584     assert(info && "expected liveness info for block");
585     auto processValue = [&](Value value, Operation *firstUseOrDef) {
586       // We don't need to process the root op argument, this value is always
587       // assigned to the first memory slot.
588       if (value == rootOpArg)
589         return;
590 
591       // Set indices for the range of this block that the value is used.
592       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
593       defRangeIt->second.liveness->insert(
594           opToFirstIndex[firstUseOrDef],
595           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
596           /*dummyValue*/ 0);
597 
598       // Check to see if this value is a range type.
599       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
600         Type eleType = rangeTy.getElementType();
601         if (eleType.isa<pdl::OperationType>())
602           defRangeIt->second.opRangeIndex = 0;
603         else if (eleType.isa<pdl::TypeType>())
604           defRangeIt->second.typeRangeIndex = 0;
605         else if (eleType.isa<pdl::ValueType>())
606           defRangeIt->second.valueRangeIndex = 0;
607       }
608     };
609 
610     // Process the live-ins of this block.
611     for (Value liveIn : info->in()) {
612       // Only process the value if it has been defined in the current region.
613       // Other values that span across pdl_interp.foreach will be added higher
614       // up. This ensures that the we keep them alive for the entire duration
615       // of the loop.
616       if (liveIn.getParentRegion() == block->getParent())
617         processValue(liveIn, &block->front());
618     }
619 
620     // Process the block arguments for the entry block (those are not live-in).
621     if (block->isEntryBlock()) {
622       for (Value argument : block->getArguments())
623         processValue(argument, &block->front());
624     }
625 
626     // Process any new defs within this block.
627     for (Operation &op : *block)
628       for (Value result : op.getResults())
629         processValue(result, &op);
630   });
631 
632   // Greedily allocate memory slots using the computed def live ranges.
633   std::vector<ByteCodeLiveRange> allocatedIndices;
634 
635   // The number of memory indices currently allocated (and its next value).
636   // Recall that the root gets allocated memory index 0.
637   ByteCodeField numIndices = 1;
638 
639   // The number of memory ranges of various types (and their next values).
640   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
641 
642   for (auto &defIt : valueDefRanges) {
643     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
644     ByteCodeLiveRange &defRange = defIt.second;
645 
646     // Try to allocate to an existing index.
647     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
648       ByteCodeLiveRange &existingRange = existingIndexIt.value();
649       if (!defRange.overlaps(existingRange)) {
650         existingRange.unionWith(defRange);
651         memIndex = existingIndexIt.index() + 1;
652 
653         if (defRange.opRangeIndex) {
654           if (!existingRange.opRangeIndex)
655             existingRange.opRangeIndex = numOpRanges++;
656           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
657         } else if (defRange.typeRangeIndex) {
658           if (!existingRange.typeRangeIndex)
659             existingRange.typeRangeIndex = numTypeRanges++;
660           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
661         } else if (defRange.valueRangeIndex) {
662           if (!existingRange.valueRangeIndex)
663             existingRange.valueRangeIndex = numValueRanges++;
664           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
665         }
666         break;
667       }
668     }
669 
670     // If no existing index could be used, add a new one.
671     if (memIndex == 0) {
672       allocatedIndices.emplace_back(allocator);
673       ByteCodeLiveRange &newRange = allocatedIndices.back();
674       newRange.unionWith(defRange);
675 
676       // Allocate an index for op/type/value ranges.
677       if (defRange.opRangeIndex) {
678         newRange.opRangeIndex = numOpRanges;
679         valueToRangeIndex[defIt.first] = numOpRanges++;
680       } else if (defRange.typeRangeIndex) {
681         newRange.typeRangeIndex = numTypeRanges;
682         valueToRangeIndex[defIt.first] = numTypeRanges++;
683       } else if (defRange.valueRangeIndex) {
684         newRange.valueRangeIndex = numValueRanges;
685         valueToRangeIndex[defIt.first] = numValueRanges++;
686       }
687 
688       memIndex = allocatedIndices.size();
689       ++numIndices;
690     }
691   }
692 
693   // Print the index usage and ensure that we did not run out of index space.
694   LLVM_DEBUG({
695     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
696                  << "(down from initial " << valueDefRanges.size() << ").\n";
697   });
698   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
699          "Ran out of memory for allocated indices");
700 
701   // Update the max number of indices.
702   if (numIndices > maxValueMemoryIndex)
703     maxValueMemoryIndex = numIndices;
704   if (numOpRanges > maxOpRangeMemoryIndex)
705     maxOpRangeMemoryIndex = numOpRanges;
706   if (numTypeRanges > maxTypeRangeMemoryIndex)
707     maxTypeRangeMemoryIndex = numTypeRanges;
708   if (numValueRanges > maxValueRangeMemoryIndex)
709     maxValueRangeMemoryIndex = numValueRanges;
710 }
711 
712 void Generator::generate(Region *region, ByteCodeWriter &writer) {
713   llvm::ReversePostOrderTraversal<Region *> rpot(region);
714   for (Block *block : rpot) {
715     // Keep track of where this block begins within the matcher function.
716     blockToAddr.try_emplace(block, matcherByteCode.size());
717     for (Operation &op : *block)
718       generate(&op, writer);
719   }
720 }
721 
722 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
723   LLVM_DEBUG({
724     // The following list must contain all the operations that do not
725     // produce any bytecode.
726     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
727              pdl_interp::InferredTypesOp>(op))
728       writer.appendInline(op->getLoc());
729   });
730   TypeSwitch<Operation *>(op)
731       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
732             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
733             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
734             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
735             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
736             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
737             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
738             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
739             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
740             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
741             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
742             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
743             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
744             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
745             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
746             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
747             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
748             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
749             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
750           [&](auto interpOp) { this->generate(interpOp, writer); })
751       .Default([](Operation *) {
752         llvm_unreachable("unknown `pdl_interp` operation");
753       });
754 }
755 
756 void Generator::generate(pdl_interp::ApplyConstraintOp op,
757                          ByteCodeWriter &writer) {
758   assert(constraintToMemIndex.count(op.getName()) &&
759          "expected index for constraint function");
760   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()],
761                 op.getConstParamsAttr());
762   writer.appendPDLValueList(op.getArgs());
763   writer.append(op.getSuccessors());
764 }
765 void Generator::generate(pdl_interp::ApplyRewriteOp op,
766                          ByteCodeWriter &writer) {
767   assert(externalRewriterToMemIndex.count(op.getName()) &&
768          "expected index for rewrite function");
769   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()],
770                 op.getConstParamsAttr());
771   writer.appendPDLValueList(op.getArgs());
772 
773   ResultRange results = op.getResults();
774   writer.append(ByteCodeField(results.size()));
775   for (Value result : results) {
776     // In debug mode we also record the expected kind of the result, so that we
777     // can provide extra verification of the native rewrite function.
778 #ifndef NDEBUG
779     writer.appendPDLValueKind(result);
780 #endif
781 
782     // Range results also need to append the range storage index.
783     if (result.getType().isa<pdl::RangeType>())
784       writer.append(getRangeStorageIndex(result));
785     writer.append(result);
786   }
787 }
788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
789   Value lhs = op.getLhs();
790   if (lhs.getType().isa<pdl::RangeType>()) {
791     writer.append(OpCode::AreRangesEqual);
792     writer.appendPDLValueKind(lhs);
793     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
794     return;
795   }
796 
797   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
798 }
799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
800   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
801 }
802 void Generator::generate(pdl_interp::CheckAttributeOp op,
803                          ByteCodeWriter &writer) {
804   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
805                 op.getSuccessors());
806 }
807 void Generator::generate(pdl_interp::CheckOperandCountOp op,
808                          ByteCodeWriter &writer) {
809   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
810                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
811                 op.getSuccessors());
812 }
813 void Generator::generate(pdl_interp::CheckOperationNameOp op,
814                          ByteCodeWriter &writer) {
815   writer.append(OpCode::CheckOperationName, op.getInputOp(),
816                 OperationName(op.getName(), ctx), op.getSuccessors());
817 }
818 void Generator::generate(pdl_interp::CheckResultCountOp op,
819                          ByteCodeWriter &writer) {
820   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
821                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
822                 op.getSuccessors());
823 }
824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
825   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
826                 op.getSuccessors());
827 }
828 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
829   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
830                 op.getSuccessors());
831 }
832 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
833   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
834   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
835 }
836 void Generator::generate(pdl_interp::CreateAttributeOp op,
837                          ByteCodeWriter &writer) {
838   // Simply repoint the memory index of the result to the constant.
839   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
840 }
841 void Generator::generate(pdl_interp::CreateOperationOp op,
842                          ByteCodeWriter &writer) {
843   writer.append(OpCode::CreateOperation, op.getResultOp(),
844                 OperationName(op.getName(), ctx));
845   writer.appendPDLValueList(op.getInputOperands());
846 
847   // Add the attributes.
848   OperandRange attributes = op.getInputAttributes();
849   writer.append(static_cast<ByteCodeField>(attributes.size()));
850   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
851     writer.append(std::get<0>(it), std::get<1>(it));
852   writer.appendPDLValueList(op.getInputResultTypes());
853 }
854 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
855   // Simply repoint the memory index of the result to the constant.
856   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
857 }
858 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
859   writer.append(OpCode::CreateTypes, op.getResult(),
860                 getRangeStorageIndex(op.getResult()), op.getValue());
861 }
862 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
863   writer.append(OpCode::EraseOp, op.getInputOp());
864 }
865 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
866   OpCode opCode =
867       TypeSwitch<Type, OpCode>(op.getResult().getType())
868           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
869           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
870           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
871           .Default([](Type) -> OpCode {
872             llvm_unreachable("unsupported element type");
873           });
874   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
875 }
876 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
877   writer.append(OpCode::Finalize);
878 }
879 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
880   BlockArgument arg = op.getLoopVariable();
881   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
882   writer.appendPDLValueKind(arg.getType());
883   writer.append(curLoopLevel, op.getSuccessor());
884   ++curLoopLevel;
885   if (curLoopLevel > maxLoopLevel)
886     maxLoopLevel = curLoopLevel;
887   generate(&op.getRegion(), writer);
888   --curLoopLevel;
889 }
890 void Generator::generate(pdl_interp::GetAttributeOp op,
891                          ByteCodeWriter &writer) {
892   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
893                 op.getNameAttr());
894 }
895 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
896                          ByteCodeWriter &writer) {
897   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
898 }
899 void Generator::generate(pdl_interp::GetDefiningOpOp op,
900                          ByteCodeWriter &writer) {
901   writer.append(OpCode::GetDefiningOp, op.getInputOp());
902   writer.appendPDLValue(op.getValue());
903 }
904 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
905   uint32_t index = op.getIndex();
906   if (index < 4)
907     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
908   else
909     writer.append(OpCode::GetOperandN, index);
910   writer.append(op.getInputOp(), op.getValue());
911 }
912 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
913   Value result = op.getValue();
914   Optional<uint32_t> index = op.getIndex();
915   writer.append(OpCode::GetOperands,
916                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
917                 op.getInputOp());
918   if (result.getType().isa<pdl::RangeType>())
919     writer.append(getRangeStorageIndex(result));
920   else
921     writer.append(std::numeric_limits<ByteCodeField>::max());
922   writer.append(result);
923 }
924 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
925   uint32_t index = op.getIndex();
926   if (index < 4)
927     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
928   else
929     writer.append(OpCode::GetResultN, index);
930   writer.append(op.getInputOp(), op.getValue());
931 }
932 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
933   Value result = op.getValue();
934   Optional<uint32_t> index = op.getIndex();
935   writer.append(OpCode::GetResults,
936                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
937                 op.getInputOp());
938   if (result.getType().isa<pdl::RangeType>())
939     writer.append(getRangeStorageIndex(result));
940   else
941     writer.append(std::numeric_limits<ByteCodeField>::max());
942   writer.append(result);
943 }
944 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
945   Value operations = op.getOperations();
946   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
947   writer.append(OpCode::GetUsers, operations, rangeIndex);
948   writer.appendPDLValue(op.getValue());
949 }
950 void Generator::generate(pdl_interp::GetValueTypeOp op,
951                          ByteCodeWriter &writer) {
952   if (op.getType().isa<pdl::RangeType>()) {
953     Value result = op.getResult();
954     writer.append(OpCode::GetValueRangeTypes, result,
955                   getRangeStorageIndex(result), op.getValue());
956   } else {
957     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
958   }
959 }
960 
961 void Generator::generate(pdl_interp::InferredTypesOp op,
962                          ByteCodeWriter &writer) {
963   // InferType maps to a null type as a marker for inferring result types.
964   getMemIndex(op.getResult()) = getMemIndex(Type());
965 }
966 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
967   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
968 }
969 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
970   ByteCodeField patternIndex = patterns.size();
971   patterns.emplace_back(PDLByteCodePattern::create(
972       op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
973   writer.append(OpCode::RecordMatch, patternIndex,
974                 SuccessorRange(op.getOperation()), op.getMatchedOps());
975   writer.appendPDLValueList(op.getInputs());
976 }
977 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
978   writer.append(OpCode::ReplaceOp, op.getInputOp());
979   writer.appendPDLValueList(op.getReplValues());
980 }
981 void Generator::generate(pdl_interp::SwitchAttributeOp op,
982                          ByteCodeWriter &writer) {
983   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
984                 op.getCaseValuesAttr(), op.getSuccessors());
985 }
986 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
987                          ByteCodeWriter &writer) {
988   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
989                 op.getCaseValuesAttr(), op.getSuccessors());
990 }
991 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
992                          ByteCodeWriter &writer) {
993   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
994     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
995   });
996   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
997                 op.getSuccessors());
998 }
999 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1000                          ByteCodeWriter &writer) {
1001   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1002                 op.getCaseValuesAttr(), op.getSuccessors());
1003 }
1004 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1005   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1006                 op.getSuccessors());
1007 }
1008 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1009   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1010                 op.getSuccessors());
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // PDLByteCode
1015 //===----------------------------------------------------------------------===//
1016 
1017 PDLByteCode::PDLByteCode(ModuleOp module,
1018                          llvm::StringMap<PDLConstraintFunction> constraintFns,
1019                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1020   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1021                       rewriterByteCode, patterns, maxValueMemoryIndex,
1022                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1023                       maxLoopLevel, constraintFns, rewriteFns);
1024   generator.generate(module);
1025 
1026   // Initialize the external functions.
1027   for (auto &it : constraintFns)
1028     constraintFunctions.push_back(std::move(it.second));
1029   for (auto &it : rewriteFns)
1030     rewriteFunctions.push_back(std::move(it.second));
1031 }
1032 
1033 /// Initialize the given state such that it can be used to execute the current
1034 /// bytecode.
1035 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1036   state.memory.resize(maxValueMemoryIndex, nullptr);
1037   state.opRangeMemory.resize(maxOpRangeCount);
1038   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1039   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1040   state.loopIndex.resize(maxLoopLevel, 0);
1041   state.currentPatternBenefits.reserve(patterns.size());
1042   for (const PDLByteCodePattern &pattern : patterns)
1043     state.currentPatternBenefits.push_back(pattern.getBenefit());
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // ByteCode Execution
1048 
1049 namespace {
1050 /// This class provides support for executing a bytecode stream.
1051 class ByteCodeExecutor {
1052 public:
1053   ByteCodeExecutor(
1054       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1055       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1056       MutableArrayRef<TypeRange> typeRangeMemory,
1057       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1058       MutableArrayRef<ValueRange> valueRangeMemory,
1059       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1060       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1061       ArrayRef<ByteCodeField> code,
1062       ArrayRef<PatternBenefit> currentPatternBenefits,
1063       ArrayRef<PDLByteCodePattern> patterns,
1064       ArrayRef<PDLConstraintFunction> constraintFunctions,
1065       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1066       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1067         typeRangeMemory(typeRangeMemory),
1068         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1069         valueRangeMemory(valueRangeMemory),
1070         allocatedValueRangeMemory(allocatedValueRangeMemory),
1071         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1072         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1073         constraintFunctions(constraintFunctions),
1074         rewriteFunctions(rewriteFunctions) {}
1075 
1076   /// Start executing the code at the current bytecode index. `matches` is an
1077   /// optional field provided when this function is executed in a matching
1078   /// context.
1079   void execute(PatternRewriter &rewriter,
1080                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1081                Optional<Location> mainRewriteLoc = {});
1082 
1083 private:
1084   /// Internal implementation of executing each of the bytecode commands.
1085   void executeApplyConstraint(PatternRewriter &rewriter);
1086   void executeApplyRewrite(PatternRewriter &rewriter);
1087   void executeAreEqual();
1088   void executeAreRangesEqual();
1089   void executeBranch();
1090   void executeCheckOperandCount();
1091   void executeCheckOperationName();
1092   void executeCheckResultCount();
1093   void executeCheckTypes();
1094   void executeContinue();
1095   void executeCreateOperation(PatternRewriter &rewriter,
1096                               Location mainRewriteLoc);
1097   void executeCreateTypes();
1098   void executeEraseOp(PatternRewriter &rewriter);
1099   template <typename T, typename Range, PDLValue::Kind kind>
1100   void executeExtract();
1101   void executeFinalize();
1102   void executeForEach();
1103   void executeGetAttribute();
1104   void executeGetAttributeType();
1105   void executeGetDefiningOp();
1106   void executeGetOperand(unsigned index);
1107   void executeGetOperands();
1108   void executeGetResult(unsigned index);
1109   void executeGetResults();
1110   void executeGetUsers();
1111   void executeGetValueType();
1112   void executeGetValueRangeTypes();
1113   void executeIsNotNull();
1114   void executeRecordMatch(PatternRewriter &rewriter,
1115                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1116   void executeReplaceOp(PatternRewriter &rewriter);
1117   void executeSwitchAttribute();
1118   void executeSwitchOperandCount();
1119   void executeSwitchOperationName();
1120   void executeSwitchResultCount();
1121   void executeSwitchType();
1122   void executeSwitchTypes();
1123 
1124   /// Pushes a code iterator to the stack.
1125   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1126 
1127   /// Pops a code iterator from the stack, returning true on success.
1128   void popCodeIt() {
1129     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1130     curCodeIt = resumeCodeIt.back();
1131     resumeCodeIt.pop_back();
1132   }
1133 
1134   /// Return the bytecode iterator at the start of the current op code.
1135   const ByteCodeField *getPrevCodeIt() const {
1136     LLVM_DEBUG({
1137       // Account for the op code and the Location stored inline.
1138       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1139     });
1140 
1141     // Account for the op code only.
1142     return curCodeIt - 1;
1143   }
1144 
1145   /// Read a value from the bytecode buffer, optionally skipping a certain
1146   /// number of prefix values. These methods always update the buffer to point
1147   /// to the next field after the read data.
1148   template <typename T = ByteCodeField>
1149   T read(size_t skipN = 0) {
1150     curCodeIt += skipN;
1151     return readImpl<T>();
1152   }
1153   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1154 
1155   /// Read a list of values from the bytecode buffer.
1156   template <typename ValueT, typename T>
1157   void readList(SmallVectorImpl<T> &list) {
1158     list.clear();
1159     for (unsigned i = 0, e = read(); i != e; ++i)
1160       list.push_back(read<ValueT>());
1161   }
1162 
1163   /// Read a list of values from the bytecode buffer. The values may be encoded
1164   /// as either Value or ValueRange elements.
1165   void readValueList(SmallVectorImpl<Value> &list) {
1166     for (unsigned i = 0, e = read(); i != e; ++i) {
1167       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1168         list.push_back(read<Value>());
1169       } else {
1170         ValueRange *values = read<ValueRange *>();
1171         list.append(values->begin(), values->end());
1172       }
1173     }
1174   }
1175 
1176   /// Read a value stored inline as a pointer.
1177   template <typename T>
1178   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1179   readInline() {
1180     const void *pointer;
1181     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1182     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1183     return T::getFromOpaquePointer(pointer);
1184   }
1185 
1186   /// Jump to a specific successor based on a predicate value.
1187   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1188   /// Jump to a specific successor based on a destination index.
1189   void selectJump(size_t destIndex) {
1190     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1191   }
1192 
1193   /// Handle a switch operation with the provided value and cases.
1194   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1195   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1196     LLVM_DEBUG({
1197       llvm::dbgs() << "  * Value: " << value << "\n"
1198                    << "  * Cases: ";
1199       llvm::interleaveComma(cases, llvm::dbgs());
1200       llvm::dbgs() << "\n";
1201     });
1202 
1203     // Check to see if the attribute value is within the case list. Jump to
1204     // the correct successor index based on the result.
1205     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1206       if (cmp(*it, value))
1207         return selectJump(size_t((it - cases.begin()) + 1));
1208     selectJump(size_t(0));
1209   }
1210 
1211   /// Store a pointer to memory.
1212   void storeToMemory(unsigned index, const void *value) {
1213     memory[index] = value;
1214   }
1215 
1216   /// Store a value to memory as an opaque pointer.
1217   template <typename T>
1218   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1219   storeToMemory(unsigned index, T value) {
1220     memory[index] = value.getAsOpaquePointer();
1221   }
1222 
1223   /// Internal implementation of reading various data types from the bytecode
1224   /// stream.
1225   template <typename T>
1226   const void *readFromMemory() {
1227     size_t index = *curCodeIt++;
1228 
1229     // If this type is an SSA value, it can only be stored in non-const memory.
1230     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1231                         Value>::value ||
1232         index < memory.size())
1233       return memory[index];
1234 
1235     // Otherwise, if this index is not inbounds it is uniqued.
1236     return uniquedMemory[index - memory.size()];
1237   }
1238   template <typename T>
1239   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1240     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1241   }
1242   template <typename T>
1243   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1244                    T>
1245   readImpl() {
1246     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1247   }
1248   template <typename T>
1249   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1250     switch (read<PDLValue::Kind>()) {
1251     case PDLValue::Kind::Attribute:
1252       return read<Attribute>();
1253     case PDLValue::Kind::Operation:
1254       return read<Operation *>();
1255     case PDLValue::Kind::Type:
1256       return read<Type>();
1257     case PDLValue::Kind::Value:
1258       return read<Value>();
1259     case PDLValue::Kind::TypeRange:
1260       return read<TypeRange *>();
1261     case PDLValue::Kind::ValueRange:
1262       return read<ValueRange *>();
1263     }
1264     llvm_unreachable("unhandled PDLValue::Kind");
1265   }
1266   template <typename T>
1267   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1268     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1269                   "unexpected ByteCode address size");
1270     ByteCodeAddr result;
1271     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1272     curCodeIt += 2;
1273     return result;
1274   }
1275   template <typename T>
1276   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1277     return *curCodeIt++;
1278   }
1279   template <typename T>
1280   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1281     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1282   }
1283 
1284   /// The underlying bytecode buffer.
1285   const ByteCodeField *curCodeIt;
1286 
1287   /// The stack of bytecode positions at which to resume operation.
1288   SmallVector<const ByteCodeField *> resumeCodeIt;
1289 
1290   /// The current execution memory.
1291   MutableArrayRef<const void *> memory;
1292   MutableArrayRef<OwningOpRange> opRangeMemory;
1293   MutableArrayRef<TypeRange> typeRangeMemory;
1294   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1295   MutableArrayRef<ValueRange> valueRangeMemory;
1296   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1297 
1298   /// The current loop indices.
1299   MutableArrayRef<unsigned> loopIndex;
1300 
1301   /// References to ByteCode data necessary for execution.
1302   ArrayRef<const void *> uniquedMemory;
1303   ArrayRef<ByteCodeField> code;
1304   ArrayRef<PatternBenefit> currentPatternBenefits;
1305   ArrayRef<PDLByteCodePattern> patterns;
1306   ArrayRef<PDLConstraintFunction> constraintFunctions;
1307   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1308 };
1309 
1310 /// This class is an instantiation of the PDLResultList that provides access to
1311 /// the returned results. This API is not on `PDLResultList` to avoid
1312 /// overexposing access to information specific solely to the ByteCode.
1313 class ByteCodeRewriteResultList : public PDLResultList {
1314 public:
1315   ByteCodeRewriteResultList(unsigned maxNumResults)
1316       : PDLResultList(maxNumResults) {}
1317 
1318   /// Return the list of PDL results.
1319   MutableArrayRef<PDLValue> getResults() { return results; }
1320 
1321   /// Return the type ranges allocated by this list.
1322   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1323     return allocatedTypeRanges;
1324   }
1325 
1326   /// Return the value ranges allocated by this list.
1327   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1328     return allocatedValueRanges;
1329   }
1330 };
1331 } // namespace
1332 
1333 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1334   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1335   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1336   ArrayAttr constParams = read<ArrayAttr>();
1337   SmallVector<PDLValue, 16> args;
1338   readList<PDLValue>(args);
1339 
1340   LLVM_DEBUG({
1341     llvm::dbgs() << "  * Arguments: ";
1342     llvm::interleaveComma(args, llvm::dbgs());
1343     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1344   });
1345 
1346   // Invoke the constraint and jump to the proper destination.
1347   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1348 }
1349 
1350 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1351   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1352   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1353   ArrayAttr constParams = read<ArrayAttr>();
1354   SmallVector<PDLValue, 16> args;
1355   readList<PDLValue>(args);
1356 
1357   LLVM_DEBUG({
1358     llvm::dbgs() << "  * Arguments: ";
1359     llvm::interleaveComma(args, llvm::dbgs());
1360     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1361   });
1362 
1363   // Execute the rewrite function.
1364   ByteCodeField numResults = read();
1365   ByteCodeRewriteResultList results(numResults);
1366   rewriteFn(args, constParams, rewriter, results);
1367 
1368   assert(results.getResults().size() == numResults &&
1369          "native PDL rewrite function returned unexpected number of results");
1370 
1371   // Store the results in the bytecode memory.
1372   for (PDLValue &result : results.getResults()) {
1373     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1374 
1375 // In debug mode we also verify the expected kind of the result.
1376 #ifndef NDEBUG
1377     assert(result.getKind() == read<PDLValue::Kind>() &&
1378            "native PDL rewrite function returned an unexpected type of result");
1379 #endif
1380 
1381     // If the result is a range, we need to copy it over to the bytecodes
1382     // range memory.
1383     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1384       unsigned rangeIndex = read();
1385       typeRangeMemory[rangeIndex] = *typeRange;
1386       memory[read()] = &typeRangeMemory[rangeIndex];
1387     } else if (Optional<ValueRange> valueRange =
1388                    result.dyn_cast<ValueRange>()) {
1389       unsigned rangeIndex = read();
1390       valueRangeMemory[rangeIndex] = *valueRange;
1391       memory[read()] = &valueRangeMemory[rangeIndex];
1392     } else {
1393       memory[read()] = result.getAsOpaquePointer();
1394     }
1395   }
1396 
1397   // Copy over any underlying storage allocated for result ranges.
1398   for (auto &it : results.getAllocatedTypeRanges())
1399     allocatedTypeRangeMemory.push_back(std::move(it));
1400   for (auto &it : results.getAllocatedValueRanges())
1401     allocatedValueRangeMemory.push_back(std::move(it));
1402 }
1403 
1404 void ByteCodeExecutor::executeAreEqual() {
1405   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1406   const void *lhs = read<const void *>();
1407   const void *rhs = read<const void *>();
1408 
1409   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1410   selectJump(lhs == rhs);
1411 }
1412 
1413 void ByteCodeExecutor::executeAreRangesEqual() {
1414   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1415   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1416   const void *lhs = read<const void *>();
1417   const void *rhs = read<const void *>();
1418 
1419   switch (valueKind) {
1420   case PDLValue::Kind::TypeRange: {
1421     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1422     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1423     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1424     selectJump(*lhsRange == *rhsRange);
1425     break;
1426   }
1427   case PDLValue::Kind::ValueRange: {
1428     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1429     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1430     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1431     selectJump(*lhsRange == *rhsRange);
1432     break;
1433   }
1434   default:
1435     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1436   }
1437 }
1438 
1439 void ByteCodeExecutor::executeBranch() {
1440   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1441   curCodeIt = &code[read<ByteCodeAddr>()];
1442 }
1443 
1444 void ByteCodeExecutor::executeCheckOperandCount() {
1445   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1446   Operation *op = read<Operation *>();
1447   uint32_t expectedCount = read<uint32_t>();
1448   bool compareAtLeast = read();
1449 
1450   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1451                           << "  * Expected: " << expectedCount << "\n"
1452                           << "  * Comparator: "
1453                           << (compareAtLeast ? ">=" : "==") << "\n");
1454   if (compareAtLeast)
1455     selectJump(op->getNumOperands() >= expectedCount);
1456   else
1457     selectJump(op->getNumOperands() == expectedCount);
1458 }
1459 
1460 void ByteCodeExecutor::executeCheckOperationName() {
1461   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1462   Operation *op = read<Operation *>();
1463   OperationName expectedName = read<OperationName>();
1464 
1465   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1466                           << "  * Expected: \"" << expectedName << "\"\n");
1467   selectJump(op->getName() == expectedName);
1468 }
1469 
1470 void ByteCodeExecutor::executeCheckResultCount() {
1471   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1472   Operation *op = read<Operation *>();
1473   uint32_t expectedCount = read<uint32_t>();
1474   bool compareAtLeast = read();
1475 
1476   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1477                           << "  * Expected: " << expectedCount << "\n"
1478                           << "  * Comparator: "
1479                           << (compareAtLeast ? ">=" : "==") << "\n");
1480   if (compareAtLeast)
1481     selectJump(op->getNumResults() >= expectedCount);
1482   else
1483     selectJump(op->getNumResults() == expectedCount);
1484 }
1485 
1486 void ByteCodeExecutor::executeCheckTypes() {
1487   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1488   TypeRange *lhs = read<TypeRange *>();
1489   Attribute rhs = read<Attribute>();
1490   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1491 
1492   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1493 }
1494 
1495 void ByteCodeExecutor::executeContinue() {
1496   ByteCodeField level = read();
1497   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1498                           << "  * Level: " << level << "\n");
1499   ++loopIndex[level];
1500   popCodeIt();
1501 }
1502 
1503 void ByteCodeExecutor::executeCreateTypes() {
1504   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1505   unsigned memIndex = read();
1506   unsigned rangeIndex = read();
1507   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1508 
1509   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1510 
1511   // Allocate a buffer for this type range.
1512   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1513   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1514   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1515 
1516   // Assign this to the range slot and use the range as the value for the
1517   // memory index.
1518   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1519   memory[memIndex] = &typeRangeMemory[rangeIndex];
1520 }
1521 
1522 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1523                                               Location mainRewriteLoc) {
1524   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1525 
1526   unsigned memIndex = read();
1527   OperationState state(mainRewriteLoc, read<OperationName>());
1528   readValueList(state.operands);
1529   for (unsigned i = 0, e = read(); i != e; ++i) {
1530     StringAttr name = read<StringAttr>();
1531     if (Attribute attr = read<Attribute>())
1532       state.addAttribute(name, attr);
1533   }
1534 
1535   for (unsigned i = 0, e = read(); i != e; ++i) {
1536     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1537       state.types.push_back(read<Type>());
1538       continue;
1539     }
1540 
1541     // If we find a null range, this signals that the types are infered.
1542     if (TypeRange *resultTypes = read<TypeRange *>()) {
1543       state.types.append(resultTypes->begin(), resultTypes->end());
1544       continue;
1545     }
1546 
1547     // Handle the case where the operation has inferred types.
1548     InferTypeOpInterface::Concept *inferInterface =
1549         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1550 
1551     // TODO: Handle failure.
1552     state.types.clear();
1553     if (failed(inferInterface->inferReturnTypes(
1554             state.getContext(), state.location, state.operands,
1555             state.attributes.getDictionary(state.getContext()), state.regions,
1556             state.types)))
1557       return;
1558     break;
1559   }
1560 
1561   Operation *resultOp = rewriter.createOperation(state);
1562   memory[memIndex] = resultOp;
1563 
1564   LLVM_DEBUG({
1565     llvm::dbgs() << "  * Attributes: "
1566                  << state.attributes.getDictionary(state.getContext())
1567                  << "\n  * Operands: ";
1568     llvm::interleaveComma(state.operands, llvm::dbgs());
1569     llvm::dbgs() << "\n  * Result Types: ";
1570     llvm::interleaveComma(state.types, llvm::dbgs());
1571     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1572   });
1573 }
1574 
1575 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1576   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1577   Operation *op = read<Operation *>();
1578 
1579   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1580   rewriter.eraseOp(op);
1581 }
1582 
1583 template <typename T, typename Range, PDLValue::Kind kind>
1584 void ByteCodeExecutor::executeExtract() {
1585   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1586   Range *range = read<Range *>();
1587   unsigned index = read<uint32_t>();
1588   unsigned memIndex = read();
1589 
1590   if (!range) {
1591     memory[memIndex] = nullptr;
1592     return;
1593   }
1594 
1595   T result = index < range->size() ? (*range)[index] : T();
1596   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1597                           << "  * Index: " << index << "\n"
1598                           << "  * Result: " << result << "\n");
1599   storeToMemory(memIndex, result);
1600 }
1601 
1602 void ByteCodeExecutor::executeFinalize() {
1603   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1604 }
1605 
1606 void ByteCodeExecutor::executeForEach() {
1607   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1608   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1609   unsigned rangeIndex = read();
1610   unsigned memIndex = read();
1611   const void *value = nullptr;
1612 
1613   switch (read<PDLValue::Kind>()) {
1614   case PDLValue::Kind::Operation: {
1615     unsigned &index = loopIndex[read()];
1616     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1617     assert(index <= array.size() && "iterated past the end");
1618     if (index < array.size()) {
1619       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1620       value = array[index];
1621       break;
1622     }
1623 
1624     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1625     index = 0;
1626     selectJump(size_t(0));
1627     return;
1628   }
1629   default:
1630     llvm_unreachable("unexpected `ForEach` value kind");
1631   }
1632 
1633   // Store the iterate value and the stack address.
1634   memory[memIndex] = value;
1635   pushCodeIt(prevCodeIt);
1636 
1637   // Skip over the successor (we will enter the body of the loop).
1638   read<ByteCodeAddr>();
1639 }
1640 
1641 void ByteCodeExecutor::executeGetAttribute() {
1642   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1643   unsigned memIndex = read();
1644   Operation *op = read<Operation *>();
1645   StringAttr attrName = read<StringAttr>();
1646   Attribute attr = op->getAttr(attrName);
1647 
1648   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1649                           << "  * Attribute: " << attrName << "\n"
1650                           << "  * Result: " << attr << "\n");
1651   memory[memIndex] = attr.getAsOpaquePointer();
1652 }
1653 
1654 void ByteCodeExecutor::executeGetAttributeType() {
1655   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1656   unsigned memIndex = read();
1657   Attribute attr = read<Attribute>();
1658   Type type = attr ? attr.getType() : Type();
1659 
1660   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1661                           << "  * Result: " << type << "\n");
1662   memory[memIndex] = type.getAsOpaquePointer();
1663 }
1664 
1665 void ByteCodeExecutor::executeGetDefiningOp() {
1666   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1667   unsigned memIndex = read();
1668   Operation *op = nullptr;
1669   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1670     Value value = read<Value>();
1671     if (value)
1672       op = value.getDefiningOp();
1673     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1674   } else {
1675     ValueRange *values = read<ValueRange *>();
1676     if (values && !values->empty()) {
1677       op = values->front().getDefiningOp();
1678     }
1679     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1680   }
1681 
1682   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1683   memory[memIndex] = op;
1684 }
1685 
1686 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1687   Operation *op = read<Operation *>();
1688   unsigned memIndex = read();
1689   Value operand =
1690       index < op->getNumOperands() ? op->getOperand(index) : Value();
1691 
1692   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1693                           << "  * Index: " << index << "\n"
1694                           << "  * Result: " << operand << "\n");
1695   memory[memIndex] = operand.getAsOpaquePointer();
1696 }
1697 
1698 /// This function is the internal implementation of `GetResults` and
1699 /// `GetOperands` that provides support for extracting a value range from the
1700 /// given operation.
1701 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1702 static void *
1703 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1704                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1705                           MutableArrayRef<ValueRange> valueRangeMemory) {
1706   // Check for the sentinel index that signals that all values should be
1707   // returned.
1708   if (index == std::numeric_limits<uint32_t>::max()) {
1709     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1710     // `values` is already the full value range.
1711 
1712     // Otherwise, check to see if this operation uses AttrSizedSegments.
1713   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1714     LLVM_DEBUG(llvm::dbgs()
1715                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1716 
1717     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1718     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1719       return nullptr;
1720 
1721     auto segments = segmentAttr.getValues<int32_t>();
1722     unsigned startIndex =
1723         std::accumulate(segments.begin(), segments.begin() + index, 0);
1724     values = values.slice(startIndex, *std::next(segments.begin(), index));
1725 
1726     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1727                             << *std::next(segments.begin(), index) << "]\n");
1728 
1729     // Otherwise, assume this is the last operand group of the operation.
1730     // FIXME: We currently don't support operations with
1731     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1732     // have a way to detect it's presence.
1733   } else if (values.size() >= index) {
1734     LLVM_DEBUG(llvm::dbgs()
1735                << "  * Treating values as trailing variadic range\n");
1736     values = values.drop_front(index);
1737 
1738     // If we couldn't detect a way to compute the values, bail out.
1739   } else {
1740     return nullptr;
1741   }
1742 
1743   // If the range index is valid, we are returning a range.
1744   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1745     valueRangeMemory[rangeIndex] = values;
1746     return &valueRangeMemory[rangeIndex];
1747   }
1748 
1749   // If a range index wasn't provided, the range is required to be non-variadic.
1750   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1751 }
1752 
1753 void ByteCodeExecutor::executeGetOperands() {
1754   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1755   unsigned index = read<uint32_t>();
1756   Operation *op = read<Operation *>();
1757   ByteCodeField rangeIndex = read();
1758 
1759   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1760       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1761       valueRangeMemory);
1762   if (!result)
1763     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1764   memory[read()] = result;
1765 }
1766 
1767 void ByteCodeExecutor::executeGetResult(unsigned index) {
1768   Operation *op = read<Operation *>();
1769   unsigned memIndex = read();
1770   OpResult result =
1771       index < op->getNumResults() ? op->getResult(index) : OpResult();
1772 
1773   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1774                           << "  * Index: " << index << "\n"
1775                           << "  * Result: " << result << "\n");
1776   memory[memIndex] = result.getAsOpaquePointer();
1777 }
1778 
1779 void ByteCodeExecutor::executeGetResults() {
1780   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1781   unsigned index = read<uint32_t>();
1782   Operation *op = read<Operation *>();
1783   ByteCodeField rangeIndex = read();
1784 
1785   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1786       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1787       valueRangeMemory);
1788   if (!result)
1789     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1790   memory[read()] = result;
1791 }
1792 
1793 void ByteCodeExecutor::executeGetUsers() {
1794   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1795   unsigned memIndex = read();
1796   unsigned rangeIndex = read();
1797   OwningOpRange &range = opRangeMemory[rangeIndex];
1798   memory[memIndex] = &range;
1799 
1800   range = OwningOpRange();
1801   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1802     // Read the value.
1803     Value value = read<Value>();
1804     if (!value)
1805       return;
1806     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1807 
1808     // Extract the users of a single value.
1809     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1810     llvm::copy(value.getUsers(), range.begin());
1811   } else {
1812     // Read a range of values.
1813     ValueRange *values = read<ValueRange *>();
1814     if (!values)
1815       return;
1816     LLVM_DEBUG({
1817       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1818       llvm::interleaveComma(*values, llvm::dbgs());
1819       llvm::dbgs() << "\n";
1820     });
1821 
1822     // Extract all the users of a range of values.
1823     SmallVector<Operation *> users;
1824     for (Value value : *values)
1825       users.append(value.user_begin(), value.user_end());
1826     range = OwningOpRange(users.size());
1827     llvm::copy(users, range.begin());
1828   }
1829 
1830   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1831 }
1832 
1833 void ByteCodeExecutor::executeGetValueType() {
1834   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1835   unsigned memIndex = read();
1836   Value value = read<Value>();
1837   Type type = value ? value.getType() : Type();
1838 
1839   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1840                           << "  * Result: " << type << "\n");
1841   memory[memIndex] = type.getAsOpaquePointer();
1842 }
1843 
1844 void ByteCodeExecutor::executeGetValueRangeTypes() {
1845   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1846   unsigned memIndex = read();
1847   unsigned rangeIndex = read();
1848   ValueRange *values = read<ValueRange *>();
1849   if (!values) {
1850     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1851     memory[memIndex] = nullptr;
1852     return;
1853   }
1854 
1855   LLVM_DEBUG({
1856     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1857     llvm::interleaveComma(*values, llvm::dbgs());
1858     llvm::dbgs() << "\n  * Result: ";
1859     llvm::interleaveComma(values->getType(), llvm::dbgs());
1860     llvm::dbgs() << "\n";
1861   });
1862   typeRangeMemory[rangeIndex] = values->getType();
1863   memory[memIndex] = &typeRangeMemory[rangeIndex];
1864 }
1865 
1866 void ByteCodeExecutor::executeIsNotNull() {
1867   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1868   const void *value = read<const void *>();
1869 
1870   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1871   selectJump(value != nullptr);
1872 }
1873 
1874 void ByteCodeExecutor::executeRecordMatch(
1875     PatternRewriter &rewriter,
1876     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1877   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1878   unsigned patternIndex = read();
1879   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1880   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1881 
1882   // If the benefit of the pattern is impossible, skip the processing of the
1883   // rest of the pattern.
1884   if (benefit.isImpossibleToMatch()) {
1885     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1886     curCodeIt = dest;
1887     return;
1888   }
1889 
1890   // Create a fused location containing the locations of each of the
1891   // operations used in the match. This will be used as the location for
1892   // created operations during the rewrite that don't already have an
1893   // explicit location set.
1894   unsigned numMatchLocs = read();
1895   SmallVector<Location, 4> matchLocs;
1896   matchLocs.reserve(numMatchLocs);
1897   for (unsigned i = 0; i != numMatchLocs; ++i)
1898     matchLocs.push_back(read<Operation *>()->getLoc());
1899   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1900 
1901   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1902                           << "  * Location: " << matchLoc << "\n");
1903   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1904   PDLByteCode::MatchResult &match = matches.back();
1905 
1906   // Record all of the inputs to the match. If any of the inputs are ranges, we
1907   // will also need to remap the range pointer to memory stored in the match
1908   // state.
1909   unsigned numInputs = read();
1910   match.values.reserve(numInputs);
1911   match.typeRangeValues.reserve(numInputs);
1912   match.valueRangeValues.reserve(numInputs);
1913   for (unsigned i = 0; i < numInputs; ++i) {
1914     switch (read<PDLValue::Kind>()) {
1915     case PDLValue::Kind::TypeRange:
1916       match.typeRangeValues.push_back(*read<TypeRange *>());
1917       match.values.push_back(&match.typeRangeValues.back());
1918       break;
1919     case PDLValue::Kind::ValueRange:
1920       match.valueRangeValues.push_back(*read<ValueRange *>());
1921       match.values.push_back(&match.valueRangeValues.back());
1922       break;
1923     default:
1924       match.values.push_back(read<const void *>());
1925       break;
1926     }
1927   }
1928   curCodeIt = dest;
1929 }
1930 
1931 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1932   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1933   Operation *op = read<Operation *>();
1934   SmallVector<Value, 16> args;
1935   readValueList(args);
1936 
1937   LLVM_DEBUG({
1938     llvm::dbgs() << "  * Operation: " << *op << "\n"
1939                  << "  * Values: ";
1940     llvm::interleaveComma(args, llvm::dbgs());
1941     llvm::dbgs() << "\n";
1942   });
1943   rewriter.replaceOp(op, args);
1944 }
1945 
1946 void ByteCodeExecutor::executeSwitchAttribute() {
1947   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1948   Attribute value = read<Attribute>();
1949   ArrayAttr cases = read<ArrayAttr>();
1950   handleSwitch(value, cases);
1951 }
1952 
1953 void ByteCodeExecutor::executeSwitchOperandCount() {
1954   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1955   Operation *op = read<Operation *>();
1956   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1957 
1958   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1959   handleSwitch(op->getNumOperands(), cases);
1960 }
1961 
1962 void ByteCodeExecutor::executeSwitchOperationName() {
1963   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1964   OperationName value = read<Operation *>()->getName();
1965   size_t caseCount = read();
1966 
1967   // The operation names are stored in-line, so to print them out for
1968   // debugging purposes we need to read the array before executing the
1969   // switch so that we can display all of the possible values.
1970   LLVM_DEBUG({
1971     const ByteCodeField *prevCodeIt = curCodeIt;
1972     llvm::dbgs() << "  * Value: " << value << "\n"
1973                  << "  * Cases: ";
1974     llvm::interleaveComma(
1975         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1976                         [&](size_t) { return read<OperationName>(); }),
1977         llvm::dbgs());
1978     llvm::dbgs() << "\n";
1979     curCodeIt = prevCodeIt;
1980   });
1981 
1982   // Try to find the switch value within any of the cases.
1983   for (size_t i = 0; i != caseCount; ++i) {
1984     if (read<OperationName>() == value) {
1985       curCodeIt += (caseCount - i - 1);
1986       return selectJump(i + 1);
1987     }
1988   }
1989   selectJump(size_t(0));
1990 }
1991 
1992 void ByteCodeExecutor::executeSwitchResultCount() {
1993   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1994   Operation *op = read<Operation *>();
1995   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1996 
1997   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1998   handleSwitch(op->getNumResults(), cases);
1999 }
2000 
2001 void ByteCodeExecutor::executeSwitchType() {
2002   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2003   Type value = read<Type>();
2004   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2005   handleSwitch(value, cases);
2006 }
2007 
2008 void ByteCodeExecutor::executeSwitchTypes() {
2009   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2010   TypeRange *value = read<TypeRange *>();
2011   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2012   if (!value) {
2013     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2014     return selectJump(size_t(0));
2015   }
2016   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2017     return value == caseValue.getAsValueRange<TypeAttr>();
2018   });
2019 }
2020 
2021 void ByteCodeExecutor::execute(
2022     PatternRewriter &rewriter,
2023     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2024     Optional<Location> mainRewriteLoc) {
2025   while (true) {
2026     // Print the location of the operation being executed.
2027     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2028 
2029     OpCode opCode = static_cast<OpCode>(read());
2030     switch (opCode) {
2031     case ApplyConstraint:
2032       executeApplyConstraint(rewriter);
2033       break;
2034     case ApplyRewrite:
2035       executeApplyRewrite(rewriter);
2036       break;
2037     case AreEqual:
2038       executeAreEqual();
2039       break;
2040     case AreRangesEqual:
2041       executeAreRangesEqual();
2042       break;
2043     case Branch:
2044       executeBranch();
2045       break;
2046     case CheckOperandCount:
2047       executeCheckOperandCount();
2048       break;
2049     case CheckOperationName:
2050       executeCheckOperationName();
2051       break;
2052     case CheckResultCount:
2053       executeCheckResultCount();
2054       break;
2055     case CheckTypes:
2056       executeCheckTypes();
2057       break;
2058     case Continue:
2059       executeContinue();
2060       break;
2061     case CreateOperation:
2062       executeCreateOperation(rewriter, *mainRewriteLoc);
2063       break;
2064     case CreateTypes:
2065       executeCreateTypes();
2066       break;
2067     case EraseOp:
2068       executeEraseOp(rewriter);
2069       break;
2070     case ExtractOp:
2071       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2072       break;
2073     case ExtractType:
2074       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2075       break;
2076     case ExtractValue:
2077       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2078       break;
2079     case Finalize:
2080       executeFinalize();
2081       LLVM_DEBUG(llvm::dbgs() << "\n");
2082       return;
2083     case ForEach:
2084       executeForEach();
2085       break;
2086     case GetAttribute:
2087       executeGetAttribute();
2088       break;
2089     case GetAttributeType:
2090       executeGetAttributeType();
2091       break;
2092     case GetDefiningOp:
2093       executeGetDefiningOp();
2094       break;
2095     case GetOperand0:
2096     case GetOperand1:
2097     case GetOperand2:
2098     case GetOperand3: {
2099       unsigned index = opCode - GetOperand0;
2100       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2101       executeGetOperand(index);
2102       break;
2103     }
2104     case GetOperandN:
2105       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2106       executeGetOperand(read<uint32_t>());
2107       break;
2108     case GetOperands:
2109       executeGetOperands();
2110       break;
2111     case GetResult0:
2112     case GetResult1:
2113     case GetResult2:
2114     case GetResult3: {
2115       unsigned index = opCode - GetResult0;
2116       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2117       executeGetResult(index);
2118       break;
2119     }
2120     case GetResultN:
2121       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2122       executeGetResult(read<uint32_t>());
2123       break;
2124     case GetResults:
2125       executeGetResults();
2126       break;
2127     case GetUsers:
2128       executeGetUsers();
2129       break;
2130     case GetValueType:
2131       executeGetValueType();
2132       break;
2133     case GetValueRangeTypes:
2134       executeGetValueRangeTypes();
2135       break;
2136     case IsNotNull:
2137       executeIsNotNull();
2138       break;
2139     case RecordMatch:
2140       assert(matches &&
2141              "expected matches to be provided when executing the matcher");
2142       executeRecordMatch(rewriter, *matches);
2143       break;
2144     case ReplaceOp:
2145       executeReplaceOp(rewriter);
2146       break;
2147     case SwitchAttribute:
2148       executeSwitchAttribute();
2149       break;
2150     case SwitchOperandCount:
2151       executeSwitchOperandCount();
2152       break;
2153     case SwitchOperationName:
2154       executeSwitchOperationName();
2155       break;
2156     case SwitchResultCount:
2157       executeSwitchResultCount();
2158       break;
2159     case SwitchType:
2160       executeSwitchType();
2161       break;
2162     case SwitchTypes:
2163       executeSwitchTypes();
2164       break;
2165     }
2166     LLVM_DEBUG(llvm::dbgs() << "\n");
2167   }
2168 }
2169 
2170 /// Run the pattern matcher on the given root operation, collecting the matched
2171 /// patterns in `matches`.
2172 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2173                         SmallVectorImpl<MatchResult> &matches,
2174                         PDLByteCodeMutableState &state) const {
2175   // The first memory slot is always the root operation.
2176   state.memory[0] = op;
2177 
2178   // The matcher function always starts at code address 0.
2179   ByteCodeExecutor executor(
2180       matcherByteCode.data(), state.memory, state.opRangeMemory,
2181       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2182       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2183       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2184       constraintFunctions, rewriteFunctions);
2185   executor.execute(rewriter, &matches);
2186 
2187   // Order the found matches by benefit.
2188   std::stable_sort(matches.begin(), matches.end(),
2189                    [](const MatchResult &lhs, const MatchResult &rhs) {
2190                      return lhs.benefit > rhs.benefit;
2191                    });
2192 }
2193 
2194 /// Run the rewriter of the given pattern on the root operation `op`.
2195 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2196                           PDLByteCodeMutableState &state) const {
2197   // The arguments of the rewrite function are stored at the start of the
2198   // memory buffer.
2199   llvm::copy(match.values, state.memory.begin());
2200 
2201   ByteCodeExecutor executor(
2202       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2203       state.opRangeMemory, state.typeRangeMemory,
2204       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2205       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2206       rewriterByteCode, state.currentPatternBenefits, patterns,
2207       constraintFunctions, rewriteFunctions);
2208   executor.execute(rewriter, /*matches=*/nullptr, match.location);
2209 }
2210