1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "pdl-bytecode"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // PDLByteCodePattern
31 //===----------------------------------------------------------------------===//
32 
33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
34                                               ByteCodeAddr rewriterAddr) {
35   SmallVector<StringRef, 8> generatedOps;
36   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
37     generatedOps =
38         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
39 
40   PatternBenefit benefit = matchOp.benefit();
41   MLIRContext *ctx = matchOp.getContext();
42 
43   // Check to see if this is pattern matches a specific operation type.
44   if (Optional<StringRef> rootKind = matchOp.rootKind())
45     return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
46                               ctx);
47   return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
48                             MatchAnyOpTypeTag());
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // PDLByteCodeMutableState
53 //===----------------------------------------------------------------------===//
54 
55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
56 /// to the position of the pattern within the range returned by
57 /// `PDLByteCode::getPatterns`.
58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
59                                                    PatternBenefit benefit) {
60   currentPatternBenefits[patternIndex] = benefit;
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Bytecode OpCodes
65 //===----------------------------------------------------------------------===//
66 
67 namespace {
68 enum OpCode : ByteCodeField {
69   /// Apply an externally registered constraint.
70   ApplyConstraint,
71   /// Apply an externally registered rewrite.
72   ApplyRewrite,
73   /// Check if two generic values are equal.
74   AreEqual,
75   /// Unconditional branch.
76   Branch,
77   /// Compare the operand count of an operation with a constant.
78   CheckOperandCount,
79   /// Compare the name of an operation with a constant.
80   CheckOperationName,
81   /// Compare the result count of an operation with a constant.
82   CheckResultCount,
83   /// Create an operation.
84   CreateOperation,
85   /// Erase an operation.
86   EraseOp,
87   /// Terminate a matcher or rewrite sequence.
88   Finalize,
89   /// Get a specific attribute of an operation.
90   GetAttribute,
91   /// Get the type of an attribute.
92   GetAttributeType,
93   /// Get the defining operation of a value.
94   GetDefiningOp,
95   /// Get a specific operand of an operation.
96   GetOperand0,
97   GetOperand1,
98   GetOperand2,
99   GetOperand3,
100   GetOperandN,
101   /// Get a specific result of an operation.
102   GetResult0,
103   GetResult1,
104   GetResult2,
105   GetResult3,
106   GetResultN,
107   /// Get the type of a value.
108   GetValueType,
109   /// Check if a generic value is not null.
110   IsNotNull,
111   /// Record a successful pattern match.
112   RecordMatch,
113   /// Replace an operation.
114   ReplaceOp,
115   /// Compare an attribute with a set of constants.
116   SwitchAttribute,
117   /// Compare the operand count of an operation with a set of constants.
118   SwitchOperandCount,
119   /// Compare the name of an operation with a set of constants.
120   SwitchOperationName,
121   /// Compare the result count of an operation with a set of constants.
122   SwitchResultCount,
123   /// Compare a type with a set of constants.
124   SwitchType,
125 };
126 
127 enum class PDLValueKind { Attribute, Operation, Type, Value };
128 } // end anonymous namespace
129 
130 //===----------------------------------------------------------------------===//
131 // ByteCode Generation
132 //===----------------------------------------------------------------------===//
133 
134 //===----------------------------------------------------------------------===//
135 // Generator
136 
137 namespace {
138 struct ByteCodeWriter;
139 
140 /// This class represents the main generator for the pattern bytecode.
141 class Generator {
142 public:
143   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
144             SmallVectorImpl<ByteCodeField> &matcherByteCode,
145             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
146             SmallVectorImpl<PDLByteCodePattern> &patterns,
147             ByteCodeField &maxValueMemoryIndex,
148             llvm::StringMap<PDLConstraintFunction> &constraintFns,
149             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
150       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
151         rewriterByteCode(rewriterByteCode), patterns(patterns),
152         maxValueMemoryIndex(maxValueMemoryIndex) {
153     for (auto it : llvm::enumerate(constraintFns))
154       constraintToMemIndex.try_emplace(it.value().first(), it.index());
155     for (auto it : llvm::enumerate(rewriteFns))
156       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
157   }
158 
159   /// Generate the bytecode for the given PDL interpreter module.
160   void generate(ModuleOp module);
161 
162   /// Return the memory index to use for the given value.
163   ByteCodeField &getMemIndex(Value value) {
164     assert(valueToMemIndex.count(value) &&
165            "expected memory index to be assigned");
166     return valueToMemIndex[value];
167   }
168 
169   /// Return an index to use when referring to the given data that is uniqued in
170   /// the MLIR context.
171   template <typename T>
172   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
173   getMemIndex(T val) {
174     const void *opaqueVal = val.getAsOpaquePointer();
175 
176     // Get or insert a reference to this value.
177     auto it = uniquedDataToMemIndex.try_emplace(
178         opaqueVal, maxValueMemoryIndex + uniquedData.size());
179     if (it.second)
180       uniquedData.push_back(opaqueVal);
181     return it.first->second;
182   }
183 
184 private:
185   /// Allocate memory indices for the results of operations within the matcher
186   /// and rewriters.
187   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
188 
189   /// Generate the bytecode for the given operation.
190   void generate(Operation *op, ByteCodeWriter &writer);
191   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
192   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
193   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
194   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
195   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
196   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
197   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
198   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
199   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
200   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
201   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
202   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
203   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
204   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
205   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
206   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
207   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
208   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
209   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
210   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
211   void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
212   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
213   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
214   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
215   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
216   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
217   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
218   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
219   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
220 
221   /// Mapping from value to its corresponding memory index.
222   DenseMap<Value, ByteCodeField> valueToMemIndex;
223 
224   /// Mapping from the name of an externally registered rewrite to its index in
225   /// the bytecode registry.
226   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
227 
228   /// Mapping from the name of an externally registered constraint to its index
229   /// in the bytecode registry.
230   llvm::StringMap<ByteCodeField> constraintToMemIndex;
231 
232   /// Mapping from rewriter function name to the bytecode address of the
233   /// rewriter function in byte.
234   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
235 
236   /// Mapping from a uniqued storage object to its memory index within
237   /// `uniquedData`.
238   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
239 
240   /// The current MLIR context.
241   MLIRContext *ctx;
242 
243   /// Data of the ByteCode class to be populated.
244   std::vector<const void *> &uniquedData;
245   SmallVectorImpl<ByteCodeField> &matcherByteCode;
246   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
247   SmallVectorImpl<PDLByteCodePattern> &patterns;
248   ByteCodeField &maxValueMemoryIndex;
249 };
250 
251 /// This class provides utilities for writing a bytecode stream.
252 struct ByteCodeWriter {
253   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
254       : bytecode(bytecode), generator(generator) {}
255 
256   /// Append a field to the bytecode.
257   void append(ByteCodeField field) { bytecode.push_back(field); }
258   void append(OpCode opCode) { bytecode.push_back(opCode); }
259 
260   /// Append an address to the bytecode.
261   void append(ByteCodeAddr field) {
262     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
263                   "unexpected ByteCode address size");
264 
265     ByteCodeField fieldParts[2];
266     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
267     bytecode.append({fieldParts[0], fieldParts[1]});
268   }
269 
270   /// Append a successor range to the bytecode, the exact address will need to
271   /// be resolved later.
272   void append(SuccessorRange successors) {
273     // Add back references to the any successors so that the address can be
274     // resolved later.
275     for (Block *successor : successors) {
276       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
277       append(ByteCodeAddr(0));
278     }
279   }
280 
281   /// Append a range of values that will be read as generic PDLValues.
282   void appendPDLValueList(OperandRange values) {
283     bytecode.push_back(values.size());
284     for (Value value : values) {
285       // Append the type of the value in addition to the value itself.
286       PDLValueKind kind =
287           TypeSwitch<Type, PDLValueKind>(value.getType())
288               .Case<pdl::AttributeType>(
289                   [](Type) { return PDLValueKind::Attribute; })
290               .Case<pdl::OperationType>(
291                   [](Type) { return PDLValueKind::Operation; })
292               .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
293               .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
294       bytecode.push_back(static_cast<ByteCodeField>(kind));
295       append(value);
296     }
297   }
298 
299   /// Check if the given class `T` has an iterator type.
300   template <typename T, typename... Args>
301   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
302 
303   /// Append a value that will be stored in a memory slot and not inline within
304   /// the bytecode.
305   template <typename T>
306   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
307                    std::is_pointer<T>::value>
308   append(T value) {
309     bytecode.push_back(generator.getMemIndex(value));
310   }
311 
312   /// Append a range of values.
313   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
314   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
315   append(T range) {
316     bytecode.push_back(llvm::size(range));
317     for (auto it : range)
318       append(it);
319   }
320 
321   /// Append a variadic number of fields to the bytecode.
322   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
323   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
324     append(field);
325     append(field2, fields...);
326   }
327 
328   /// Successor references in the bytecode that have yet to be resolved.
329   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
330 
331   /// The underlying bytecode buffer.
332   SmallVectorImpl<ByteCodeField> &bytecode;
333 
334   /// The main generator producing PDL.
335   Generator &generator;
336 };
337 } // end anonymous namespace
338 
339 void Generator::generate(ModuleOp module) {
340   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
341       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
342   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
343       pdl_interp::PDLInterpDialect::getRewriterModuleName());
344   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
345 
346   // Allocate memory indices for the results of operations within the matcher
347   // and rewriters.
348   allocateMemoryIndices(matcherFunc, rewriterModule);
349 
350   // Generate code for the rewriter functions.
351   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
352   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
353     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
354     for (Operation &op : rewriterFunc.getOps())
355       generate(&op, rewriterByteCodeWriter);
356   }
357   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
358          "unexpected branches in rewriter function");
359 
360   // Generate code for the matcher function.
361   DenseMap<Block *, ByteCodeAddr> blockToAddr;
362   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
363   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
364   for (Block *block : rpot) {
365     // Keep track of where this block begins within the matcher function.
366     blockToAddr.try_emplace(block, matcherByteCode.size());
367     for (Operation &op : *block)
368       generate(&op, matcherByteCodeWriter);
369   }
370 
371   // Resolve successor references in the matcher.
372   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
373     ByteCodeAddr addr = blockToAddr[it.first];
374     for (unsigned offsetToFix : it.second)
375       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
376   }
377 }
378 
379 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
380                                       ModuleOp rewriterModule) {
381   // Rewriters use simplistic allocation scheme that simply assigns an index to
382   // each result.
383   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
384     ByteCodeField index = 0;
385     for (BlockArgument arg : rewriterFunc.getArguments())
386       valueToMemIndex.try_emplace(arg, index++);
387     rewriterFunc.getBody().walk([&](Operation *op) {
388       for (Value result : op->getResults())
389         valueToMemIndex.try_emplace(result, index++);
390     });
391     if (index > maxValueMemoryIndex)
392       maxValueMemoryIndex = index;
393   }
394 
395   // The matcher function uses a more sophisticated numbering that tries to
396   // minimize the number of memory indices assigned. This is done by determining
397   // a live range of the values within the matcher, then the allocation is just
398   // finding the minimal number of overlapping live ranges. This is essentially
399   // a simplified form of register allocation where we don't necessarily have a
400   // limited number of registers, but we still want to minimize the number used.
401   DenseMap<Operation *, ByteCodeField> opToIndex;
402   matcherFunc.getBody().walk([&](Operation *op) {
403     opToIndex.insert(std::make_pair(op, opToIndex.size()));
404   });
405 
406   // Liveness info for each of the defs within the matcher.
407   using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
408   LivenessSet::Allocator allocator;
409   DenseMap<Value, LivenessSet> valueDefRanges;
410 
411   // Assign the root operation being matched to slot 0.
412   BlockArgument rootOpArg = matcherFunc.getArgument(0);
413   valueToMemIndex[rootOpArg] = 0;
414 
415   // Walk each of the blocks, computing the def interval that the value is used.
416   Liveness matcherLiveness(matcherFunc);
417   for (Block &block : matcherFunc.getBody()) {
418     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
419     assert(info && "expected liveness info for block");
420     auto processValue = [&](Value value, Operation *firstUseOrDef) {
421       // We don't need to process the root op argument, this value is always
422       // assigned to the first memory slot.
423       if (value == rootOpArg)
424         return;
425 
426       // Set indices for the range of this block that the value is used.
427       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
428       defRangeIt->second.insert(
429           opToIndex[firstUseOrDef],
430           opToIndex[info->getEndOperation(value, firstUseOrDef)],
431           /*dummyValue*/ 0);
432     };
433 
434     // Process the live-ins of this block.
435     for (Value liveIn : info->in())
436       processValue(liveIn, &block.front());
437 
438     // Process any new defs within this block.
439     for (Operation &op : block)
440       for (Value result : op.getResults())
441         processValue(result, &op);
442   }
443 
444   // Greedily allocate memory slots using the computed def live ranges.
445   std::vector<LivenessSet> allocatedIndices;
446   for (auto &defIt : valueDefRanges) {
447     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
448     LivenessSet &defSet = defIt.second;
449 
450     // Try to allocate to an existing index.
451     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
452       LivenessSet &existingIndex = existingIndexIt.value();
453       llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
454           defIt.second, existingIndex);
455       if (overlaps.valid())
456         continue;
457       // Union the range of the def within the existing index.
458       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
459         existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
460       memIndex = existingIndexIt.index() + 1;
461     }
462 
463     // If no existing index could be used, add a new one.
464     if (memIndex == 0) {
465       allocatedIndices.emplace_back(allocator);
466       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
467         allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
468       memIndex = allocatedIndices.size();
469     }
470   }
471 
472   // Update the max number of indices.
473   ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
474   if (numMatcherIndices > maxValueMemoryIndex)
475     maxValueMemoryIndex = numMatcherIndices;
476 }
477 
478 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
479   TypeSwitch<Operation *>(op)
480       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
481             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
482             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
483             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
484             pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
485             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
486             pdl_interp::EraseOp, pdl_interp::FinalizeOp,
487             pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
488             pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
489             pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp,
490             pdl_interp::InferredTypeOp, pdl_interp::IsNotNullOp,
491             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
492             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
493             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
494             pdl_interp::SwitchResultCountOp>(
495           [&](auto interpOp) { this->generate(interpOp, writer); })
496       .Default([](Operation *) {
497         llvm_unreachable("unknown `pdl_interp` operation");
498       });
499 }
500 
501 void Generator::generate(pdl_interp::ApplyConstraintOp op,
502                          ByteCodeWriter &writer) {
503   assert(constraintToMemIndex.count(op.name()) &&
504          "expected index for constraint function");
505   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
506                 op.constParamsAttr());
507   writer.appendPDLValueList(op.args());
508   writer.append(op.getSuccessors());
509 }
510 void Generator::generate(pdl_interp::ApplyRewriteOp op,
511                          ByteCodeWriter &writer) {
512   assert(externalRewriterToMemIndex.count(op.name()) &&
513          "expected index for rewrite function");
514   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
515                 op.constParamsAttr());
516   writer.appendPDLValueList(op.args());
517 
518 #ifndef NDEBUG
519   // In debug mode we also append the number of results so that we can assert
520   // that the native creation function gave us the correct number of results.
521   writer.append(ByteCodeField(op.results().size()));
522 #endif
523   for (Value result : op.results())
524     writer.append(result);
525 }
526 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
527   writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
528 }
529 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
530   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
531 }
532 void Generator::generate(pdl_interp::CheckAttributeOp op,
533                          ByteCodeWriter &writer) {
534   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
535                 op.getSuccessors());
536 }
537 void Generator::generate(pdl_interp::CheckOperandCountOp op,
538                          ByteCodeWriter &writer) {
539   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
540                 op.getSuccessors());
541 }
542 void Generator::generate(pdl_interp::CheckOperationNameOp op,
543                          ByteCodeWriter &writer) {
544   writer.append(OpCode::CheckOperationName, op.operation(),
545                 OperationName(op.name(), ctx), op.getSuccessors());
546 }
547 void Generator::generate(pdl_interp::CheckResultCountOp op,
548                          ByteCodeWriter &writer) {
549   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
550                 op.getSuccessors());
551 }
552 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
553   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
554 }
555 void Generator::generate(pdl_interp::CreateAttributeOp op,
556                          ByteCodeWriter &writer) {
557   // Simply repoint the memory index of the result to the constant.
558   getMemIndex(op.attribute()) = getMemIndex(op.value());
559 }
560 void Generator::generate(pdl_interp::CreateOperationOp op,
561                          ByteCodeWriter &writer) {
562   writer.append(OpCode::CreateOperation, op.operation(),
563                 OperationName(op.name(), ctx), op.operands());
564 
565   // Add the attributes.
566   OperandRange attributes = op.attributes();
567   writer.append(static_cast<ByteCodeField>(attributes.size()));
568   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
569     writer.append(
570         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
571         std::get<1>(it));
572   }
573   writer.append(op.types());
574 }
575 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
576   // Simply repoint the memory index of the result to the constant.
577   getMemIndex(op.result()) = getMemIndex(op.value());
578 }
579 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
580   writer.append(OpCode::EraseOp, op.operation());
581 }
582 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
583   writer.append(OpCode::Finalize);
584 }
585 void Generator::generate(pdl_interp::GetAttributeOp op,
586                          ByteCodeWriter &writer) {
587   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
588                 Identifier::get(op.name(), ctx));
589 }
590 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
591                          ByteCodeWriter &writer) {
592   writer.append(OpCode::GetAttributeType, op.result(), op.value());
593 }
594 void Generator::generate(pdl_interp::GetDefiningOpOp op,
595                          ByteCodeWriter &writer) {
596   writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
597 }
598 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
599   uint32_t index = op.index();
600   if (index < 4)
601     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
602   else
603     writer.append(OpCode::GetOperandN, index);
604   writer.append(op.operation(), op.value());
605 }
606 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
607   uint32_t index = op.index();
608   if (index < 4)
609     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
610   else
611     writer.append(OpCode::GetResultN, index);
612   writer.append(op.operation(), op.value());
613 }
614 void Generator::generate(pdl_interp::GetValueTypeOp op,
615                          ByteCodeWriter &writer) {
616   writer.append(OpCode::GetValueType, op.result(), op.value());
617 }
618 void Generator::generate(pdl_interp::InferredTypeOp op,
619                          ByteCodeWriter &writer) {
620   // InferType maps to a null type as a marker for inferring a result type.
621   getMemIndex(op.type()) = getMemIndex(Type());
622 }
623 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
624   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
625 }
626 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
627   ByteCodeField patternIndex = patterns.size();
628   patterns.emplace_back(PDLByteCodePattern::create(
629       op, rewriterToAddr[op.rewriter().getLeafReference()]));
630   writer.append(OpCode::RecordMatch, patternIndex,
631                 SuccessorRange(op.getOperation()), op.matchedOps(),
632                 op.inputs());
633 }
634 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
635   writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
636 }
637 void Generator::generate(pdl_interp::SwitchAttributeOp op,
638                          ByteCodeWriter &writer) {
639   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
640                 op.getSuccessors());
641 }
642 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
643                          ByteCodeWriter &writer) {
644   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
645                 op.getSuccessors());
646 }
647 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
648                          ByteCodeWriter &writer) {
649   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
650     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
651   });
652   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
653                 op.getSuccessors());
654 }
655 void Generator::generate(pdl_interp::SwitchResultCountOp op,
656                          ByteCodeWriter &writer) {
657   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
658                 op.getSuccessors());
659 }
660 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
661   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
662                 op.getSuccessors());
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // PDLByteCode
667 //===----------------------------------------------------------------------===//
668 
669 PDLByteCode::PDLByteCode(ModuleOp module,
670                          llvm::StringMap<PDLConstraintFunction> constraintFns,
671                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
672   Generator generator(module.getContext(), uniquedData, matcherByteCode,
673                       rewriterByteCode, patterns, maxValueMemoryIndex,
674                       constraintFns, rewriteFns);
675   generator.generate(module);
676 
677   // Initialize the external functions.
678   for (auto &it : constraintFns)
679     constraintFunctions.push_back(std::move(it.second));
680   for (auto &it : rewriteFns)
681     rewriteFunctions.push_back(std::move(it.second));
682 }
683 
684 /// Initialize the given state such that it can be used to execute the current
685 /// bytecode.
686 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
687   state.memory.resize(maxValueMemoryIndex, nullptr);
688   state.currentPatternBenefits.reserve(patterns.size());
689   for (const PDLByteCodePattern &pattern : patterns)
690     state.currentPatternBenefits.push_back(pattern.getBenefit());
691 }
692 
693 //===----------------------------------------------------------------------===//
694 // ByteCode Execution
695 
696 namespace {
697 /// This class provides support for executing a bytecode stream.
698 class ByteCodeExecutor {
699 public:
700   ByteCodeExecutor(const ByteCodeField *curCodeIt,
701                    MutableArrayRef<const void *> memory,
702                    ArrayRef<const void *> uniquedMemory,
703                    ArrayRef<ByteCodeField> code,
704                    ArrayRef<PatternBenefit> currentPatternBenefits,
705                    ArrayRef<PDLByteCodePattern> patterns,
706                    ArrayRef<PDLConstraintFunction> constraintFunctions,
707                    ArrayRef<PDLRewriteFunction> rewriteFunctions)
708       : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
709         code(code), currentPatternBenefits(currentPatternBenefits),
710         patterns(patterns), constraintFunctions(constraintFunctions),
711         rewriteFunctions(rewriteFunctions) {}
712 
713   /// Start executing the code at the current bytecode index. `matches` is an
714   /// optional field provided when this function is executed in a matching
715   /// context.
716   void execute(PatternRewriter &rewriter,
717                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
718                Optional<Location> mainRewriteLoc = {});
719 
720 private:
721   /// Internal implementation of executing each of the bytecode commands.
722   void executeApplyConstraint(PatternRewriter &rewriter);
723   void executeApplyRewrite(PatternRewriter &rewriter);
724   void executeAreEqual();
725   void executeBranch();
726   void executeCheckOperandCount();
727   void executeCheckOperationName();
728   void executeCheckResultCount();
729   void executeCreateOperation(PatternRewriter &rewriter,
730                               Location mainRewriteLoc);
731   void executeEraseOp(PatternRewriter &rewriter);
732   void executeGetAttribute();
733   void executeGetAttributeType();
734   void executeGetDefiningOp();
735   void executeGetOperand(unsigned index);
736   void executeGetResult(unsigned index);
737   void executeGetValueType();
738   void executeIsNotNull();
739   void executeRecordMatch(PatternRewriter &rewriter,
740                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
741   void executeReplaceOp(PatternRewriter &rewriter);
742   void executeSwitchAttribute();
743   void executeSwitchOperandCount();
744   void executeSwitchOperationName();
745   void executeSwitchResultCount();
746   void executeSwitchType();
747 
748   /// Read a value from the bytecode buffer, optionally skipping a certain
749   /// number of prefix values. These methods always update the buffer to point
750   /// to the next field after the read data.
751   template <typename T = ByteCodeField>
752   T read(size_t skipN = 0) {
753     curCodeIt += skipN;
754     return readImpl<T>();
755   }
756   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
757 
758   /// Read a list of values from the bytecode buffer.
759   template <typename ValueT, typename T>
760   void readList(SmallVectorImpl<T> &list) {
761     list.clear();
762     for (unsigned i = 0, e = read(); i != e; ++i)
763       list.push_back(read<ValueT>());
764   }
765 
766   /// Jump to a specific successor based on a predicate value.
767   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
768   /// Jump to a specific successor based on a destination index.
769   void selectJump(size_t destIndex) {
770     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
771   }
772 
773   /// Handle a switch operation with the provided value and cases.
774   template <typename T, typename RangeT>
775   void handleSwitch(const T &value, RangeT &&cases) {
776     LLVM_DEBUG({
777       llvm::dbgs() << "  * Value: " << value << "\n"
778                    << "  * Cases: ";
779       llvm::interleaveComma(cases, llvm::dbgs());
780       llvm::dbgs() << "\n";
781     });
782 
783     // Check to see if the attribute value is within the case list. Jump to
784     // the correct successor index based on the result.
785     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
786       if (*it == value)
787         return selectJump(size_t((it - cases.begin()) + 1));
788     selectJump(size_t(0));
789   }
790 
791   /// Internal implementation of reading various data types from the bytecode
792   /// stream.
793   template <typename T>
794   const void *readFromMemory() {
795     size_t index = *curCodeIt++;
796 
797     // If this type is an SSA value, it can only be stored in non-const memory.
798     if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
799       return memory[index];
800 
801     // Otherwise, if this index is not inbounds it is uniqued.
802     return uniquedMemory[index - memory.size()];
803   }
804   template <typename T>
805   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
806     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
807   }
808   template <typename T>
809   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
810                    T>
811   readImpl() {
812     return T(T::getFromOpaquePointer(readFromMemory<T>()));
813   }
814   template <typename T>
815   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
816     switch (static_cast<PDLValueKind>(read())) {
817     case PDLValueKind::Attribute:
818       return read<Attribute>();
819     case PDLValueKind::Operation:
820       return read<Operation *>();
821     case PDLValueKind::Type:
822       return read<Type>();
823     case PDLValueKind::Value:
824       return read<Value>();
825     }
826     llvm_unreachable("unhandled PDLValueKind");
827   }
828   template <typename T>
829   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
830     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
831                   "unexpected ByteCode address size");
832     ByteCodeAddr result;
833     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
834     curCodeIt += 2;
835     return result;
836   }
837   template <typename T>
838   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
839     return *curCodeIt++;
840   }
841 
842   /// The underlying bytecode buffer.
843   const ByteCodeField *curCodeIt;
844 
845   /// The current execution memory.
846   MutableArrayRef<const void *> memory;
847 
848   /// References to ByteCode data necessary for execution.
849   ArrayRef<const void *> uniquedMemory;
850   ArrayRef<ByteCodeField> code;
851   ArrayRef<PatternBenefit> currentPatternBenefits;
852   ArrayRef<PDLByteCodePattern> patterns;
853   ArrayRef<PDLConstraintFunction> constraintFunctions;
854   ArrayRef<PDLRewriteFunction> rewriteFunctions;
855 };
856 
857 /// This class is an instantiation of the PDLResultList that provides access to
858 /// the returned results. This API is not on `PDLResultList` to avoid
859 /// overexposing access to information specific solely to the ByteCode.
860 class ByteCodeRewriteResultList : public PDLResultList {
861 public:
862   /// Return the list of PDL results.
863   MutableArrayRef<PDLValue> getResults() { return results; }
864 };
865 } // end anonymous namespace
866 
867 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
868   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
869   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
870   ArrayAttr constParams = read<ArrayAttr>();
871   SmallVector<PDLValue, 16> args;
872   readList<PDLValue>(args);
873 
874   LLVM_DEBUG({
875     llvm::dbgs() << "  * Arguments: ";
876     llvm::interleaveComma(args, llvm::dbgs());
877     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
878   });
879 
880   // Invoke the constraint and jump to the proper destination.
881   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
882 }
883 
884 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
885   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
886   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
887   ArrayAttr constParams = read<ArrayAttr>();
888   SmallVector<PDLValue, 16> args;
889   readList<PDLValue>(args);
890 
891   LLVM_DEBUG({
892     llvm::dbgs() << "  * Arguments: ";
893     llvm::interleaveComma(args, llvm::dbgs());
894     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
895   });
896   ByteCodeRewriteResultList results;
897   rewriteFn(args, constParams, rewriter, results);
898 
899   // Store the results in the bytecode memory.
900 #ifndef NDEBUG
901   ByteCodeField expectedNumberOfResults = read();
902   assert(results.getResults().size() == expectedNumberOfResults &&
903          "native PDL rewrite function returned unexpected number of results");
904 #endif
905 
906   // Store the results in the bytecode memory.
907   for (PDLValue &result : results.getResults()) {
908     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
909     memory[read()] = result.getAsOpaquePointer();
910   }
911 }
912 
913 void ByteCodeExecutor::executeAreEqual() {
914   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
915   const void *lhs = read<const void *>();
916   const void *rhs = read<const void *>();
917 
918   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
919   selectJump(lhs == rhs);
920 }
921 
922 void ByteCodeExecutor::executeBranch() {
923   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
924   curCodeIt = &code[read<ByteCodeAddr>()];
925 }
926 
927 void ByteCodeExecutor::executeCheckOperandCount() {
928   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
929   Operation *op = read<Operation *>();
930   uint32_t expectedCount = read<uint32_t>();
931 
932   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
933                           << "  * Expected: " << expectedCount << "\n");
934   selectJump(op->getNumOperands() == expectedCount);
935 }
936 
937 void ByteCodeExecutor::executeCheckOperationName() {
938   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
939   Operation *op = read<Operation *>();
940   OperationName expectedName = read<OperationName>();
941 
942   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
943                           << "  * Expected: \"" << expectedName << "\"\n");
944   selectJump(op->getName() == expectedName);
945 }
946 
947 void ByteCodeExecutor::executeCheckResultCount() {
948   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
949   Operation *op = read<Operation *>();
950   uint32_t expectedCount = read<uint32_t>();
951 
952   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
953                           << "  * Expected: " << expectedCount << "\n");
954   selectJump(op->getNumResults() == expectedCount);
955 }
956 
957 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
958                                               Location mainRewriteLoc) {
959   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
960 
961   unsigned memIndex = read();
962   OperationState state(mainRewriteLoc, read<OperationName>());
963   readList<Value>(state.operands);
964   for (unsigned i = 0, e = read(); i != e; ++i) {
965     Identifier name = read<Identifier>();
966     if (Attribute attr = read<Attribute>())
967       state.addAttribute(name, attr);
968   }
969 
970   bool hasInferredTypes = false;
971   for (unsigned i = 0, e = read(); i != e; ++i) {
972     Type resultType = read<Type>();
973     hasInferredTypes |= !resultType;
974     state.types.push_back(resultType);
975   }
976 
977   // Handle the case where the operation has inferred types.
978   if (hasInferredTypes) {
979     InferTypeOpInterface::Concept *concept =
980         state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
981 
982     // TODO: Handle failure.
983     SmallVector<Type, 2> inferredTypes;
984     if (failed(concept->inferReturnTypes(
985             state.getContext(), state.location, state.operands,
986             state.attributes.getDictionary(state.getContext()), state.regions,
987             inferredTypes)))
988       return;
989 
990     for (unsigned i = 0, e = state.types.size(); i != e; ++i)
991       if (!state.types[i])
992         state.types[i] = inferredTypes[i];
993   }
994   Operation *resultOp = rewriter.createOperation(state);
995   memory[memIndex] = resultOp;
996 
997   LLVM_DEBUG({
998     llvm::dbgs() << "  * Attributes: "
999                  << state.attributes.getDictionary(state.getContext())
1000                  << "\n  * Operands: ";
1001     llvm::interleaveComma(state.operands, llvm::dbgs());
1002     llvm::dbgs() << "\n  * Result Types: ";
1003     llvm::interleaveComma(state.types, llvm::dbgs());
1004     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1005   });
1006 }
1007 
1008 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1009   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1010   Operation *op = read<Operation *>();
1011 
1012   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1013   rewriter.eraseOp(op);
1014 }
1015 
1016 void ByteCodeExecutor::executeGetAttribute() {
1017   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1018   unsigned memIndex = read();
1019   Operation *op = read<Operation *>();
1020   Identifier attrName = read<Identifier>();
1021   Attribute attr = op->getAttr(attrName);
1022 
1023   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1024                           << "  * Attribute: " << attrName << "\n"
1025                           << "  * Result: " << attr << "\n");
1026   memory[memIndex] = attr.getAsOpaquePointer();
1027 }
1028 
1029 void ByteCodeExecutor::executeGetAttributeType() {
1030   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1031   unsigned memIndex = read();
1032   Attribute attr = read<Attribute>();
1033   Type type = attr ? attr.getType() : Type();
1034 
1035   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1036                           << "  * Result: " << type << "\n");
1037   memory[memIndex] = type.getAsOpaquePointer();
1038 }
1039 
1040 void ByteCodeExecutor::executeGetDefiningOp() {
1041   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1042   unsigned memIndex = read();
1043   Value value = read<Value>();
1044   Operation *op = value ? value.getDefiningOp() : nullptr;
1045 
1046   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1047                           << "  * Result: " << *op << "\n");
1048   memory[memIndex] = op;
1049 }
1050 
1051 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1052   Operation *op = read<Operation *>();
1053   unsigned memIndex = read();
1054   Value operand =
1055       index < op->getNumOperands() ? op->getOperand(index) : Value();
1056 
1057   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1058                           << "  * Index: " << index << "\n"
1059                           << "  * Result: " << operand << "\n");
1060   memory[memIndex] = operand.getAsOpaquePointer();
1061 }
1062 
1063 void ByteCodeExecutor::executeGetResult(unsigned index) {
1064   Operation *op = read<Operation *>();
1065   unsigned memIndex = read();
1066   OpResult result =
1067       index < op->getNumResults() ? op->getResult(index) : OpResult();
1068 
1069   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1070                           << "  * Index: " << index << "\n"
1071                           << "  * Result: " << result << "\n");
1072   memory[memIndex] = result.getAsOpaquePointer();
1073 }
1074 
1075 void ByteCodeExecutor::executeGetValueType() {
1076   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1077   unsigned memIndex = read();
1078   Value value = read<Value>();
1079   Type type = value ? value.getType() : Type();
1080 
1081   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1082                           << "  * Result: " << type << "\n");
1083   memory[memIndex] = type.getAsOpaquePointer();
1084 }
1085 
1086 void ByteCodeExecutor::executeIsNotNull() {
1087   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1088   const void *value = read<const void *>();
1089 
1090   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1091   selectJump(value != nullptr);
1092 }
1093 
1094 void ByteCodeExecutor::executeRecordMatch(
1095     PatternRewriter &rewriter,
1096     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1097   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1098   unsigned patternIndex = read();
1099   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1100   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1101 
1102   // If the benefit of the pattern is impossible, skip the processing of the
1103   // rest of the pattern.
1104   if (benefit.isImpossibleToMatch()) {
1105     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1106     curCodeIt = dest;
1107     return;
1108   }
1109 
1110   // Create a fused location containing the locations of each of the
1111   // operations used in the match. This will be used as the location for
1112   // created operations during the rewrite that don't already have an
1113   // explicit location set.
1114   unsigned numMatchLocs = read();
1115   SmallVector<Location, 4> matchLocs;
1116   matchLocs.reserve(numMatchLocs);
1117   for (unsigned i = 0; i != numMatchLocs; ++i)
1118     matchLocs.push_back(read<Operation *>()->getLoc());
1119   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1120 
1121   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1122                           << "  * Location: " << matchLoc << "\n");
1123   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1124   readList<const void *>(matches.back().values);
1125   curCodeIt = dest;
1126 }
1127 
1128 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1129   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1130   Operation *op = read<Operation *>();
1131   SmallVector<Value, 16> args;
1132   readList<Value>(args);
1133 
1134   LLVM_DEBUG({
1135     llvm::dbgs() << "  * Operation: " << *op << "\n"
1136                  << "  * Values: ";
1137     llvm::interleaveComma(args, llvm::dbgs());
1138     llvm::dbgs() << "\n";
1139   });
1140   rewriter.replaceOp(op, args);
1141 }
1142 
1143 void ByteCodeExecutor::executeSwitchAttribute() {
1144   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1145   Attribute value = read<Attribute>();
1146   ArrayAttr cases = read<ArrayAttr>();
1147   handleSwitch(value, cases);
1148 }
1149 
1150 void ByteCodeExecutor::executeSwitchOperandCount() {
1151   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1152   Operation *op = read<Operation *>();
1153   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1154 
1155   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1156   handleSwitch(op->getNumOperands(), cases);
1157 }
1158 
1159 void ByteCodeExecutor::executeSwitchOperationName() {
1160   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1161   OperationName value = read<Operation *>()->getName();
1162   size_t caseCount = read();
1163 
1164   // The operation names are stored in-line, so to print them out for
1165   // debugging purposes we need to read the array before executing the
1166   // switch so that we can display all of the possible values.
1167   LLVM_DEBUG({
1168     const ByteCodeField *prevCodeIt = curCodeIt;
1169     llvm::dbgs() << "  * Value: " << value << "\n"
1170                  << "  * Cases: ";
1171     llvm::interleaveComma(
1172         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1173                         [&](size_t) { return read<OperationName>(); }),
1174         llvm::dbgs());
1175     llvm::dbgs() << "\n";
1176     curCodeIt = prevCodeIt;
1177   });
1178 
1179   // Try to find the switch value within any of the cases.
1180   for (size_t i = 0; i != caseCount; ++i) {
1181     if (read<OperationName>() == value) {
1182       curCodeIt += (caseCount - i - 1);
1183       return selectJump(i + 1);
1184     }
1185   }
1186   selectJump(size_t(0));
1187 }
1188 
1189 void ByteCodeExecutor::executeSwitchResultCount() {
1190   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1191   Operation *op = read<Operation *>();
1192   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1193 
1194   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1195   handleSwitch(op->getNumResults(), cases);
1196 }
1197 
1198 void ByteCodeExecutor::executeSwitchType() {
1199   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1200   Type value = read<Type>();
1201   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1202   handleSwitch(value, cases);
1203 }
1204 
1205 void ByteCodeExecutor::execute(
1206     PatternRewriter &rewriter,
1207     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1208     Optional<Location> mainRewriteLoc) {
1209   while (true) {
1210     OpCode opCode = static_cast<OpCode>(read());
1211     switch (opCode) {
1212     case ApplyConstraint:
1213       executeApplyConstraint(rewriter);
1214       break;
1215     case ApplyRewrite:
1216       executeApplyRewrite(rewriter);
1217       break;
1218     case AreEqual:
1219       executeAreEqual();
1220       break;
1221     case Branch:
1222       executeBranch();
1223       break;
1224     case CheckOperandCount:
1225       executeCheckOperandCount();
1226       break;
1227     case CheckOperationName:
1228       executeCheckOperationName();
1229       break;
1230     case CheckResultCount:
1231       executeCheckResultCount();
1232       break;
1233     case CreateOperation:
1234       executeCreateOperation(rewriter, *mainRewriteLoc);
1235       break;
1236     case EraseOp:
1237       executeEraseOp(rewriter);
1238       break;
1239     case Finalize:
1240       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1241       return;
1242     case GetAttribute:
1243       executeGetAttribute();
1244       break;
1245     case GetAttributeType:
1246       executeGetAttributeType();
1247       break;
1248     case GetDefiningOp:
1249       executeGetDefiningOp();
1250       break;
1251     case GetOperand0:
1252     case GetOperand1:
1253     case GetOperand2:
1254     case GetOperand3: {
1255       unsigned index = opCode - GetOperand0;
1256       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
1257       executeGetOperand(index);
1258       break;
1259     }
1260     case GetOperandN:
1261       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
1262       executeGetOperand(read<uint32_t>());
1263       break;
1264     case GetResult0:
1265     case GetResult1:
1266     case GetResult2:
1267     case GetResult3: {
1268       unsigned index = opCode - GetResult0;
1269       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
1270       executeGetResult(index);
1271       break;
1272     }
1273     case GetResultN:
1274       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
1275       executeGetResult(read<uint32_t>());
1276       break;
1277     case GetValueType:
1278       executeGetValueType();
1279       break;
1280     case IsNotNull:
1281       executeIsNotNull();
1282       break;
1283     case RecordMatch:
1284       assert(matches &&
1285              "expected matches to be provided when executing the matcher");
1286       executeRecordMatch(rewriter, *matches);
1287       break;
1288     case ReplaceOp:
1289       executeReplaceOp(rewriter);
1290       break;
1291     case SwitchAttribute:
1292       executeSwitchAttribute();
1293       break;
1294     case SwitchOperandCount:
1295       executeSwitchOperandCount();
1296       break;
1297     case SwitchOperationName:
1298       executeSwitchOperationName();
1299       break;
1300     case SwitchResultCount:
1301       executeSwitchResultCount();
1302       break;
1303     case SwitchType:
1304       executeSwitchType();
1305       break;
1306     }
1307     LLVM_DEBUG(llvm::dbgs() << "\n");
1308   }
1309 }
1310 
1311 /// Run the pattern matcher on the given root operation, collecting the matched
1312 /// patterns in `matches`.
1313 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1314                         SmallVectorImpl<MatchResult> &matches,
1315                         PDLByteCodeMutableState &state) const {
1316   // The first memory slot is always the root operation.
1317   state.memory[0] = op;
1318 
1319   // The matcher function always starts at code address 0.
1320   ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
1321                             matcherByteCode, state.currentPatternBenefits,
1322                             patterns, constraintFunctions, rewriteFunctions);
1323   executor.execute(rewriter, &matches);
1324 
1325   // Order the found matches by benefit.
1326   std::stable_sort(matches.begin(), matches.end(),
1327                    [](const MatchResult &lhs, const MatchResult &rhs) {
1328                      return lhs.benefit > rhs.benefit;
1329                    });
1330 }
1331 
1332 /// Run the rewriter of the given pattern on the root operation `op`.
1333 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1334                           PDLByteCodeMutableState &state) const {
1335   // The arguments of the rewrite function are stored at the start of the
1336   // memory buffer.
1337   llvm::copy(match.values, state.memory.begin());
1338 
1339   ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()],
1340                             state.memory, uniquedData, rewriterByteCode,
1341                             state.currentPatternBenefits, patterns,
1342                             constraintFunctions, rewriteFunctions);
1343   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1344 }
1345