1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "pdl-bytecode"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // PDLByteCodePattern
31 //===----------------------------------------------------------------------===//
32 
33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
34                                               ByteCodeAddr rewriterAddr) {
35   SmallVector<StringRef, 8> generatedOps;
36   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
37     generatedOps =
38         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
39 
40   PatternBenefit benefit = matchOp.benefit();
41   MLIRContext *ctx = matchOp.getContext();
42 
43   // Check to see if this is pattern matches a specific operation type.
44   if (Optional<StringRef> rootKind = matchOp.rootKind())
45     return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
46                               ctx);
47   return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
48                             MatchAnyOpTypeTag());
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // PDLByteCodeMutableState
53 //===----------------------------------------------------------------------===//
54 
55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
56 /// to the position of the pattern within the range returned by
57 /// `PDLByteCode::getPatterns`.
58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
59                                                    PatternBenefit benefit) {
60   currentPatternBenefits[patternIndex] = benefit;
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Bytecode OpCodes
65 //===----------------------------------------------------------------------===//
66 
67 namespace {
68 enum OpCode : ByteCodeField {
69   /// Apply an externally registered constraint.
70   ApplyConstraint,
71   /// Apply an externally registered rewrite.
72   ApplyRewrite,
73   /// Check if two generic values are equal.
74   AreEqual,
75   /// Unconditional branch.
76   Branch,
77   /// Compare the operand count of an operation with a constant.
78   CheckOperandCount,
79   /// Compare the name of an operation with a constant.
80   CheckOperationName,
81   /// Compare the result count of an operation with a constant.
82   CheckResultCount,
83   /// Invoke a native creation method.
84   CreateNative,
85   /// Create an operation.
86   CreateOperation,
87   /// Erase an operation.
88   EraseOp,
89   /// Terminate a matcher or rewrite sequence.
90   Finalize,
91   /// Get a specific attribute of an operation.
92   GetAttribute,
93   /// Get the type of an attribute.
94   GetAttributeType,
95   /// Get the defining operation of a value.
96   GetDefiningOp,
97   /// Get a specific operand of an operation.
98   GetOperand0,
99   GetOperand1,
100   GetOperand2,
101   GetOperand3,
102   GetOperandN,
103   /// Get a specific result of an operation.
104   GetResult0,
105   GetResult1,
106   GetResult2,
107   GetResult3,
108   GetResultN,
109   /// Get the type of a value.
110   GetValueType,
111   /// Check if a generic value is not null.
112   IsNotNull,
113   /// Record a successful pattern match.
114   RecordMatch,
115   /// Replace an operation.
116   ReplaceOp,
117   /// Compare an attribute with a set of constants.
118   SwitchAttribute,
119   /// Compare the operand count of an operation with a set of constants.
120   SwitchOperandCount,
121   /// Compare the name of an operation with a set of constants.
122   SwitchOperationName,
123   /// Compare the result count of an operation with a set of constants.
124   SwitchResultCount,
125   /// Compare a type with a set of constants.
126   SwitchType,
127 };
128 
129 enum class PDLValueKind { Attribute, Operation, Type, Value };
130 } // end anonymous namespace
131 
132 //===----------------------------------------------------------------------===//
133 // ByteCode Generation
134 //===----------------------------------------------------------------------===//
135 
136 //===----------------------------------------------------------------------===//
137 // Generator
138 
139 namespace {
140 struct ByteCodeWriter;
141 
142 /// This class represents the main generator for the pattern bytecode.
143 class Generator {
144 public:
145   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
146             SmallVectorImpl<ByteCodeField> &matcherByteCode,
147             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
148             SmallVectorImpl<PDLByteCodePattern> &patterns,
149             ByteCodeField &maxValueMemoryIndex,
150             llvm::StringMap<PDLConstraintFunction> &constraintFns,
151             llvm::StringMap<PDLCreateFunction> &createFns,
152             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
153       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
154         rewriterByteCode(rewriterByteCode), patterns(patterns),
155         maxValueMemoryIndex(maxValueMemoryIndex) {
156     for (auto it : llvm::enumerate(constraintFns))
157       constraintToMemIndex.try_emplace(it.value().first(), it.index());
158     for (auto it : llvm::enumerate(createFns))
159       nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
160     for (auto it : llvm::enumerate(rewriteFns))
161       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
162   }
163 
164   /// Generate the bytecode for the given PDL interpreter module.
165   void generate(ModuleOp module);
166 
167   /// Return the memory index to use for the given value.
168   ByteCodeField &getMemIndex(Value value) {
169     assert(valueToMemIndex.count(value) &&
170            "expected memory index to be assigned");
171     return valueToMemIndex[value];
172   }
173 
174   /// Return an index to use when referring to the given data that is uniqued in
175   /// the MLIR context.
176   template <typename T>
177   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
178   getMemIndex(T val) {
179     const void *opaqueVal = val.getAsOpaquePointer();
180 
181     // Get or insert a reference to this value.
182     auto it = uniquedDataToMemIndex.try_emplace(
183         opaqueVal, maxValueMemoryIndex + uniquedData.size());
184     if (it.second)
185       uniquedData.push_back(opaqueVal);
186     return it.first->second;
187   }
188 
189 private:
190   /// Allocate memory indices for the results of operations within the matcher
191   /// and rewriters.
192   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
193 
194   /// Generate the bytecode for the given operation.
195   void generate(Operation *op, ByteCodeWriter &writer);
196   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
197   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
198   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
199   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
200   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
201   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
202   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
203   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
204   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
205   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
206   void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
207   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
208   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
209   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
210   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
211   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
212   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
213   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
214   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
215   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
216   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
217   void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
218   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
219   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
220   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
221   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
222   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
223   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
224   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
225   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
226 
227   /// Mapping from value to its corresponding memory index.
228   DenseMap<Value, ByteCodeField> valueToMemIndex;
229 
230   /// Mapping from the name of an externally registered rewrite to its index in
231   /// the bytecode registry.
232   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
233 
234   /// Mapping from the name of an externally registered constraint to its index
235   /// in the bytecode registry.
236   llvm::StringMap<ByteCodeField> constraintToMemIndex;
237 
238   /// Mapping from the name of an externally registered creation method to its
239   /// index in the bytecode registry.
240   llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
241 
242   /// Mapping from rewriter function name to the bytecode address of the
243   /// rewriter function in byte.
244   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
245 
246   /// Mapping from a uniqued storage object to its memory index within
247   /// `uniquedData`.
248   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
249 
250   /// The current MLIR context.
251   MLIRContext *ctx;
252 
253   /// Data of the ByteCode class to be populated.
254   std::vector<const void *> &uniquedData;
255   SmallVectorImpl<ByteCodeField> &matcherByteCode;
256   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
257   SmallVectorImpl<PDLByteCodePattern> &patterns;
258   ByteCodeField &maxValueMemoryIndex;
259 };
260 
261 /// This class provides utilities for writing a bytecode stream.
262 struct ByteCodeWriter {
263   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
264       : bytecode(bytecode), generator(generator) {}
265 
266   /// Append a field to the bytecode.
267   void append(ByteCodeField field) { bytecode.push_back(field); }
268   void append(OpCode opCode) { bytecode.push_back(opCode); }
269 
270   /// Append an address to the bytecode.
271   void append(ByteCodeAddr field) {
272     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
273                   "unexpected ByteCode address size");
274 
275     ByteCodeField fieldParts[2];
276     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
277     bytecode.append({fieldParts[0], fieldParts[1]});
278   }
279 
280   /// Append a successor range to the bytecode, the exact address will need to
281   /// be resolved later.
282   void append(SuccessorRange successors) {
283     // Add back references to the any successors so that the address can be
284     // resolved later.
285     for (Block *successor : successors) {
286       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
287       append(ByteCodeAddr(0));
288     }
289   }
290 
291   /// Append a range of values that will be read as generic PDLValues.
292   void appendPDLValueList(OperandRange values) {
293     bytecode.push_back(values.size());
294     for (Value value : values) {
295       // Append the type of the value in addition to the value itself.
296       PDLValueKind kind =
297           TypeSwitch<Type, PDLValueKind>(value.getType())
298               .Case<pdl::AttributeType>(
299                   [](Type) { return PDLValueKind::Attribute; })
300               .Case<pdl::OperationType>(
301                   [](Type) { return PDLValueKind::Operation; })
302               .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
303               .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
304       bytecode.push_back(static_cast<ByteCodeField>(kind));
305       append(value);
306     }
307   }
308 
309   /// Check if the given class `T` has an iterator type.
310   template <typename T, typename... Args>
311   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
312 
313   /// Append a value that will be stored in a memory slot and not inline within
314   /// the bytecode.
315   template <typename T>
316   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
317                    std::is_pointer<T>::value>
318   append(T value) {
319     bytecode.push_back(generator.getMemIndex(value));
320   }
321 
322   /// Append a range of values.
323   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
324   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
325   append(T range) {
326     bytecode.push_back(llvm::size(range));
327     for (auto it : range)
328       append(it);
329   }
330 
331   /// Append a variadic number of fields to the bytecode.
332   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
333   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
334     append(field);
335     append(field2, fields...);
336   }
337 
338   /// Successor references in the bytecode that have yet to be resolved.
339   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
340 
341   /// The underlying bytecode buffer.
342   SmallVectorImpl<ByteCodeField> &bytecode;
343 
344   /// The main generator producing PDL.
345   Generator &generator;
346 };
347 } // end anonymous namespace
348 
349 void Generator::generate(ModuleOp module) {
350   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
351       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
352   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
353       pdl_interp::PDLInterpDialect::getRewriterModuleName());
354   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
355 
356   // Allocate memory indices for the results of operations within the matcher
357   // and rewriters.
358   allocateMemoryIndices(matcherFunc, rewriterModule);
359 
360   // Generate code for the rewriter functions.
361   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
362   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
363     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
364     for (Operation &op : rewriterFunc.getOps())
365       generate(&op, rewriterByteCodeWriter);
366   }
367   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
368          "unexpected branches in rewriter function");
369 
370   // Generate code for the matcher function.
371   DenseMap<Block *, ByteCodeAddr> blockToAddr;
372   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
373   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
374   for (Block *block : rpot) {
375     // Keep track of where this block begins within the matcher function.
376     blockToAddr.try_emplace(block, matcherByteCode.size());
377     for (Operation &op : *block)
378       generate(&op, matcherByteCodeWriter);
379   }
380 
381   // Resolve successor references in the matcher.
382   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
383     ByteCodeAddr addr = blockToAddr[it.first];
384     for (unsigned offsetToFix : it.second)
385       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
386   }
387 }
388 
389 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
390                                       ModuleOp rewriterModule) {
391   // Rewriters use simplistic allocation scheme that simply assigns an index to
392   // each result.
393   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
394     ByteCodeField index = 0;
395     for (BlockArgument arg : rewriterFunc.getArguments())
396       valueToMemIndex.try_emplace(arg, index++);
397     rewriterFunc.getBody().walk([&](Operation *op) {
398       for (Value result : op->getResults())
399         valueToMemIndex.try_emplace(result, index++);
400     });
401     if (index > maxValueMemoryIndex)
402       maxValueMemoryIndex = index;
403   }
404 
405   // The matcher function uses a more sophisticated numbering that tries to
406   // minimize the number of memory indices assigned. This is done by determining
407   // a live range of the values within the matcher, then the allocation is just
408   // finding the minimal number of overlapping live ranges. This is essentially
409   // a simplified form of register allocation where we don't necessarily have a
410   // limited number of registers, but we still want to minimize the number used.
411   DenseMap<Operation *, ByteCodeField> opToIndex;
412   matcherFunc.getBody().walk([&](Operation *op) {
413     opToIndex.insert(std::make_pair(op, opToIndex.size()));
414   });
415 
416   // Liveness info for each of the defs within the matcher.
417   using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
418   LivenessSet::Allocator allocator;
419   DenseMap<Value, LivenessSet> valueDefRanges;
420 
421   // Assign the root operation being matched to slot 0.
422   BlockArgument rootOpArg = matcherFunc.getArgument(0);
423   valueToMemIndex[rootOpArg] = 0;
424 
425   // Walk each of the blocks, computing the def interval that the value is used.
426   Liveness matcherLiveness(matcherFunc);
427   for (Block &block : matcherFunc.getBody()) {
428     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
429     assert(info && "expected liveness info for block");
430     auto processValue = [&](Value value, Operation *firstUseOrDef) {
431       // We don't need to process the root op argument, this value is always
432       // assigned to the first memory slot.
433       if (value == rootOpArg)
434         return;
435 
436       // Set indices for the range of this block that the value is used.
437       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
438       defRangeIt->second.insert(
439           opToIndex[firstUseOrDef],
440           opToIndex[info->getEndOperation(value, firstUseOrDef)],
441           /*dummyValue*/ 0);
442     };
443 
444     // Process the live-ins of this block.
445     for (Value liveIn : info->in())
446       processValue(liveIn, &block.front());
447 
448     // Process any new defs within this block.
449     for (Operation &op : block)
450       for (Value result : op.getResults())
451         processValue(result, &op);
452   }
453 
454   // Greedily allocate memory slots using the computed def live ranges.
455   std::vector<LivenessSet> allocatedIndices;
456   for (auto &defIt : valueDefRanges) {
457     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
458     LivenessSet &defSet = defIt.second;
459 
460     // Try to allocate to an existing index.
461     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
462       LivenessSet &existingIndex = existingIndexIt.value();
463       llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
464           defIt.second, existingIndex);
465       if (overlaps.valid())
466         continue;
467       // Union the range of the def within the existing index.
468       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
469         existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
470       memIndex = existingIndexIt.index() + 1;
471     }
472 
473     // If no existing index could be used, add a new one.
474     if (memIndex == 0) {
475       allocatedIndices.emplace_back(allocator);
476       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
477         allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
478       memIndex = allocatedIndices.size();
479     }
480   }
481 
482   // Update the max number of indices.
483   ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
484   if (numMatcherIndices > maxValueMemoryIndex)
485     maxValueMemoryIndex = numMatcherIndices;
486 }
487 
488 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
489   TypeSwitch<Operation *>(op)
490       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
491             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
492             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
493             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
494             pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
495             pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
496             pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
497             pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
498             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
499             pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
500             pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
501             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
502             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
503             pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
504             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
505           [&](auto interpOp) { this->generate(interpOp, writer); })
506       .Default([](Operation *) {
507         llvm_unreachable("unknown `pdl_interp` operation");
508       });
509 }
510 
511 void Generator::generate(pdl_interp::ApplyConstraintOp op,
512                          ByteCodeWriter &writer) {
513   assert(constraintToMemIndex.count(op.name()) &&
514          "expected index for constraint function");
515   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
516                 op.constParamsAttr());
517   writer.appendPDLValueList(op.args());
518   writer.append(op.getSuccessors());
519 }
520 void Generator::generate(pdl_interp::ApplyRewriteOp op,
521                          ByteCodeWriter &writer) {
522   assert(externalRewriterToMemIndex.count(op.name()) &&
523          "expected index for rewrite function");
524   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
525                 op.constParamsAttr(), op.root());
526   writer.appendPDLValueList(op.args());
527 }
528 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
529   writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
530 }
531 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
532   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
533 }
534 void Generator::generate(pdl_interp::CheckAttributeOp op,
535                          ByteCodeWriter &writer) {
536   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
537                 op.getSuccessors());
538 }
539 void Generator::generate(pdl_interp::CheckOperandCountOp op,
540                          ByteCodeWriter &writer) {
541   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
542                 op.getSuccessors());
543 }
544 void Generator::generate(pdl_interp::CheckOperationNameOp op,
545                          ByteCodeWriter &writer) {
546   writer.append(OpCode::CheckOperationName, op.operation(),
547                 OperationName(op.name(), ctx), op.getSuccessors());
548 }
549 void Generator::generate(pdl_interp::CheckResultCountOp op,
550                          ByteCodeWriter &writer) {
551   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
552                 op.getSuccessors());
553 }
554 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
555   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
556 }
557 void Generator::generate(pdl_interp::CreateAttributeOp op,
558                          ByteCodeWriter &writer) {
559   // Simply repoint the memory index of the result to the constant.
560   getMemIndex(op.attribute()) = getMemIndex(op.value());
561 }
562 void Generator::generate(pdl_interp::CreateNativeOp op,
563                          ByteCodeWriter &writer) {
564   assert(nativeCreateToMemIndex.count(op.name()) &&
565          "expected index for creation function");
566   writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
567                 op.result(), op.constParamsAttr());
568   writer.appendPDLValueList(op.args());
569 }
570 void Generator::generate(pdl_interp::CreateOperationOp op,
571                          ByteCodeWriter &writer) {
572   writer.append(OpCode::CreateOperation, op.operation(),
573                 OperationName(op.name(), ctx), op.operands());
574 
575   // Add the attributes.
576   OperandRange attributes = op.attributes();
577   writer.append(static_cast<ByteCodeField>(attributes.size()));
578   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
579     writer.append(
580         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
581         std::get<1>(it));
582   }
583   writer.append(op.types());
584 }
585 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
586   // Simply repoint the memory index of the result to the constant.
587   getMemIndex(op.result()) = getMemIndex(op.value());
588 }
589 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
590   writer.append(OpCode::EraseOp, op.operation());
591 }
592 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
593   writer.append(OpCode::Finalize);
594 }
595 void Generator::generate(pdl_interp::GetAttributeOp op,
596                          ByteCodeWriter &writer) {
597   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
598                 Identifier::get(op.name(), ctx));
599 }
600 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
601                          ByteCodeWriter &writer) {
602   writer.append(OpCode::GetAttributeType, op.result(), op.value());
603 }
604 void Generator::generate(pdl_interp::GetDefiningOpOp op,
605                          ByteCodeWriter &writer) {
606   writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
607 }
608 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
609   uint32_t index = op.index();
610   if (index < 4)
611     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
612   else
613     writer.append(OpCode::GetOperandN, index);
614   writer.append(op.operation(), op.value());
615 }
616 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
617   uint32_t index = op.index();
618   if (index < 4)
619     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
620   else
621     writer.append(OpCode::GetResultN, index);
622   writer.append(op.operation(), op.value());
623 }
624 void Generator::generate(pdl_interp::GetValueTypeOp op,
625                          ByteCodeWriter &writer) {
626   writer.append(OpCode::GetValueType, op.result(), op.value());
627 }
628 void Generator::generate(pdl_interp::InferredTypeOp op,
629                          ByteCodeWriter &writer) {
630   // InferType maps to a null type as a marker for inferring a result type.
631   getMemIndex(op.type()) = getMemIndex(Type());
632 }
633 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
634   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
635 }
636 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
637   ByteCodeField patternIndex = patterns.size();
638   patterns.emplace_back(PDLByteCodePattern::create(
639       op, rewriterToAddr[op.rewriter().getLeafReference()]));
640   writer.append(OpCode::RecordMatch, patternIndex,
641                 SuccessorRange(op.getOperation()), op.matchedOps(),
642                 op.inputs());
643 }
644 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
645   writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
646 }
647 void Generator::generate(pdl_interp::SwitchAttributeOp op,
648                          ByteCodeWriter &writer) {
649   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
650                 op.getSuccessors());
651 }
652 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
653                          ByteCodeWriter &writer) {
654   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
655                 op.getSuccessors());
656 }
657 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
658                          ByteCodeWriter &writer) {
659   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
660     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
661   });
662   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
663                 op.getSuccessors());
664 }
665 void Generator::generate(pdl_interp::SwitchResultCountOp op,
666                          ByteCodeWriter &writer) {
667   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
668                 op.getSuccessors());
669 }
670 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
671   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
672                 op.getSuccessors());
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // PDLByteCode
677 //===----------------------------------------------------------------------===//
678 
679 PDLByteCode::PDLByteCode(ModuleOp module,
680                          llvm::StringMap<PDLConstraintFunction> constraintFns,
681                          llvm::StringMap<PDLCreateFunction> createFns,
682                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
683   Generator generator(module.getContext(), uniquedData, matcherByteCode,
684                       rewriterByteCode, patterns, maxValueMemoryIndex,
685                       constraintFns, createFns, rewriteFns);
686   generator.generate(module);
687 
688   // Initialize the external functions.
689   for (auto &it : constraintFns)
690     constraintFunctions.push_back(std::move(it.second));
691   for (auto &it : createFns)
692     createFunctions.push_back(std::move(it.second));
693   for (auto &it : rewriteFns)
694     rewriteFunctions.push_back(std::move(it.second));
695 }
696 
697 /// Initialize the given state such that it can be used to execute the current
698 /// bytecode.
699 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
700   state.memory.resize(maxValueMemoryIndex, nullptr);
701   state.currentPatternBenefits.reserve(patterns.size());
702   for (const PDLByteCodePattern &pattern : patterns)
703     state.currentPatternBenefits.push_back(pattern.getBenefit());
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // ByteCode Execution
708 
709 namespace {
710 /// This class provides support for executing a bytecode stream.
711 class ByteCodeExecutor {
712 public:
713   ByteCodeExecutor(const ByteCodeField *curCodeIt,
714                    MutableArrayRef<const void *> memory,
715                    ArrayRef<const void *> uniquedMemory,
716                    ArrayRef<ByteCodeField> code,
717                    ArrayRef<PatternBenefit> currentPatternBenefits,
718                    ArrayRef<PDLByteCodePattern> patterns,
719                    ArrayRef<PDLConstraintFunction> constraintFunctions,
720                    ArrayRef<PDLCreateFunction> createFunctions,
721                    ArrayRef<PDLRewriteFunction> rewriteFunctions)
722       : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
723         code(code), currentPatternBenefits(currentPatternBenefits),
724         patterns(patterns), constraintFunctions(constraintFunctions),
725         createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
726 
727   /// Start executing the code at the current bytecode index. `matches` is an
728   /// optional field provided when this function is executed in a matching
729   /// context.
730   void execute(PatternRewriter &rewriter,
731                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
732                Optional<Location> mainRewriteLoc = {});
733 
734 private:
735   /// Internal implementation of executing each of the bytecode commands.
736   void executeApplyConstraint(PatternRewriter &rewriter);
737   void executeApplyRewrite(PatternRewriter &rewriter);
738   void executeAreEqual();
739   void executeBranch();
740   void executeCheckOperandCount();
741   void executeCheckOperationName();
742   void executeCheckResultCount();
743   void executeCreateNative(PatternRewriter &rewriter);
744   void executeCreateOperation(PatternRewriter &rewriter,
745                               Location mainRewriteLoc);
746   void executeEraseOp(PatternRewriter &rewriter);
747   void executeGetAttribute();
748   void executeGetAttributeType();
749   void executeGetDefiningOp();
750   void executeGetOperand(unsigned index);
751   void executeGetResult(unsigned index);
752   void executeGetValueType();
753   void executeIsNotNull();
754   void executeRecordMatch(PatternRewriter &rewriter,
755                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
756   void executeReplaceOp(PatternRewriter &rewriter);
757   void executeSwitchAttribute();
758   void executeSwitchOperandCount();
759   void executeSwitchOperationName();
760   void executeSwitchResultCount();
761   void executeSwitchType();
762 
763   /// Read a value from the bytecode buffer, optionally skipping a certain
764   /// number of prefix values. These methods always update the buffer to point
765   /// to the next field after the read data.
766   template <typename T = ByteCodeField>
767   T read(size_t skipN = 0) {
768     curCodeIt += skipN;
769     return readImpl<T>();
770   }
771   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
772 
773   /// Read a list of values from the bytecode buffer.
774   template <typename ValueT, typename T>
775   void readList(SmallVectorImpl<T> &list) {
776     list.clear();
777     for (unsigned i = 0, e = read(); i != e; ++i)
778       list.push_back(read<ValueT>());
779   }
780 
781   /// Jump to a specific successor based on a predicate value.
782   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
783   /// Jump to a specific successor based on a destination index.
784   void selectJump(size_t destIndex) {
785     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
786   }
787 
788   /// Handle a switch operation with the provided value and cases.
789   template <typename T, typename RangeT>
790   void handleSwitch(const T &value, RangeT &&cases) {
791     LLVM_DEBUG({
792       llvm::dbgs() << "  * Value: " << value << "\n"
793                    << "  * Cases: ";
794       llvm::interleaveComma(cases, llvm::dbgs());
795       llvm::dbgs() << "\n";
796     });
797 
798     // Check to see if the attribute value is within the case list. Jump to
799     // the correct successor index based on the result.
800     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
801       if (*it == value)
802         return selectJump(size_t((it - cases.begin()) + 1));
803     selectJump(size_t(0));
804   }
805 
806   /// Internal implementation of reading various data types from the bytecode
807   /// stream.
808   template <typename T>
809   const void *readFromMemory() {
810     size_t index = *curCodeIt++;
811 
812     // If this type is an SSA value, it can only be stored in non-const memory.
813     if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
814       return memory[index];
815 
816     // Otherwise, if this index is not inbounds it is uniqued.
817     return uniquedMemory[index - memory.size()];
818   }
819   template <typename T>
820   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
821     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
822   }
823   template <typename T>
824   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
825                    T>
826   readImpl() {
827     return T(T::getFromOpaquePointer(readFromMemory<T>()));
828   }
829   template <typename T>
830   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
831     switch (static_cast<PDLValueKind>(read())) {
832     case PDLValueKind::Attribute:
833       return read<Attribute>();
834     case PDLValueKind::Operation:
835       return read<Operation *>();
836     case PDLValueKind::Type:
837       return read<Type>();
838     case PDLValueKind::Value:
839       return read<Value>();
840     }
841     llvm_unreachable("unhandled PDLValueKind");
842   }
843   template <typename T>
844   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
845     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
846                   "unexpected ByteCode address size");
847     ByteCodeAddr result;
848     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
849     curCodeIt += 2;
850     return result;
851   }
852   template <typename T>
853   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
854     return *curCodeIt++;
855   }
856 
857   /// The underlying bytecode buffer.
858   const ByteCodeField *curCodeIt;
859 
860   /// The current execution memory.
861   MutableArrayRef<const void *> memory;
862 
863   /// References to ByteCode data necessary for execution.
864   ArrayRef<const void *> uniquedMemory;
865   ArrayRef<ByteCodeField> code;
866   ArrayRef<PatternBenefit> currentPatternBenefits;
867   ArrayRef<PDLByteCodePattern> patterns;
868   ArrayRef<PDLConstraintFunction> constraintFunctions;
869   ArrayRef<PDLCreateFunction> createFunctions;
870   ArrayRef<PDLRewriteFunction> rewriteFunctions;
871 };
872 } // end anonymous namespace
873 
874 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
875   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
876   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
877   ArrayAttr constParams = read<ArrayAttr>();
878   SmallVector<PDLValue, 16> args;
879   readList<PDLValue>(args);
880 
881   LLVM_DEBUG({
882     llvm::dbgs() << "  * Arguments: ";
883     llvm::interleaveComma(args, llvm::dbgs());
884     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
885   });
886 
887   // Invoke the constraint and jump to the proper destination.
888   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
889 }
890 
891 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
892   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
893   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
894   ArrayAttr constParams = read<ArrayAttr>();
895   Operation *root = read<Operation *>();
896   SmallVector<PDLValue, 16> args;
897   readList<PDLValue>(args);
898 
899   LLVM_DEBUG({
900     llvm::dbgs() << "  * Root: " << *root << "\n  * Arguments: ";
901     llvm::interleaveComma(args, llvm::dbgs());
902     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
903   });
904 
905   // Invoke the native rewrite function.
906   rewriteFn(root, args, constParams, rewriter);
907 }
908 
909 void ByteCodeExecutor::executeAreEqual() {
910   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
911   const void *lhs = read<const void *>();
912   const void *rhs = read<const void *>();
913 
914   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
915   selectJump(lhs == rhs);
916 }
917 
918 void ByteCodeExecutor::executeBranch() {
919   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
920   curCodeIt = &code[read<ByteCodeAddr>()];
921 }
922 
923 void ByteCodeExecutor::executeCheckOperandCount() {
924   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
925   Operation *op = read<Operation *>();
926   uint32_t expectedCount = read<uint32_t>();
927 
928   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
929                           << "  * Expected: " << expectedCount << "\n");
930   selectJump(op->getNumOperands() == expectedCount);
931 }
932 
933 void ByteCodeExecutor::executeCheckOperationName() {
934   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
935   Operation *op = read<Operation *>();
936   OperationName expectedName = read<OperationName>();
937 
938   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
939                           << "  * Expected: \"" << expectedName << "\"\n");
940   selectJump(op->getName() == expectedName);
941 }
942 
943 void ByteCodeExecutor::executeCheckResultCount() {
944   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
945   Operation *op = read<Operation *>();
946   uint32_t expectedCount = read<uint32_t>();
947 
948   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
949                           << "  * Expected: " << expectedCount << "\n");
950   selectJump(op->getNumResults() == expectedCount);
951 }
952 
953 void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) {
954   LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
955   const PDLCreateFunction &createFn = createFunctions[read()];
956   ByteCodeField resultIndex = read();
957   ArrayAttr constParams = read<ArrayAttr>();
958   SmallVector<PDLValue, 16> args;
959   readList<PDLValue>(args);
960 
961   LLVM_DEBUG({
962     llvm::dbgs() << "  * Arguments: ";
963     llvm::interleaveComma(args, llvm::dbgs());
964     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
965   });
966 
967   PDLValue result = createFn(args, constParams, rewriter);
968   memory[resultIndex] = result.getAsOpaquePointer();
969 
970   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
971 }
972 
973 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
974                                               Location mainRewriteLoc) {
975   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
976 
977   unsigned memIndex = read();
978   OperationState state(mainRewriteLoc, read<OperationName>());
979   readList<Value>(state.operands);
980   for (unsigned i = 0, e = read(); i != e; ++i) {
981     Identifier name = read<Identifier>();
982     if (Attribute attr = read<Attribute>())
983       state.addAttribute(name, attr);
984   }
985 
986   bool hasInferredTypes = false;
987   for (unsigned i = 0, e = read(); i != e; ++i) {
988     Type resultType = read<Type>();
989     hasInferredTypes |= !resultType;
990     state.types.push_back(resultType);
991   }
992 
993   // Handle the case where the operation has inferred types.
994   if (hasInferredTypes) {
995     InferTypeOpInterface::Concept *concept =
996         state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
997 
998     // TODO: Handle failure.
999     SmallVector<Type, 2> inferredTypes;
1000     if (failed(concept->inferReturnTypes(
1001             state.getContext(), state.location, state.operands,
1002             state.attributes.getDictionary(state.getContext()), state.regions,
1003             inferredTypes)))
1004       return;
1005 
1006     for (unsigned i = 0, e = state.types.size(); i != e; ++i)
1007       if (!state.types[i])
1008         state.types[i] = inferredTypes[i];
1009   }
1010   Operation *resultOp = rewriter.createOperation(state);
1011   memory[memIndex] = resultOp;
1012 
1013   LLVM_DEBUG({
1014     llvm::dbgs() << "  * Attributes: "
1015                  << state.attributes.getDictionary(state.getContext())
1016                  << "\n  * Operands: ";
1017     llvm::interleaveComma(state.operands, llvm::dbgs());
1018     llvm::dbgs() << "\n  * Result Types: ";
1019     llvm::interleaveComma(state.types, llvm::dbgs());
1020     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1021   });
1022 }
1023 
1024 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1025   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1026   Operation *op = read<Operation *>();
1027 
1028   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1029   rewriter.eraseOp(op);
1030 }
1031 
1032 void ByteCodeExecutor::executeGetAttribute() {
1033   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1034   unsigned memIndex = read();
1035   Operation *op = read<Operation *>();
1036   Identifier attrName = read<Identifier>();
1037   Attribute attr = op->getAttr(attrName);
1038 
1039   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1040                           << "  * Attribute: " << attrName << "\n"
1041                           << "  * Result: " << attr << "\n");
1042   memory[memIndex] = attr.getAsOpaquePointer();
1043 }
1044 
1045 void ByteCodeExecutor::executeGetAttributeType() {
1046   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1047   unsigned memIndex = read();
1048   Attribute attr = read<Attribute>();
1049   Type type = attr ? attr.getType() : Type();
1050 
1051   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1052                           << "  * Result: " << type << "\n");
1053   memory[memIndex] = type.getAsOpaquePointer();
1054 }
1055 
1056 void ByteCodeExecutor::executeGetDefiningOp() {
1057   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1058   unsigned memIndex = read();
1059   Value value = read<Value>();
1060   Operation *op = value ? value.getDefiningOp() : nullptr;
1061 
1062   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1063                           << "  * Result: " << *op << "\n");
1064   memory[memIndex] = op;
1065 }
1066 
1067 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1068   Operation *op = read<Operation *>();
1069   unsigned memIndex = read();
1070   Value operand =
1071       index < op->getNumOperands() ? op->getOperand(index) : Value();
1072 
1073   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1074                           << "  * Index: " << index << "\n"
1075                           << "  * Result: " << operand << "\n");
1076   memory[memIndex] = operand.getAsOpaquePointer();
1077 }
1078 
1079 void ByteCodeExecutor::executeGetResult(unsigned index) {
1080   Operation *op = read<Operation *>();
1081   unsigned memIndex = read();
1082   OpResult result =
1083       index < op->getNumResults() ? op->getResult(index) : OpResult();
1084 
1085   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1086                           << "  * Index: " << index << "\n"
1087                           << "  * Result: " << result << "\n");
1088   memory[memIndex] = result.getAsOpaquePointer();
1089 }
1090 
1091 void ByteCodeExecutor::executeGetValueType() {
1092   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1093   unsigned memIndex = read();
1094   Value value = read<Value>();
1095   Type type = value ? value.getType() : Type();
1096 
1097   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1098                           << "  * Result: " << type << "\n");
1099   memory[memIndex] = type.getAsOpaquePointer();
1100 }
1101 
1102 void ByteCodeExecutor::executeIsNotNull() {
1103   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1104   const void *value = read<const void *>();
1105 
1106   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1107   selectJump(value != nullptr);
1108 }
1109 
1110 void ByteCodeExecutor::executeRecordMatch(
1111     PatternRewriter &rewriter,
1112     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1113   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1114   unsigned patternIndex = read();
1115   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1116   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1117 
1118   // If the benefit of the pattern is impossible, skip the processing of the
1119   // rest of the pattern.
1120   if (benefit.isImpossibleToMatch()) {
1121     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1122     curCodeIt = dest;
1123     return;
1124   }
1125 
1126   // Create a fused location containing the locations of each of the
1127   // operations used in the match. This will be used as the location for
1128   // created operations during the rewrite that don't already have an
1129   // explicit location set.
1130   unsigned numMatchLocs = read();
1131   SmallVector<Location, 4> matchLocs;
1132   matchLocs.reserve(numMatchLocs);
1133   for (unsigned i = 0; i != numMatchLocs; ++i)
1134     matchLocs.push_back(read<Operation *>()->getLoc());
1135   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1136 
1137   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1138                           << "  * Location: " << matchLoc << "\n");
1139   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1140   readList<const void *>(matches.back().values);
1141   curCodeIt = dest;
1142 }
1143 
1144 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1145   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1146   Operation *op = read<Operation *>();
1147   SmallVector<Value, 16> args;
1148   readList<Value>(args);
1149 
1150   LLVM_DEBUG({
1151     llvm::dbgs() << "  * Operation: " << *op << "\n"
1152                  << "  * Values: ";
1153     llvm::interleaveComma(args, llvm::dbgs());
1154     llvm::dbgs() << "\n";
1155   });
1156   rewriter.replaceOp(op, args);
1157 }
1158 
1159 void ByteCodeExecutor::executeSwitchAttribute() {
1160   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1161   Attribute value = read<Attribute>();
1162   ArrayAttr cases = read<ArrayAttr>();
1163   handleSwitch(value, cases);
1164 }
1165 
1166 void ByteCodeExecutor::executeSwitchOperandCount() {
1167   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1168   Operation *op = read<Operation *>();
1169   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1170 
1171   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1172   handleSwitch(op->getNumOperands(), cases);
1173 }
1174 
1175 void ByteCodeExecutor::executeSwitchOperationName() {
1176   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1177   OperationName value = read<Operation *>()->getName();
1178   size_t caseCount = read();
1179 
1180   // The operation names are stored in-line, so to print them out for
1181   // debugging purposes we need to read the array before executing the
1182   // switch so that we can display all of the possible values.
1183   LLVM_DEBUG({
1184     const ByteCodeField *prevCodeIt = curCodeIt;
1185     llvm::dbgs() << "  * Value: " << value << "\n"
1186                  << "  * Cases: ";
1187     llvm::interleaveComma(
1188         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1189                         [&](size_t) { return read<OperationName>(); }),
1190         llvm::dbgs());
1191     llvm::dbgs() << "\n";
1192     curCodeIt = prevCodeIt;
1193   });
1194 
1195   // Try to find the switch value within any of the cases.
1196   for (size_t i = 0; i != caseCount; ++i) {
1197     if (read<OperationName>() == value) {
1198       curCodeIt += (caseCount - i - 1);
1199       return selectJump(i + 1);
1200     }
1201   }
1202   selectJump(size_t(0));
1203 }
1204 
1205 void ByteCodeExecutor::executeSwitchResultCount() {
1206   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1207   Operation *op = read<Operation *>();
1208   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1209 
1210   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1211   handleSwitch(op->getNumResults(), cases);
1212 }
1213 
1214 void ByteCodeExecutor::executeSwitchType() {
1215   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1216   Type value = read<Type>();
1217   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1218   handleSwitch(value, cases);
1219 }
1220 
1221 void ByteCodeExecutor::execute(
1222     PatternRewriter &rewriter,
1223     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1224     Optional<Location> mainRewriteLoc) {
1225   while (true) {
1226     OpCode opCode = static_cast<OpCode>(read());
1227     switch (opCode) {
1228     case ApplyConstraint:
1229       executeApplyConstraint(rewriter);
1230       break;
1231     case ApplyRewrite:
1232       executeApplyRewrite(rewriter);
1233       break;
1234     case AreEqual:
1235       executeAreEqual();
1236       break;
1237     case Branch:
1238       executeBranch();
1239       break;
1240     case CheckOperandCount:
1241       executeCheckOperandCount();
1242       break;
1243     case CheckOperationName:
1244       executeCheckOperationName();
1245       break;
1246     case CheckResultCount:
1247       executeCheckResultCount();
1248       break;
1249     case CreateNative:
1250       executeCreateNative(rewriter);
1251       break;
1252     case CreateOperation:
1253       executeCreateOperation(rewriter, *mainRewriteLoc);
1254       break;
1255     case EraseOp:
1256       executeEraseOp(rewriter);
1257       break;
1258     case Finalize:
1259       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1260       return;
1261     case GetAttribute:
1262       executeGetAttribute();
1263       break;
1264     case GetAttributeType:
1265       executeGetAttributeType();
1266       break;
1267     case GetDefiningOp:
1268       executeGetDefiningOp();
1269       break;
1270     case GetOperand0:
1271     case GetOperand1:
1272     case GetOperand2:
1273     case GetOperand3: {
1274       unsigned index = opCode - GetOperand0;
1275       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
1276       executeGetOperand(index);
1277       break;
1278     }
1279     case GetOperandN:
1280       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
1281       executeGetOperand(read<uint32_t>());
1282       break;
1283     case GetResult0:
1284     case GetResult1:
1285     case GetResult2:
1286     case GetResult3: {
1287       unsigned index = opCode - GetResult0;
1288       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
1289       executeGetResult(index);
1290       break;
1291     }
1292     case GetResultN:
1293       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
1294       executeGetResult(read<uint32_t>());
1295       break;
1296     case GetValueType:
1297       executeGetValueType();
1298       break;
1299     case IsNotNull:
1300       executeIsNotNull();
1301       break;
1302     case RecordMatch:
1303       assert(matches &&
1304              "expected matches to be provided when executing the matcher");
1305       executeRecordMatch(rewriter, *matches);
1306       break;
1307     case ReplaceOp:
1308       executeReplaceOp(rewriter);
1309       break;
1310     case SwitchAttribute:
1311       executeSwitchAttribute();
1312       break;
1313     case SwitchOperandCount:
1314       executeSwitchOperandCount();
1315       break;
1316     case SwitchOperationName:
1317       executeSwitchOperationName();
1318       break;
1319     case SwitchResultCount:
1320       executeSwitchResultCount();
1321       break;
1322     case SwitchType:
1323       executeSwitchType();
1324       break;
1325     }
1326     LLVM_DEBUG(llvm::dbgs() << "\n");
1327   }
1328 }
1329 
1330 /// Run the pattern matcher on the given root operation, collecting the matched
1331 /// patterns in `matches`.
1332 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1333                         SmallVectorImpl<MatchResult> &matches,
1334                         PDLByteCodeMutableState &state) const {
1335   // The first memory slot is always the root operation.
1336   state.memory[0] = op;
1337 
1338   // The matcher function always starts at code address 0.
1339   ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
1340                             matcherByteCode, state.currentPatternBenefits,
1341                             patterns, constraintFunctions, createFunctions,
1342                             rewriteFunctions);
1343   executor.execute(rewriter, &matches);
1344 
1345   // Order the found matches by benefit.
1346   std::stable_sort(matches.begin(), matches.end(),
1347                    [](const MatchResult &lhs, const MatchResult &rhs) {
1348                      return lhs.benefit > rhs.benefit;
1349                    });
1350 }
1351 
1352 /// Run the rewriter of the given pattern on the root operation `op`.
1353 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1354                           PDLByteCodeMutableState &state) const {
1355   // The arguments of the rewrite function are stored at the start of the
1356   // memory buffer.
1357   llvm::copy(match.values, state.memory.begin());
1358 
1359   ByteCodeExecutor executor(
1360       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1361       uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
1362       constraintFunctions, createFunctions, rewriteFunctions);
1363   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1364 }
1365