1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "pdl-bytecode"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // PDLByteCodePattern
31 //===----------------------------------------------------------------------===//
32 
33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
34                                               ByteCodeAddr rewriterAddr) {
35   SmallVector<StringRef, 8> generatedOps;
36   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
37     generatedOps =
38         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
39 
40   PatternBenefit benefit = matchOp.benefit();
41   MLIRContext *ctx = matchOp.getContext();
42 
43   // Check to see if this is pattern matches a specific operation type.
44   if (Optional<StringRef> rootKind = matchOp.rootKind())
45     return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
46                               ctx);
47   return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
48                             MatchAnyOpTypeTag());
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // PDLByteCodeMutableState
53 //===----------------------------------------------------------------------===//
54 
55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
56 /// to the position of the pattern within the range returned by
57 /// `PDLByteCode::getPatterns`.
58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
59                                                    PatternBenefit benefit) {
60   currentPatternBenefits[patternIndex] = benefit;
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Bytecode OpCodes
65 //===----------------------------------------------------------------------===//
66 
67 namespace {
68 enum OpCode : ByteCodeField {
69   /// Apply an externally registered constraint.
70   ApplyConstraint,
71   /// Apply an externally registered rewrite.
72   ApplyRewrite,
73   /// Check if two generic values are equal.
74   AreEqual,
75   /// Unconditional branch.
76   Branch,
77   /// Compare the operand count of an operation with a constant.
78   CheckOperandCount,
79   /// Compare the name of an operation with a constant.
80   CheckOperationName,
81   /// Compare the result count of an operation with a constant.
82   CheckResultCount,
83   /// Create an operation.
84   CreateOperation,
85   /// Erase an operation.
86   EraseOp,
87   /// Terminate a matcher or rewrite sequence.
88   Finalize,
89   /// Get a specific attribute of an operation.
90   GetAttribute,
91   /// Get the type of an attribute.
92   GetAttributeType,
93   /// Get the defining operation of a value.
94   GetDefiningOp,
95   /// Get a specific operand of an operation.
96   GetOperand0,
97   GetOperand1,
98   GetOperand2,
99   GetOperand3,
100   GetOperandN,
101   /// Get a specific result of an operation.
102   GetResult0,
103   GetResult1,
104   GetResult2,
105   GetResult3,
106   GetResultN,
107   /// Get the type of a value.
108   GetValueType,
109   /// Check if a generic value is not null.
110   IsNotNull,
111   /// Record a successful pattern match.
112   RecordMatch,
113   /// Replace an operation.
114   ReplaceOp,
115   /// Compare an attribute with a set of constants.
116   SwitchAttribute,
117   /// Compare the operand count of an operation with a set of constants.
118   SwitchOperandCount,
119   /// Compare the name of an operation with a set of constants.
120   SwitchOperationName,
121   /// Compare the result count of an operation with a set of constants.
122   SwitchResultCount,
123   /// Compare a type with a set of constants.
124   SwitchType,
125 };
126 
127 enum class PDLValueKind { Attribute, Operation, Type, Value };
128 } // end anonymous namespace
129 
130 //===----------------------------------------------------------------------===//
131 // ByteCode Generation
132 //===----------------------------------------------------------------------===//
133 
134 //===----------------------------------------------------------------------===//
135 // Generator
136 
137 namespace {
138 struct ByteCodeWriter;
139 
140 /// This class represents the main generator for the pattern bytecode.
141 class Generator {
142 public:
143   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
144             SmallVectorImpl<ByteCodeField> &matcherByteCode,
145             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
146             SmallVectorImpl<PDLByteCodePattern> &patterns,
147             ByteCodeField &maxValueMemoryIndex,
148             llvm::StringMap<PDLConstraintFunction> &constraintFns,
149             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
150       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
151         rewriterByteCode(rewriterByteCode), patterns(patterns),
152         maxValueMemoryIndex(maxValueMemoryIndex) {
153     for (auto it : llvm::enumerate(constraintFns))
154       constraintToMemIndex.try_emplace(it.value().first(), it.index());
155     for (auto it : llvm::enumerate(rewriteFns))
156       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
157   }
158 
159   /// Generate the bytecode for the given PDL interpreter module.
160   void generate(ModuleOp module);
161 
162   /// Return the memory index to use for the given value.
163   ByteCodeField &getMemIndex(Value value) {
164     assert(valueToMemIndex.count(value) &&
165            "expected memory index to be assigned");
166     return valueToMemIndex[value];
167   }
168 
169   /// Return an index to use when referring to the given data that is uniqued in
170   /// the MLIR context.
171   template <typename T>
172   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
173   getMemIndex(T val) {
174     const void *opaqueVal = val.getAsOpaquePointer();
175 
176     // Get or insert a reference to this value.
177     auto it = uniquedDataToMemIndex.try_emplace(
178         opaqueVal, maxValueMemoryIndex + uniquedData.size());
179     if (it.second)
180       uniquedData.push_back(opaqueVal);
181     return it.first->second;
182   }
183 
184 private:
185   /// Allocate memory indices for the results of operations within the matcher
186   /// and rewriters.
187   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
188 
189   /// Generate the bytecode for the given operation.
190   void generate(Operation *op, ByteCodeWriter &writer);
191   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
192   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
193   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
194   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
195   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
196   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
197   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
198   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
199   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
200   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
201   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
202   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
203   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
204   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
205   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
206   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
207   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
208   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
209   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
210   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
211   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
212   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
213   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
214   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
215   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
216   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
217   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
218   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
219   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
220 
221   /// Mapping from value to its corresponding memory index.
222   DenseMap<Value, ByteCodeField> valueToMemIndex;
223 
224   /// Mapping from the name of an externally registered rewrite to its index in
225   /// the bytecode registry.
226   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
227 
228   /// Mapping from the name of an externally registered constraint to its index
229   /// in the bytecode registry.
230   llvm::StringMap<ByteCodeField> constraintToMemIndex;
231 
232   /// Mapping from rewriter function name to the bytecode address of the
233   /// rewriter function in byte.
234   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
235 
236   /// Mapping from a uniqued storage object to its memory index within
237   /// `uniquedData`.
238   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
239 
240   /// The current MLIR context.
241   MLIRContext *ctx;
242 
243   /// Data of the ByteCode class to be populated.
244   std::vector<const void *> &uniquedData;
245   SmallVectorImpl<ByteCodeField> &matcherByteCode;
246   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
247   SmallVectorImpl<PDLByteCodePattern> &patterns;
248   ByteCodeField &maxValueMemoryIndex;
249 };
250 
251 /// This class provides utilities for writing a bytecode stream.
252 struct ByteCodeWriter {
253   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
254       : bytecode(bytecode), generator(generator) {}
255 
256   /// Append a field to the bytecode.
257   void append(ByteCodeField field) { bytecode.push_back(field); }
258   void append(OpCode opCode) { bytecode.push_back(opCode); }
259 
260   /// Append an address to the bytecode.
261   void append(ByteCodeAddr field) {
262     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
263                   "unexpected ByteCode address size");
264 
265     ByteCodeField fieldParts[2];
266     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
267     bytecode.append({fieldParts[0], fieldParts[1]});
268   }
269 
270   /// Append a successor range to the bytecode, the exact address will need to
271   /// be resolved later.
272   void append(SuccessorRange successors) {
273     // Add back references to the any successors so that the address can be
274     // resolved later.
275     for (Block *successor : successors) {
276       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
277       append(ByteCodeAddr(0));
278     }
279   }
280 
281   /// Append a range of values that will be read as generic PDLValues.
282   void appendPDLValueList(OperandRange values) {
283     bytecode.push_back(values.size());
284     for (Value value : values) {
285       // Append the type of the value in addition to the value itself.
286       PDLValueKind kind =
287           TypeSwitch<Type, PDLValueKind>(value.getType())
288               .Case<pdl::AttributeType>(
289                   [](Type) { return PDLValueKind::Attribute; })
290               .Case<pdl::OperationType>(
291                   [](Type) { return PDLValueKind::Operation; })
292               .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
293               .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
294       bytecode.push_back(static_cast<ByteCodeField>(kind));
295       append(value);
296     }
297   }
298 
299   /// Check if the given class `T` has an iterator type.
300   template <typename T, typename... Args>
301   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
302 
303   /// Append a value that will be stored in a memory slot and not inline within
304   /// the bytecode.
305   template <typename T>
306   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
307                    std::is_pointer<T>::value>
308   append(T value) {
309     bytecode.push_back(generator.getMemIndex(value));
310   }
311 
312   /// Append a range of values.
313   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
314   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
315   append(T range) {
316     bytecode.push_back(llvm::size(range));
317     for (auto it : range)
318       append(it);
319   }
320 
321   /// Append a variadic number of fields to the bytecode.
322   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
323   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
324     append(field);
325     append(field2, fields...);
326   }
327 
328   /// Successor references in the bytecode that have yet to be resolved.
329   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
330 
331   /// The underlying bytecode buffer.
332   SmallVectorImpl<ByteCodeField> &bytecode;
333 
334   /// The main generator producing PDL.
335   Generator &generator;
336 };
337 } // end anonymous namespace
338 
339 void Generator::generate(ModuleOp module) {
340   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
341       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
342   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
343       pdl_interp::PDLInterpDialect::getRewriterModuleName());
344   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
345 
346   // Allocate memory indices for the results of operations within the matcher
347   // and rewriters.
348   allocateMemoryIndices(matcherFunc, rewriterModule);
349 
350   // Generate code for the rewriter functions.
351   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
352   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
353     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
354     for (Operation &op : rewriterFunc.getOps())
355       generate(&op, rewriterByteCodeWriter);
356   }
357   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
358          "unexpected branches in rewriter function");
359 
360   // Generate code for the matcher function.
361   DenseMap<Block *, ByteCodeAddr> blockToAddr;
362   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
363   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
364   for (Block *block : rpot) {
365     // Keep track of where this block begins within the matcher function.
366     blockToAddr.try_emplace(block, matcherByteCode.size());
367     for (Operation &op : *block)
368       generate(&op, matcherByteCodeWriter);
369   }
370 
371   // Resolve successor references in the matcher.
372   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
373     ByteCodeAddr addr = blockToAddr[it.first];
374     for (unsigned offsetToFix : it.second)
375       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
376   }
377 }
378 
379 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
380                                       ModuleOp rewriterModule) {
381   // Rewriters use simplistic allocation scheme that simply assigns an index to
382   // each result.
383   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
384     ByteCodeField index = 0;
385     for (BlockArgument arg : rewriterFunc.getArguments())
386       valueToMemIndex.try_emplace(arg, index++);
387     rewriterFunc.getBody().walk([&](Operation *op) {
388       for (Value result : op->getResults())
389         valueToMemIndex.try_emplace(result, index++);
390     });
391     if (index > maxValueMemoryIndex)
392       maxValueMemoryIndex = index;
393   }
394 
395   // The matcher function uses a more sophisticated numbering that tries to
396   // minimize the number of memory indices assigned. This is done by determining
397   // a live range of the values within the matcher, then the allocation is just
398   // finding the minimal number of overlapping live ranges. This is essentially
399   // a simplified form of register allocation where we don't necessarily have a
400   // limited number of registers, but we still want to minimize the number used.
401   DenseMap<Operation *, ByteCodeField> opToIndex;
402   matcherFunc.getBody().walk([&](Operation *op) {
403     opToIndex.insert(std::make_pair(op, opToIndex.size()));
404   });
405 
406   // Liveness info for each of the defs within the matcher.
407   using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
408   LivenessSet::Allocator allocator;
409   DenseMap<Value, LivenessSet> valueDefRanges;
410 
411   // Assign the root operation being matched to slot 0.
412   BlockArgument rootOpArg = matcherFunc.getArgument(0);
413   valueToMemIndex[rootOpArg] = 0;
414 
415   // Walk each of the blocks, computing the def interval that the value is used.
416   Liveness matcherLiveness(matcherFunc);
417   for (Block &block : matcherFunc.getBody()) {
418     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
419     assert(info && "expected liveness info for block");
420     auto processValue = [&](Value value, Operation *firstUseOrDef) {
421       // We don't need to process the root op argument, this value is always
422       // assigned to the first memory slot.
423       if (value == rootOpArg)
424         return;
425 
426       // Set indices for the range of this block that the value is used.
427       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
428       defRangeIt->second.insert(
429           opToIndex[firstUseOrDef],
430           opToIndex[info->getEndOperation(value, firstUseOrDef)],
431           /*dummyValue*/ 0);
432     };
433 
434     // Process the live-ins of this block.
435     for (Value liveIn : info->in())
436       processValue(liveIn, &block.front());
437 
438     // Process any new defs within this block.
439     for (Operation &op : block)
440       for (Value result : op.getResults())
441         processValue(result, &op);
442   }
443 
444   // Greedily allocate memory slots using the computed def live ranges.
445   std::vector<LivenessSet> allocatedIndices;
446   for (auto &defIt : valueDefRanges) {
447     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
448     LivenessSet &defSet = defIt.second;
449 
450     // Try to allocate to an existing index.
451     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
452       LivenessSet &existingIndex = existingIndexIt.value();
453       llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
454           defIt.second, existingIndex);
455       if (overlaps.valid())
456         continue;
457       // Union the range of the def within the existing index.
458       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
459         existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
460       memIndex = existingIndexIt.index() + 1;
461     }
462 
463     // If no existing index could be used, add a new one.
464     if (memIndex == 0) {
465       allocatedIndices.emplace_back(allocator);
466       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
467         allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
468       memIndex = allocatedIndices.size();
469     }
470   }
471 
472   // Update the max number of indices.
473   ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
474   if (numMatcherIndices > maxValueMemoryIndex)
475     maxValueMemoryIndex = numMatcherIndices;
476 }
477 
478 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
479   TypeSwitch<Operation *>(op)
480       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
481             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
482             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
483             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
484             pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
485             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
486             pdl_interp::EraseOp, pdl_interp::FinalizeOp,
487             pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
488             pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
489             pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp,
490             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
491             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
492             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
493             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
494             pdl_interp::SwitchResultCountOp>(
495           [&](auto interpOp) { this->generate(interpOp, writer); })
496       .Default([](Operation *) {
497         llvm_unreachable("unknown `pdl_interp` operation");
498       });
499 }
500 
501 void Generator::generate(pdl_interp::ApplyConstraintOp op,
502                          ByteCodeWriter &writer) {
503   assert(constraintToMemIndex.count(op.name()) &&
504          "expected index for constraint function");
505   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
506                 op.constParamsAttr());
507   writer.appendPDLValueList(op.args());
508   writer.append(op.getSuccessors());
509 }
510 void Generator::generate(pdl_interp::ApplyRewriteOp op,
511                          ByteCodeWriter &writer) {
512   assert(externalRewriterToMemIndex.count(op.name()) &&
513          "expected index for rewrite function");
514   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
515                 op.constParamsAttr());
516   writer.appendPDLValueList(op.args());
517 
518 #ifndef NDEBUG
519   // In debug mode we also append the number of results so that we can assert
520   // that the native creation function gave us the correct number of results.
521   writer.append(ByteCodeField(op.results().size()));
522 #endif
523   for (Value result : op.results())
524     writer.append(result);
525 }
526 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
527   writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
528 }
529 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
530   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
531 }
532 void Generator::generate(pdl_interp::CheckAttributeOp op,
533                          ByteCodeWriter &writer) {
534   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
535                 op.getSuccessors());
536 }
537 void Generator::generate(pdl_interp::CheckOperandCountOp op,
538                          ByteCodeWriter &writer) {
539   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
540                 op.getSuccessors());
541 }
542 void Generator::generate(pdl_interp::CheckOperationNameOp op,
543                          ByteCodeWriter &writer) {
544   writer.append(OpCode::CheckOperationName, op.operation(),
545                 OperationName(op.name(), ctx), op.getSuccessors());
546 }
547 void Generator::generate(pdl_interp::CheckResultCountOp op,
548                          ByteCodeWriter &writer) {
549   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
550                 op.getSuccessors());
551 }
552 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
553   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
554 }
555 void Generator::generate(pdl_interp::CreateAttributeOp op,
556                          ByteCodeWriter &writer) {
557   // Simply repoint the memory index of the result to the constant.
558   getMemIndex(op.attribute()) = getMemIndex(op.value());
559 }
560 void Generator::generate(pdl_interp::CreateOperationOp op,
561                          ByteCodeWriter &writer) {
562   writer.append(OpCode::CreateOperation, op.operation(),
563                 OperationName(op.name(), ctx), op.operands());
564 
565   // Add the attributes.
566   OperandRange attributes = op.attributes();
567   writer.append(static_cast<ByteCodeField>(attributes.size()));
568   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
569     writer.append(
570         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
571         std::get<1>(it));
572   }
573   writer.append(op.types());
574 }
575 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
576   // Simply repoint the memory index of the result to the constant.
577   getMemIndex(op.result()) = getMemIndex(op.value());
578 }
579 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
580   writer.append(OpCode::EraseOp, op.operation());
581 }
582 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
583   writer.append(OpCode::Finalize);
584 }
585 void Generator::generate(pdl_interp::GetAttributeOp op,
586                          ByteCodeWriter &writer) {
587   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
588                 Identifier::get(op.name(), ctx));
589 }
590 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
591                          ByteCodeWriter &writer) {
592   writer.append(OpCode::GetAttributeType, op.result(), op.value());
593 }
594 void Generator::generate(pdl_interp::GetDefiningOpOp op,
595                          ByteCodeWriter &writer) {
596   writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
597 }
598 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
599   uint32_t index = op.index();
600   if (index < 4)
601     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
602   else
603     writer.append(OpCode::GetOperandN, index);
604   writer.append(op.operation(), op.value());
605 }
606 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
607   uint32_t index = op.index();
608   if (index < 4)
609     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
610   else
611     writer.append(OpCode::GetResultN, index);
612   writer.append(op.operation(), op.value());
613 }
614 void Generator::generate(pdl_interp::GetValueTypeOp op,
615                          ByteCodeWriter &writer) {
616   writer.append(OpCode::GetValueType, op.result(), op.value());
617 }
618 void Generator::generate(pdl_interp::InferredTypesOp op,
619                          ByteCodeWriter &writer) {
620   // InferType maps to a null type as a marker for inferring result types.
621   getMemIndex(op.type()) = getMemIndex(Type());
622 }
623 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
624   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
625 }
626 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
627   ByteCodeField patternIndex = patterns.size();
628   patterns.emplace_back(PDLByteCodePattern::create(
629       op, rewriterToAddr[op.rewriter().getLeafReference()]));
630   writer.append(OpCode::RecordMatch, patternIndex,
631                 SuccessorRange(op.getOperation()), op.matchedOps(),
632                 op.inputs());
633 }
634 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
635   writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
636 }
637 void Generator::generate(pdl_interp::SwitchAttributeOp op,
638                          ByteCodeWriter &writer) {
639   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
640                 op.getSuccessors());
641 }
642 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
643                          ByteCodeWriter &writer) {
644   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
645                 op.getSuccessors());
646 }
647 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
648                          ByteCodeWriter &writer) {
649   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
650     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
651   });
652   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
653                 op.getSuccessors());
654 }
655 void Generator::generate(pdl_interp::SwitchResultCountOp op,
656                          ByteCodeWriter &writer) {
657   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
658                 op.getSuccessors());
659 }
660 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
661   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
662                 op.getSuccessors());
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // PDLByteCode
667 //===----------------------------------------------------------------------===//
668 
669 PDLByteCode::PDLByteCode(ModuleOp module,
670                          llvm::StringMap<PDLConstraintFunction> constraintFns,
671                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
672   Generator generator(module.getContext(), uniquedData, matcherByteCode,
673                       rewriterByteCode, patterns, maxValueMemoryIndex,
674                       constraintFns, rewriteFns);
675   generator.generate(module);
676 
677   // Initialize the external functions.
678   for (auto &it : constraintFns)
679     constraintFunctions.push_back(std::move(it.second));
680   for (auto &it : rewriteFns)
681     rewriteFunctions.push_back(std::move(it.second));
682 }
683 
684 /// Initialize the given state such that it can be used to execute the current
685 /// bytecode.
686 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
687   state.memory.resize(maxValueMemoryIndex, nullptr);
688   state.currentPatternBenefits.reserve(patterns.size());
689   for (const PDLByteCodePattern &pattern : patterns)
690     state.currentPatternBenefits.push_back(pattern.getBenefit());
691 }
692 
693 //===----------------------------------------------------------------------===//
694 // ByteCode Execution
695 
696 namespace {
697 /// This class provides support for executing a bytecode stream.
698 class ByteCodeExecutor {
699 public:
700   ByteCodeExecutor(const ByteCodeField *curCodeIt,
701                    MutableArrayRef<const void *> memory,
702                    ArrayRef<const void *> uniquedMemory,
703                    ArrayRef<ByteCodeField> code,
704                    ArrayRef<PatternBenefit> currentPatternBenefits,
705                    ArrayRef<PDLByteCodePattern> patterns,
706                    ArrayRef<PDLConstraintFunction> constraintFunctions,
707                    ArrayRef<PDLRewriteFunction> rewriteFunctions)
708       : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
709         code(code), currentPatternBenefits(currentPatternBenefits),
710         patterns(patterns), constraintFunctions(constraintFunctions),
711         rewriteFunctions(rewriteFunctions) {}
712 
713   /// Start executing the code at the current bytecode index. `matches` is an
714   /// optional field provided when this function is executed in a matching
715   /// context.
716   void execute(PatternRewriter &rewriter,
717                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
718                Optional<Location> mainRewriteLoc = {});
719 
720 private:
721   /// Internal implementation of executing each of the bytecode commands.
722   void executeApplyConstraint(PatternRewriter &rewriter);
723   void executeApplyRewrite(PatternRewriter &rewriter);
724   void executeAreEqual();
725   void executeBranch();
726   void executeCheckOperandCount();
727   void executeCheckOperationName();
728   void executeCheckResultCount();
729   void executeCreateOperation(PatternRewriter &rewriter,
730                               Location mainRewriteLoc);
731   void executeEraseOp(PatternRewriter &rewriter);
732   void executeGetAttribute();
733   void executeGetAttributeType();
734   void executeGetDefiningOp();
735   void executeGetOperand(unsigned index);
736   void executeGetResult(unsigned index);
737   void executeGetValueType();
738   void executeIsNotNull();
739   void executeRecordMatch(PatternRewriter &rewriter,
740                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
741   void executeReplaceOp(PatternRewriter &rewriter);
742   void executeSwitchAttribute();
743   void executeSwitchOperandCount();
744   void executeSwitchOperationName();
745   void executeSwitchResultCount();
746   void executeSwitchType();
747 
748   /// Read a value from the bytecode buffer, optionally skipping a certain
749   /// number of prefix values. These methods always update the buffer to point
750   /// to the next field after the read data.
751   template <typename T = ByteCodeField>
752   T read(size_t skipN = 0) {
753     curCodeIt += skipN;
754     return readImpl<T>();
755   }
756   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
757 
758   /// Read a list of values from the bytecode buffer.
759   template <typename ValueT, typename T>
760   void readList(SmallVectorImpl<T> &list) {
761     list.clear();
762     for (unsigned i = 0, e = read(); i != e; ++i)
763       list.push_back(read<ValueT>());
764   }
765 
766   /// Jump to a specific successor based on a predicate value.
767   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
768   /// Jump to a specific successor based on a destination index.
769   void selectJump(size_t destIndex) {
770     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
771   }
772 
773   /// Handle a switch operation with the provided value and cases.
774   template <typename T, typename RangeT>
775   void handleSwitch(const T &value, RangeT &&cases) {
776     LLVM_DEBUG({
777       llvm::dbgs() << "  * Value: " << value << "\n"
778                    << "  * Cases: ";
779       llvm::interleaveComma(cases, llvm::dbgs());
780       llvm::dbgs() << "\n";
781     });
782 
783     // Check to see if the attribute value is within the case list. Jump to
784     // the correct successor index based on the result.
785     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
786       if (*it == value)
787         return selectJump(size_t((it - cases.begin()) + 1));
788     selectJump(size_t(0));
789   }
790 
791   /// Internal implementation of reading various data types from the bytecode
792   /// stream.
793   template <typename T>
794   const void *readFromMemory() {
795     size_t index = *curCodeIt++;
796 
797     // If this type is an SSA value, it can only be stored in non-const memory.
798     if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
799       return memory[index];
800 
801     // Otherwise, if this index is not inbounds it is uniqued.
802     return uniquedMemory[index - memory.size()];
803   }
804   template <typename T>
805   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
806     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
807   }
808   template <typename T>
809   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
810                    T>
811   readImpl() {
812     return T(T::getFromOpaquePointer(readFromMemory<T>()));
813   }
814   template <typename T>
815   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
816     switch (static_cast<PDLValueKind>(read())) {
817     case PDLValueKind::Attribute:
818       return read<Attribute>();
819     case PDLValueKind::Operation:
820       return read<Operation *>();
821     case PDLValueKind::Type:
822       return read<Type>();
823     case PDLValueKind::Value:
824       return read<Value>();
825     }
826     llvm_unreachable("unhandled PDLValueKind");
827   }
828   template <typename T>
829   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
830     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
831                   "unexpected ByteCode address size");
832     ByteCodeAddr result;
833     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
834     curCodeIt += 2;
835     return result;
836   }
837   template <typename T>
838   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
839     return *curCodeIt++;
840   }
841 
842   /// The underlying bytecode buffer.
843   const ByteCodeField *curCodeIt;
844 
845   /// The current execution memory.
846   MutableArrayRef<const void *> memory;
847 
848   /// References to ByteCode data necessary for execution.
849   ArrayRef<const void *> uniquedMemory;
850   ArrayRef<ByteCodeField> code;
851   ArrayRef<PatternBenefit> currentPatternBenefits;
852   ArrayRef<PDLByteCodePattern> patterns;
853   ArrayRef<PDLConstraintFunction> constraintFunctions;
854   ArrayRef<PDLRewriteFunction> rewriteFunctions;
855 };
856 
857 /// This class is an instantiation of the PDLResultList that provides access to
858 /// the returned results. This API is not on `PDLResultList` to avoid
859 /// overexposing access to information specific solely to the ByteCode.
860 class ByteCodeRewriteResultList : public PDLResultList {
861 public:
862   /// Return the list of PDL results.
863   MutableArrayRef<PDLValue> getResults() { return results; }
864 };
865 } // end anonymous namespace
866 
867 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
868   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
869   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
870   ArrayAttr constParams = read<ArrayAttr>();
871   SmallVector<PDLValue, 16> args;
872   readList<PDLValue>(args);
873 
874   LLVM_DEBUG({
875     llvm::dbgs() << "  * Arguments: ";
876     llvm::interleaveComma(args, llvm::dbgs());
877     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
878   });
879 
880   // Invoke the constraint and jump to the proper destination.
881   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
882 }
883 
884 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
885   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
886   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
887   ArrayAttr constParams = read<ArrayAttr>();
888   SmallVector<PDLValue, 16> args;
889   readList<PDLValue>(args);
890 
891   LLVM_DEBUG({
892     llvm::dbgs() << "  * Arguments: ";
893     llvm::interleaveComma(args, llvm::dbgs());
894     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
895   });
896   ByteCodeRewriteResultList results;
897   rewriteFn(args, constParams, rewriter, results);
898 
899   // Store the results in the bytecode memory.
900 #ifndef NDEBUG
901   ByteCodeField expectedNumberOfResults = read();
902   assert(results.getResults().size() == expectedNumberOfResults &&
903          "native PDL rewrite function returned unexpected number of results");
904 #endif
905 
906   // Store the results in the bytecode memory.
907   for (PDLValue &result : results.getResults()) {
908     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
909     memory[read()] = result.getAsOpaquePointer();
910   }
911 }
912 
913 void ByteCodeExecutor::executeAreEqual() {
914   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
915   const void *lhs = read<const void *>();
916   const void *rhs = read<const void *>();
917 
918   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
919   selectJump(lhs == rhs);
920 }
921 
922 void ByteCodeExecutor::executeBranch() {
923   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
924   curCodeIt = &code[read<ByteCodeAddr>()];
925 }
926 
927 void ByteCodeExecutor::executeCheckOperandCount() {
928   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
929   Operation *op = read<Operation *>();
930   uint32_t expectedCount = read<uint32_t>();
931 
932   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
933                           << "  * Expected: " << expectedCount << "\n");
934   selectJump(op->getNumOperands() == expectedCount);
935 }
936 
937 void ByteCodeExecutor::executeCheckOperationName() {
938   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
939   Operation *op = read<Operation *>();
940   OperationName expectedName = read<OperationName>();
941 
942   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
943                           << "  * Expected: \"" << expectedName << "\"\n");
944   selectJump(op->getName() == expectedName);
945 }
946 
947 void ByteCodeExecutor::executeCheckResultCount() {
948   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
949   Operation *op = read<Operation *>();
950   uint32_t expectedCount = read<uint32_t>();
951 
952   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
953                           << "  * Expected: " << expectedCount << "\n");
954   selectJump(op->getNumResults() == expectedCount);
955 }
956 
957 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
958                                               Location mainRewriteLoc) {
959   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
960 
961   unsigned memIndex = read();
962   OperationState state(mainRewriteLoc, read<OperationName>());
963   readList<Value>(state.operands);
964   for (unsigned i = 0, e = read(); i != e; ++i) {
965     Identifier name = read<Identifier>();
966     if (Attribute attr = read<Attribute>())
967       state.addAttribute(name, attr);
968   }
969 
970   bool hasInferredTypes = false;
971   for (unsigned i = 0, e = read(); i != e; ++i) {
972     Type resultType = read<Type>();
973     hasInferredTypes |= !resultType;
974     state.types.push_back(resultType);
975   }
976 
977   // Handle the case where the operation has inferred types.
978   if (hasInferredTypes) {
979     InferTypeOpInterface::Concept *concept =
980         state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
981 
982     // TODO: Handle failure.
983     state.types.clear();
984     if (failed(concept->inferReturnTypes(
985             state.getContext(), state.location, state.operands,
986             state.attributes.getDictionary(state.getContext()), state.regions,
987             state.types)))
988       return;
989   }
990   Operation *resultOp = rewriter.createOperation(state);
991   memory[memIndex] = resultOp;
992 
993   LLVM_DEBUG({
994     llvm::dbgs() << "  * Attributes: "
995                  << state.attributes.getDictionary(state.getContext())
996                  << "\n  * Operands: ";
997     llvm::interleaveComma(state.operands, llvm::dbgs());
998     llvm::dbgs() << "\n  * Result Types: ";
999     llvm::interleaveComma(state.types, llvm::dbgs());
1000     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1001   });
1002 }
1003 
1004 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1005   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1006   Operation *op = read<Operation *>();
1007 
1008   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1009   rewriter.eraseOp(op);
1010 }
1011 
1012 void ByteCodeExecutor::executeGetAttribute() {
1013   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1014   unsigned memIndex = read();
1015   Operation *op = read<Operation *>();
1016   Identifier attrName = read<Identifier>();
1017   Attribute attr = op->getAttr(attrName);
1018 
1019   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1020                           << "  * Attribute: " << attrName << "\n"
1021                           << "  * Result: " << attr << "\n");
1022   memory[memIndex] = attr.getAsOpaquePointer();
1023 }
1024 
1025 void ByteCodeExecutor::executeGetAttributeType() {
1026   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1027   unsigned memIndex = read();
1028   Attribute attr = read<Attribute>();
1029   Type type = attr ? attr.getType() : Type();
1030 
1031   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1032                           << "  * Result: " << type << "\n");
1033   memory[memIndex] = type.getAsOpaquePointer();
1034 }
1035 
1036 void ByteCodeExecutor::executeGetDefiningOp() {
1037   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1038   unsigned memIndex = read();
1039   Value value = read<Value>();
1040   Operation *op = value ? value.getDefiningOp() : nullptr;
1041 
1042   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1043                           << "  * Result: " << *op << "\n");
1044   memory[memIndex] = op;
1045 }
1046 
1047 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1048   Operation *op = read<Operation *>();
1049   unsigned memIndex = read();
1050   Value operand =
1051       index < op->getNumOperands() ? op->getOperand(index) : Value();
1052 
1053   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1054                           << "  * Index: " << index << "\n"
1055                           << "  * Result: " << operand << "\n");
1056   memory[memIndex] = operand.getAsOpaquePointer();
1057 }
1058 
1059 void ByteCodeExecutor::executeGetResult(unsigned index) {
1060   Operation *op = read<Operation *>();
1061   unsigned memIndex = read();
1062   OpResult result =
1063       index < op->getNumResults() ? op->getResult(index) : OpResult();
1064 
1065   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1066                           << "  * Index: " << index << "\n"
1067                           << "  * Result: " << result << "\n");
1068   memory[memIndex] = result.getAsOpaquePointer();
1069 }
1070 
1071 void ByteCodeExecutor::executeGetValueType() {
1072   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1073   unsigned memIndex = read();
1074   Value value = read<Value>();
1075   Type type = value ? value.getType() : Type();
1076 
1077   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1078                           << "  * Result: " << type << "\n");
1079   memory[memIndex] = type.getAsOpaquePointer();
1080 }
1081 
1082 void ByteCodeExecutor::executeIsNotNull() {
1083   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1084   const void *value = read<const void *>();
1085 
1086   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1087   selectJump(value != nullptr);
1088 }
1089 
1090 void ByteCodeExecutor::executeRecordMatch(
1091     PatternRewriter &rewriter,
1092     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1093   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1094   unsigned patternIndex = read();
1095   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1096   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1097 
1098   // If the benefit of the pattern is impossible, skip the processing of the
1099   // rest of the pattern.
1100   if (benefit.isImpossibleToMatch()) {
1101     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1102     curCodeIt = dest;
1103     return;
1104   }
1105 
1106   // Create a fused location containing the locations of each of the
1107   // operations used in the match. This will be used as the location for
1108   // created operations during the rewrite that don't already have an
1109   // explicit location set.
1110   unsigned numMatchLocs = read();
1111   SmallVector<Location, 4> matchLocs;
1112   matchLocs.reserve(numMatchLocs);
1113   for (unsigned i = 0; i != numMatchLocs; ++i)
1114     matchLocs.push_back(read<Operation *>()->getLoc());
1115   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1116 
1117   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1118                           << "  * Location: " << matchLoc << "\n");
1119   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1120   readList<const void *>(matches.back().values);
1121   curCodeIt = dest;
1122 }
1123 
1124 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1125   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1126   Operation *op = read<Operation *>();
1127   SmallVector<Value, 16> args;
1128   readList<Value>(args);
1129 
1130   LLVM_DEBUG({
1131     llvm::dbgs() << "  * Operation: " << *op << "\n"
1132                  << "  * Values: ";
1133     llvm::interleaveComma(args, llvm::dbgs());
1134     llvm::dbgs() << "\n";
1135   });
1136   rewriter.replaceOp(op, args);
1137 }
1138 
1139 void ByteCodeExecutor::executeSwitchAttribute() {
1140   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1141   Attribute value = read<Attribute>();
1142   ArrayAttr cases = read<ArrayAttr>();
1143   handleSwitch(value, cases);
1144 }
1145 
1146 void ByteCodeExecutor::executeSwitchOperandCount() {
1147   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1148   Operation *op = read<Operation *>();
1149   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1150 
1151   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1152   handleSwitch(op->getNumOperands(), cases);
1153 }
1154 
1155 void ByteCodeExecutor::executeSwitchOperationName() {
1156   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1157   OperationName value = read<Operation *>()->getName();
1158   size_t caseCount = read();
1159 
1160   // The operation names are stored in-line, so to print them out for
1161   // debugging purposes we need to read the array before executing the
1162   // switch so that we can display all of the possible values.
1163   LLVM_DEBUG({
1164     const ByteCodeField *prevCodeIt = curCodeIt;
1165     llvm::dbgs() << "  * Value: " << value << "\n"
1166                  << "  * Cases: ";
1167     llvm::interleaveComma(
1168         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1169                         [&](size_t) { return read<OperationName>(); }),
1170         llvm::dbgs());
1171     llvm::dbgs() << "\n";
1172     curCodeIt = prevCodeIt;
1173   });
1174 
1175   // Try to find the switch value within any of the cases.
1176   for (size_t i = 0; i != caseCount; ++i) {
1177     if (read<OperationName>() == value) {
1178       curCodeIt += (caseCount - i - 1);
1179       return selectJump(i + 1);
1180     }
1181   }
1182   selectJump(size_t(0));
1183 }
1184 
1185 void ByteCodeExecutor::executeSwitchResultCount() {
1186   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1187   Operation *op = read<Operation *>();
1188   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1189 
1190   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1191   handleSwitch(op->getNumResults(), cases);
1192 }
1193 
1194 void ByteCodeExecutor::executeSwitchType() {
1195   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1196   Type value = read<Type>();
1197   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1198   handleSwitch(value, cases);
1199 }
1200 
1201 void ByteCodeExecutor::execute(
1202     PatternRewriter &rewriter,
1203     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1204     Optional<Location> mainRewriteLoc) {
1205   while (true) {
1206     OpCode opCode = static_cast<OpCode>(read());
1207     switch (opCode) {
1208     case ApplyConstraint:
1209       executeApplyConstraint(rewriter);
1210       break;
1211     case ApplyRewrite:
1212       executeApplyRewrite(rewriter);
1213       break;
1214     case AreEqual:
1215       executeAreEqual();
1216       break;
1217     case Branch:
1218       executeBranch();
1219       break;
1220     case CheckOperandCount:
1221       executeCheckOperandCount();
1222       break;
1223     case CheckOperationName:
1224       executeCheckOperationName();
1225       break;
1226     case CheckResultCount:
1227       executeCheckResultCount();
1228       break;
1229     case CreateOperation:
1230       executeCreateOperation(rewriter, *mainRewriteLoc);
1231       break;
1232     case EraseOp:
1233       executeEraseOp(rewriter);
1234       break;
1235     case Finalize:
1236       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1237       return;
1238     case GetAttribute:
1239       executeGetAttribute();
1240       break;
1241     case GetAttributeType:
1242       executeGetAttributeType();
1243       break;
1244     case GetDefiningOp:
1245       executeGetDefiningOp();
1246       break;
1247     case GetOperand0:
1248     case GetOperand1:
1249     case GetOperand2:
1250     case GetOperand3: {
1251       unsigned index = opCode - GetOperand0;
1252       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
1253       executeGetOperand(index);
1254       break;
1255     }
1256     case GetOperandN:
1257       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
1258       executeGetOperand(read<uint32_t>());
1259       break;
1260     case GetResult0:
1261     case GetResult1:
1262     case GetResult2:
1263     case GetResult3: {
1264       unsigned index = opCode - GetResult0;
1265       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
1266       executeGetResult(index);
1267       break;
1268     }
1269     case GetResultN:
1270       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
1271       executeGetResult(read<uint32_t>());
1272       break;
1273     case GetValueType:
1274       executeGetValueType();
1275       break;
1276     case IsNotNull:
1277       executeIsNotNull();
1278       break;
1279     case RecordMatch:
1280       assert(matches &&
1281              "expected matches to be provided when executing the matcher");
1282       executeRecordMatch(rewriter, *matches);
1283       break;
1284     case ReplaceOp:
1285       executeReplaceOp(rewriter);
1286       break;
1287     case SwitchAttribute:
1288       executeSwitchAttribute();
1289       break;
1290     case SwitchOperandCount:
1291       executeSwitchOperandCount();
1292       break;
1293     case SwitchOperationName:
1294       executeSwitchOperationName();
1295       break;
1296     case SwitchResultCount:
1297       executeSwitchResultCount();
1298       break;
1299     case SwitchType:
1300       executeSwitchType();
1301       break;
1302     }
1303     LLVM_DEBUG(llvm::dbgs() << "\n");
1304   }
1305 }
1306 
1307 /// Run the pattern matcher on the given root operation, collecting the matched
1308 /// patterns in `matches`.
1309 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1310                         SmallVectorImpl<MatchResult> &matches,
1311                         PDLByteCodeMutableState &state) const {
1312   // The first memory slot is always the root operation.
1313   state.memory[0] = op;
1314 
1315   // The matcher function always starts at code address 0.
1316   ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
1317                             matcherByteCode, state.currentPatternBenefits,
1318                             patterns, constraintFunctions, rewriteFunctions);
1319   executor.execute(rewriter, &matches);
1320 
1321   // Order the found matches by benefit.
1322   std::stable_sort(matches.begin(), matches.end(),
1323                    [](const MatchResult &lhs, const MatchResult &rhs) {
1324                      return lhs.benefit > rhs.benefit;
1325                    });
1326 }
1327 
1328 /// Run the rewriter of the given pattern on the root operation `op`.
1329 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1330                           PDLByteCodeMutableState &state) const {
1331   // The arguments of the rewrite function are stored at the start of the
1332   // memory buffer.
1333   llvm::copy(match.values, state.memory.begin());
1334 
1335   ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()],
1336                             state.memory, uniquedData, rewriterByteCode,
1337                             state.currentPatternBenefits, patterns,
1338                             constraintFunctions, rewriteFunctions);
1339   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1340 }
1341