1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.benefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.rootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Continue to the next iteration of a loop.
99   Continue,
100   /// Create an operation.
101   CreateOperation,
102   /// Create a range of types.
103   CreateTypes,
104   /// Erase an operation.
105   EraseOp,
106   /// Extract the op from a range at the specified index.
107   ExtractOp,
108   /// Extract the type from a range at the specified index.
109   ExtractType,
110   /// Extract the value from a range at the specified index.
111   ExtractValue,
112   /// Terminate a matcher or rewrite sequence.
113   Finalize,
114   /// Iterate over a range of values.
115   ForEach,
116   /// Get a specific attribute of an operation.
117   GetAttribute,
118   /// Get the type of an attribute.
119   GetAttributeType,
120   /// Get the defining operation of a value.
121   GetDefiningOp,
122   /// Get a specific operand of an operation.
123   GetOperand0,
124   GetOperand1,
125   GetOperand2,
126   GetOperand3,
127   GetOperandN,
128   /// Get a specific operand group of an operation.
129   GetOperands,
130   /// Get a specific result of an operation.
131   GetResult0,
132   GetResult1,
133   GetResult2,
134   GetResult3,
135   GetResultN,
136   /// Get a specific result group of an operation.
137   GetResults,
138   /// Get the users of a value or a range of values.
139   GetUsers,
140   /// Get the type of a value.
141   GetValueType,
142   /// Get the types of a value range.
143   GetValueRangeTypes,
144   /// Check if a generic value is not null.
145   IsNotNull,
146   /// Record a successful pattern match.
147   RecordMatch,
148   /// Replace an operation.
149   ReplaceOp,
150   /// Compare an attribute with a set of constants.
151   SwitchAttribute,
152   /// Compare the operand count of an operation with a set of constants.
153   SwitchOperandCount,
154   /// Compare the name of an operation with a set of constants.
155   SwitchOperationName,
156   /// Compare the result count of an operation with a set of constants.
157   SwitchResultCount,
158   /// Compare a type with a set of constants.
159   SwitchType,
160   /// Compare a range of types with a set of constants.
161   SwitchTypes,
162 };
163 } // end anonymous namespace
164 
165 //===----------------------------------------------------------------------===//
166 // ByteCode Generation
167 //===----------------------------------------------------------------------===//
168 
169 //===----------------------------------------------------------------------===//
170 // Generator
171 
172 namespace {
173 struct ByteCodeLiveRange;
174 struct ByteCodeWriter;
175 
176 /// Check if the given class `T` can be converted to an opaque pointer.
177 template <typename T, typename... Args>
178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
179 
180 /// This class represents the main generator for the pattern bytecode.
181 class Generator {
182 public:
183   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
184             SmallVectorImpl<ByteCodeField> &matcherByteCode,
185             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
186             SmallVectorImpl<PDLByteCodePattern> &patterns,
187             ByteCodeField &maxValueMemoryIndex,
188             ByteCodeField &maxOpRangeMemoryIndex,
189             ByteCodeField &maxTypeRangeMemoryIndex,
190             ByteCodeField &maxValueRangeMemoryIndex,
191             ByteCodeField &maxLoopLevel,
192             llvm::StringMap<PDLConstraintFunction> &constraintFns,
193             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
194       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
195         rewriterByteCode(rewriterByteCode), patterns(patterns),
196         maxValueMemoryIndex(maxValueMemoryIndex),
197         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
198         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
199         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
200         maxLoopLevel(maxLoopLevel) {
201     for (auto it : llvm::enumerate(constraintFns))
202       constraintToMemIndex.try_emplace(it.value().first(), it.index());
203     for (auto it : llvm::enumerate(rewriteFns))
204       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
205   }
206 
207   /// Generate the bytecode for the given PDL interpreter module.
208   void generate(ModuleOp module);
209 
210   /// Return the memory index to use for the given value.
211   ByteCodeField &getMemIndex(Value value) {
212     assert(valueToMemIndex.count(value) &&
213            "expected memory index to be assigned");
214     return valueToMemIndex[value];
215   }
216 
217   /// Return the range memory index used to store the given range value.
218   ByteCodeField &getRangeStorageIndex(Value value) {
219     assert(valueToRangeIndex.count(value) &&
220            "expected range index to be assigned");
221     return valueToRangeIndex[value];
222   }
223 
224   /// Return an index to use when referring to the given data that is uniqued in
225   /// the MLIR context.
226   template <typename T>
227   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
228   getMemIndex(T val) {
229     const void *opaqueVal = val.getAsOpaquePointer();
230 
231     // Get or insert a reference to this value.
232     auto it = uniquedDataToMemIndex.try_emplace(
233         opaqueVal, maxValueMemoryIndex + uniquedData.size());
234     if (it.second)
235       uniquedData.push_back(opaqueVal);
236     return it.first->second;
237   }
238 
239 private:
240   /// Allocate memory indices for the results of operations within the matcher
241   /// and rewriters.
242   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
243 
244   /// Generate the bytecode for the given operation.
245   void generate(Region *region, ByteCodeWriter &writer);
246   void generate(Operation *op, ByteCodeWriter &writer);
247   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
248   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
249   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
250   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
251   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
252   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
259   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
285 
286   /// Mapping from value to its corresponding memory index.
287   DenseMap<Value, ByteCodeField> valueToMemIndex;
288 
289   /// Mapping from a range value to its corresponding range storage index.
290   DenseMap<Value, ByteCodeField> valueToRangeIndex;
291 
292   /// Mapping from the name of an externally registered rewrite to its index in
293   /// the bytecode registry.
294   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
295 
296   /// Mapping from the name of an externally registered constraint to its index
297   /// in the bytecode registry.
298   llvm::StringMap<ByteCodeField> constraintToMemIndex;
299 
300   /// Mapping from rewriter function name to the bytecode address of the
301   /// rewriter function in byte.
302   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
303 
304   /// Mapping from a uniqued storage object to its memory index within
305   /// `uniquedData`.
306   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
307 
308   /// The current level of the foreach loop.
309   ByteCodeField curLoopLevel = 0;
310 
311   /// The current MLIR context.
312   MLIRContext *ctx;
313 
314   /// Mapping from block to its address.
315   DenseMap<Block *, ByteCodeAddr> blockToAddr;
316 
317   /// Data of the ByteCode class to be populated.
318   std::vector<const void *> &uniquedData;
319   SmallVectorImpl<ByteCodeField> &matcherByteCode;
320   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
321   SmallVectorImpl<PDLByteCodePattern> &patterns;
322   ByteCodeField &maxValueMemoryIndex;
323   ByteCodeField &maxOpRangeMemoryIndex;
324   ByteCodeField &maxTypeRangeMemoryIndex;
325   ByteCodeField &maxValueRangeMemoryIndex;
326   ByteCodeField &maxLoopLevel;
327 };
328 
329 /// This class provides utilities for writing a bytecode stream.
330 struct ByteCodeWriter {
331   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
332       : bytecode(bytecode), generator(generator) {}
333 
334   /// Append a field to the bytecode.
335   void append(ByteCodeField field) { bytecode.push_back(field); }
336   void append(OpCode opCode) { bytecode.push_back(opCode); }
337 
338   /// Append an address to the bytecode.
339   void append(ByteCodeAddr field) {
340     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
341                   "unexpected ByteCode address size");
342 
343     ByteCodeField fieldParts[2];
344     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
345     bytecode.append({fieldParts[0], fieldParts[1]});
346   }
347 
348   /// Append a single successor to the bytecode, the exact address will need to
349   /// be resolved later.
350   void append(Block *successor) {
351     // Add back a reference to the successor so that the address can be resolved
352     // later.
353     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
354     append(ByteCodeAddr(0));
355   }
356 
357   /// Append a successor range to the bytecode, the exact address will need to
358   /// be resolved later.
359   void append(SuccessorRange successors) {
360     for (Block *successor : successors)
361       append(successor);
362   }
363 
364   /// Append a range of values that will be read as generic PDLValues.
365   void appendPDLValueList(OperandRange values) {
366     bytecode.push_back(values.size());
367     for (Value value : values)
368       appendPDLValue(value);
369   }
370 
371   /// Append a value as a PDLValue.
372   void appendPDLValue(Value value) {
373     appendPDLValueKind(value);
374     append(value);
375   }
376 
377   /// Append the PDLValue::Kind of the given value.
378   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
379 
380   /// Append the PDLValue::Kind of the given type.
381   void appendPDLValueKind(Type type) {
382     PDLValue::Kind kind =
383         TypeSwitch<Type, PDLValue::Kind>(type)
384             .Case<pdl::AttributeType>(
385                 [](Type) { return PDLValue::Kind::Attribute; })
386             .Case<pdl::OperationType>(
387                 [](Type) { return PDLValue::Kind::Operation; })
388             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
389               if (rangeTy.getElementType().isa<pdl::TypeType>())
390                 return PDLValue::Kind::TypeRange;
391               return PDLValue::Kind::ValueRange;
392             })
393             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
394             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
395     bytecode.push_back(static_cast<ByteCodeField>(kind));
396   }
397 
398   /// Append a value that will be stored in a memory slot and not inline within
399   /// the bytecode.
400   template <typename T>
401   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
402                    std::is_pointer<T>::value>
403   append(T value) {
404     bytecode.push_back(generator.getMemIndex(value));
405   }
406 
407   /// Append a range of values.
408   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
409   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
410   append(T range) {
411     bytecode.push_back(llvm::size(range));
412     for (auto it : range)
413       append(it);
414   }
415 
416   /// Append a variadic number of fields to the bytecode.
417   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
418   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
419     append(field);
420     append(field2, fields...);
421   }
422 
423   /// Successor references in the bytecode that have yet to be resolved.
424   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
425 
426   /// The underlying bytecode buffer.
427   SmallVectorImpl<ByteCodeField> &bytecode;
428 
429   /// The main generator producing PDL.
430   Generator &generator;
431 };
432 
433 /// This class represents a live range of PDL Interpreter values, containing
434 /// information about when values are live within a match/rewrite.
435 struct ByteCodeLiveRange {
436   using Set = llvm::IntervalMap<uint64_t, char, 16>;
437   using Allocator = Set::Allocator;
438 
439   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
440 
441   /// Union this live range with the one provided.
442   void unionWith(const ByteCodeLiveRange &rhs) {
443     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
444          ++it)
445       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
446   }
447 
448   /// Returns true if this range overlaps with the one provided.
449   bool overlaps(const ByteCodeLiveRange &rhs) const {
450     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
451         .valid();
452   }
453 
454   /// A map representing the ranges of the match/rewrite that a value is live in
455   /// the interpreter.
456   ///
457   /// We use std::unique_ptr here, because IntervalMap does not provide a
458   /// correct copy or move constructor. We can eliminate the pointer once
459   /// https://reviews.llvm.org/D113240 lands.
460   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
461 
462   /// The operation range storage index for this range.
463   Optional<unsigned> opRangeIndex;
464 
465   /// The type range storage index for this range.
466   Optional<unsigned> typeRangeIndex;
467 
468   /// The value range storage index for this range.
469   Optional<unsigned> valueRangeIndex;
470 };
471 } // end anonymous namespace
472 
473 void Generator::generate(ModuleOp module) {
474   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
475       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
476   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
477       pdl_interp::PDLInterpDialect::getRewriterModuleName());
478   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
479 
480   // Allocate memory indices for the results of operations within the matcher
481   // and rewriters.
482   allocateMemoryIndices(matcherFunc, rewriterModule);
483 
484   // Generate code for the rewriter functions.
485   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
486   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
487     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
488     for (Operation &op : rewriterFunc.getOps())
489       generate(&op, rewriterByteCodeWriter);
490   }
491   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
492          "unexpected branches in rewriter function");
493 
494   // Generate code for the matcher function.
495   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
496   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
497 
498   // Resolve successor references in the matcher.
499   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
500     ByteCodeAddr addr = blockToAddr[it.first];
501     for (unsigned offsetToFix : it.second)
502       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
503   }
504 }
505 
506 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
507                                       ModuleOp rewriterModule) {
508   // Rewriters use simplistic allocation scheme that simply assigns an index to
509   // each result.
510   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
511     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
512     auto processRewriterValue = [&](Value val) {
513       valueToMemIndex.try_emplace(val, index++);
514       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
515         Type elementTy = rangeType.getElementType();
516         if (elementTy.isa<pdl::TypeType>())
517           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
518         else if (elementTy.isa<pdl::ValueType>())
519           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
520       }
521     };
522 
523     for (BlockArgument arg : rewriterFunc.getArguments())
524       processRewriterValue(arg);
525     rewriterFunc.getBody().walk([&](Operation *op) {
526       for (Value result : op->getResults())
527         processRewriterValue(result);
528     });
529     if (index > maxValueMemoryIndex)
530       maxValueMemoryIndex = index;
531     if (typeRangeIndex > maxTypeRangeMemoryIndex)
532       maxTypeRangeMemoryIndex = typeRangeIndex;
533     if (valueRangeIndex > maxValueRangeMemoryIndex)
534       maxValueRangeMemoryIndex = valueRangeIndex;
535   }
536 
537   // The matcher function uses a more sophisticated numbering that tries to
538   // minimize the number of memory indices assigned. This is done by determining
539   // a live range of the values within the matcher, then the allocation is just
540   // finding the minimal number of overlapping live ranges. This is essentially
541   // a simplified form of register allocation where we don't necessarily have a
542   // limited number of registers, but we still want to minimize the number used.
543   DenseMap<Operation *, unsigned> opToIndex;
544   matcherFunc.getBody().walk([&](Operation *op) {
545     opToIndex.insert(std::make_pair(op, opToIndex.size()));
546   });
547 
548   // Liveness info for each of the defs within the matcher.
549   ByteCodeLiveRange::Allocator allocator;
550   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
551 
552   // Assign the root operation being matched to slot 0.
553   BlockArgument rootOpArg = matcherFunc.getArgument(0);
554   valueToMemIndex[rootOpArg] = 0;
555 
556   // Walk each of the blocks, computing the def interval that the value is used.
557   Liveness matcherLiveness(matcherFunc);
558   matcherFunc->walk([&](Block *block) {
559     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
560     assert(info && "expected liveness info for block");
561     auto processValue = [&](Value value, Operation *firstUseOrDef) {
562       // We don't need to process the root op argument, this value is always
563       // assigned to the first memory slot.
564       if (value == rootOpArg)
565         return;
566 
567       // Set indices for the range of this block that the value is used.
568       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
569       defRangeIt->second.liveness->insert(
570           opToIndex[firstUseOrDef],
571           opToIndex[info->getEndOperation(value, firstUseOrDef)],
572           /*dummyValue*/ 0);
573 
574       // Check to see if this value is a range type.
575       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
576         Type eleType = rangeTy.getElementType();
577         if (eleType.isa<pdl::OperationType>())
578           defRangeIt->second.opRangeIndex = 0;
579         else if (eleType.isa<pdl::TypeType>())
580           defRangeIt->second.typeRangeIndex = 0;
581         else if (eleType.isa<pdl::ValueType>())
582           defRangeIt->second.valueRangeIndex = 0;
583       }
584     };
585 
586     // Process the live-ins of this block.
587     for (Value liveIn : info->in()) {
588       // Only process the value if it has been defined in the current region.
589       // Other values that span across pdl_interp.foreach will be added higher
590       // up. This ensures that the we keep them alive for the entire duration
591       // of the loop.
592       if (liveIn.getParentRegion() == block->getParent())
593         processValue(liveIn, &block->front());
594     }
595 
596     // Process the block arguments for the entry block (those are not live-in).
597     if (block->isEntryBlock()) {
598       for (Value argument : block->getArguments())
599         processValue(argument, &block->front());
600     }
601 
602     // Process any new defs within this block.
603     for (Operation &op : *block)
604       for (Value result : op.getResults())
605         processValue(result, &op);
606   });
607 
608   // Greedily allocate memory slots using the computed def live ranges.
609   std::vector<ByteCodeLiveRange> allocatedIndices;
610 
611   // The number of memory indices currently allocated (and its next value).
612   // Recall that the root gets allocated memory index 0.
613   ByteCodeField numIndices = 1;
614 
615   // The number of memory ranges of various types (and their next values).
616   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
617 
618   for (auto &defIt : valueDefRanges) {
619     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
620     ByteCodeLiveRange &defRange = defIt.second;
621 
622     // Try to allocate to an existing index.
623     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
624       ByteCodeLiveRange &existingRange = existingIndexIt.value();
625       if (!defRange.overlaps(existingRange)) {
626         existingRange.unionWith(defRange);
627         memIndex = existingIndexIt.index() + 1;
628 
629         if (defRange.opRangeIndex) {
630           if (!existingRange.opRangeIndex)
631             existingRange.opRangeIndex = numOpRanges++;
632           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
633         } else if (defRange.typeRangeIndex) {
634           if (!existingRange.typeRangeIndex)
635             existingRange.typeRangeIndex = numTypeRanges++;
636           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
637         } else if (defRange.valueRangeIndex) {
638           if (!existingRange.valueRangeIndex)
639             existingRange.valueRangeIndex = numValueRanges++;
640           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
641         }
642         break;
643       }
644     }
645 
646     // If no existing index could be used, add a new one.
647     if (memIndex == 0) {
648       allocatedIndices.emplace_back(allocator);
649       ByteCodeLiveRange &newRange = allocatedIndices.back();
650       newRange.unionWith(defRange);
651 
652       // Allocate an index for op/type/value ranges.
653       if (defRange.opRangeIndex) {
654         newRange.opRangeIndex = numOpRanges;
655         valueToRangeIndex[defIt.first] = numOpRanges++;
656       } else if (defRange.typeRangeIndex) {
657         newRange.typeRangeIndex = numTypeRanges;
658         valueToRangeIndex[defIt.first] = numTypeRanges++;
659       } else if (defRange.valueRangeIndex) {
660         newRange.valueRangeIndex = numValueRanges;
661         valueToRangeIndex[defIt.first] = numValueRanges++;
662       }
663 
664       memIndex = allocatedIndices.size();
665       ++numIndices;
666     }
667   }
668 
669   // Print the index usage and ensure that we did not run out of index space.
670   LLVM_DEBUG({
671     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
672                  << "(down from initial " << valueDefRanges.size() << ").\n";
673   });
674   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
675          "Ran out of memory for allocated indices");
676 
677   // Update the max number of indices.
678   if (numIndices > maxValueMemoryIndex)
679     maxValueMemoryIndex = numIndices;
680   if (numOpRanges > maxOpRangeMemoryIndex)
681     maxOpRangeMemoryIndex = numOpRanges;
682   if (numTypeRanges > maxTypeRangeMemoryIndex)
683     maxTypeRangeMemoryIndex = numTypeRanges;
684   if (numValueRanges > maxValueRangeMemoryIndex)
685     maxValueRangeMemoryIndex = numValueRanges;
686 }
687 
688 void Generator::generate(Region *region, ByteCodeWriter &writer) {
689   llvm::ReversePostOrderTraversal<Region *> rpot(region);
690   for (Block *block : rpot) {
691     // Keep track of where this block begins within the matcher function.
692     blockToAddr.try_emplace(block, matcherByteCode.size());
693     for (Operation &op : *block)
694       generate(&op, writer);
695   }
696 }
697 
698 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
699   TypeSwitch<Operation *>(op)
700       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
701             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
702             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
703             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
704             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
705             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
706             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
707             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
708             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
709             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
710             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
711             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
712             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
713             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
714             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
715             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
716             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
717             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
718             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
719           [&](auto interpOp) { this->generate(interpOp, writer); })
720       .Default([](Operation *) {
721         llvm_unreachable("unknown `pdl_interp` operation");
722       });
723 }
724 
725 void Generator::generate(pdl_interp::ApplyConstraintOp op,
726                          ByteCodeWriter &writer) {
727   assert(constraintToMemIndex.count(op.name()) &&
728          "expected index for constraint function");
729   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
730                 op.constParamsAttr());
731   writer.appendPDLValueList(op.args());
732   writer.append(op.getSuccessors());
733 }
734 void Generator::generate(pdl_interp::ApplyRewriteOp op,
735                          ByteCodeWriter &writer) {
736   assert(externalRewriterToMemIndex.count(op.name()) &&
737          "expected index for rewrite function");
738   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
739                 op.constParamsAttr());
740   writer.appendPDLValueList(op.args());
741 
742   ResultRange results = op.results();
743   writer.append(ByteCodeField(results.size()));
744   for (Value result : results) {
745     // In debug mode we also record the expected kind of the result, so that we
746     // can provide extra verification of the native rewrite function.
747 #ifndef NDEBUG
748     writer.appendPDLValueKind(result);
749 #endif
750 
751     // Range results also need to append the range storage index.
752     if (result.getType().isa<pdl::RangeType>())
753       writer.append(getRangeStorageIndex(result));
754     writer.append(result);
755   }
756 }
757 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
758   Value lhs = op.lhs();
759   if (lhs.getType().isa<pdl::RangeType>()) {
760     writer.append(OpCode::AreRangesEqual);
761     writer.appendPDLValueKind(lhs);
762     writer.append(op.lhs(), op.rhs(), op.getSuccessors());
763     return;
764   }
765 
766   writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
767 }
768 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
769   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
770 }
771 void Generator::generate(pdl_interp::CheckAttributeOp op,
772                          ByteCodeWriter &writer) {
773   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
774                 op.getSuccessors());
775 }
776 void Generator::generate(pdl_interp::CheckOperandCountOp op,
777                          ByteCodeWriter &writer) {
778   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
779                 static_cast<ByteCodeField>(op.compareAtLeast()),
780                 op.getSuccessors());
781 }
782 void Generator::generate(pdl_interp::CheckOperationNameOp op,
783                          ByteCodeWriter &writer) {
784   writer.append(OpCode::CheckOperationName, op.operation(),
785                 OperationName(op.name(), ctx), op.getSuccessors());
786 }
787 void Generator::generate(pdl_interp::CheckResultCountOp op,
788                          ByteCodeWriter &writer) {
789   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
790                 static_cast<ByteCodeField>(op.compareAtLeast()),
791                 op.getSuccessors());
792 }
793 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
794   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
795 }
796 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
797   writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
798 }
799 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
800   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
801   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
802 }
803 void Generator::generate(pdl_interp::CreateAttributeOp op,
804                          ByteCodeWriter &writer) {
805   // Simply repoint the memory index of the result to the constant.
806   getMemIndex(op.attribute()) = getMemIndex(op.value());
807 }
808 void Generator::generate(pdl_interp::CreateOperationOp op,
809                          ByteCodeWriter &writer) {
810   writer.append(OpCode::CreateOperation, op.operation(),
811                 OperationName(op.name(), ctx));
812   writer.appendPDLValueList(op.operands());
813 
814   // Add the attributes.
815   OperandRange attributes = op.attributes();
816   writer.append(static_cast<ByteCodeField>(attributes.size()));
817   for (auto it : llvm::zip(op.attributeNames(), op.attributes()))
818     writer.append(std::get<0>(it), std::get<1>(it));
819   writer.appendPDLValueList(op.types());
820 }
821 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
822   // Simply repoint the memory index of the result to the constant.
823   getMemIndex(op.result()) = getMemIndex(op.value());
824 }
825 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
826   writer.append(OpCode::CreateTypes, op.result(),
827                 getRangeStorageIndex(op.result()), op.value());
828 }
829 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
830   writer.append(OpCode::EraseOp, op.operation());
831 }
832 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
833   OpCode opCode =
834       TypeSwitch<Type, OpCode>(op.result().getType())
835           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
836           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
837           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
838           .Default([](Type) -> OpCode {
839             llvm_unreachable("unsupported element type");
840           });
841   writer.append(opCode, op.range(), op.index(), op.result());
842 }
843 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
844   writer.append(OpCode::Finalize);
845 }
846 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
847   BlockArgument arg = op.getLoopVariable();
848   writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg);
849   writer.appendPDLValueKind(arg.getType());
850   writer.append(curLoopLevel, op.successor());
851   ++curLoopLevel;
852   if (curLoopLevel > maxLoopLevel)
853     maxLoopLevel = curLoopLevel;
854   generate(&op.region(), writer);
855   --curLoopLevel;
856 }
857 void Generator::generate(pdl_interp::GetAttributeOp op,
858                          ByteCodeWriter &writer) {
859   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
860                 op.nameAttr());
861 }
862 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
863                          ByteCodeWriter &writer) {
864   writer.append(OpCode::GetAttributeType, op.result(), op.value());
865 }
866 void Generator::generate(pdl_interp::GetDefiningOpOp op,
867                          ByteCodeWriter &writer) {
868   writer.append(OpCode::GetDefiningOp, op.operation());
869   writer.appendPDLValue(op.value());
870 }
871 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
872   uint32_t index = op.index();
873   if (index < 4)
874     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
875   else
876     writer.append(OpCode::GetOperandN, index);
877   writer.append(op.operation(), op.value());
878 }
879 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
880   Value result = op.value();
881   Optional<uint32_t> index = op.index();
882   writer.append(OpCode::GetOperands,
883                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
884                 op.operation());
885   if (result.getType().isa<pdl::RangeType>())
886     writer.append(getRangeStorageIndex(result));
887   else
888     writer.append(std::numeric_limits<ByteCodeField>::max());
889   writer.append(result);
890 }
891 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
892   uint32_t index = op.index();
893   if (index < 4)
894     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
895   else
896     writer.append(OpCode::GetResultN, index);
897   writer.append(op.operation(), op.value());
898 }
899 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
900   Value result = op.value();
901   Optional<uint32_t> index = op.index();
902   writer.append(OpCode::GetResults,
903                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
904                 op.operation());
905   if (result.getType().isa<pdl::RangeType>())
906     writer.append(getRangeStorageIndex(result));
907   else
908     writer.append(std::numeric_limits<ByteCodeField>::max());
909   writer.append(result);
910 }
911 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
912   Value operations = op.operations();
913   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
914   writer.append(OpCode::GetUsers, operations, rangeIndex);
915   writer.appendPDLValue(op.value());
916 }
917 void Generator::generate(pdl_interp::GetValueTypeOp op,
918                          ByteCodeWriter &writer) {
919   if (op.getType().isa<pdl::RangeType>()) {
920     Value result = op.result();
921     writer.append(OpCode::GetValueRangeTypes, result,
922                   getRangeStorageIndex(result), op.value());
923   } else {
924     writer.append(OpCode::GetValueType, op.result(), op.value());
925   }
926 }
927 
928 void Generator::generate(pdl_interp::InferredTypesOp op,
929                          ByteCodeWriter &writer) {
930   // InferType maps to a null type as a marker for inferring result types.
931   getMemIndex(op.type()) = getMemIndex(Type());
932 }
933 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
934   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
935 }
936 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
937   ByteCodeField patternIndex = patterns.size();
938   patterns.emplace_back(PDLByteCodePattern::create(
939       op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
940   writer.append(OpCode::RecordMatch, patternIndex,
941                 SuccessorRange(op.getOperation()), op.matchedOps());
942   writer.appendPDLValueList(op.inputs());
943 }
944 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
945   writer.append(OpCode::ReplaceOp, op.operation());
946   writer.appendPDLValueList(op.replValues());
947 }
948 void Generator::generate(pdl_interp::SwitchAttributeOp op,
949                          ByteCodeWriter &writer) {
950   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
951                 op.getSuccessors());
952 }
953 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
954                          ByteCodeWriter &writer) {
955   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
956                 op.getSuccessors());
957 }
958 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
959                          ByteCodeWriter &writer) {
960   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
961     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
962   });
963   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
964                 op.getSuccessors());
965 }
966 void Generator::generate(pdl_interp::SwitchResultCountOp op,
967                          ByteCodeWriter &writer) {
968   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
969                 op.getSuccessors());
970 }
971 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
972   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
973                 op.getSuccessors());
974 }
975 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
976   writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
977                 op.getSuccessors());
978 }
979 
980 //===----------------------------------------------------------------------===//
981 // PDLByteCode
982 //===----------------------------------------------------------------------===//
983 
984 PDLByteCode::PDLByteCode(ModuleOp module,
985                          llvm::StringMap<PDLConstraintFunction> constraintFns,
986                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
987   Generator generator(module.getContext(), uniquedData, matcherByteCode,
988                       rewriterByteCode, patterns, maxValueMemoryIndex,
989                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
990                       maxLoopLevel, constraintFns, rewriteFns);
991   generator.generate(module);
992 
993   // Initialize the external functions.
994   for (auto &it : constraintFns)
995     constraintFunctions.push_back(std::move(it.second));
996   for (auto &it : rewriteFns)
997     rewriteFunctions.push_back(std::move(it.second));
998 }
999 
1000 /// Initialize the given state such that it can be used to execute the current
1001 /// bytecode.
1002 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1003   state.memory.resize(maxValueMemoryIndex, nullptr);
1004   state.opRangeMemory.resize(maxOpRangeCount);
1005   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1006   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1007   state.loopIndex.resize(maxLoopLevel, 0);
1008   state.currentPatternBenefits.reserve(patterns.size());
1009   for (const PDLByteCodePattern &pattern : patterns)
1010     state.currentPatternBenefits.push_back(pattern.getBenefit());
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // ByteCode Execution
1015 
1016 namespace {
1017 /// This class provides support for executing a bytecode stream.
1018 class ByteCodeExecutor {
1019 public:
1020   ByteCodeExecutor(
1021       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1022       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1023       MutableArrayRef<TypeRange> typeRangeMemory,
1024       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1025       MutableArrayRef<ValueRange> valueRangeMemory,
1026       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1027       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1028       ArrayRef<ByteCodeField> code,
1029       ArrayRef<PatternBenefit> currentPatternBenefits,
1030       ArrayRef<PDLByteCodePattern> patterns,
1031       ArrayRef<PDLConstraintFunction> constraintFunctions,
1032       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1033       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1034         typeRangeMemory(typeRangeMemory),
1035         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1036         valueRangeMemory(valueRangeMemory),
1037         allocatedValueRangeMemory(allocatedValueRangeMemory),
1038         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1039         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1040         constraintFunctions(constraintFunctions),
1041         rewriteFunctions(rewriteFunctions) {}
1042 
1043   /// Start executing the code at the current bytecode index. `matches` is an
1044   /// optional field provided when this function is executed in a matching
1045   /// context.
1046   void execute(PatternRewriter &rewriter,
1047                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1048                Optional<Location> mainRewriteLoc = {});
1049 
1050 private:
1051   /// Internal implementation of executing each of the bytecode commands.
1052   void executeApplyConstraint(PatternRewriter &rewriter);
1053   void executeApplyRewrite(PatternRewriter &rewriter);
1054   void executeAreEqual();
1055   void executeAreRangesEqual();
1056   void executeBranch();
1057   void executeCheckOperandCount();
1058   void executeCheckOperationName();
1059   void executeCheckResultCount();
1060   void executeCheckTypes();
1061   void executeContinue();
1062   void executeCreateOperation(PatternRewriter &rewriter,
1063                               Location mainRewriteLoc);
1064   void executeCreateTypes();
1065   void executeEraseOp(PatternRewriter &rewriter);
1066   template <typename T, typename Range, PDLValue::Kind kind>
1067   void executeExtract();
1068   void executeFinalize();
1069   void executeForEach();
1070   void executeGetAttribute();
1071   void executeGetAttributeType();
1072   void executeGetDefiningOp();
1073   void executeGetOperand(unsigned index);
1074   void executeGetOperands();
1075   void executeGetResult(unsigned index);
1076   void executeGetResults();
1077   void executeGetUsers();
1078   void executeGetValueType();
1079   void executeGetValueRangeTypes();
1080   void executeIsNotNull();
1081   void executeRecordMatch(PatternRewriter &rewriter,
1082                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1083   void executeReplaceOp(PatternRewriter &rewriter);
1084   void executeSwitchAttribute();
1085   void executeSwitchOperandCount();
1086   void executeSwitchOperationName();
1087   void executeSwitchResultCount();
1088   void executeSwitchType();
1089   void executeSwitchTypes();
1090 
1091   /// Pushes a code iterator to the stack.
1092   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1093 
1094   /// Pops a code iterator from the stack, returning true on success.
1095   void popCodeIt() {
1096     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1097     curCodeIt = resumeCodeIt.back();
1098     resumeCodeIt.pop_back();
1099   }
1100 
1101   /// Read a value from the bytecode buffer, optionally skipping a certain
1102   /// number of prefix values. These methods always update the buffer to point
1103   /// to the next field after the read data.
1104   template <typename T = ByteCodeField>
1105   T read(size_t skipN = 0) {
1106     curCodeIt += skipN;
1107     return readImpl<T>();
1108   }
1109   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1110 
1111   /// Read a list of values from the bytecode buffer.
1112   template <typename ValueT, typename T>
1113   void readList(SmallVectorImpl<T> &list) {
1114     list.clear();
1115     for (unsigned i = 0, e = read(); i != e; ++i)
1116       list.push_back(read<ValueT>());
1117   }
1118 
1119   /// Read a list of values from the bytecode buffer. The values may be encoded
1120   /// as either Value or ValueRange elements.
1121   void readValueList(SmallVectorImpl<Value> &list) {
1122     for (unsigned i = 0, e = read(); i != e; ++i) {
1123       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1124         list.push_back(read<Value>());
1125       } else {
1126         ValueRange *values = read<ValueRange *>();
1127         list.append(values->begin(), values->end());
1128       }
1129     }
1130   }
1131 
1132   /// Jump to a specific successor based on a predicate value.
1133   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1134   /// Jump to a specific successor based on a destination index.
1135   void selectJump(size_t destIndex) {
1136     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1137   }
1138 
1139   /// Handle a switch operation with the provided value and cases.
1140   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1141   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1142     LLVM_DEBUG({
1143       llvm::dbgs() << "  * Value: " << value << "\n"
1144                    << "  * Cases: ";
1145       llvm::interleaveComma(cases, llvm::dbgs());
1146       llvm::dbgs() << "\n";
1147     });
1148 
1149     // Check to see if the attribute value is within the case list. Jump to
1150     // the correct successor index based on the result.
1151     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1152       if (cmp(*it, value))
1153         return selectJump(size_t((it - cases.begin()) + 1));
1154     selectJump(size_t(0));
1155   }
1156 
1157   /// Store a pointer to memory.
1158   void storeToMemory(unsigned index, const void *value) {
1159     memory[index] = value;
1160   }
1161 
1162   /// Store a value to memory as an opaque pointer.
1163   template <typename T>
1164   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1165   storeToMemory(unsigned index, T value) {
1166     memory[index] = value.getAsOpaquePointer();
1167   }
1168 
1169   /// Internal implementation of reading various data types from the bytecode
1170   /// stream.
1171   template <typename T>
1172   const void *readFromMemory() {
1173     size_t index = *curCodeIt++;
1174 
1175     // If this type is an SSA value, it can only be stored in non-const memory.
1176     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1177                         Value>::value ||
1178         index < memory.size())
1179       return memory[index];
1180 
1181     // Otherwise, if this index is not inbounds it is uniqued.
1182     return uniquedMemory[index - memory.size()];
1183   }
1184   template <typename T>
1185   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1186     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1187   }
1188   template <typename T>
1189   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1190                    T>
1191   readImpl() {
1192     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1193   }
1194   template <typename T>
1195   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1196     switch (read<PDLValue::Kind>()) {
1197     case PDLValue::Kind::Attribute:
1198       return read<Attribute>();
1199     case PDLValue::Kind::Operation:
1200       return read<Operation *>();
1201     case PDLValue::Kind::Type:
1202       return read<Type>();
1203     case PDLValue::Kind::Value:
1204       return read<Value>();
1205     case PDLValue::Kind::TypeRange:
1206       return read<TypeRange *>();
1207     case PDLValue::Kind::ValueRange:
1208       return read<ValueRange *>();
1209     }
1210     llvm_unreachable("unhandled PDLValue::Kind");
1211   }
1212   template <typename T>
1213   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1214     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1215                   "unexpected ByteCode address size");
1216     ByteCodeAddr result;
1217     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1218     curCodeIt += 2;
1219     return result;
1220   }
1221   template <typename T>
1222   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1223     return *curCodeIt++;
1224   }
1225   template <typename T>
1226   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1227     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1228   }
1229 
1230   /// The underlying bytecode buffer.
1231   const ByteCodeField *curCodeIt;
1232 
1233   /// The stack of bytecode positions at which to resume operation.
1234   SmallVector<const ByteCodeField *> resumeCodeIt;
1235 
1236   /// The current execution memory.
1237   MutableArrayRef<const void *> memory;
1238   MutableArrayRef<OwningOpRange> opRangeMemory;
1239   MutableArrayRef<TypeRange> typeRangeMemory;
1240   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1241   MutableArrayRef<ValueRange> valueRangeMemory;
1242   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1243 
1244   /// The current loop indices.
1245   MutableArrayRef<unsigned> loopIndex;
1246 
1247   /// References to ByteCode data necessary for execution.
1248   ArrayRef<const void *> uniquedMemory;
1249   ArrayRef<ByteCodeField> code;
1250   ArrayRef<PatternBenefit> currentPatternBenefits;
1251   ArrayRef<PDLByteCodePattern> patterns;
1252   ArrayRef<PDLConstraintFunction> constraintFunctions;
1253   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1254 };
1255 
1256 /// This class is an instantiation of the PDLResultList that provides access to
1257 /// the returned results. This API is not on `PDLResultList` to avoid
1258 /// overexposing access to information specific solely to the ByteCode.
1259 class ByteCodeRewriteResultList : public PDLResultList {
1260 public:
1261   ByteCodeRewriteResultList(unsigned maxNumResults)
1262       : PDLResultList(maxNumResults) {}
1263 
1264   /// Return the list of PDL results.
1265   MutableArrayRef<PDLValue> getResults() { return results; }
1266 
1267   /// Return the type ranges allocated by this list.
1268   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1269     return allocatedTypeRanges;
1270   }
1271 
1272   /// Return the value ranges allocated by this list.
1273   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1274     return allocatedValueRanges;
1275   }
1276 };
1277 } // end anonymous namespace
1278 
1279 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1280   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1281   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1282   ArrayAttr constParams = read<ArrayAttr>();
1283   SmallVector<PDLValue, 16> args;
1284   readList<PDLValue>(args);
1285 
1286   LLVM_DEBUG({
1287     llvm::dbgs() << "  * Arguments: ";
1288     llvm::interleaveComma(args, llvm::dbgs());
1289     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1290   });
1291 
1292   // Invoke the constraint and jump to the proper destination.
1293   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1294 }
1295 
1296 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1297   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1298   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1299   ArrayAttr constParams = read<ArrayAttr>();
1300   SmallVector<PDLValue, 16> args;
1301   readList<PDLValue>(args);
1302 
1303   LLVM_DEBUG({
1304     llvm::dbgs() << "  * Arguments: ";
1305     llvm::interleaveComma(args, llvm::dbgs());
1306     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1307   });
1308 
1309   // Execute the rewrite function.
1310   ByteCodeField numResults = read();
1311   ByteCodeRewriteResultList results(numResults);
1312   rewriteFn(args, constParams, rewriter, results);
1313 
1314   assert(results.getResults().size() == numResults &&
1315          "native PDL rewrite function returned unexpected number of results");
1316 
1317   // Store the results in the bytecode memory.
1318   for (PDLValue &result : results.getResults()) {
1319     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1320 
1321 // In debug mode we also verify the expected kind of the result.
1322 #ifndef NDEBUG
1323     assert(result.getKind() == read<PDLValue::Kind>() &&
1324            "native PDL rewrite function returned an unexpected type of result");
1325 #endif
1326 
1327     // If the result is a range, we need to copy it over to the bytecodes
1328     // range memory.
1329     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1330       unsigned rangeIndex = read();
1331       typeRangeMemory[rangeIndex] = *typeRange;
1332       memory[read()] = &typeRangeMemory[rangeIndex];
1333     } else if (Optional<ValueRange> valueRange =
1334                    result.dyn_cast<ValueRange>()) {
1335       unsigned rangeIndex = read();
1336       valueRangeMemory[rangeIndex] = *valueRange;
1337       memory[read()] = &valueRangeMemory[rangeIndex];
1338     } else {
1339       memory[read()] = result.getAsOpaquePointer();
1340     }
1341   }
1342 
1343   // Copy over any underlying storage allocated for result ranges.
1344   for (auto &it : results.getAllocatedTypeRanges())
1345     allocatedTypeRangeMemory.push_back(std::move(it));
1346   for (auto &it : results.getAllocatedValueRanges())
1347     allocatedValueRangeMemory.push_back(std::move(it));
1348 }
1349 
1350 void ByteCodeExecutor::executeAreEqual() {
1351   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1352   const void *lhs = read<const void *>();
1353   const void *rhs = read<const void *>();
1354 
1355   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1356   selectJump(lhs == rhs);
1357 }
1358 
1359 void ByteCodeExecutor::executeAreRangesEqual() {
1360   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1361   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1362   const void *lhs = read<const void *>();
1363   const void *rhs = read<const void *>();
1364 
1365   switch (valueKind) {
1366   case PDLValue::Kind::TypeRange: {
1367     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1368     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1369     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1370     selectJump(*lhsRange == *rhsRange);
1371     break;
1372   }
1373   case PDLValue::Kind::ValueRange: {
1374     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1375     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1376     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1377     selectJump(*lhsRange == *rhsRange);
1378     break;
1379   }
1380   default:
1381     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1382   }
1383 }
1384 
1385 void ByteCodeExecutor::executeBranch() {
1386   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1387   curCodeIt = &code[read<ByteCodeAddr>()];
1388 }
1389 
1390 void ByteCodeExecutor::executeCheckOperandCount() {
1391   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1392   Operation *op = read<Operation *>();
1393   uint32_t expectedCount = read<uint32_t>();
1394   bool compareAtLeast = read();
1395 
1396   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1397                           << "  * Expected: " << expectedCount << "\n"
1398                           << "  * Comparator: "
1399                           << (compareAtLeast ? ">=" : "==") << "\n");
1400   if (compareAtLeast)
1401     selectJump(op->getNumOperands() >= expectedCount);
1402   else
1403     selectJump(op->getNumOperands() == expectedCount);
1404 }
1405 
1406 void ByteCodeExecutor::executeCheckOperationName() {
1407   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1408   Operation *op = read<Operation *>();
1409   OperationName expectedName = read<OperationName>();
1410 
1411   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1412                           << "  * Expected: \"" << expectedName << "\"\n");
1413   selectJump(op->getName() == expectedName);
1414 }
1415 
1416 void ByteCodeExecutor::executeCheckResultCount() {
1417   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1418   Operation *op = read<Operation *>();
1419   uint32_t expectedCount = read<uint32_t>();
1420   bool compareAtLeast = read();
1421 
1422   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1423                           << "  * Expected: " << expectedCount << "\n"
1424                           << "  * Comparator: "
1425                           << (compareAtLeast ? ">=" : "==") << "\n");
1426   if (compareAtLeast)
1427     selectJump(op->getNumResults() >= expectedCount);
1428   else
1429     selectJump(op->getNumResults() == expectedCount);
1430 }
1431 
1432 void ByteCodeExecutor::executeCheckTypes() {
1433   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1434   TypeRange *lhs = read<TypeRange *>();
1435   Attribute rhs = read<Attribute>();
1436   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1437 
1438   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1439 }
1440 
1441 void ByteCodeExecutor::executeContinue() {
1442   ByteCodeField level = read();
1443   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1444                           << "  * Level: " << level << "\n");
1445   ++loopIndex[level];
1446   popCodeIt();
1447 }
1448 
1449 void ByteCodeExecutor::executeCreateTypes() {
1450   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1451   unsigned memIndex = read();
1452   unsigned rangeIndex = read();
1453   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1454 
1455   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1456 
1457   // Allocate a buffer for this type range.
1458   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1459   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1460   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1461 
1462   // Assign this to the range slot and use the range as the value for the
1463   // memory index.
1464   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1465   memory[memIndex] = &typeRangeMemory[rangeIndex];
1466 }
1467 
1468 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1469                                               Location mainRewriteLoc) {
1470   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1471 
1472   unsigned memIndex = read();
1473   OperationState state(mainRewriteLoc, read<OperationName>());
1474   readValueList(state.operands);
1475   for (unsigned i = 0, e = read(); i != e; ++i) {
1476     StringAttr name = read<StringAttr>();
1477     if (Attribute attr = read<Attribute>())
1478       state.addAttribute(name, attr);
1479   }
1480 
1481   for (unsigned i = 0, e = read(); i != e; ++i) {
1482     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1483       state.types.push_back(read<Type>());
1484       continue;
1485     }
1486 
1487     // If we find a null range, this signals that the types are infered.
1488     if (TypeRange *resultTypes = read<TypeRange *>()) {
1489       state.types.append(resultTypes->begin(), resultTypes->end());
1490       continue;
1491     }
1492 
1493     // Handle the case where the operation has inferred types.
1494     InferTypeOpInterface::Concept *concept =
1495         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1496 
1497     // TODO: Handle failure.
1498     state.types.clear();
1499     if (failed(concept->inferReturnTypes(
1500             state.getContext(), state.location, state.operands,
1501             state.attributes.getDictionary(state.getContext()), state.regions,
1502             state.types)))
1503       return;
1504     break;
1505   }
1506 
1507   Operation *resultOp = rewriter.createOperation(state);
1508   memory[memIndex] = resultOp;
1509 
1510   LLVM_DEBUG({
1511     llvm::dbgs() << "  * Attributes: "
1512                  << state.attributes.getDictionary(state.getContext())
1513                  << "\n  * Operands: ";
1514     llvm::interleaveComma(state.operands, llvm::dbgs());
1515     llvm::dbgs() << "\n  * Result Types: ";
1516     llvm::interleaveComma(state.types, llvm::dbgs());
1517     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1518   });
1519 }
1520 
1521 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1522   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1523   Operation *op = read<Operation *>();
1524 
1525   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1526   rewriter.eraseOp(op);
1527 }
1528 
1529 template <typename T, typename Range, PDLValue::Kind kind>
1530 void ByteCodeExecutor::executeExtract() {
1531   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1532   Range *range = read<Range *>();
1533   unsigned index = read<uint32_t>();
1534   unsigned memIndex = read();
1535 
1536   if (!range) {
1537     memory[memIndex] = nullptr;
1538     return;
1539   }
1540 
1541   T result = index < range->size() ? (*range)[index] : T();
1542   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1543                           << "  * Index: " << index << "\n"
1544                           << "  * Result: " << result << "\n");
1545   storeToMemory(memIndex, result);
1546 }
1547 
1548 void ByteCodeExecutor::executeFinalize() {
1549   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1550 }
1551 
1552 void ByteCodeExecutor::executeForEach() {
1553   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1554   // Subtract 1 for the op code.
1555   const ByteCodeField *it = curCodeIt - 1;
1556   unsigned rangeIndex = read();
1557   unsigned memIndex = read();
1558   const void *value = nullptr;
1559 
1560   switch (read<PDLValue::Kind>()) {
1561   case PDLValue::Kind::Operation: {
1562     unsigned &index = loopIndex[read()];
1563     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1564     assert(index <= array.size() && "iterated past the end");
1565     if (index < array.size()) {
1566       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1567       value = array[index];
1568       break;
1569     }
1570 
1571     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1572     index = 0;
1573     selectJump(size_t(0));
1574     return;
1575   }
1576   default:
1577     llvm_unreachable("unexpected `ForEach` value kind");
1578   }
1579 
1580   // Store the iterate value and the stack address.
1581   memory[memIndex] = value;
1582   pushCodeIt(it);
1583 
1584   // Skip over the successor (we will enter the body of the loop).
1585   read<ByteCodeAddr>();
1586 }
1587 
1588 void ByteCodeExecutor::executeGetAttribute() {
1589   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1590   unsigned memIndex = read();
1591   Operation *op = read<Operation *>();
1592   StringAttr attrName = read<StringAttr>();
1593   Attribute attr = op->getAttr(attrName);
1594 
1595   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1596                           << "  * Attribute: " << attrName << "\n"
1597                           << "  * Result: " << attr << "\n");
1598   memory[memIndex] = attr.getAsOpaquePointer();
1599 }
1600 
1601 void ByteCodeExecutor::executeGetAttributeType() {
1602   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1603   unsigned memIndex = read();
1604   Attribute attr = read<Attribute>();
1605   Type type = attr ? attr.getType() : Type();
1606 
1607   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1608                           << "  * Result: " << type << "\n");
1609   memory[memIndex] = type.getAsOpaquePointer();
1610 }
1611 
1612 void ByteCodeExecutor::executeGetDefiningOp() {
1613   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1614   unsigned memIndex = read();
1615   Operation *op = nullptr;
1616   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1617     Value value = read<Value>();
1618     if (value)
1619       op = value.getDefiningOp();
1620     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1621   } else {
1622     ValueRange *values = read<ValueRange *>();
1623     if (values && !values->empty()) {
1624       op = values->front().getDefiningOp();
1625     }
1626     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1627   }
1628 
1629   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1630   memory[memIndex] = op;
1631 }
1632 
1633 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1634   Operation *op = read<Operation *>();
1635   unsigned memIndex = read();
1636   Value operand =
1637       index < op->getNumOperands() ? op->getOperand(index) : Value();
1638 
1639   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1640                           << "  * Index: " << index << "\n"
1641                           << "  * Result: " << operand << "\n");
1642   memory[memIndex] = operand.getAsOpaquePointer();
1643 }
1644 
1645 /// This function is the internal implementation of `GetResults` and
1646 /// `GetOperands` that provides support for extracting a value range from the
1647 /// given operation.
1648 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1649 static void *
1650 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1651                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1652                           MutableArrayRef<ValueRange> valueRangeMemory) {
1653   // Check for the sentinel index that signals that all values should be
1654   // returned.
1655   if (index == std::numeric_limits<uint32_t>::max()) {
1656     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1657     // `values` is already the full value range.
1658 
1659     // Otherwise, check to see if this operation uses AttrSizedSegments.
1660   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1661     LLVM_DEBUG(llvm::dbgs()
1662                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1663 
1664     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1665     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1666       return nullptr;
1667 
1668     auto segments = segmentAttr.getValues<int32_t>();
1669     unsigned startIndex =
1670         std::accumulate(segments.begin(), segments.begin() + index, 0);
1671     values = values.slice(startIndex, *std::next(segments.begin(), index));
1672 
1673     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1674                             << *std::next(segments.begin(), index) << "]\n");
1675 
1676     // Otherwise, assume this is the last operand group of the operation.
1677     // FIXME: We currently don't support operations with
1678     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1679     // have a way to detect it's presence.
1680   } else if (values.size() >= index) {
1681     LLVM_DEBUG(llvm::dbgs()
1682                << "  * Treating values as trailing variadic range\n");
1683     values = values.drop_front(index);
1684 
1685     // If we couldn't detect a way to compute the values, bail out.
1686   } else {
1687     return nullptr;
1688   }
1689 
1690   // If the range index is valid, we are returning a range.
1691   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1692     valueRangeMemory[rangeIndex] = values;
1693     return &valueRangeMemory[rangeIndex];
1694   }
1695 
1696   // If a range index wasn't provided, the range is required to be non-variadic.
1697   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1698 }
1699 
1700 void ByteCodeExecutor::executeGetOperands() {
1701   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1702   unsigned index = read<uint32_t>();
1703   Operation *op = read<Operation *>();
1704   ByteCodeField rangeIndex = read();
1705 
1706   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1707       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1708       valueRangeMemory);
1709   if (!result)
1710     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1711   memory[read()] = result;
1712 }
1713 
1714 void ByteCodeExecutor::executeGetResult(unsigned index) {
1715   Operation *op = read<Operation *>();
1716   unsigned memIndex = read();
1717   OpResult result =
1718       index < op->getNumResults() ? op->getResult(index) : OpResult();
1719 
1720   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1721                           << "  * Index: " << index << "\n"
1722                           << "  * Result: " << result << "\n");
1723   memory[memIndex] = result.getAsOpaquePointer();
1724 }
1725 
1726 void ByteCodeExecutor::executeGetResults() {
1727   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1728   unsigned index = read<uint32_t>();
1729   Operation *op = read<Operation *>();
1730   ByteCodeField rangeIndex = read();
1731 
1732   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1733       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1734       valueRangeMemory);
1735   if (!result)
1736     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1737   memory[read()] = result;
1738 }
1739 
1740 void ByteCodeExecutor::executeGetUsers() {
1741   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1742   unsigned memIndex = read();
1743   unsigned rangeIndex = read();
1744   OwningOpRange &range = opRangeMemory[rangeIndex];
1745   memory[memIndex] = &range;
1746 
1747   range = OwningOpRange();
1748   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1749     // Read the value.
1750     Value value = read<Value>();
1751     if (!value)
1752       return;
1753     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1754 
1755     // Extract the users of a single value.
1756     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1757     llvm::copy(value.getUsers(), range.begin());
1758   } else {
1759     // Read a range of values.
1760     ValueRange *values = read<ValueRange *>();
1761     if (!values)
1762       return;
1763     LLVM_DEBUG({
1764       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1765       llvm::interleaveComma(*values, llvm::dbgs());
1766       llvm::dbgs() << "\n";
1767     });
1768 
1769     // Extract all the users of a range of values.
1770     SmallVector<Operation *> users;
1771     for (Value value : *values)
1772       users.append(value.user_begin(), value.user_end());
1773     range = OwningOpRange(users.size());
1774     llvm::copy(users, range.begin());
1775   }
1776 
1777   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1778 }
1779 
1780 void ByteCodeExecutor::executeGetValueType() {
1781   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1782   unsigned memIndex = read();
1783   Value value = read<Value>();
1784   Type type = value ? value.getType() : Type();
1785 
1786   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1787                           << "  * Result: " << type << "\n");
1788   memory[memIndex] = type.getAsOpaquePointer();
1789 }
1790 
1791 void ByteCodeExecutor::executeGetValueRangeTypes() {
1792   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1793   unsigned memIndex = read();
1794   unsigned rangeIndex = read();
1795   ValueRange *values = read<ValueRange *>();
1796   if (!values) {
1797     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1798     memory[memIndex] = nullptr;
1799     return;
1800   }
1801 
1802   LLVM_DEBUG({
1803     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1804     llvm::interleaveComma(*values, llvm::dbgs());
1805     llvm::dbgs() << "\n  * Result: ";
1806     llvm::interleaveComma(values->getType(), llvm::dbgs());
1807     llvm::dbgs() << "\n";
1808   });
1809   typeRangeMemory[rangeIndex] = values->getType();
1810   memory[memIndex] = &typeRangeMemory[rangeIndex];
1811 }
1812 
1813 void ByteCodeExecutor::executeIsNotNull() {
1814   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1815   const void *value = read<const void *>();
1816 
1817   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1818   selectJump(value != nullptr);
1819 }
1820 
1821 void ByteCodeExecutor::executeRecordMatch(
1822     PatternRewriter &rewriter,
1823     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1824   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1825   unsigned patternIndex = read();
1826   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1827   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1828 
1829   // If the benefit of the pattern is impossible, skip the processing of the
1830   // rest of the pattern.
1831   if (benefit.isImpossibleToMatch()) {
1832     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1833     curCodeIt = dest;
1834     return;
1835   }
1836 
1837   // Create a fused location containing the locations of each of the
1838   // operations used in the match. This will be used as the location for
1839   // created operations during the rewrite that don't already have an
1840   // explicit location set.
1841   unsigned numMatchLocs = read();
1842   SmallVector<Location, 4> matchLocs;
1843   matchLocs.reserve(numMatchLocs);
1844   for (unsigned i = 0; i != numMatchLocs; ++i)
1845     matchLocs.push_back(read<Operation *>()->getLoc());
1846   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1847 
1848   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1849                           << "  * Location: " << matchLoc << "\n");
1850   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1851   PDLByteCode::MatchResult &match = matches.back();
1852 
1853   // Record all of the inputs to the match. If any of the inputs are ranges, we
1854   // will also need to remap the range pointer to memory stored in the match
1855   // state.
1856   unsigned numInputs = read();
1857   match.values.reserve(numInputs);
1858   match.typeRangeValues.reserve(numInputs);
1859   match.valueRangeValues.reserve(numInputs);
1860   for (unsigned i = 0; i < numInputs; ++i) {
1861     switch (read<PDLValue::Kind>()) {
1862     case PDLValue::Kind::TypeRange:
1863       match.typeRangeValues.push_back(*read<TypeRange *>());
1864       match.values.push_back(&match.typeRangeValues.back());
1865       break;
1866     case PDLValue::Kind::ValueRange:
1867       match.valueRangeValues.push_back(*read<ValueRange *>());
1868       match.values.push_back(&match.valueRangeValues.back());
1869       break;
1870     default:
1871       match.values.push_back(read<const void *>());
1872       break;
1873     }
1874   }
1875   curCodeIt = dest;
1876 }
1877 
1878 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1879   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1880   Operation *op = read<Operation *>();
1881   SmallVector<Value, 16> args;
1882   readValueList(args);
1883 
1884   LLVM_DEBUG({
1885     llvm::dbgs() << "  * Operation: " << *op << "\n"
1886                  << "  * Values: ";
1887     llvm::interleaveComma(args, llvm::dbgs());
1888     llvm::dbgs() << "\n";
1889   });
1890   rewriter.replaceOp(op, args);
1891 }
1892 
1893 void ByteCodeExecutor::executeSwitchAttribute() {
1894   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1895   Attribute value = read<Attribute>();
1896   ArrayAttr cases = read<ArrayAttr>();
1897   handleSwitch(value, cases);
1898 }
1899 
1900 void ByteCodeExecutor::executeSwitchOperandCount() {
1901   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1902   Operation *op = read<Operation *>();
1903   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1904 
1905   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1906   handleSwitch(op->getNumOperands(), cases);
1907 }
1908 
1909 void ByteCodeExecutor::executeSwitchOperationName() {
1910   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1911   OperationName value = read<Operation *>()->getName();
1912   size_t caseCount = read();
1913 
1914   // The operation names are stored in-line, so to print them out for
1915   // debugging purposes we need to read the array before executing the
1916   // switch so that we can display all of the possible values.
1917   LLVM_DEBUG({
1918     const ByteCodeField *prevCodeIt = curCodeIt;
1919     llvm::dbgs() << "  * Value: " << value << "\n"
1920                  << "  * Cases: ";
1921     llvm::interleaveComma(
1922         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1923                         [&](size_t) { return read<OperationName>(); }),
1924         llvm::dbgs());
1925     llvm::dbgs() << "\n";
1926     curCodeIt = prevCodeIt;
1927   });
1928 
1929   // Try to find the switch value within any of the cases.
1930   for (size_t i = 0; i != caseCount; ++i) {
1931     if (read<OperationName>() == value) {
1932       curCodeIt += (caseCount - i - 1);
1933       return selectJump(i + 1);
1934     }
1935   }
1936   selectJump(size_t(0));
1937 }
1938 
1939 void ByteCodeExecutor::executeSwitchResultCount() {
1940   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1941   Operation *op = read<Operation *>();
1942   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1943 
1944   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1945   handleSwitch(op->getNumResults(), cases);
1946 }
1947 
1948 void ByteCodeExecutor::executeSwitchType() {
1949   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1950   Type value = read<Type>();
1951   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1952   handleSwitch(value, cases);
1953 }
1954 
1955 void ByteCodeExecutor::executeSwitchTypes() {
1956   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
1957   TypeRange *value = read<TypeRange *>();
1958   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
1959   if (!value) {
1960     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
1961     return selectJump(size_t(0));
1962   }
1963   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
1964     return value == caseValue.getAsValueRange<TypeAttr>();
1965   });
1966 }
1967 
1968 void ByteCodeExecutor::execute(
1969     PatternRewriter &rewriter,
1970     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1971     Optional<Location> mainRewriteLoc) {
1972   while (true) {
1973     OpCode opCode = static_cast<OpCode>(read());
1974     switch (opCode) {
1975     case ApplyConstraint:
1976       executeApplyConstraint(rewriter);
1977       break;
1978     case ApplyRewrite:
1979       executeApplyRewrite(rewriter);
1980       break;
1981     case AreEqual:
1982       executeAreEqual();
1983       break;
1984     case AreRangesEqual:
1985       executeAreRangesEqual();
1986       break;
1987     case Branch:
1988       executeBranch();
1989       break;
1990     case CheckOperandCount:
1991       executeCheckOperandCount();
1992       break;
1993     case CheckOperationName:
1994       executeCheckOperationName();
1995       break;
1996     case CheckResultCount:
1997       executeCheckResultCount();
1998       break;
1999     case CheckTypes:
2000       executeCheckTypes();
2001       break;
2002     case Continue:
2003       executeContinue();
2004       break;
2005     case CreateOperation:
2006       executeCreateOperation(rewriter, *mainRewriteLoc);
2007       break;
2008     case CreateTypes:
2009       executeCreateTypes();
2010       break;
2011     case EraseOp:
2012       executeEraseOp(rewriter);
2013       break;
2014     case ExtractOp:
2015       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2016       break;
2017     case ExtractType:
2018       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2019       break;
2020     case ExtractValue:
2021       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2022       break;
2023     case Finalize:
2024       executeFinalize();
2025       LLVM_DEBUG(llvm::dbgs() << "\n");
2026       return;
2027     case ForEach:
2028       executeForEach();
2029       break;
2030     case GetAttribute:
2031       executeGetAttribute();
2032       break;
2033     case GetAttributeType:
2034       executeGetAttributeType();
2035       break;
2036     case GetDefiningOp:
2037       executeGetDefiningOp();
2038       break;
2039     case GetOperand0:
2040     case GetOperand1:
2041     case GetOperand2:
2042     case GetOperand3: {
2043       unsigned index = opCode - GetOperand0;
2044       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2045       executeGetOperand(index);
2046       break;
2047     }
2048     case GetOperandN:
2049       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2050       executeGetOperand(read<uint32_t>());
2051       break;
2052     case GetOperands:
2053       executeGetOperands();
2054       break;
2055     case GetResult0:
2056     case GetResult1:
2057     case GetResult2:
2058     case GetResult3: {
2059       unsigned index = opCode - GetResult0;
2060       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2061       executeGetResult(index);
2062       break;
2063     }
2064     case GetResultN:
2065       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2066       executeGetResult(read<uint32_t>());
2067       break;
2068     case GetResults:
2069       executeGetResults();
2070       break;
2071     case GetUsers:
2072       executeGetUsers();
2073       break;
2074     case GetValueType:
2075       executeGetValueType();
2076       break;
2077     case GetValueRangeTypes:
2078       executeGetValueRangeTypes();
2079       break;
2080     case IsNotNull:
2081       executeIsNotNull();
2082       break;
2083     case RecordMatch:
2084       assert(matches &&
2085              "expected matches to be provided when executing the matcher");
2086       executeRecordMatch(rewriter, *matches);
2087       break;
2088     case ReplaceOp:
2089       executeReplaceOp(rewriter);
2090       break;
2091     case SwitchAttribute:
2092       executeSwitchAttribute();
2093       break;
2094     case SwitchOperandCount:
2095       executeSwitchOperandCount();
2096       break;
2097     case SwitchOperationName:
2098       executeSwitchOperationName();
2099       break;
2100     case SwitchResultCount:
2101       executeSwitchResultCount();
2102       break;
2103     case SwitchType:
2104       executeSwitchType();
2105       break;
2106     case SwitchTypes:
2107       executeSwitchTypes();
2108       break;
2109     }
2110     LLVM_DEBUG(llvm::dbgs() << "\n");
2111   }
2112 }
2113 
2114 /// Run the pattern matcher on the given root operation, collecting the matched
2115 /// patterns in `matches`.
2116 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2117                         SmallVectorImpl<MatchResult> &matches,
2118                         PDLByteCodeMutableState &state) const {
2119   // The first memory slot is always the root operation.
2120   state.memory[0] = op;
2121 
2122   // The matcher function always starts at code address 0.
2123   ByteCodeExecutor executor(
2124       matcherByteCode.data(), state.memory, state.opRangeMemory,
2125       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2126       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2127       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2128       constraintFunctions, rewriteFunctions);
2129   executor.execute(rewriter, &matches);
2130 
2131   // Order the found matches by benefit.
2132   std::stable_sort(matches.begin(), matches.end(),
2133                    [](const MatchResult &lhs, const MatchResult &rhs) {
2134                      return lhs.benefit > rhs.benefit;
2135                    });
2136 }
2137 
2138 /// Run the rewriter of the given pattern on the root operation `op`.
2139 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2140                           PDLByteCodeMutableState &state) const {
2141   // The arguments of the rewrite function are stored at the start of the
2142   // memory buffer.
2143   llvm::copy(match.values, state.memory.begin());
2144 
2145   ByteCodeExecutor executor(
2146       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2147       state.opRangeMemory, state.typeRangeMemory,
2148       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2149       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2150       rewriterByteCode, state.currentPatternBenefits, patterns,
2151       constraintFunctions, rewriteFunctions);
2152   executor.execute(rewriter, /*matches=*/nullptr, match.location);
2153 }
2154