1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.benefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.rootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Create an operation.
99   CreateOperation,
100   /// Create a range of types.
101   CreateTypes,
102   /// Erase an operation.
103   EraseOp,
104   /// Terminate a matcher or rewrite sequence.
105   Finalize,
106   /// Get a specific attribute of an operation.
107   GetAttribute,
108   /// Get the type of an attribute.
109   GetAttributeType,
110   /// Get the defining operation of a value.
111   GetDefiningOp,
112   /// Get a specific operand of an operation.
113   GetOperand0,
114   GetOperand1,
115   GetOperand2,
116   GetOperand3,
117   GetOperandN,
118   /// Get a specific operand group of an operation.
119   GetOperands,
120   /// Get a specific result of an operation.
121   GetResult0,
122   GetResult1,
123   GetResult2,
124   GetResult3,
125   GetResultN,
126   /// Get a specific result group of an operation.
127   GetResults,
128   /// Get the type of a value.
129   GetValueType,
130   /// Get the types of a value range.
131   GetValueRangeTypes,
132   /// Check if a generic value is not null.
133   IsNotNull,
134   /// Record a successful pattern match.
135   RecordMatch,
136   /// Replace an operation.
137   ReplaceOp,
138   /// Compare an attribute with a set of constants.
139   SwitchAttribute,
140   /// Compare the operand count of an operation with a set of constants.
141   SwitchOperandCount,
142   /// Compare the name of an operation with a set of constants.
143   SwitchOperationName,
144   /// Compare the result count of an operation with a set of constants.
145   SwitchResultCount,
146   /// Compare a type with a set of constants.
147   SwitchType,
148   /// Compare a range of types with a set of constants.
149   SwitchTypes,
150 };
151 } // end anonymous namespace
152 
153 //===----------------------------------------------------------------------===//
154 // ByteCode Generation
155 //===----------------------------------------------------------------------===//
156 
157 //===----------------------------------------------------------------------===//
158 // Generator
159 
160 namespace {
161 struct ByteCodeWriter;
162 
163 /// This class represents the main generator for the pattern bytecode.
164 class Generator {
165 public:
166   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
167             SmallVectorImpl<ByteCodeField> &matcherByteCode,
168             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
169             SmallVectorImpl<PDLByteCodePattern> &patterns,
170             ByteCodeField &maxValueMemoryIndex,
171             ByteCodeField &maxTypeRangeMemoryIndex,
172             ByteCodeField &maxValueRangeMemoryIndex,
173             llvm::StringMap<PDLConstraintFunction> &constraintFns,
174             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
175       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
176         rewriterByteCode(rewriterByteCode), patterns(patterns),
177         maxValueMemoryIndex(maxValueMemoryIndex),
178         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
179         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
180     for (auto it : llvm::enumerate(constraintFns))
181       constraintToMemIndex.try_emplace(it.value().first(), it.index());
182     for (auto it : llvm::enumerate(rewriteFns))
183       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
184   }
185 
186   /// Generate the bytecode for the given PDL interpreter module.
187   void generate(ModuleOp module);
188 
189   /// Return the memory index to use for the given value.
190   ByteCodeField &getMemIndex(Value value) {
191     assert(valueToMemIndex.count(value) &&
192            "expected memory index to be assigned");
193     return valueToMemIndex[value];
194   }
195 
196   /// Return the range memory index used to store the given range value.
197   ByteCodeField &getRangeStorageIndex(Value value) {
198     assert(valueToRangeIndex.count(value) &&
199            "expected range index to be assigned");
200     return valueToRangeIndex[value];
201   }
202 
203   /// Return an index to use when referring to the given data that is uniqued in
204   /// the MLIR context.
205   template <typename T>
206   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
207   getMemIndex(T val) {
208     const void *opaqueVal = val.getAsOpaquePointer();
209 
210     // Get or insert a reference to this value.
211     auto it = uniquedDataToMemIndex.try_emplace(
212         opaqueVal, maxValueMemoryIndex + uniquedData.size());
213     if (it.second)
214       uniquedData.push_back(opaqueVal);
215     return it.first->second;
216   }
217 
218 private:
219   /// Allocate memory indices for the results of operations within the matcher
220   /// and rewriters.
221   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
222 
223   /// Generate the bytecode for the given operation.
224   void generate(Operation *op, ByteCodeWriter &writer);
225   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
226   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
227   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
228   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
229   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
230   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
231   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
232   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
233   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
234   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
235   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
236   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
237   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
238   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
239   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
240   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
241   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
242   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
243   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
244   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
245   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
246   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
247   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
248   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
249   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
250   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
251   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
252   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
259 
260   /// Mapping from value to its corresponding memory index.
261   DenseMap<Value, ByteCodeField> valueToMemIndex;
262 
263   /// Mapping from a range value to its corresponding range storage index.
264   DenseMap<Value, ByteCodeField> valueToRangeIndex;
265 
266   /// Mapping from the name of an externally registered rewrite to its index in
267   /// the bytecode registry.
268   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
269 
270   /// Mapping from the name of an externally registered constraint to its index
271   /// in the bytecode registry.
272   llvm::StringMap<ByteCodeField> constraintToMemIndex;
273 
274   /// Mapping from rewriter function name to the bytecode address of the
275   /// rewriter function in byte.
276   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
277 
278   /// Mapping from a uniqued storage object to its memory index within
279   /// `uniquedData`.
280   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
281 
282   /// The current MLIR context.
283   MLIRContext *ctx;
284 
285   /// Data of the ByteCode class to be populated.
286   std::vector<const void *> &uniquedData;
287   SmallVectorImpl<ByteCodeField> &matcherByteCode;
288   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
289   SmallVectorImpl<PDLByteCodePattern> &patterns;
290   ByteCodeField &maxValueMemoryIndex;
291   ByteCodeField &maxTypeRangeMemoryIndex;
292   ByteCodeField &maxValueRangeMemoryIndex;
293 };
294 
295 /// This class provides utilities for writing a bytecode stream.
296 struct ByteCodeWriter {
297   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
298       : bytecode(bytecode), generator(generator) {}
299 
300   /// Append a field to the bytecode.
301   void append(ByteCodeField field) { bytecode.push_back(field); }
302   void append(OpCode opCode) { bytecode.push_back(opCode); }
303 
304   /// Append an address to the bytecode.
305   void append(ByteCodeAddr field) {
306     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
307                   "unexpected ByteCode address size");
308 
309     ByteCodeField fieldParts[2];
310     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
311     bytecode.append({fieldParts[0], fieldParts[1]});
312   }
313 
314   /// Append a successor range to the bytecode, the exact address will need to
315   /// be resolved later.
316   void append(SuccessorRange successors) {
317     // Add back references to the any successors so that the address can be
318     // resolved later.
319     for (Block *successor : successors) {
320       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
321       append(ByteCodeAddr(0));
322     }
323   }
324 
325   /// Append a range of values that will be read as generic PDLValues.
326   void appendPDLValueList(OperandRange values) {
327     bytecode.push_back(values.size());
328     for (Value value : values)
329       appendPDLValue(value);
330   }
331 
332   /// Append a value as a PDLValue.
333   void appendPDLValue(Value value) {
334     appendPDLValueKind(value);
335     append(value);
336   }
337 
338   /// Append the PDLValue::Kind of the given value.
339   void appendPDLValueKind(Value value) {
340     // Append the type of the value in addition to the value itself.
341     PDLValue::Kind kind =
342         TypeSwitch<Type, PDLValue::Kind>(value.getType())
343             .Case<pdl::AttributeType>(
344                 [](Type) { return PDLValue::Kind::Attribute; })
345             .Case<pdl::OperationType>(
346                 [](Type) { return PDLValue::Kind::Operation; })
347             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
348               if (rangeTy.getElementType().isa<pdl::TypeType>())
349                 return PDLValue::Kind::TypeRange;
350               return PDLValue::Kind::ValueRange;
351             })
352             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
353             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
354     bytecode.push_back(static_cast<ByteCodeField>(kind));
355   }
356 
357   /// Check if the given class `T` has an iterator type.
358   template <typename T, typename... Args>
359   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
360 
361   /// Append a value that will be stored in a memory slot and not inline within
362   /// the bytecode.
363   template <typename T>
364   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
365                    std::is_pointer<T>::value>
366   append(T value) {
367     bytecode.push_back(generator.getMemIndex(value));
368   }
369 
370   /// Append a range of values.
371   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
372   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
373   append(T range) {
374     bytecode.push_back(llvm::size(range));
375     for (auto it : range)
376       append(it);
377   }
378 
379   /// Append a variadic number of fields to the bytecode.
380   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
381   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
382     append(field);
383     append(field2, fields...);
384   }
385 
386   /// Successor references in the bytecode that have yet to be resolved.
387   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
388 
389   /// The underlying bytecode buffer.
390   SmallVectorImpl<ByteCodeField> &bytecode;
391 
392   /// The main generator producing PDL.
393   Generator &generator;
394 };
395 
396 /// This class represents a live range of PDL Interpreter values, containing
397 /// information about when values are live within a match/rewrite.
398 struct ByteCodeLiveRange {
399   using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
400   using Allocator = Set::Allocator;
401 
402   ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
403 
404   /// Union this live range with the one provided.
405   void unionWith(const ByteCodeLiveRange &rhs) {
406     for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
407       liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
408   }
409 
410   /// Returns true if this range overlaps with the one provided.
411   bool overlaps(const ByteCodeLiveRange &rhs) const {
412     return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
413   }
414 
415   /// A map representing the ranges of the match/rewrite that a value is live in
416   /// the interpreter.
417   llvm::IntervalMap<ByteCodeField, char, 16> liveness;
418 
419   /// The type range storage index for this range.
420   Optional<unsigned> typeRangeIndex;
421 
422   /// The value range storage index for this range.
423   Optional<unsigned> valueRangeIndex;
424 };
425 } // end anonymous namespace
426 
427 void Generator::generate(ModuleOp module) {
428   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
429       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
430   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
431       pdl_interp::PDLInterpDialect::getRewriterModuleName());
432   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
433 
434   // Allocate memory indices for the results of operations within the matcher
435   // and rewriters.
436   allocateMemoryIndices(matcherFunc, rewriterModule);
437 
438   // Generate code for the rewriter functions.
439   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
440   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
441     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
442     for (Operation &op : rewriterFunc.getOps())
443       generate(&op, rewriterByteCodeWriter);
444   }
445   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
446          "unexpected branches in rewriter function");
447 
448   // Generate code for the matcher function.
449   DenseMap<Block *, ByteCodeAddr> blockToAddr;
450   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
451   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
452   for (Block *block : rpot) {
453     // Keep track of where this block begins within the matcher function.
454     blockToAddr.try_emplace(block, matcherByteCode.size());
455     for (Operation &op : *block)
456       generate(&op, matcherByteCodeWriter);
457   }
458 
459   // Resolve successor references in the matcher.
460   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
461     ByteCodeAddr addr = blockToAddr[it.first];
462     for (unsigned offsetToFix : it.second)
463       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
464   }
465 }
466 
467 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
468                                       ModuleOp rewriterModule) {
469   // Rewriters use simplistic allocation scheme that simply assigns an index to
470   // each result.
471   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
472     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
473     auto processRewriterValue = [&](Value val) {
474       valueToMemIndex.try_emplace(val, index++);
475       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
476         Type elementTy = rangeType.getElementType();
477         if (elementTy.isa<pdl::TypeType>())
478           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
479         else if (elementTy.isa<pdl::ValueType>())
480           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
481       }
482     };
483 
484     for (BlockArgument arg : rewriterFunc.getArguments())
485       processRewriterValue(arg);
486     rewriterFunc.getBody().walk([&](Operation *op) {
487       for (Value result : op->getResults())
488         processRewriterValue(result);
489     });
490     if (index > maxValueMemoryIndex)
491       maxValueMemoryIndex = index;
492     if (typeRangeIndex > maxTypeRangeMemoryIndex)
493       maxTypeRangeMemoryIndex = typeRangeIndex;
494     if (valueRangeIndex > maxValueRangeMemoryIndex)
495       maxValueRangeMemoryIndex = valueRangeIndex;
496   }
497 
498   // The matcher function uses a more sophisticated numbering that tries to
499   // minimize the number of memory indices assigned. This is done by determining
500   // a live range of the values within the matcher, then the allocation is just
501   // finding the minimal number of overlapping live ranges. This is essentially
502   // a simplified form of register allocation where we don't necessarily have a
503   // limited number of registers, but we still want to minimize the number used.
504   DenseMap<Operation *, ByteCodeField> opToIndex;
505   matcherFunc.getBody().walk([&](Operation *op) {
506     opToIndex.insert(std::make_pair(op, opToIndex.size()));
507   });
508 
509   // Liveness info for each of the defs within the matcher.
510   ByteCodeLiveRange::Allocator allocator;
511   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
512 
513   // Assign the root operation being matched to slot 0.
514   BlockArgument rootOpArg = matcherFunc.getArgument(0);
515   valueToMemIndex[rootOpArg] = 0;
516 
517   // Walk each of the blocks, computing the def interval that the value is used.
518   Liveness matcherLiveness(matcherFunc);
519   for (Block &block : matcherFunc.getBody()) {
520     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
521     assert(info && "expected liveness info for block");
522     auto processValue = [&](Value value, Operation *firstUseOrDef) {
523       // We don't need to process the root op argument, this value is always
524       // assigned to the first memory slot.
525       if (value == rootOpArg)
526         return;
527 
528       // Set indices for the range of this block that the value is used.
529       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
530       defRangeIt->second.liveness.insert(
531           opToIndex[firstUseOrDef],
532           opToIndex[info->getEndOperation(value, firstUseOrDef)],
533           /*dummyValue*/ 0);
534 
535       // Check to see if this value is a range type.
536       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
537         Type eleType = rangeTy.getElementType();
538         if (eleType.isa<pdl::TypeType>())
539           defRangeIt->second.typeRangeIndex = 0;
540         else if (eleType.isa<pdl::ValueType>())
541           defRangeIt->second.valueRangeIndex = 0;
542       }
543     };
544 
545     // Process the live-ins of this block.
546     for (Value liveIn : info->in())
547       processValue(liveIn, &block.front());
548 
549     // Process any new defs within this block.
550     for (Operation &op : block)
551       for (Value result : op.getResults())
552         processValue(result, &op);
553   }
554 
555   // Greedily allocate memory slots using the computed def live ranges.
556   std::vector<ByteCodeLiveRange> allocatedIndices;
557   ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
558   for (auto &defIt : valueDefRanges) {
559     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
560     ByteCodeLiveRange &defRange = defIt.second;
561 
562     // Try to allocate to an existing index.
563     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
564       ByteCodeLiveRange &existingRange = existingIndexIt.value();
565       if (!defRange.overlaps(existingRange)) {
566         existingRange.unionWith(defRange);
567         memIndex = existingIndexIt.index() + 1;
568 
569         if (defRange.typeRangeIndex) {
570           if (!existingRange.typeRangeIndex)
571             existingRange.typeRangeIndex = numTypeRanges++;
572           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
573         } else if (defRange.valueRangeIndex) {
574           if (!existingRange.valueRangeIndex)
575             existingRange.valueRangeIndex = numValueRanges++;
576           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
577         }
578         break;
579       }
580     }
581 
582     // If no existing index could be used, add a new one.
583     if (memIndex == 0) {
584       allocatedIndices.emplace_back(allocator);
585       ByteCodeLiveRange &newRange = allocatedIndices.back();
586       newRange.unionWith(defRange);
587 
588       // Allocate an index for type/value ranges.
589       if (defRange.typeRangeIndex) {
590         newRange.typeRangeIndex = numTypeRanges;
591         valueToRangeIndex[defIt.first] = numTypeRanges++;
592       } else if (defRange.valueRangeIndex) {
593         newRange.valueRangeIndex = numValueRanges;
594         valueToRangeIndex[defIt.first] = numValueRanges++;
595       }
596 
597       memIndex = allocatedIndices.size();
598       ++numIndices;
599     }
600   }
601 
602   // Update the max number of indices.
603   if (numIndices > maxValueMemoryIndex)
604     maxValueMemoryIndex = numIndices;
605   if (numTypeRanges > maxTypeRangeMemoryIndex)
606     maxTypeRangeMemoryIndex = numTypeRanges;
607   if (numValueRanges > maxValueRangeMemoryIndex)
608     maxValueRangeMemoryIndex = numValueRanges;
609 }
610 
611 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
612   TypeSwitch<Operation *>(op)
613       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
614             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
615             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
616             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
617             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
618             pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
619             pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
620             pdl_interp::EraseOp, pdl_interp::FinalizeOp,
621             pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
622             pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
623             pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
624             pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
625             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
626             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
627             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
628             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
629             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
630           [&](auto interpOp) { this->generate(interpOp, writer); })
631       .Default([](Operation *) {
632         llvm_unreachable("unknown `pdl_interp` operation");
633       });
634 }
635 
636 void Generator::generate(pdl_interp::ApplyConstraintOp op,
637                          ByteCodeWriter &writer) {
638   assert(constraintToMemIndex.count(op.name()) &&
639          "expected index for constraint function");
640   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
641                 op.constParamsAttr());
642   writer.appendPDLValueList(op.args());
643   writer.append(op.getSuccessors());
644 }
645 void Generator::generate(pdl_interp::ApplyRewriteOp op,
646                          ByteCodeWriter &writer) {
647   assert(externalRewriterToMemIndex.count(op.name()) &&
648          "expected index for rewrite function");
649   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
650                 op.constParamsAttr());
651   writer.appendPDLValueList(op.args());
652 
653   ResultRange results = op.results();
654   writer.append(ByteCodeField(results.size()));
655   for (Value result : results) {
656     // In debug mode we also record the expected kind of the result, so that we
657     // can provide extra verification of the native rewrite function.
658 #ifndef NDEBUG
659     writer.appendPDLValueKind(result);
660 #endif
661 
662     // Range results also need to append the range storage index.
663     if (result.getType().isa<pdl::RangeType>())
664       writer.append(getRangeStorageIndex(result));
665     writer.append(result);
666   }
667 }
668 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
669   Value lhs = op.lhs();
670   if (lhs.getType().isa<pdl::RangeType>()) {
671     writer.append(OpCode::AreRangesEqual);
672     writer.appendPDLValueKind(lhs);
673     writer.append(op.lhs(), op.rhs(), op.getSuccessors());
674     return;
675   }
676 
677   writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
678 }
679 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
680   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
681 }
682 void Generator::generate(pdl_interp::CheckAttributeOp op,
683                          ByteCodeWriter &writer) {
684   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
685                 op.getSuccessors());
686 }
687 void Generator::generate(pdl_interp::CheckOperandCountOp op,
688                          ByteCodeWriter &writer) {
689   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
690                 static_cast<ByteCodeField>(op.compareAtLeast()),
691                 op.getSuccessors());
692 }
693 void Generator::generate(pdl_interp::CheckOperationNameOp op,
694                          ByteCodeWriter &writer) {
695   writer.append(OpCode::CheckOperationName, op.operation(),
696                 OperationName(op.name(), ctx), op.getSuccessors());
697 }
698 void Generator::generate(pdl_interp::CheckResultCountOp op,
699                          ByteCodeWriter &writer) {
700   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
701                 static_cast<ByteCodeField>(op.compareAtLeast()),
702                 op.getSuccessors());
703 }
704 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
705   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
706 }
707 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
708   writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
709 }
710 void Generator::generate(pdl_interp::CreateAttributeOp op,
711                          ByteCodeWriter &writer) {
712   // Simply repoint the memory index of the result to the constant.
713   getMemIndex(op.attribute()) = getMemIndex(op.value());
714 }
715 void Generator::generate(pdl_interp::CreateOperationOp op,
716                          ByteCodeWriter &writer) {
717   writer.append(OpCode::CreateOperation, op.operation(),
718                 OperationName(op.name(), ctx));
719   writer.appendPDLValueList(op.operands());
720 
721   // Add the attributes.
722   OperandRange attributes = op.attributes();
723   writer.append(static_cast<ByteCodeField>(attributes.size()));
724   for (auto it : llvm::zip(op.attributeNames(), op.attributes()))
725     writer.append(std::get<0>(it), std::get<1>(it));
726   writer.appendPDLValueList(op.types());
727 }
728 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
729   // Simply repoint the memory index of the result to the constant.
730   getMemIndex(op.result()) = getMemIndex(op.value());
731 }
732 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
733   writer.append(OpCode::CreateTypes, op.result(),
734                 getRangeStorageIndex(op.result()), op.value());
735 }
736 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
737   writer.append(OpCode::EraseOp, op.operation());
738 }
739 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
740   writer.append(OpCode::Finalize);
741 }
742 void Generator::generate(pdl_interp::GetAttributeOp op,
743                          ByteCodeWriter &writer) {
744   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
745                 op.nameAttr());
746 }
747 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
748                          ByteCodeWriter &writer) {
749   writer.append(OpCode::GetAttributeType, op.result(), op.value());
750 }
751 void Generator::generate(pdl_interp::GetDefiningOpOp op,
752                          ByteCodeWriter &writer) {
753   writer.append(OpCode::GetDefiningOp, op.operation());
754   writer.appendPDLValue(op.value());
755 }
756 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
757   uint32_t index = op.index();
758   if (index < 4)
759     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
760   else
761     writer.append(OpCode::GetOperandN, index);
762   writer.append(op.operation(), op.value());
763 }
764 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
765   Value result = op.value();
766   Optional<uint32_t> index = op.index();
767   writer.append(OpCode::GetOperands,
768                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
769                 op.operation());
770   if (result.getType().isa<pdl::RangeType>())
771     writer.append(getRangeStorageIndex(result));
772   else
773     writer.append(std::numeric_limits<ByteCodeField>::max());
774   writer.append(result);
775 }
776 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
777   uint32_t index = op.index();
778   if (index < 4)
779     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
780   else
781     writer.append(OpCode::GetResultN, index);
782   writer.append(op.operation(), op.value());
783 }
784 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
785   Value result = op.value();
786   Optional<uint32_t> index = op.index();
787   writer.append(OpCode::GetResults,
788                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
789                 op.operation());
790   if (result.getType().isa<pdl::RangeType>())
791     writer.append(getRangeStorageIndex(result));
792   else
793     writer.append(std::numeric_limits<ByteCodeField>::max());
794   writer.append(result);
795 }
796 void Generator::generate(pdl_interp::GetValueTypeOp op,
797                          ByteCodeWriter &writer) {
798   if (op.getType().isa<pdl::RangeType>()) {
799     Value result = op.result();
800     writer.append(OpCode::GetValueRangeTypes, result,
801                   getRangeStorageIndex(result), op.value());
802   } else {
803     writer.append(OpCode::GetValueType, op.result(), op.value());
804   }
805 }
806 
807 void Generator::generate(pdl_interp::InferredTypesOp op,
808                          ByteCodeWriter &writer) {
809   // InferType maps to a null type as a marker for inferring result types.
810   getMemIndex(op.type()) = getMemIndex(Type());
811 }
812 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
813   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
814 }
815 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
816   ByteCodeField patternIndex = patterns.size();
817   patterns.emplace_back(PDLByteCodePattern::create(
818       op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
819   writer.append(OpCode::RecordMatch, patternIndex,
820                 SuccessorRange(op.getOperation()), op.matchedOps());
821   writer.appendPDLValueList(op.inputs());
822 }
823 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
824   writer.append(OpCode::ReplaceOp, op.operation());
825   writer.appendPDLValueList(op.replValues());
826 }
827 void Generator::generate(pdl_interp::SwitchAttributeOp op,
828                          ByteCodeWriter &writer) {
829   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
830                 op.getSuccessors());
831 }
832 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
833                          ByteCodeWriter &writer) {
834   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
835                 op.getSuccessors());
836 }
837 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
838                          ByteCodeWriter &writer) {
839   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
840     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
841   });
842   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
843                 op.getSuccessors());
844 }
845 void Generator::generate(pdl_interp::SwitchResultCountOp op,
846                          ByteCodeWriter &writer) {
847   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
848                 op.getSuccessors());
849 }
850 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
851   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
852                 op.getSuccessors());
853 }
854 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
855   writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
856                 op.getSuccessors());
857 }
858 
859 //===----------------------------------------------------------------------===//
860 // PDLByteCode
861 //===----------------------------------------------------------------------===//
862 
863 PDLByteCode::PDLByteCode(ModuleOp module,
864                          llvm::StringMap<PDLConstraintFunction> constraintFns,
865                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
866   Generator generator(module.getContext(), uniquedData, matcherByteCode,
867                       rewriterByteCode, patterns, maxValueMemoryIndex,
868                       maxTypeRangeCount, maxValueRangeCount, constraintFns,
869                       rewriteFns);
870   generator.generate(module);
871 
872   // Initialize the external functions.
873   for (auto &it : constraintFns)
874     constraintFunctions.push_back(std::move(it.second));
875   for (auto &it : rewriteFns)
876     rewriteFunctions.push_back(std::move(it.second));
877 }
878 
879 /// Initialize the given state such that it can be used to execute the current
880 /// bytecode.
881 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
882   state.memory.resize(maxValueMemoryIndex, nullptr);
883   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
884   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
885   state.currentPatternBenefits.reserve(patterns.size());
886   for (const PDLByteCodePattern &pattern : patterns)
887     state.currentPatternBenefits.push_back(pattern.getBenefit());
888 }
889 
890 //===----------------------------------------------------------------------===//
891 // ByteCode Execution
892 
893 namespace {
894 /// This class provides support for executing a bytecode stream.
895 class ByteCodeExecutor {
896 public:
897   ByteCodeExecutor(
898       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
899       MutableArrayRef<TypeRange> typeRangeMemory,
900       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
901       MutableArrayRef<ValueRange> valueRangeMemory,
902       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
903       ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
904       ArrayRef<PatternBenefit> currentPatternBenefits,
905       ArrayRef<PDLByteCodePattern> patterns,
906       ArrayRef<PDLConstraintFunction> constraintFunctions,
907       ArrayRef<PDLRewriteFunction> rewriteFunctions)
908       : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
909         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
910         valueRangeMemory(valueRangeMemory),
911         allocatedValueRangeMemory(allocatedValueRangeMemory),
912         uniquedMemory(uniquedMemory), code(code),
913         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
914         constraintFunctions(constraintFunctions),
915         rewriteFunctions(rewriteFunctions) {}
916 
917   /// Start executing the code at the current bytecode index. `matches` is an
918   /// optional field provided when this function is executed in a matching
919   /// context.
920   void execute(PatternRewriter &rewriter,
921                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
922                Optional<Location> mainRewriteLoc = {});
923 
924 private:
925   /// Internal implementation of executing each of the bytecode commands.
926   void executeApplyConstraint(PatternRewriter &rewriter);
927   void executeApplyRewrite(PatternRewriter &rewriter);
928   void executeAreEqual();
929   void executeAreRangesEqual();
930   void executeBranch();
931   void executeCheckOperandCount();
932   void executeCheckOperationName();
933   void executeCheckResultCount();
934   void executeCheckTypes();
935   void executeCreateOperation(PatternRewriter &rewriter,
936                               Location mainRewriteLoc);
937   void executeCreateTypes();
938   void executeEraseOp(PatternRewriter &rewriter);
939   void executeGetAttribute();
940   void executeGetAttributeType();
941   void executeGetDefiningOp();
942   void executeGetOperand(unsigned index);
943   void executeGetOperands();
944   void executeGetResult(unsigned index);
945   void executeGetResults();
946   void executeGetValueType();
947   void executeGetValueRangeTypes();
948   void executeIsNotNull();
949   void executeRecordMatch(PatternRewriter &rewriter,
950                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
951   void executeReplaceOp(PatternRewriter &rewriter);
952   void executeSwitchAttribute();
953   void executeSwitchOperandCount();
954   void executeSwitchOperationName();
955   void executeSwitchResultCount();
956   void executeSwitchType();
957   void executeSwitchTypes();
958 
959   /// Read a value from the bytecode buffer, optionally skipping a certain
960   /// number of prefix values. These methods always update the buffer to point
961   /// to the next field after the read data.
962   template <typename T = ByteCodeField>
963   T read(size_t skipN = 0) {
964     curCodeIt += skipN;
965     return readImpl<T>();
966   }
967   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
968 
969   /// Read a list of values from the bytecode buffer.
970   template <typename ValueT, typename T>
971   void readList(SmallVectorImpl<T> &list) {
972     list.clear();
973     for (unsigned i = 0, e = read(); i != e; ++i)
974       list.push_back(read<ValueT>());
975   }
976 
977   /// Read a list of values from the bytecode buffer. The values may be encoded
978   /// as either Value or ValueRange elements.
979   void readValueList(SmallVectorImpl<Value> &list) {
980     for (unsigned i = 0, e = read(); i != e; ++i) {
981       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
982         list.push_back(read<Value>());
983       } else {
984         ValueRange *values = read<ValueRange *>();
985         list.append(values->begin(), values->end());
986       }
987     }
988   }
989 
990   /// Jump to a specific successor based on a predicate value.
991   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
992   /// Jump to a specific successor based on a destination index.
993   void selectJump(size_t destIndex) {
994     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
995   }
996 
997   /// Handle a switch operation with the provided value and cases.
998   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
999   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1000     LLVM_DEBUG({
1001       llvm::dbgs() << "  * Value: " << value << "\n"
1002                    << "  * Cases: ";
1003       llvm::interleaveComma(cases, llvm::dbgs());
1004       llvm::dbgs() << "\n";
1005     });
1006 
1007     // Check to see if the attribute value is within the case list. Jump to
1008     // the correct successor index based on the result.
1009     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1010       if (cmp(*it, value))
1011         return selectJump(size_t((it - cases.begin()) + 1));
1012     selectJump(size_t(0));
1013   }
1014 
1015   /// Internal implementation of reading various data types from the bytecode
1016   /// stream.
1017   template <typename T>
1018   const void *readFromMemory() {
1019     size_t index = *curCodeIt++;
1020 
1021     // If this type is an SSA value, it can only be stored in non-const memory.
1022     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1023                         Value>::value ||
1024         index < memory.size())
1025       return memory[index];
1026 
1027     // Otherwise, if this index is not inbounds it is uniqued.
1028     return uniquedMemory[index - memory.size()];
1029   }
1030   template <typename T>
1031   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1032     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1033   }
1034   template <typename T>
1035   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1036                    T>
1037   readImpl() {
1038     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1039   }
1040   template <typename T>
1041   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1042     switch (read<PDLValue::Kind>()) {
1043     case PDLValue::Kind::Attribute:
1044       return read<Attribute>();
1045     case PDLValue::Kind::Operation:
1046       return read<Operation *>();
1047     case PDLValue::Kind::Type:
1048       return read<Type>();
1049     case PDLValue::Kind::Value:
1050       return read<Value>();
1051     case PDLValue::Kind::TypeRange:
1052       return read<TypeRange *>();
1053     case PDLValue::Kind::ValueRange:
1054       return read<ValueRange *>();
1055     }
1056     llvm_unreachable("unhandled PDLValue::Kind");
1057   }
1058   template <typename T>
1059   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1060     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1061                   "unexpected ByteCode address size");
1062     ByteCodeAddr result;
1063     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1064     curCodeIt += 2;
1065     return result;
1066   }
1067   template <typename T>
1068   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1069     return *curCodeIt++;
1070   }
1071   template <typename T>
1072   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1073     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1074   }
1075 
1076   /// The underlying bytecode buffer.
1077   const ByteCodeField *curCodeIt;
1078 
1079   /// The current execution memory.
1080   MutableArrayRef<const void *> memory;
1081   MutableArrayRef<TypeRange> typeRangeMemory;
1082   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1083   MutableArrayRef<ValueRange> valueRangeMemory;
1084   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1085 
1086   /// References to ByteCode data necessary for execution.
1087   ArrayRef<const void *> uniquedMemory;
1088   ArrayRef<ByteCodeField> code;
1089   ArrayRef<PatternBenefit> currentPatternBenefits;
1090   ArrayRef<PDLByteCodePattern> patterns;
1091   ArrayRef<PDLConstraintFunction> constraintFunctions;
1092   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1093 };
1094 
1095 /// This class is an instantiation of the PDLResultList that provides access to
1096 /// the returned results. This API is not on `PDLResultList` to avoid
1097 /// overexposing access to information specific solely to the ByteCode.
1098 class ByteCodeRewriteResultList : public PDLResultList {
1099 public:
1100   ByteCodeRewriteResultList(unsigned maxNumResults)
1101       : PDLResultList(maxNumResults) {}
1102 
1103   /// Return the list of PDL results.
1104   MutableArrayRef<PDLValue> getResults() { return results; }
1105 
1106   /// Return the type ranges allocated by this list.
1107   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1108     return allocatedTypeRanges;
1109   }
1110 
1111   /// Return the value ranges allocated by this list.
1112   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1113     return allocatedValueRanges;
1114   }
1115 };
1116 } // end anonymous namespace
1117 
1118 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1119   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1120   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1121   ArrayAttr constParams = read<ArrayAttr>();
1122   SmallVector<PDLValue, 16> args;
1123   readList<PDLValue>(args);
1124 
1125   LLVM_DEBUG({
1126     llvm::dbgs() << "  * Arguments: ";
1127     llvm::interleaveComma(args, llvm::dbgs());
1128     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1129   });
1130 
1131   // Invoke the constraint and jump to the proper destination.
1132   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1133 }
1134 
1135 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1136   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1137   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1138   ArrayAttr constParams = read<ArrayAttr>();
1139   SmallVector<PDLValue, 16> args;
1140   readList<PDLValue>(args);
1141 
1142   LLVM_DEBUG({
1143     llvm::dbgs() << "  * Arguments: ";
1144     llvm::interleaveComma(args, llvm::dbgs());
1145     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1146   });
1147 
1148   // Execute the rewrite function.
1149   ByteCodeField numResults = read();
1150   ByteCodeRewriteResultList results(numResults);
1151   rewriteFn(args, constParams, rewriter, results);
1152 
1153   assert(results.getResults().size() == numResults &&
1154          "native PDL rewrite function returned unexpected number of results");
1155 
1156   // Store the results in the bytecode memory.
1157   for (PDLValue &result : results.getResults()) {
1158     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1159 
1160 // In debug mode we also verify the expected kind of the result.
1161 #ifndef NDEBUG
1162     assert(result.getKind() == read<PDLValue::Kind>() &&
1163            "native PDL rewrite function returned an unexpected type of result");
1164 #endif
1165 
1166     // If the result is a range, we need to copy it over to the bytecodes
1167     // range memory.
1168     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1169       unsigned rangeIndex = read();
1170       typeRangeMemory[rangeIndex] = *typeRange;
1171       memory[read()] = &typeRangeMemory[rangeIndex];
1172     } else if (Optional<ValueRange> valueRange =
1173                    result.dyn_cast<ValueRange>()) {
1174       unsigned rangeIndex = read();
1175       valueRangeMemory[rangeIndex] = *valueRange;
1176       memory[read()] = &valueRangeMemory[rangeIndex];
1177     } else {
1178       memory[read()] = result.getAsOpaquePointer();
1179     }
1180   }
1181 
1182   // Copy over any underlying storage allocated for result ranges.
1183   for (auto &it : results.getAllocatedTypeRanges())
1184     allocatedTypeRangeMemory.push_back(std::move(it));
1185   for (auto &it : results.getAllocatedValueRanges())
1186     allocatedValueRangeMemory.push_back(std::move(it));
1187 }
1188 
1189 void ByteCodeExecutor::executeAreEqual() {
1190   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1191   const void *lhs = read<const void *>();
1192   const void *rhs = read<const void *>();
1193 
1194   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1195   selectJump(lhs == rhs);
1196 }
1197 
1198 void ByteCodeExecutor::executeAreRangesEqual() {
1199   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1200   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1201   const void *lhs = read<const void *>();
1202   const void *rhs = read<const void *>();
1203 
1204   switch (valueKind) {
1205   case PDLValue::Kind::TypeRange: {
1206     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1207     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1208     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1209     selectJump(*lhsRange == *rhsRange);
1210     break;
1211   }
1212   case PDLValue::Kind::ValueRange: {
1213     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1214     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1215     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1216     selectJump(*lhsRange == *rhsRange);
1217     break;
1218   }
1219   default:
1220     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1221   }
1222 }
1223 
1224 void ByteCodeExecutor::executeBranch() {
1225   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1226   curCodeIt = &code[read<ByteCodeAddr>()];
1227 }
1228 
1229 void ByteCodeExecutor::executeCheckOperandCount() {
1230   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1231   Operation *op = read<Operation *>();
1232   uint32_t expectedCount = read<uint32_t>();
1233   bool compareAtLeast = read();
1234 
1235   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1236                           << "  * Expected: " << expectedCount << "\n"
1237                           << "  * Comparator: "
1238                           << (compareAtLeast ? ">=" : "==") << "\n");
1239   if (compareAtLeast)
1240     selectJump(op->getNumOperands() >= expectedCount);
1241   else
1242     selectJump(op->getNumOperands() == expectedCount);
1243 }
1244 
1245 void ByteCodeExecutor::executeCheckOperationName() {
1246   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1247   Operation *op = read<Operation *>();
1248   OperationName expectedName = read<OperationName>();
1249 
1250   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1251                           << "  * Expected: \"" << expectedName << "\"\n");
1252   selectJump(op->getName() == expectedName);
1253 }
1254 
1255 void ByteCodeExecutor::executeCheckResultCount() {
1256   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1257   Operation *op = read<Operation *>();
1258   uint32_t expectedCount = read<uint32_t>();
1259   bool compareAtLeast = read();
1260 
1261   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1262                           << "  * Expected: " << expectedCount << "\n"
1263                           << "  * Comparator: "
1264                           << (compareAtLeast ? ">=" : "==") << "\n");
1265   if (compareAtLeast)
1266     selectJump(op->getNumResults() >= expectedCount);
1267   else
1268     selectJump(op->getNumResults() == expectedCount);
1269 }
1270 
1271 void ByteCodeExecutor::executeCheckTypes() {
1272   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1273   TypeRange *lhs = read<TypeRange *>();
1274   Attribute rhs = read<Attribute>();
1275   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1276 
1277   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1278 }
1279 
1280 void ByteCodeExecutor::executeCreateTypes() {
1281   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1282   unsigned memIndex = read();
1283   unsigned rangeIndex = read();
1284   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1285 
1286   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1287 
1288   // Allocate a buffer for this type range.
1289   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1290   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1291   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1292 
1293   // Assign this to the range slot and use the range as the value for the
1294   // memory index.
1295   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1296   memory[memIndex] = &typeRangeMemory[rangeIndex];
1297 }
1298 
1299 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1300                                               Location mainRewriteLoc) {
1301   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1302 
1303   unsigned memIndex = read();
1304   OperationState state(mainRewriteLoc, read<OperationName>());
1305   readValueList(state.operands);
1306   for (unsigned i = 0, e = read(); i != e; ++i) {
1307     StringAttr name = read<StringAttr>();
1308     if (Attribute attr = read<Attribute>())
1309       state.addAttribute(name, attr);
1310   }
1311 
1312   for (unsigned i = 0, e = read(); i != e; ++i) {
1313     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1314       state.types.push_back(read<Type>());
1315       continue;
1316     }
1317 
1318     // If we find a null range, this signals that the types are infered.
1319     if (TypeRange *resultTypes = read<TypeRange *>()) {
1320       state.types.append(resultTypes->begin(), resultTypes->end());
1321       continue;
1322     }
1323 
1324     // Handle the case where the operation has inferred types.
1325     InferTypeOpInterface::Concept *concept =
1326         state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
1327 
1328     // TODO: Handle failure.
1329     state.types.clear();
1330     if (failed(concept->inferReturnTypes(
1331             state.getContext(), state.location, state.operands,
1332             state.attributes.getDictionary(state.getContext()), state.regions,
1333             state.types)))
1334       return;
1335     break;
1336   }
1337 
1338   Operation *resultOp = rewriter.createOperation(state);
1339   memory[memIndex] = resultOp;
1340 
1341   LLVM_DEBUG({
1342     llvm::dbgs() << "  * Attributes: "
1343                  << state.attributes.getDictionary(state.getContext())
1344                  << "\n  * Operands: ";
1345     llvm::interleaveComma(state.operands, llvm::dbgs());
1346     llvm::dbgs() << "\n  * Result Types: ";
1347     llvm::interleaveComma(state.types, llvm::dbgs());
1348     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1349   });
1350 }
1351 
1352 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1353   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1354   Operation *op = read<Operation *>();
1355 
1356   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1357   rewriter.eraseOp(op);
1358 }
1359 
1360 void ByteCodeExecutor::executeGetAttribute() {
1361   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1362   unsigned memIndex = read();
1363   Operation *op = read<Operation *>();
1364   StringAttr attrName = read<StringAttr>();
1365   Attribute attr = op->getAttr(attrName);
1366 
1367   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1368                           << "  * Attribute: " << attrName << "\n"
1369                           << "  * Result: " << attr << "\n");
1370   memory[memIndex] = attr.getAsOpaquePointer();
1371 }
1372 
1373 void ByteCodeExecutor::executeGetAttributeType() {
1374   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1375   unsigned memIndex = read();
1376   Attribute attr = read<Attribute>();
1377   Type type = attr ? attr.getType() : Type();
1378 
1379   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1380                           << "  * Result: " << type << "\n");
1381   memory[memIndex] = type.getAsOpaquePointer();
1382 }
1383 
1384 void ByteCodeExecutor::executeGetDefiningOp() {
1385   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1386   unsigned memIndex = read();
1387   Operation *op = nullptr;
1388   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1389     Value value = read<Value>();
1390     if (value)
1391       op = value.getDefiningOp();
1392     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1393   } else {
1394     ValueRange *values = read<ValueRange *>();
1395     if (values && !values->empty()) {
1396       op = values->front().getDefiningOp();
1397     }
1398     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1399   }
1400 
1401   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1402   memory[memIndex] = op;
1403 }
1404 
1405 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1406   Operation *op = read<Operation *>();
1407   unsigned memIndex = read();
1408   Value operand =
1409       index < op->getNumOperands() ? op->getOperand(index) : Value();
1410 
1411   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1412                           << "  * Index: " << index << "\n"
1413                           << "  * Result: " << operand << "\n");
1414   memory[memIndex] = operand.getAsOpaquePointer();
1415 }
1416 
1417 /// This function is the internal implementation of `GetResults` and
1418 /// `GetOperands` that provides support for extracting a value range from the
1419 /// given operation.
1420 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1421 static void *
1422 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1423                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1424                           MutableArrayRef<ValueRange> &valueRangeMemory) {
1425   // Check for the sentinel index that signals that all values should be
1426   // returned.
1427   if (index == std::numeric_limits<uint32_t>::max()) {
1428     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1429     // `values` is already the full value range.
1430 
1431     // Otherwise, check to see if this operation uses AttrSizedSegments.
1432   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1433     LLVM_DEBUG(llvm::dbgs()
1434                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1435 
1436     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1437     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1438       return nullptr;
1439 
1440     auto segments = segmentAttr.getValues<int32_t>();
1441     unsigned startIndex =
1442         std::accumulate(segments.begin(), segments.begin() + index, 0);
1443     values = values.slice(startIndex, *std::next(segments.begin(), index));
1444 
1445     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1446                             << *std::next(segments.begin(), index) << "]\n");
1447 
1448     // Otherwise, assume this is the last operand group of the operation.
1449     // FIXME: We currently don't support operations with
1450     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1451     // have a way to detect it's presence.
1452   } else if (values.size() >= index) {
1453     LLVM_DEBUG(llvm::dbgs()
1454                << "  * Treating values as trailing variadic range\n");
1455     values = values.drop_front(index);
1456 
1457     // If we couldn't detect a way to compute the values, bail out.
1458   } else {
1459     return nullptr;
1460   }
1461 
1462   // If the range index is valid, we are returning a range.
1463   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1464     valueRangeMemory[rangeIndex] = values;
1465     return &valueRangeMemory[rangeIndex];
1466   }
1467 
1468   // If a range index wasn't provided, the range is required to be non-variadic.
1469   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1470 }
1471 
1472 void ByteCodeExecutor::executeGetOperands() {
1473   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1474   unsigned index = read<uint32_t>();
1475   Operation *op = read<Operation *>();
1476   ByteCodeField rangeIndex = read();
1477 
1478   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1479       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1480       valueRangeMemory);
1481   if (!result)
1482     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1483   memory[read()] = result;
1484 }
1485 
1486 void ByteCodeExecutor::executeGetResult(unsigned index) {
1487   Operation *op = read<Operation *>();
1488   unsigned memIndex = read();
1489   OpResult result =
1490       index < op->getNumResults() ? op->getResult(index) : OpResult();
1491 
1492   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1493                           << "  * Index: " << index << "\n"
1494                           << "  * Result: " << result << "\n");
1495   memory[memIndex] = result.getAsOpaquePointer();
1496 }
1497 
1498 void ByteCodeExecutor::executeGetResults() {
1499   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1500   unsigned index = read<uint32_t>();
1501   Operation *op = read<Operation *>();
1502   ByteCodeField rangeIndex = read();
1503 
1504   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1505       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1506       valueRangeMemory);
1507   if (!result)
1508     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1509   memory[read()] = result;
1510 }
1511 
1512 void ByteCodeExecutor::executeGetValueType() {
1513   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1514   unsigned memIndex = read();
1515   Value value = read<Value>();
1516   Type type = value ? value.getType() : Type();
1517 
1518   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1519                           << "  * Result: " << type << "\n");
1520   memory[memIndex] = type.getAsOpaquePointer();
1521 }
1522 
1523 void ByteCodeExecutor::executeGetValueRangeTypes() {
1524   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1525   unsigned memIndex = read();
1526   unsigned rangeIndex = read();
1527   ValueRange *values = read<ValueRange *>();
1528   if (!values) {
1529     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1530     memory[memIndex] = nullptr;
1531     return;
1532   }
1533 
1534   LLVM_DEBUG({
1535     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1536     llvm::interleaveComma(*values, llvm::dbgs());
1537     llvm::dbgs() << "\n  * Result: ";
1538     llvm::interleaveComma(values->getType(), llvm::dbgs());
1539     llvm::dbgs() << "\n";
1540   });
1541   typeRangeMemory[rangeIndex] = values->getType();
1542   memory[memIndex] = &typeRangeMemory[rangeIndex];
1543 }
1544 
1545 void ByteCodeExecutor::executeIsNotNull() {
1546   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1547   const void *value = read<const void *>();
1548 
1549   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1550   selectJump(value != nullptr);
1551 }
1552 
1553 void ByteCodeExecutor::executeRecordMatch(
1554     PatternRewriter &rewriter,
1555     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1556   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1557   unsigned patternIndex = read();
1558   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1559   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1560 
1561   // If the benefit of the pattern is impossible, skip the processing of the
1562   // rest of the pattern.
1563   if (benefit.isImpossibleToMatch()) {
1564     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1565     curCodeIt = dest;
1566     return;
1567   }
1568 
1569   // Create a fused location containing the locations of each of the
1570   // operations used in the match. This will be used as the location for
1571   // created operations during the rewrite that don't already have an
1572   // explicit location set.
1573   unsigned numMatchLocs = read();
1574   SmallVector<Location, 4> matchLocs;
1575   matchLocs.reserve(numMatchLocs);
1576   for (unsigned i = 0; i != numMatchLocs; ++i)
1577     matchLocs.push_back(read<Operation *>()->getLoc());
1578   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1579 
1580   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1581                           << "  * Location: " << matchLoc << "\n");
1582   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1583   PDLByteCode::MatchResult &match = matches.back();
1584 
1585   // Record all of the inputs to the match. If any of the inputs are ranges, we
1586   // will also need to remap the range pointer to memory stored in the match
1587   // state.
1588   unsigned numInputs = read();
1589   match.values.reserve(numInputs);
1590   match.typeRangeValues.reserve(numInputs);
1591   match.valueRangeValues.reserve(numInputs);
1592   for (unsigned i = 0; i < numInputs; ++i) {
1593     switch (read<PDLValue::Kind>()) {
1594     case PDLValue::Kind::TypeRange:
1595       match.typeRangeValues.push_back(*read<TypeRange *>());
1596       match.values.push_back(&match.typeRangeValues.back());
1597       break;
1598     case PDLValue::Kind::ValueRange:
1599       match.valueRangeValues.push_back(*read<ValueRange *>());
1600       match.values.push_back(&match.valueRangeValues.back());
1601       break;
1602     default:
1603       match.values.push_back(read<const void *>());
1604       break;
1605     }
1606   }
1607   curCodeIt = dest;
1608 }
1609 
1610 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1611   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1612   Operation *op = read<Operation *>();
1613   SmallVector<Value, 16> args;
1614   readValueList(args);
1615 
1616   LLVM_DEBUG({
1617     llvm::dbgs() << "  * Operation: " << *op << "\n"
1618                  << "  * Values: ";
1619     llvm::interleaveComma(args, llvm::dbgs());
1620     llvm::dbgs() << "\n";
1621   });
1622   rewriter.replaceOp(op, args);
1623 }
1624 
1625 void ByteCodeExecutor::executeSwitchAttribute() {
1626   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1627   Attribute value = read<Attribute>();
1628   ArrayAttr cases = read<ArrayAttr>();
1629   handleSwitch(value, cases);
1630 }
1631 
1632 void ByteCodeExecutor::executeSwitchOperandCount() {
1633   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1634   Operation *op = read<Operation *>();
1635   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1636 
1637   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1638   handleSwitch(op->getNumOperands(), cases);
1639 }
1640 
1641 void ByteCodeExecutor::executeSwitchOperationName() {
1642   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1643   OperationName value = read<Operation *>()->getName();
1644   size_t caseCount = read();
1645 
1646   // The operation names are stored in-line, so to print them out for
1647   // debugging purposes we need to read the array before executing the
1648   // switch so that we can display all of the possible values.
1649   LLVM_DEBUG({
1650     const ByteCodeField *prevCodeIt = curCodeIt;
1651     llvm::dbgs() << "  * Value: " << value << "\n"
1652                  << "  * Cases: ";
1653     llvm::interleaveComma(
1654         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1655                         [&](size_t) { return read<OperationName>(); }),
1656         llvm::dbgs());
1657     llvm::dbgs() << "\n";
1658     curCodeIt = prevCodeIt;
1659   });
1660 
1661   // Try to find the switch value within any of the cases.
1662   for (size_t i = 0; i != caseCount; ++i) {
1663     if (read<OperationName>() == value) {
1664       curCodeIt += (caseCount - i - 1);
1665       return selectJump(i + 1);
1666     }
1667   }
1668   selectJump(size_t(0));
1669 }
1670 
1671 void ByteCodeExecutor::executeSwitchResultCount() {
1672   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1673   Operation *op = read<Operation *>();
1674   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1675 
1676   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1677   handleSwitch(op->getNumResults(), cases);
1678 }
1679 
1680 void ByteCodeExecutor::executeSwitchType() {
1681   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1682   Type value = read<Type>();
1683   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1684   handleSwitch(value, cases);
1685 }
1686 
1687 void ByteCodeExecutor::executeSwitchTypes() {
1688   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
1689   TypeRange *value = read<TypeRange *>();
1690   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
1691   if (!value) {
1692     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
1693     return selectJump(size_t(0));
1694   }
1695   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
1696     return value == caseValue.getAsValueRange<TypeAttr>();
1697   });
1698 }
1699 
1700 void ByteCodeExecutor::execute(
1701     PatternRewriter &rewriter,
1702     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1703     Optional<Location> mainRewriteLoc) {
1704   while (true) {
1705     OpCode opCode = static_cast<OpCode>(read());
1706     switch (opCode) {
1707     case ApplyConstraint:
1708       executeApplyConstraint(rewriter);
1709       break;
1710     case ApplyRewrite:
1711       executeApplyRewrite(rewriter);
1712       break;
1713     case AreEqual:
1714       executeAreEqual();
1715       break;
1716     case AreRangesEqual:
1717       executeAreRangesEqual();
1718       break;
1719     case Branch:
1720       executeBranch();
1721       break;
1722     case CheckOperandCount:
1723       executeCheckOperandCount();
1724       break;
1725     case CheckOperationName:
1726       executeCheckOperationName();
1727       break;
1728     case CheckResultCount:
1729       executeCheckResultCount();
1730       break;
1731     case CheckTypes:
1732       executeCheckTypes();
1733       break;
1734     case CreateOperation:
1735       executeCreateOperation(rewriter, *mainRewriteLoc);
1736       break;
1737     case CreateTypes:
1738       executeCreateTypes();
1739       break;
1740     case EraseOp:
1741       executeEraseOp(rewriter);
1742       break;
1743     case Finalize:
1744       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1745       return;
1746     case GetAttribute:
1747       executeGetAttribute();
1748       break;
1749     case GetAttributeType:
1750       executeGetAttributeType();
1751       break;
1752     case GetDefiningOp:
1753       executeGetDefiningOp();
1754       break;
1755     case GetOperand0:
1756     case GetOperand1:
1757     case GetOperand2:
1758     case GetOperand3: {
1759       unsigned index = opCode - GetOperand0;
1760       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
1761       executeGetOperand(index);
1762       break;
1763     }
1764     case GetOperandN:
1765       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
1766       executeGetOperand(read<uint32_t>());
1767       break;
1768     case GetOperands:
1769       executeGetOperands();
1770       break;
1771     case GetResult0:
1772     case GetResult1:
1773     case GetResult2:
1774     case GetResult3: {
1775       unsigned index = opCode - GetResult0;
1776       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
1777       executeGetResult(index);
1778       break;
1779     }
1780     case GetResultN:
1781       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
1782       executeGetResult(read<uint32_t>());
1783       break;
1784     case GetResults:
1785       executeGetResults();
1786       break;
1787     case GetValueType:
1788       executeGetValueType();
1789       break;
1790     case GetValueRangeTypes:
1791       executeGetValueRangeTypes();
1792       break;
1793     case IsNotNull:
1794       executeIsNotNull();
1795       break;
1796     case RecordMatch:
1797       assert(matches &&
1798              "expected matches to be provided when executing the matcher");
1799       executeRecordMatch(rewriter, *matches);
1800       break;
1801     case ReplaceOp:
1802       executeReplaceOp(rewriter);
1803       break;
1804     case SwitchAttribute:
1805       executeSwitchAttribute();
1806       break;
1807     case SwitchOperandCount:
1808       executeSwitchOperandCount();
1809       break;
1810     case SwitchOperationName:
1811       executeSwitchOperationName();
1812       break;
1813     case SwitchResultCount:
1814       executeSwitchResultCount();
1815       break;
1816     case SwitchType:
1817       executeSwitchType();
1818       break;
1819     case SwitchTypes:
1820       executeSwitchTypes();
1821       break;
1822     }
1823     LLVM_DEBUG(llvm::dbgs() << "\n");
1824   }
1825 }
1826 
1827 /// Run the pattern matcher on the given root operation, collecting the matched
1828 /// patterns in `matches`.
1829 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1830                         SmallVectorImpl<MatchResult> &matches,
1831                         PDLByteCodeMutableState &state) const {
1832   // The first memory slot is always the root operation.
1833   state.memory[0] = op;
1834 
1835   // The matcher function always starts at code address 0.
1836   ByteCodeExecutor executor(
1837       matcherByteCode.data(), state.memory, state.typeRangeMemory,
1838       state.allocatedTypeRangeMemory, state.valueRangeMemory,
1839       state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
1840       state.currentPatternBenefits, patterns, constraintFunctions,
1841       rewriteFunctions);
1842   executor.execute(rewriter, &matches);
1843 
1844   // Order the found matches by benefit.
1845   std::stable_sort(matches.begin(), matches.end(),
1846                    [](const MatchResult &lhs, const MatchResult &rhs) {
1847                      return lhs.benefit > rhs.benefit;
1848                    });
1849 }
1850 
1851 /// Run the rewriter of the given pattern on the root operation `op`.
1852 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1853                           PDLByteCodeMutableState &state) const {
1854   // The arguments of the rewrite function are stored at the start of the
1855   // memory buffer.
1856   llvm::copy(match.values, state.memory.begin());
1857 
1858   ByteCodeExecutor executor(
1859       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1860       state.typeRangeMemory, state.allocatedTypeRangeMemory,
1861       state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
1862       rewriterByteCode, state.currentPatternBenefits, patterns,
1863       constraintFunctions, rewriteFunctions);
1864   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1865 }
1866