1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
create(pdl_interp::RecordMatchOp matchOp,ByteCodeAddr rewriterAddr)36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.getBenefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.getRootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
cleanupAfterMatchAndRewrite()69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Continue to the next iteration of a loop.
99   Continue,
100   /// Create an operation.
101   CreateOperation,
102   /// Create a range of types.
103   CreateTypes,
104   /// Erase an operation.
105   EraseOp,
106   /// Extract the op from a range at the specified index.
107   ExtractOp,
108   /// Extract the type from a range at the specified index.
109   ExtractType,
110   /// Extract the value from a range at the specified index.
111   ExtractValue,
112   /// Terminate a matcher or rewrite sequence.
113   Finalize,
114   /// Iterate over a range of values.
115   ForEach,
116   /// Get a specific attribute of an operation.
117   GetAttribute,
118   /// Get the type of an attribute.
119   GetAttributeType,
120   /// Get the defining operation of a value.
121   GetDefiningOp,
122   /// Get a specific operand of an operation.
123   GetOperand0,
124   GetOperand1,
125   GetOperand2,
126   GetOperand3,
127   GetOperandN,
128   /// Get a specific operand group of an operation.
129   GetOperands,
130   /// Get a specific result of an operation.
131   GetResult0,
132   GetResult1,
133   GetResult2,
134   GetResult3,
135   GetResultN,
136   /// Get a specific result group of an operation.
137   GetResults,
138   /// Get the users of a value or a range of values.
139   GetUsers,
140   /// Get the type of a value.
141   GetValueType,
142   /// Get the types of a value range.
143   GetValueRangeTypes,
144   /// Check if a generic value is not null.
145   IsNotNull,
146   /// Record a successful pattern match.
147   RecordMatch,
148   /// Replace an operation.
149   ReplaceOp,
150   /// Compare an attribute with a set of constants.
151   SwitchAttribute,
152   /// Compare the operand count of an operation with a set of constants.
153   SwitchOperandCount,
154   /// Compare the name of an operation with a set of constants.
155   SwitchOperationName,
156   /// Compare the result count of an operation with a set of constants.
157   SwitchResultCount,
158   /// Compare a type with a set of constants.
159   SwitchType,
160   /// Compare a range of types with a set of constants.
161   SwitchTypes,
162 };
163 } // namespace
164 
165 /// A marker used to indicate if an operation should infer types.
166 static constexpr ByteCodeField kInferTypesMarker =
167     std::numeric_limits<ByteCodeField>::max();
168 
169 //===----------------------------------------------------------------------===//
170 // ByteCode Generation
171 //===----------------------------------------------------------------------===//
172 
173 //===----------------------------------------------------------------------===//
174 // Generator
175 
176 namespace {
177 struct ByteCodeLiveRange;
178 struct ByteCodeWriter;
179 
180 /// Check if the given class `T` can be converted to an opaque pointer.
181 template <typename T, typename... Args>
182 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
183 
184 /// This class represents the main generator for the pattern bytecode.
185 class Generator {
186 public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,ByteCodeField & maxOpRangeMemoryIndex,ByteCodeField & maxTypeRangeMemoryIndex,ByteCodeField & maxValueRangeMemoryIndex,ByteCodeField & maxLoopLevel,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns)187   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
188             SmallVectorImpl<ByteCodeField> &matcherByteCode,
189             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
190             SmallVectorImpl<PDLByteCodePattern> &patterns,
191             ByteCodeField &maxValueMemoryIndex,
192             ByteCodeField &maxOpRangeMemoryIndex,
193             ByteCodeField &maxTypeRangeMemoryIndex,
194             ByteCodeField &maxValueRangeMemoryIndex,
195             ByteCodeField &maxLoopLevel,
196             llvm::StringMap<PDLConstraintFunction> &constraintFns,
197             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
198       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
199         rewriterByteCode(rewriterByteCode), patterns(patterns),
200         maxValueMemoryIndex(maxValueMemoryIndex),
201         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
202         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
203         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
204         maxLoopLevel(maxLoopLevel) {
205     for (const auto &it : llvm::enumerate(constraintFns))
206       constraintToMemIndex.try_emplace(it.value().first(), it.index());
207     for (const auto &it : llvm::enumerate(rewriteFns))
208       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
209   }
210 
211   /// Generate the bytecode for the given PDL interpreter module.
212   void generate(ModuleOp module);
213 
214   /// Return the memory index to use for the given value.
getMemIndex(Value value)215   ByteCodeField &getMemIndex(Value value) {
216     assert(valueToMemIndex.count(value) &&
217            "expected memory index to be assigned");
218     return valueToMemIndex[value];
219   }
220 
221   /// Return the range memory index used to store the given range value.
getRangeStorageIndex(Value value)222   ByteCodeField &getRangeStorageIndex(Value value) {
223     assert(valueToRangeIndex.count(value) &&
224            "expected range index to be assigned");
225     return valueToRangeIndex[value];
226   }
227 
228   /// Return an index to use when referring to the given data that is uniqued in
229   /// the MLIR context.
230   template <typename T>
231   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)232   getMemIndex(T val) {
233     const void *opaqueVal = val.getAsOpaquePointer();
234 
235     // Get or insert a reference to this value.
236     auto it = uniquedDataToMemIndex.try_emplace(
237         opaqueVal, maxValueMemoryIndex + uniquedData.size());
238     if (it.second)
239       uniquedData.push_back(opaqueVal);
240     return it.first->second;
241   }
242 
243 private:
244   /// Allocate memory indices for the results of operations within the matcher
245   /// and rewriters.
246   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
247                              ModuleOp rewriterModule);
248 
249   /// Generate the bytecode for the given operation.
250   void generate(Region *region, ByteCodeWriter &writer);
251   void generate(Operation *op, ByteCodeWriter &writer);
252   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
259   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
260   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
261   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
262   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
263   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
264   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
265   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
266   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
267   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
268   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
269   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
270   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
271   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
272   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
273   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
274   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
275   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
276   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
277   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
278   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
279   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
280   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
281   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
282   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
283   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
284   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
285   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
286   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
287   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
288   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
289 
290   /// Mapping from value to its corresponding memory index.
291   DenseMap<Value, ByteCodeField> valueToMemIndex;
292 
293   /// Mapping from a range value to its corresponding range storage index.
294   DenseMap<Value, ByteCodeField> valueToRangeIndex;
295 
296   /// Mapping from the name of an externally registered rewrite to its index in
297   /// the bytecode registry.
298   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
299 
300   /// Mapping from the name of an externally registered constraint to its index
301   /// in the bytecode registry.
302   llvm::StringMap<ByteCodeField> constraintToMemIndex;
303 
304   /// Mapping from rewriter function name to the bytecode address of the
305   /// rewriter function in byte.
306   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
307 
308   /// Mapping from a uniqued storage object to its memory index within
309   /// `uniquedData`.
310   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
311 
312   /// The current level of the foreach loop.
313   ByteCodeField curLoopLevel = 0;
314 
315   /// The current MLIR context.
316   MLIRContext *ctx;
317 
318   /// Mapping from block to its address.
319   DenseMap<Block *, ByteCodeAddr> blockToAddr;
320 
321   /// Data of the ByteCode class to be populated.
322   std::vector<const void *> &uniquedData;
323   SmallVectorImpl<ByteCodeField> &matcherByteCode;
324   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
325   SmallVectorImpl<PDLByteCodePattern> &patterns;
326   ByteCodeField &maxValueMemoryIndex;
327   ByteCodeField &maxOpRangeMemoryIndex;
328   ByteCodeField &maxTypeRangeMemoryIndex;
329   ByteCodeField &maxValueRangeMemoryIndex;
330   ByteCodeField &maxLoopLevel;
331 };
332 
333 /// This class provides utilities for writing a bytecode stream.
334 struct ByteCodeWriter {
ByteCodeWriter__anonaa7cf1d90211::ByteCodeWriter335   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
336       : bytecode(bytecode), generator(generator) {}
337 
338   /// Append a field to the bytecode.
append__anonaa7cf1d90211::ByteCodeWriter339   void append(ByteCodeField field) { bytecode.push_back(field); }
append__anonaa7cf1d90211::ByteCodeWriter340   void append(OpCode opCode) { bytecode.push_back(opCode); }
341 
342   /// Append an address to the bytecode.
append__anonaa7cf1d90211::ByteCodeWriter343   void append(ByteCodeAddr field) {
344     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
345                   "unexpected ByteCode address size");
346 
347     ByteCodeField fieldParts[2];
348     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
349     bytecode.append({fieldParts[0], fieldParts[1]});
350   }
351 
352   /// Append a single successor to the bytecode, the exact address will need to
353   /// be resolved later.
append__anonaa7cf1d90211::ByteCodeWriter354   void append(Block *successor) {
355     // Add back a reference to the successor so that the address can be resolved
356     // later.
357     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
358     append(ByteCodeAddr(0));
359   }
360 
361   /// Append a successor range to the bytecode, the exact address will need to
362   /// be resolved later.
append__anonaa7cf1d90211::ByteCodeWriter363   void append(SuccessorRange successors) {
364     for (Block *successor : successors)
365       append(successor);
366   }
367 
368   /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anonaa7cf1d90211::ByteCodeWriter369   void appendPDLValueList(OperandRange values) {
370     bytecode.push_back(values.size());
371     for (Value value : values)
372       appendPDLValue(value);
373   }
374 
375   /// Append a value as a PDLValue.
appendPDLValue__anonaa7cf1d90211::ByteCodeWriter376   void appendPDLValue(Value value) {
377     appendPDLValueKind(value);
378     append(value);
379   }
380 
381   /// Append the PDLValue::Kind of the given value.
appendPDLValueKind__anonaa7cf1d90211::ByteCodeWriter382   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
383 
384   /// Append the PDLValue::Kind of the given type.
appendPDLValueKind__anonaa7cf1d90211::ByteCodeWriter385   void appendPDLValueKind(Type type) {
386     PDLValue::Kind kind =
387         TypeSwitch<Type, PDLValue::Kind>(type)
388             .Case<pdl::AttributeType>(
389                 [](Type) { return PDLValue::Kind::Attribute; })
390             .Case<pdl::OperationType>(
391                 [](Type) { return PDLValue::Kind::Operation; })
392             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
393               if (rangeTy.getElementType().isa<pdl::TypeType>())
394                 return PDLValue::Kind::TypeRange;
395               return PDLValue::Kind::ValueRange;
396             })
397             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
398             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
399     bytecode.push_back(static_cast<ByteCodeField>(kind));
400   }
401 
402   /// Append a value that will be stored in a memory slot and not inline within
403   /// the bytecode.
404   template <typename T>
405   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
406                    std::is_pointer<T>::value>
append__anonaa7cf1d90211::ByteCodeWriter407   append(T value) {
408     bytecode.push_back(generator.getMemIndex(value));
409   }
410 
411   /// Append a range of values.
412   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
413   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anonaa7cf1d90211::ByteCodeWriter414   append(T range) {
415     bytecode.push_back(llvm::size(range));
416     for (auto it : range)
417       append(it);
418   }
419 
420   /// Append a variadic number of fields to the bytecode.
421   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anonaa7cf1d90211::ByteCodeWriter422   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
423     append(field);
424     append(field2, fields...);
425   }
426 
427   /// Appends a value as a pointer, stored inline within the bytecode.
428   template <typename T>
429   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
appendInline__anonaa7cf1d90211::ByteCodeWriter430   appendInline(T value) {
431     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
432     const void *pointer = value.getAsOpaquePointer();
433     ByteCodeField fieldParts[numParts];
434     std::memcpy(fieldParts, &pointer, sizeof(const void *));
435     bytecode.append(fieldParts, fieldParts + numParts);
436   }
437 
438   /// Successor references in the bytecode that have yet to be resolved.
439   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
440 
441   /// The underlying bytecode buffer.
442   SmallVectorImpl<ByteCodeField> &bytecode;
443 
444   /// The main generator producing PDL.
445   Generator &generator;
446 };
447 
448 /// This class represents a live range of PDL Interpreter values, containing
449 /// information about when values are live within a match/rewrite.
450 struct ByteCodeLiveRange {
451   using Set = llvm::IntervalMap<uint64_t, char, 16>;
452   using Allocator = Set::Allocator;
453 
ByteCodeLiveRange__anonaa7cf1d90211::ByteCodeLiveRange454   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
455 
456   /// Union this live range with the one provided.
unionWith__anonaa7cf1d90211::ByteCodeLiveRange457   void unionWith(const ByteCodeLiveRange &rhs) {
458     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
459          ++it)
460       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
461   }
462 
463   /// Returns true if this range overlaps with the one provided.
overlaps__anonaa7cf1d90211::ByteCodeLiveRange464   bool overlaps(const ByteCodeLiveRange &rhs) const {
465     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
466         .valid();
467   }
468 
469   /// A map representing the ranges of the match/rewrite that a value is live in
470   /// the interpreter.
471   ///
472   /// We use std::unique_ptr here, because IntervalMap does not provide a
473   /// correct copy or move constructor. We can eliminate the pointer once
474   /// https://reviews.llvm.org/D113240 lands.
475   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
476 
477   /// The operation range storage index for this range.
478   Optional<unsigned> opRangeIndex;
479 
480   /// The type range storage index for this range.
481   Optional<unsigned> typeRangeIndex;
482 
483   /// The value range storage index for this range.
484   Optional<unsigned> valueRangeIndex;
485 };
486 } // namespace
487 
generate(ModuleOp module)488 void Generator::generate(ModuleOp module) {
489   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
490       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
491   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
492       pdl_interp::PDLInterpDialect::getRewriterModuleName());
493   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
494 
495   // Allocate memory indices for the results of operations within the matcher
496   // and rewriters.
497   allocateMemoryIndices(matcherFunc, rewriterModule);
498 
499   // Generate code for the rewriter functions.
500   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
501   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
502     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
503     for (Operation &op : rewriterFunc.getOps())
504       generate(&op, rewriterByteCodeWriter);
505   }
506   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
507          "unexpected branches in rewriter function");
508 
509   // Generate code for the matcher function.
510   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
511   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
512 
513   // Resolve successor references in the matcher.
514   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
515     ByteCodeAddr addr = blockToAddr[it.first];
516     for (unsigned offsetToFix : it.second)
517       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
518   }
519 }
520 
allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,ModuleOp rewriterModule)521 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
522                                       ModuleOp rewriterModule) {
523   // Rewriters use simplistic allocation scheme that simply assigns an index to
524   // each result.
525   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
526     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
527     auto processRewriterValue = [&](Value val) {
528       valueToMemIndex.try_emplace(val, index++);
529       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
530         Type elementTy = rangeType.getElementType();
531         if (elementTy.isa<pdl::TypeType>())
532           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
533         else if (elementTy.isa<pdl::ValueType>())
534           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
535       }
536     };
537 
538     for (BlockArgument arg : rewriterFunc.getArguments())
539       processRewriterValue(arg);
540     rewriterFunc.getBody().walk([&](Operation *op) {
541       for (Value result : op->getResults())
542         processRewriterValue(result);
543     });
544     if (index > maxValueMemoryIndex)
545       maxValueMemoryIndex = index;
546     if (typeRangeIndex > maxTypeRangeMemoryIndex)
547       maxTypeRangeMemoryIndex = typeRangeIndex;
548     if (valueRangeIndex > maxValueRangeMemoryIndex)
549       maxValueRangeMemoryIndex = valueRangeIndex;
550   }
551 
552   // The matcher function uses a more sophisticated numbering that tries to
553   // minimize the number of memory indices assigned. This is done by determining
554   // a live range of the values within the matcher, then the allocation is just
555   // finding the minimal number of overlapping live ranges. This is essentially
556   // a simplified form of register allocation where we don't necessarily have a
557   // limited number of registers, but we still want to minimize the number used.
558   DenseMap<Operation *, unsigned> opToFirstIndex;
559   DenseMap<Operation *, unsigned> opToLastIndex;
560 
561   // A custom walk that marks the first and the last index of each operation.
562   // The entry marks the beginning of the liveness range for this operation,
563   // followed by nested operations, followed by the end of the liveness range.
564   unsigned index = 0;
565   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
566     opToFirstIndex.try_emplace(op, index++);
567     for (Region &region : op->getRegions())
568       for (Block &block : region.getBlocks())
569         for (Operation &nested : block)
570           walk(&nested);
571     opToLastIndex.try_emplace(op, index++);
572   };
573   walk(matcherFunc);
574 
575   // Liveness info for each of the defs within the matcher.
576   ByteCodeLiveRange::Allocator allocator;
577   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
578 
579   // Assign the root operation being matched to slot 0.
580   BlockArgument rootOpArg = matcherFunc.getArgument(0);
581   valueToMemIndex[rootOpArg] = 0;
582 
583   // Walk each of the blocks, computing the def interval that the value is used.
584   Liveness matcherLiveness(matcherFunc);
585   matcherFunc->walk([&](Block *block) {
586     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
587     assert(info && "expected liveness info for block");
588     auto processValue = [&](Value value, Operation *firstUseOrDef) {
589       // We don't need to process the root op argument, this value is always
590       // assigned to the first memory slot.
591       if (value == rootOpArg)
592         return;
593 
594       // Set indices for the range of this block that the value is used.
595       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
596       defRangeIt->second.liveness->insert(
597           opToFirstIndex[firstUseOrDef],
598           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
599           /*dummyValue*/ 0);
600 
601       // Check to see if this value is a range type.
602       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
603         Type eleType = rangeTy.getElementType();
604         if (eleType.isa<pdl::OperationType>())
605           defRangeIt->second.opRangeIndex = 0;
606         else if (eleType.isa<pdl::TypeType>())
607           defRangeIt->second.typeRangeIndex = 0;
608         else if (eleType.isa<pdl::ValueType>())
609           defRangeIt->second.valueRangeIndex = 0;
610       }
611     };
612 
613     // Process the live-ins of this block.
614     for (Value liveIn : info->in()) {
615       // Only process the value if it has been defined in the current region.
616       // Other values that span across pdl_interp.foreach will be added higher
617       // up. This ensures that the we keep them alive for the entire duration
618       // of the loop.
619       if (liveIn.getParentRegion() == block->getParent())
620         processValue(liveIn, &block->front());
621     }
622 
623     // Process the block arguments for the entry block (those are not live-in).
624     if (block->isEntryBlock()) {
625       for (Value argument : block->getArguments())
626         processValue(argument, &block->front());
627     }
628 
629     // Process any new defs within this block.
630     for (Operation &op : *block)
631       for (Value result : op.getResults())
632         processValue(result, &op);
633   });
634 
635   // Greedily allocate memory slots using the computed def live ranges.
636   std::vector<ByteCodeLiveRange> allocatedIndices;
637 
638   // The number of memory indices currently allocated (and its next value).
639   // Recall that the root gets allocated memory index 0.
640   ByteCodeField numIndices = 1;
641 
642   // The number of memory ranges of various types (and their next values).
643   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
644 
645   for (auto &defIt : valueDefRanges) {
646     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
647     ByteCodeLiveRange &defRange = defIt.second;
648 
649     // Try to allocate to an existing index.
650     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
651       ByteCodeLiveRange &existingRange = existingIndexIt.value();
652       if (!defRange.overlaps(existingRange)) {
653         existingRange.unionWith(defRange);
654         memIndex = existingIndexIt.index() + 1;
655 
656         if (defRange.opRangeIndex) {
657           if (!existingRange.opRangeIndex)
658             existingRange.opRangeIndex = numOpRanges++;
659           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
660         } else if (defRange.typeRangeIndex) {
661           if (!existingRange.typeRangeIndex)
662             existingRange.typeRangeIndex = numTypeRanges++;
663           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
664         } else if (defRange.valueRangeIndex) {
665           if (!existingRange.valueRangeIndex)
666             existingRange.valueRangeIndex = numValueRanges++;
667           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
668         }
669         break;
670       }
671     }
672 
673     // If no existing index could be used, add a new one.
674     if (memIndex == 0) {
675       allocatedIndices.emplace_back(allocator);
676       ByteCodeLiveRange &newRange = allocatedIndices.back();
677       newRange.unionWith(defRange);
678 
679       // Allocate an index for op/type/value ranges.
680       if (defRange.opRangeIndex) {
681         newRange.opRangeIndex = numOpRanges;
682         valueToRangeIndex[defIt.first] = numOpRanges++;
683       } else if (defRange.typeRangeIndex) {
684         newRange.typeRangeIndex = numTypeRanges;
685         valueToRangeIndex[defIt.first] = numTypeRanges++;
686       } else if (defRange.valueRangeIndex) {
687         newRange.valueRangeIndex = numValueRanges;
688         valueToRangeIndex[defIt.first] = numValueRanges++;
689       }
690 
691       memIndex = allocatedIndices.size();
692       ++numIndices;
693     }
694   }
695 
696   // Print the index usage and ensure that we did not run out of index space.
697   LLVM_DEBUG({
698     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
699                  << "(down from initial " << valueDefRanges.size() << ").\n";
700   });
701   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
702          "Ran out of memory for allocated indices");
703 
704   // Update the max number of indices.
705   if (numIndices > maxValueMemoryIndex)
706     maxValueMemoryIndex = numIndices;
707   if (numOpRanges > maxOpRangeMemoryIndex)
708     maxOpRangeMemoryIndex = numOpRanges;
709   if (numTypeRanges > maxTypeRangeMemoryIndex)
710     maxTypeRangeMemoryIndex = numTypeRanges;
711   if (numValueRanges > maxValueRangeMemoryIndex)
712     maxValueRangeMemoryIndex = numValueRanges;
713 }
714 
generate(Region * region,ByteCodeWriter & writer)715 void Generator::generate(Region *region, ByteCodeWriter &writer) {
716   llvm::ReversePostOrderTraversal<Region *> rpot(region);
717   for (Block *block : rpot) {
718     // Keep track of where this block begins within the matcher function.
719     blockToAddr.try_emplace(block, matcherByteCode.size());
720     for (Operation &op : *block)
721       generate(&op, writer);
722   }
723 }
724 
generate(Operation * op,ByteCodeWriter & writer)725 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
726   LLVM_DEBUG({
727     // The following list must contain all the operations that do not
728     // produce any bytecode.
729     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
730       writer.appendInline(op->getLoc());
731   });
732   TypeSwitch<Operation *>(op)
733       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
734             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
735             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
736             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
737             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
738             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
739             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
740             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
741             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
742             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
743             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
744             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
745             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
746             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
747             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
748             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
749             pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
750             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
751             pdl_interp::SwitchResultCountOp>(
752           [&](auto interpOp) { this->generate(interpOp, writer); })
753       .Default([](Operation *) {
754         llvm_unreachable("unknown `pdl_interp` operation");
755       });
756 }
757 
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)758 void Generator::generate(pdl_interp::ApplyConstraintOp op,
759                          ByteCodeWriter &writer) {
760   assert(constraintToMemIndex.count(op.getName()) &&
761          "expected index for constraint function");
762   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
763   writer.appendPDLValueList(op.getArgs());
764   writer.append(op.getSuccessors());
765 }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)766 void Generator::generate(pdl_interp::ApplyRewriteOp op,
767                          ByteCodeWriter &writer) {
768   assert(externalRewriterToMemIndex.count(op.getName()) &&
769          "expected index for rewrite function");
770   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
771   writer.appendPDLValueList(op.getArgs());
772 
773   ResultRange results = op.getResults();
774   writer.append(ByteCodeField(results.size()));
775   for (Value result : results) {
776     // In debug mode we also record the expected kind of the result, so that we
777     // can provide extra verification of the native rewrite function.
778 #ifndef NDEBUG
779     writer.appendPDLValueKind(result);
780 #endif
781 
782     // Range results also need to append the range storage index.
783     if (result.getType().isa<pdl::RangeType>())
784       writer.append(getRangeStorageIndex(result));
785     writer.append(result);
786   }
787 }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
789   Value lhs = op.getLhs();
790   if (lhs.getType().isa<pdl::RangeType>()) {
791     writer.append(OpCode::AreRangesEqual);
792     writer.appendPDLValueKind(lhs);
793     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
794     return;
795   }
796 
797   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
798 }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
800   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
801 }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)802 void Generator::generate(pdl_interp::CheckAttributeOp op,
803                          ByteCodeWriter &writer) {
804   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
805                 op.getSuccessors());
806 }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)807 void Generator::generate(pdl_interp::CheckOperandCountOp op,
808                          ByteCodeWriter &writer) {
809   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
810                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
811                 op.getSuccessors());
812 }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)813 void Generator::generate(pdl_interp::CheckOperationNameOp op,
814                          ByteCodeWriter &writer) {
815   writer.append(OpCode::CheckOperationName, op.getInputOp(),
816                 OperationName(op.getName(), ctx), op.getSuccessors());
817 }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)818 void Generator::generate(pdl_interp::CheckResultCountOp op,
819                          ByteCodeWriter &writer) {
820   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
821                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
822                 op.getSuccessors());
823 }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
825   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
826                 op.getSuccessors());
827 }
generate(pdl_interp::CheckTypesOp op,ByteCodeWriter & writer)828 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
829   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
830                 op.getSuccessors());
831 }
generate(pdl_interp::ContinueOp op,ByteCodeWriter & writer)832 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
833   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
834   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
835 }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)836 void Generator::generate(pdl_interp::CreateAttributeOp op,
837                          ByteCodeWriter &writer) {
838   // Simply repoint the memory index of the result to the constant.
839   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
840 }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)841 void Generator::generate(pdl_interp::CreateOperationOp op,
842                          ByteCodeWriter &writer) {
843   writer.append(OpCode::CreateOperation, op.getResultOp(),
844                 OperationName(op.getName(), ctx));
845   writer.appendPDLValueList(op.getInputOperands());
846 
847   // Add the attributes.
848   OperandRange attributes = op.getInputAttributes();
849   writer.append(static_cast<ByteCodeField>(attributes.size()));
850   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
851     writer.append(std::get<0>(it), std::get<1>(it));
852 
853   // Add the result types. If the operation has inferred results, we use a
854   // marker "size" value. Otherwise, we add the list of explicit result types.
855   if (op.getInferredResultTypes())
856     writer.append(kInferTypesMarker);
857   else
858     writer.appendPDLValueList(op.getInputResultTypes());
859 }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)860 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
861   // Simply repoint the memory index of the result to the constant.
862   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
863 }
generate(pdl_interp::CreateTypesOp op,ByteCodeWriter & writer)864 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
865   writer.append(OpCode::CreateTypes, op.getResult(),
866                 getRangeStorageIndex(op.getResult()), op.getValue());
867 }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)868 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
869   writer.append(OpCode::EraseOp, op.getInputOp());
870 }
generate(pdl_interp::ExtractOp op,ByteCodeWriter & writer)871 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
872   OpCode opCode =
873       TypeSwitch<Type, OpCode>(op.getResult().getType())
874           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
875           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
876           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
877           .Default([](Type) -> OpCode {
878             llvm_unreachable("unsupported element type");
879           });
880   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
881 }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)882 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
883   writer.append(OpCode::Finalize);
884 }
generate(pdl_interp::ForEachOp op,ByteCodeWriter & writer)885 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
886   BlockArgument arg = op.getLoopVariable();
887   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
888   writer.appendPDLValueKind(arg.getType());
889   writer.append(curLoopLevel, op.getSuccessor());
890   ++curLoopLevel;
891   if (curLoopLevel > maxLoopLevel)
892     maxLoopLevel = curLoopLevel;
893   generate(&op.getRegion(), writer);
894   --curLoopLevel;
895 }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)896 void Generator::generate(pdl_interp::GetAttributeOp op,
897                          ByteCodeWriter &writer) {
898   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
899                 op.getNameAttr());
900 }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)901 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
902                          ByteCodeWriter &writer) {
903   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
904 }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)905 void Generator::generate(pdl_interp::GetDefiningOpOp op,
906                          ByteCodeWriter &writer) {
907   writer.append(OpCode::GetDefiningOp, op.getInputOp());
908   writer.appendPDLValue(op.getValue());
909 }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)910 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
911   uint32_t index = op.getIndex();
912   if (index < 4)
913     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
914   else
915     writer.append(OpCode::GetOperandN, index);
916   writer.append(op.getInputOp(), op.getValue());
917 }
generate(pdl_interp::GetOperandsOp op,ByteCodeWriter & writer)918 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
919   Value result = op.getValue();
920   Optional<uint32_t> index = op.getIndex();
921   writer.append(OpCode::GetOperands,
922                 index.value_or(std::numeric_limits<uint32_t>::max()),
923                 op.getInputOp());
924   if (result.getType().isa<pdl::RangeType>())
925     writer.append(getRangeStorageIndex(result));
926   else
927     writer.append(std::numeric_limits<ByteCodeField>::max());
928   writer.append(result);
929 }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)930 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
931   uint32_t index = op.getIndex();
932   if (index < 4)
933     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
934   else
935     writer.append(OpCode::GetResultN, index);
936   writer.append(op.getInputOp(), op.getValue());
937 }
generate(pdl_interp::GetResultsOp op,ByteCodeWriter & writer)938 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
939   Value result = op.getValue();
940   Optional<uint32_t> index = op.getIndex();
941   writer.append(OpCode::GetResults,
942                 index.value_or(std::numeric_limits<uint32_t>::max()),
943                 op.getInputOp());
944   if (result.getType().isa<pdl::RangeType>())
945     writer.append(getRangeStorageIndex(result));
946   else
947     writer.append(std::numeric_limits<ByteCodeField>::max());
948   writer.append(result);
949 }
generate(pdl_interp::GetUsersOp op,ByteCodeWriter & writer)950 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
951   Value operations = op.getOperations();
952   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
953   writer.append(OpCode::GetUsers, operations, rangeIndex);
954   writer.appendPDLValue(op.getValue());
955 }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)956 void Generator::generate(pdl_interp::GetValueTypeOp op,
957                          ByteCodeWriter &writer) {
958   if (op.getType().isa<pdl::RangeType>()) {
959     Value result = op.getResult();
960     writer.append(OpCode::GetValueRangeTypes, result,
961                   getRangeStorageIndex(result), op.getValue());
962   } else {
963     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
964   }
965 }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)966 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
967   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
968 }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)969 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
970   ByteCodeField patternIndex = patterns.size();
971   patterns.emplace_back(PDLByteCodePattern::create(
972       op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
973   writer.append(OpCode::RecordMatch, patternIndex,
974                 SuccessorRange(op.getOperation()), op.getMatchedOps());
975   writer.appendPDLValueList(op.getInputs());
976 }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)977 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
978   writer.append(OpCode::ReplaceOp, op.getInputOp());
979   writer.appendPDLValueList(op.getReplValues());
980 }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)981 void Generator::generate(pdl_interp::SwitchAttributeOp op,
982                          ByteCodeWriter &writer) {
983   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
984                 op.getCaseValuesAttr(), op.getSuccessors());
985 }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)986 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
987                          ByteCodeWriter &writer) {
988   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
989                 op.getCaseValuesAttr(), op.getSuccessors());
990 }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)991 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
992                          ByteCodeWriter &writer) {
993   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
994     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
995   });
996   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
997                 op.getSuccessors());
998 }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)999 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1000                          ByteCodeWriter &writer) {
1001   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1002                 op.getCaseValuesAttr(), op.getSuccessors());
1003 }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)1004 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1005   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1006                 op.getSuccessors());
1007 }
generate(pdl_interp::SwitchTypesOp op,ByteCodeWriter & writer)1008 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1009   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1010                 op.getSuccessors());
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // PDLByteCode
1015 //===----------------------------------------------------------------------===//
1016 
PDLByteCode(ModuleOp module,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)1017 PDLByteCode::PDLByteCode(ModuleOp module,
1018                          llvm::StringMap<PDLConstraintFunction> constraintFns,
1019                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1020   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1021                       rewriterByteCode, patterns, maxValueMemoryIndex,
1022                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1023                       maxLoopLevel, constraintFns, rewriteFns);
1024   generator.generate(module);
1025 
1026   // Initialize the external functions.
1027   for (auto &it : constraintFns)
1028     constraintFunctions.push_back(std::move(it.second));
1029   for (auto &it : rewriteFns)
1030     rewriteFunctions.push_back(std::move(it.second));
1031 }
1032 
1033 /// Initialize the given state such that it can be used to execute the current
1034 /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const1035 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1036   state.memory.resize(maxValueMemoryIndex, nullptr);
1037   state.opRangeMemory.resize(maxOpRangeCount);
1038   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1039   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1040   state.loopIndex.resize(maxLoopLevel, 0);
1041   state.currentPatternBenefits.reserve(patterns.size());
1042   for (const PDLByteCodePattern &pattern : patterns)
1043     state.currentPatternBenefits.push_back(pattern.getBenefit());
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // ByteCode Execution
1048 
1049 namespace {
1050 /// This class provides support for executing a bytecode stream.
1051 class ByteCodeExecutor {
1052 public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,MutableArrayRef<llvm::OwningArrayRef<Operation * >> opRangeMemory,MutableArrayRef<TypeRange> typeRangeMemory,std::vector<llvm::OwningArrayRef<Type>> & allocatedTypeRangeMemory,MutableArrayRef<ValueRange> valueRangeMemory,std::vector<llvm::OwningArrayRef<Value>> & allocatedValueRangeMemory,MutableArrayRef<unsigned> loopIndex,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)1053   ByteCodeExecutor(
1054       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1055       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1056       MutableArrayRef<TypeRange> typeRangeMemory,
1057       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1058       MutableArrayRef<ValueRange> valueRangeMemory,
1059       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1060       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1061       ArrayRef<ByteCodeField> code,
1062       ArrayRef<PatternBenefit> currentPatternBenefits,
1063       ArrayRef<PDLByteCodePattern> patterns,
1064       ArrayRef<PDLConstraintFunction> constraintFunctions,
1065       ArrayRef<PDLRewriteFunction> rewriteFunctions)
1066       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1067         typeRangeMemory(typeRangeMemory),
1068         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1069         valueRangeMemory(valueRangeMemory),
1070         allocatedValueRangeMemory(allocatedValueRangeMemory),
1071         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1072         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1073         constraintFunctions(constraintFunctions),
1074         rewriteFunctions(rewriteFunctions) {}
1075 
1076   /// Start executing the code at the current bytecode index. `matches` is an
1077   /// optional field provided when this function is executed in a matching
1078   /// context.
1079   void execute(PatternRewriter &rewriter,
1080                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1081                Optional<Location> mainRewriteLoc = {});
1082 
1083 private:
1084   /// Internal implementation of executing each of the bytecode commands.
1085   void executeApplyConstraint(PatternRewriter &rewriter);
1086   void executeApplyRewrite(PatternRewriter &rewriter);
1087   void executeAreEqual();
1088   void executeAreRangesEqual();
1089   void executeBranch();
1090   void executeCheckOperandCount();
1091   void executeCheckOperationName();
1092   void executeCheckResultCount();
1093   void executeCheckTypes();
1094   void executeContinue();
1095   void executeCreateOperation(PatternRewriter &rewriter,
1096                               Location mainRewriteLoc);
1097   void executeCreateTypes();
1098   void executeEraseOp(PatternRewriter &rewriter);
1099   template <typename T, typename Range, PDLValue::Kind kind>
1100   void executeExtract();
1101   void executeFinalize();
1102   void executeForEach();
1103   void executeGetAttribute();
1104   void executeGetAttributeType();
1105   void executeGetDefiningOp();
1106   void executeGetOperand(unsigned index);
1107   void executeGetOperands();
1108   void executeGetResult(unsigned index);
1109   void executeGetResults();
1110   void executeGetUsers();
1111   void executeGetValueType();
1112   void executeGetValueRangeTypes();
1113   void executeIsNotNull();
1114   void executeRecordMatch(PatternRewriter &rewriter,
1115                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1116   void executeReplaceOp(PatternRewriter &rewriter);
1117   void executeSwitchAttribute();
1118   void executeSwitchOperandCount();
1119   void executeSwitchOperationName();
1120   void executeSwitchResultCount();
1121   void executeSwitchType();
1122   void executeSwitchTypes();
1123 
1124   /// Pushes a code iterator to the stack.
pushCodeIt(const ByteCodeField * it)1125   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1126 
1127   /// Pops a code iterator from the stack, returning true on success.
popCodeIt()1128   void popCodeIt() {
1129     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1130     curCodeIt = resumeCodeIt.back();
1131     resumeCodeIt.pop_back();
1132   }
1133 
1134   /// Return the bytecode iterator at the start of the current op code.
getPrevCodeIt() const1135   const ByteCodeField *getPrevCodeIt() const {
1136     LLVM_DEBUG({
1137       // Account for the op code and the Location stored inline.
1138       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1139     });
1140 
1141     // Account for the op code only.
1142     return curCodeIt - 1;
1143   }
1144 
1145   /// Read a value from the bytecode buffer, optionally skipping a certain
1146   /// number of prefix values. These methods always update the buffer to point
1147   /// to the next field after the read data.
1148   template <typename T = ByteCodeField>
read(size_t skipN=0)1149   T read(size_t skipN = 0) {
1150     curCodeIt += skipN;
1151     return readImpl<T>();
1152   }
read(size_t skipN=0)1153   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1154 
1155   /// Read a list of values from the bytecode buffer.
1156   template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)1157   void readList(SmallVectorImpl<T> &list) {
1158     list.clear();
1159     for (unsigned i = 0, e = read(); i != e; ++i)
1160       list.push_back(read<ValueT>());
1161   }
1162 
1163   /// Read a list of values from the bytecode buffer. The values may be encoded
1164   /// as either Value or ValueRange elements.
readValueList(SmallVectorImpl<Value> & list)1165   void readValueList(SmallVectorImpl<Value> &list) {
1166     for (unsigned i = 0, e = read(); i != e; ++i) {
1167       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1168         list.push_back(read<Value>());
1169       } else {
1170         ValueRange *values = read<ValueRange *>();
1171         list.append(values->begin(), values->end());
1172       }
1173     }
1174   }
1175 
1176   /// Read a value stored inline as a pointer.
1177   template <typename T>
1178   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
readInline()1179   readInline() {
1180     const void *pointer;
1181     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1182     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1183     return T::getFromOpaquePointer(pointer);
1184   }
1185 
1186   /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)1187   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1188   /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)1189   void selectJump(size_t destIndex) {
1190     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1191   }
1192 
1193   /// Handle a switch operation with the provided value and cases.
1194   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
handleSwitch(const T & value,RangeT && cases,Comparator cmp={})1195   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1196     LLVM_DEBUG({
1197       llvm::dbgs() << "  * Value: " << value << "\n"
1198                    << "  * Cases: ";
1199       llvm::interleaveComma(cases, llvm::dbgs());
1200       llvm::dbgs() << "\n";
1201     });
1202 
1203     // Check to see if the attribute value is within the case list. Jump to
1204     // the correct successor index based on the result.
1205     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1206       if (cmp(*it, value))
1207         return selectJump(size_t((it - cases.begin()) + 1));
1208     selectJump(size_t(0));
1209   }
1210 
1211   /// Store a pointer to memory.
storeToMemory(unsigned index,const void * value)1212   void storeToMemory(unsigned index, const void *value) {
1213     memory[index] = value;
1214   }
1215 
1216   /// Store a value to memory as an opaque pointer.
1217   template <typename T>
1218   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
storeToMemory(unsigned index,T value)1219   storeToMemory(unsigned index, T value) {
1220     memory[index] = value.getAsOpaquePointer();
1221   }
1222 
1223   /// Internal implementation of reading various data types from the bytecode
1224   /// stream.
1225   template <typename T>
readFromMemory()1226   const void *readFromMemory() {
1227     size_t index = *curCodeIt++;
1228 
1229     // If this type is an SSA value, it can only be stored in non-const memory.
1230     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1231                         Value>::value ||
1232         index < memory.size())
1233       return memory[index];
1234 
1235     // Otherwise, if this index is not inbounds it is uniqued.
1236     return uniquedMemory[index - memory.size()];
1237   }
1238   template <typename T>
readImpl()1239   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1240     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1241   }
1242   template <typename T>
1243   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1244                    T>
readImpl()1245   readImpl() {
1246     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1247   }
1248   template <typename T>
readImpl()1249   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1250     switch (read<PDLValue::Kind>()) {
1251     case PDLValue::Kind::Attribute:
1252       return read<Attribute>();
1253     case PDLValue::Kind::Operation:
1254       return read<Operation *>();
1255     case PDLValue::Kind::Type:
1256       return read<Type>();
1257     case PDLValue::Kind::Value:
1258       return read<Value>();
1259     case PDLValue::Kind::TypeRange:
1260       return read<TypeRange *>();
1261     case PDLValue::Kind::ValueRange:
1262       return read<ValueRange *>();
1263     }
1264     llvm_unreachable("unhandled PDLValue::Kind");
1265   }
1266   template <typename T>
readImpl()1267   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1268     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1269                   "unexpected ByteCode address size");
1270     ByteCodeAddr result;
1271     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1272     curCodeIt += 2;
1273     return result;
1274   }
1275   template <typename T>
readImpl()1276   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1277     return *curCodeIt++;
1278   }
1279   template <typename T>
readImpl()1280   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1281     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1282   }
1283 
1284   /// The underlying bytecode buffer.
1285   const ByteCodeField *curCodeIt;
1286 
1287   /// The stack of bytecode positions at which to resume operation.
1288   SmallVector<const ByteCodeField *> resumeCodeIt;
1289 
1290   /// The current execution memory.
1291   MutableArrayRef<const void *> memory;
1292   MutableArrayRef<OwningOpRange> opRangeMemory;
1293   MutableArrayRef<TypeRange> typeRangeMemory;
1294   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1295   MutableArrayRef<ValueRange> valueRangeMemory;
1296   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1297 
1298   /// The current loop indices.
1299   MutableArrayRef<unsigned> loopIndex;
1300 
1301   /// References to ByteCode data necessary for execution.
1302   ArrayRef<const void *> uniquedMemory;
1303   ArrayRef<ByteCodeField> code;
1304   ArrayRef<PatternBenefit> currentPatternBenefits;
1305   ArrayRef<PDLByteCodePattern> patterns;
1306   ArrayRef<PDLConstraintFunction> constraintFunctions;
1307   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1308 };
1309 
1310 /// This class is an instantiation of the PDLResultList that provides access to
1311 /// the returned results. This API is not on `PDLResultList` to avoid
1312 /// overexposing access to information specific solely to the ByteCode.
1313 class ByteCodeRewriteResultList : public PDLResultList {
1314 public:
ByteCodeRewriteResultList(unsigned maxNumResults)1315   ByteCodeRewriteResultList(unsigned maxNumResults)
1316       : PDLResultList(maxNumResults) {}
1317 
1318   /// Return the list of PDL results.
getResults()1319   MutableArrayRef<PDLValue> getResults() { return results; }
1320 
1321   /// Return the type ranges allocated by this list.
getAllocatedTypeRanges()1322   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1323     return allocatedTypeRanges;
1324   }
1325 
1326   /// Return the value ranges allocated by this list.
getAllocatedValueRanges()1327   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1328     return allocatedValueRanges;
1329   }
1330 };
1331 } // namespace
1332 
executeApplyConstraint(PatternRewriter & rewriter)1333 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1334   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1335   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1336   SmallVector<PDLValue, 16> args;
1337   readList<PDLValue>(args);
1338 
1339   LLVM_DEBUG({
1340     llvm::dbgs() << "  * Arguments: ";
1341     llvm::interleaveComma(args, llvm::dbgs());
1342   });
1343 
1344   // Invoke the constraint and jump to the proper destination.
1345   selectJump(succeeded(constraintFn(rewriter, args)));
1346 }
1347 
executeApplyRewrite(PatternRewriter & rewriter)1348 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1349   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1350   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1351   SmallVector<PDLValue, 16> args;
1352   readList<PDLValue>(args);
1353 
1354   LLVM_DEBUG({
1355     llvm::dbgs() << "  * Arguments: ";
1356     llvm::interleaveComma(args, llvm::dbgs());
1357   });
1358 
1359   // Execute the rewrite function.
1360   ByteCodeField numResults = read();
1361   ByteCodeRewriteResultList results(numResults);
1362   rewriteFn(rewriter, results, args);
1363 
1364   assert(results.getResults().size() == numResults &&
1365          "native PDL rewrite function returned unexpected number of results");
1366 
1367   // Store the results in the bytecode memory.
1368   for (PDLValue &result : results.getResults()) {
1369     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1370 
1371 // In debug mode we also verify the expected kind of the result.
1372 #ifndef NDEBUG
1373     assert(result.getKind() == read<PDLValue::Kind>() &&
1374            "native PDL rewrite function returned an unexpected type of result");
1375 #endif
1376 
1377     // If the result is a range, we need to copy it over to the bytecodes
1378     // range memory.
1379     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1380       unsigned rangeIndex = read();
1381       typeRangeMemory[rangeIndex] = *typeRange;
1382       memory[read()] = &typeRangeMemory[rangeIndex];
1383     } else if (Optional<ValueRange> valueRange =
1384                    result.dyn_cast<ValueRange>()) {
1385       unsigned rangeIndex = read();
1386       valueRangeMemory[rangeIndex] = *valueRange;
1387       memory[read()] = &valueRangeMemory[rangeIndex];
1388     } else {
1389       memory[read()] = result.getAsOpaquePointer();
1390     }
1391   }
1392 
1393   // Copy over any underlying storage allocated for result ranges.
1394   for (auto &it : results.getAllocatedTypeRanges())
1395     allocatedTypeRangeMemory.push_back(std::move(it));
1396   for (auto &it : results.getAllocatedValueRanges())
1397     allocatedValueRangeMemory.push_back(std::move(it));
1398 }
1399 
executeAreEqual()1400 void ByteCodeExecutor::executeAreEqual() {
1401   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1402   const void *lhs = read<const void *>();
1403   const void *rhs = read<const void *>();
1404 
1405   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1406   selectJump(lhs == rhs);
1407 }
1408 
executeAreRangesEqual()1409 void ByteCodeExecutor::executeAreRangesEqual() {
1410   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1411   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1412   const void *lhs = read<const void *>();
1413   const void *rhs = read<const void *>();
1414 
1415   switch (valueKind) {
1416   case PDLValue::Kind::TypeRange: {
1417     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1418     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1419     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1420     selectJump(*lhsRange == *rhsRange);
1421     break;
1422   }
1423   case PDLValue::Kind::ValueRange: {
1424     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1425     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1426     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1427     selectJump(*lhsRange == *rhsRange);
1428     break;
1429   }
1430   default:
1431     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1432   }
1433 }
1434 
executeBranch()1435 void ByteCodeExecutor::executeBranch() {
1436   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1437   curCodeIt = &code[read<ByteCodeAddr>()];
1438 }
1439 
executeCheckOperandCount()1440 void ByteCodeExecutor::executeCheckOperandCount() {
1441   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1442   Operation *op = read<Operation *>();
1443   uint32_t expectedCount = read<uint32_t>();
1444   bool compareAtLeast = read();
1445 
1446   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1447                           << "  * Expected: " << expectedCount << "\n"
1448                           << "  * Comparator: "
1449                           << (compareAtLeast ? ">=" : "==") << "\n");
1450   if (compareAtLeast)
1451     selectJump(op->getNumOperands() >= expectedCount);
1452   else
1453     selectJump(op->getNumOperands() == expectedCount);
1454 }
1455 
executeCheckOperationName()1456 void ByteCodeExecutor::executeCheckOperationName() {
1457   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1458   Operation *op = read<Operation *>();
1459   OperationName expectedName = read<OperationName>();
1460 
1461   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1462                           << "  * Expected: \"" << expectedName << "\"\n");
1463   selectJump(op->getName() == expectedName);
1464 }
1465 
executeCheckResultCount()1466 void ByteCodeExecutor::executeCheckResultCount() {
1467   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1468   Operation *op = read<Operation *>();
1469   uint32_t expectedCount = read<uint32_t>();
1470   bool compareAtLeast = read();
1471 
1472   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1473                           << "  * Expected: " << expectedCount << "\n"
1474                           << "  * Comparator: "
1475                           << (compareAtLeast ? ">=" : "==") << "\n");
1476   if (compareAtLeast)
1477     selectJump(op->getNumResults() >= expectedCount);
1478   else
1479     selectJump(op->getNumResults() == expectedCount);
1480 }
1481 
executeCheckTypes()1482 void ByteCodeExecutor::executeCheckTypes() {
1483   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1484   TypeRange *lhs = read<TypeRange *>();
1485   Attribute rhs = read<Attribute>();
1486   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1487 
1488   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1489 }
1490 
executeContinue()1491 void ByteCodeExecutor::executeContinue() {
1492   ByteCodeField level = read();
1493   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1494                           << "  * Level: " << level << "\n");
1495   ++loopIndex[level];
1496   popCodeIt();
1497 }
1498 
executeCreateTypes()1499 void ByteCodeExecutor::executeCreateTypes() {
1500   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1501   unsigned memIndex = read();
1502   unsigned rangeIndex = read();
1503   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1504 
1505   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1506 
1507   // Allocate a buffer for this type range.
1508   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1509   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1510   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1511 
1512   // Assign this to the range slot and use the range as the value for the
1513   // memory index.
1514   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1515   memory[memIndex] = &typeRangeMemory[rangeIndex];
1516 }
1517 
executeCreateOperation(PatternRewriter & rewriter,Location mainRewriteLoc)1518 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1519                                               Location mainRewriteLoc) {
1520   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1521 
1522   unsigned memIndex = read();
1523   OperationState state(mainRewriteLoc, read<OperationName>());
1524   readValueList(state.operands);
1525   for (unsigned i = 0, e = read(); i != e; ++i) {
1526     StringAttr name = read<StringAttr>();
1527     if (Attribute attr = read<Attribute>())
1528       state.addAttribute(name, attr);
1529   }
1530 
1531   // Read in the result types. If the "size" is the sentinel value, this
1532   // indicates that the result types should be inferred.
1533   unsigned numResults = read();
1534   if (numResults == kInferTypesMarker) {
1535     InferTypeOpInterface::Concept *inferInterface =
1536         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1537     assert(inferInterface &&
1538            "expected operation to provide InferTypeOpInterface");
1539 
1540     // TODO: Handle failure.
1541     if (failed(inferInterface->inferReturnTypes(
1542             state.getContext(), state.location, state.operands,
1543             state.attributes.getDictionary(state.getContext()), state.regions,
1544             state.types)))
1545       return;
1546   } else {
1547     // Otherwise, this is a fixed number of results.
1548     for (unsigned i = 0; i != numResults; ++i) {
1549       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1550         state.types.push_back(read<Type>());
1551       } else {
1552         TypeRange *resultTypes = read<TypeRange *>();
1553         state.types.append(resultTypes->begin(), resultTypes->end());
1554       }
1555     }
1556   }
1557 
1558   Operation *resultOp = rewriter.create(state);
1559   memory[memIndex] = resultOp;
1560 
1561   LLVM_DEBUG({
1562     llvm::dbgs() << "  * Attributes: "
1563                  << state.attributes.getDictionary(state.getContext())
1564                  << "\n  * Operands: ";
1565     llvm::interleaveComma(state.operands, llvm::dbgs());
1566     llvm::dbgs() << "\n  * Result Types: ";
1567     llvm::interleaveComma(state.types, llvm::dbgs());
1568     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1569   });
1570 }
1571 
executeEraseOp(PatternRewriter & rewriter)1572 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1573   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1574   Operation *op = read<Operation *>();
1575 
1576   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1577   rewriter.eraseOp(op);
1578 }
1579 
1580 template <typename T, typename Range, PDLValue::Kind kind>
executeExtract()1581 void ByteCodeExecutor::executeExtract() {
1582   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1583   Range *range = read<Range *>();
1584   unsigned index = read<uint32_t>();
1585   unsigned memIndex = read();
1586 
1587   if (!range) {
1588     memory[memIndex] = nullptr;
1589     return;
1590   }
1591 
1592   T result = index < range->size() ? (*range)[index] : T();
1593   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
1594                           << "  * Index: " << index << "\n"
1595                           << "  * Result: " << result << "\n");
1596   storeToMemory(memIndex, result);
1597 }
1598 
executeFinalize()1599 void ByteCodeExecutor::executeFinalize() {
1600   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1601 }
1602 
executeForEach()1603 void ByteCodeExecutor::executeForEach() {
1604   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1605   const ByteCodeField *prevCodeIt = getPrevCodeIt();
1606   unsigned rangeIndex = read();
1607   unsigned memIndex = read();
1608   const void *value = nullptr;
1609 
1610   switch (read<PDLValue::Kind>()) {
1611   case PDLValue::Kind::Operation: {
1612     unsigned &index = loopIndex[read()];
1613     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1614     assert(index <= array.size() && "iterated past the end");
1615     if (index < array.size()) {
1616       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
1617       value = array[index];
1618       break;
1619     }
1620 
1621     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
1622     index = 0;
1623     selectJump(size_t(0));
1624     return;
1625   }
1626   default:
1627     llvm_unreachable("unexpected `ForEach` value kind");
1628   }
1629 
1630   // Store the iterate value and the stack address.
1631   memory[memIndex] = value;
1632   pushCodeIt(prevCodeIt);
1633 
1634   // Skip over the successor (we will enter the body of the loop).
1635   read<ByteCodeAddr>();
1636 }
1637 
executeGetAttribute()1638 void ByteCodeExecutor::executeGetAttribute() {
1639   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1640   unsigned memIndex = read();
1641   Operation *op = read<Operation *>();
1642   StringAttr attrName = read<StringAttr>();
1643   Attribute attr = op->getAttr(attrName);
1644 
1645   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1646                           << "  * Attribute: " << attrName << "\n"
1647                           << "  * Result: " << attr << "\n");
1648   memory[memIndex] = attr.getAsOpaquePointer();
1649 }
1650 
executeGetAttributeType()1651 void ByteCodeExecutor::executeGetAttributeType() {
1652   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1653   unsigned memIndex = read();
1654   Attribute attr = read<Attribute>();
1655   Type type = attr ? attr.getType() : Type();
1656 
1657   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1658                           << "  * Result: " << type << "\n");
1659   memory[memIndex] = type.getAsOpaquePointer();
1660 }
1661 
executeGetDefiningOp()1662 void ByteCodeExecutor::executeGetDefiningOp() {
1663   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1664   unsigned memIndex = read();
1665   Operation *op = nullptr;
1666   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1667     Value value = read<Value>();
1668     if (value)
1669       op = value.getDefiningOp();
1670     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1671   } else {
1672     ValueRange *values = read<ValueRange *>();
1673     if (values && !values->empty()) {
1674       op = values->front().getDefiningOp();
1675     }
1676     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1677   }
1678 
1679   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1680   memory[memIndex] = op;
1681 }
1682 
executeGetOperand(unsigned index)1683 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1684   Operation *op = read<Operation *>();
1685   unsigned memIndex = read();
1686   Value operand =
1687       index < op->getNumOperands() ? op->getOperand(index) : Value();
1688 
1689   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1690                           << "  * Index: " << index << "\n"
1691                           << "  * Result: " << operand << "\n");
1692   memory[memIndex] = operand.getAsOpaquePointer();
1693 }
1694 
1695 /// This function is the internal implementation of `GetResults` and
1696 /// `GetOperands` that provides support for extracting a value range from the
1697 /// given operation.
1698 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1699 static void *
executeGetOperandsResults(RangeT values,Operation * op,unsigned index,ByteCodeField rangeIndex,StringRef attrSizedSegments,MutableArrayRef<ValueRange> valueRangeMemory)1700 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1701                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1702                           MutableArrayRef<ValueRange> valueRangeMemory) {
1703   // Check for the sentinel index that signals that all values should be
1704   // returned.
1705   if (index == std::numeric_limits<uint32_t>::max()) {
1706     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1707     // `values` is already the full value range.
1708 
1709     // Otherwise, check to see if this operation uses AttrSizedSegments.
1710   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1711     LLVM_DEBUG(llvm::dbgs()
1712                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1713 
1714     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1715     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1716       return nullptr;
1717 
1718     auto segments = segmentAttr.getValues<int32_t>();
1719     unsigned startIndex =
1720         std::accumulate(segments.begin(), segments.begin() + index, 0);
1721     values = values.slice(startIndex, *std::next(segments.begin(), index));
1722 
1723     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1724                             << *std::next(segments.begin(), index) << "]\n");
1725 
1726     // Otherwise, assume this is the last operand group of the operation.
1727     // FIXME: We currently don't support operations with
1728     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1729     // have a way to detect it's presence.
1730   } else if (values.size() >= index) {
1731     LLVM_DEBUG(llvm::dbgs()
1732                << "  * Treating values as trailing variadic range\n");
1733     values = values.drop_front(index);
1734 
1735     // If we couldn't detect a way to compute the values, bail out.
1736   } else {
1737     return nullptr;
1738   }
1739 
1740   // If the range index is valid, we are returning a range.
1741   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1742     valueRangeMemory[rangeIndex] = values;
1743     return &valueRangeMemory[rangeIndex];
1744   }
1745 
1746   // If a range index wasn't provided, the range is required to be non-variadic.
1747   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1748 }
1749 
executeGetOperands()1750 void ByteCodeExecutor::executeGetOperands() {
1751   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1752   unsigned index = read<uint32_t>();
1753   Operation *op = read<Operation *>();
1754   ByteCodeField rangeIndex = read();
1755 
1756   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1757       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1758       valueRangeMemory);
1759   if (!result)
1760     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1761   memory[read()] = result;
1762 }
1763 
executeGetResult(unsigned index)1764 void ByteCodeExecutor::executeGetResult(unsigned index) {
1765   Operation *op = read<Operation *>();
1766   unsigned memIndex = read();
1767   OpResult result =
1768       index < op->getNumResults() ? op->getResult(index) : OpResult();
1769 
1770   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1771                           << "  * Index: " << index << "\n"
1772                           << "  * Result: " << result << "\n");
1773   memory[memIndex] = result.getAsOpaquePointer();
1774 }
1775 
executeGetResults()1776 void ByteCodeExecutor::executeGetResults() {
1777   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1778   unsigned index = read<uint32_t>();
1779   Operation *op = read<Operation *>();
1780   ByteCodeField rangeIndex = read();
1781 
1782   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1783       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1784       valueRangeMemory);
1785   if (!result)
1786     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1787   memory[read()] = result;
1788 }
1789 
executeGetUsers()1790 void ByteCodeExecutor::executeGetUsers() {
1791   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1792   unsigned memIndex = read();
1793   unsigned rangeIndex = read();
1794   OwningOpRange &range = opRangeMemory[rangeIndex];
1795   memory[memIndex] = &range;
1796 
1797   range = OwningOpRange();
1798   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1799     // Read the value.
1800     Value value = read<Value>();
1801     if (!value)
1802       return;
1803     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1804 
1805     // Extract the users of a single value.
1806     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1807     llvm::copy(value.getUsers(), range.begin());
1808   } else {
1809     // Read a range of values.
1810     ValueRange *values = read<ValueRange *>();
1811     if (!values)
1812       return;
1813     LLVM_DEBUG({
1814       llvm::dbgs() << "  * Values (" << values->size() << "): ";
1815       llvm::interleaveComma(*values, llvm::dbgs());
1816       llvm::dbgs() << "\n";
1817     });
1818 
1819     // Extract all the users of a range of values.
1820     SmallVector<Operation *> users;
1821     for (Value value : *values)
1822       users.append(value.user_begin(), value.user_end());
1823     range = OwningOpRange(users.size());
1824     llvm::copy(users, range.begin());
1825   }
1826 
1827   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
1828 }
1829 
executeGetValueType()1830 void ByteCodeExecutor::executeGetValueType() {
1831   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1832   unsigned memIndex = read();
1833   Value value = read<Value>();
1834   Type type = value ? value.getType() : Type();
1835 
1836   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1837                           << "  * Result: " << type << "\n");
1838   memory[memIndex] = type.getAsOpaquePointer();
1839 }
1840 
executeGetValueRangeTypes()1841 void ByteCodeExecutor::executeGetValueRangeTypes() {
1842   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1843   unsigned memIndex = read();
1844   unsigned rangeIndex = read();
1845   ValueRange *values = read<ValueRange *>();
1846   if (!values) {
1847     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1848     memory[memIndex] = nullptr;
1849     return;
1850   }
1851 
1852   LLVM_DEBUG({
1853     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1854     llvm::interleaveComma(*values, llvm::dbgs());
1855     llvm::dbgs() << "\n  * Result: ";
1856     llvm::interleaveComma(values->getType(), llvm::dbgs());
1857     llvm::dbgs() << "\n";
1858   });
1859   typeRangeMemory[rangeIndex] = values->getType();
1860   memory[memIndex] = &typeRangeMemory[rangeIndex];
1861 }
1862 
executeIsNotNull()1863 void ByteCodeExecutor::executeIsNotNull() {
1864   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1865   const void *value = read<const void *>();
1866 
1867   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1868   selectJump(value != nullptr);
1869 }
1870 
executeRecordMatch(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> & matches)1871 void ByteCodeExecutor::executeRecordMatch(
1872     PatternRewriter &rewriter,
1873     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1874   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1875   unsigned patternIndex = read();
1876   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1877   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1878 
1879   // If the benefit of the pattern is impossible, skip the processing of the
1880   // rest of the pattern.
1881   if (benefit.isImpossibleToMatch()) {
1882     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1883     curCodeIt = dest;
1884     return;
1885   }
1886 
1887   // Create a fused location containing the locations of each of the
1888   // operations used in the match. This will be used as the location for
1889   // created operations during the rewrite that don't already have an
1890   // explicit location set.
1891   unsigned numMatchLocs = read();
1892   SmallVector<Location, 4> matchLocs;
1893   matchLocs.reserve(numMatchLocs);
1894   for (unsigned i = 0; i != numMatchLocs; ++i)
1895     matchLocs.push_back(read<Operation *>()->getLoc());
1896   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1897 
1898   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1899                           << "  * Location: " << matchLoc << "\n");
1900   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1901   PDLByteCode::MatchResult &match = matches.back();
1902 
1903   // Record all of the inputs to the match. If any of the inputs are ranges, we
1904   // will also need to remap the range pointer to memory stored in the match
1905   // state.
1906   unsigned numInputs = read();
1907   match.values.reserve(numInputs);
1908   match.typeRangeValues.reserve(numInputs);
1909   match.valueRangeValues.reserve(numInputs);
1910   for (unsigned i = 0; i < numInputs; ++i) {
1911     switch (read<PDLValue::Kind>()) {
1912     case PDLValue::Kind::TypeRange:
1913       match.typeRangeValues.push_back(*read<TypeRange *>());
1914       match.values.push_back(&match.typeRangeValues.back());
1915       break;
1916     case PDLValue::Kind::ValueRange:
1917       match.valueRangeValues.push_back(*read<ValueRange *>());
1918       match.values.push_back(&match.valueRangeValues.back());
1919       break;
1920     default:
1921       match.values.push_back(read<const void *>());
1922       break;
1923     }
1924   }
1925   curCodeIt = dest;
1926 }
1927 
executeReplaceOp(PatternRewriter & rewriter)1928 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1929   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1930   Operation *op = read<Operation *>();
1931   SmallVector<Value, 16> args;
1932   readValueList(args);
1933 
1934   LLVM_DEBUG({
1935     llvm::dbgs() << "  * Operation: " << *op << "\n"
1936                  << "  * Values: ";
1937     llvm::interleaveComma(args, llvm::dbgs());
1938     llvm::dbgs() << "\n";
1939   });
1940   rewriter.replaceOp(op, args);
1941 }
1942 
executeSwitchAttribute()1943 void ByteCodeExecutor::executeSwitchAttribute() {
1944   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1945   Attribute value = read<Attribute>();
1946   ArrayAttr cases = read<ArrayAttr>();
1947   handleSwitch(value, cases);
1948 }
1949 
executeSwitchOperandCount()1950 void ByteCodeExecutor::executeSwitchOperandCount() {
1951   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1952   Operation *op = read<Operation *>();
1953   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1954 
1955   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1956   handleSwitch(op->getNumOperands(), cases);
1957 }
1958 
executeSwitchOperationName()1959 void ByteCodeExecutor::executeSwitchOperationName() {
1960   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1961   OperationName value = read<Operation *>()->getName();
1962   size_t caseCount = read();
1963 
1964   // The operation names are stored in-line, so to print them out for
1965   // debugging purposes we need to read the array before executing the
1966   // switch so that we can display all of the possible values.
1967   LLVM_DEBUG({
1968     const ByteCodeField *prevCodeIt = curCodeIt;
1969     llvm::dbgs() << "  * Value: " << value << "\n"
1970                  << "  * Cases: ";
1971     llvm::interleaveComma(
1972         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1973                         [&](size_t) { return read<OperationName>(); }),
1974         llvm::dbgs());
1975     llvm::dbgs() << "\n";
1976     curCodeIt = prevCodeIt;
1977   });
1978 
1979   // Try to find the switch value within any of the cases.
1980   for (size_t i = 0; i != caseCount; ++i) {
1981     if (read<OperationName>() == value) {
1982       curCodeIt += (caseCount - i - 1);
1983       return selectJump(i + 1);
1984     }
1985   }
1986   selectJump(size_t(0));
1987 }
1988 
executeSwitchResultCount()1989 void ByteCodeExecutor::executeSwitchResultCount() {
1990   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1991   Operation *op = read<Operation *>();
1992   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1993 
1994   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1995   handleSwitch(op->getNumResults(), cases);
1996 }
1997 
executeSwitchType()1998 void ByteCodeExecutor::executeSwitchType() {
1999   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2000   Type value = read<Type>();
2001   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2002   handleSwitch(value, cases);
2003 }
2004 
executeSwitchTypes()2005 void ByteCodeExecutor::executeSwitchTypes() {
2006   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2007   TypeRange *value = read<TypeRange *>();
2008   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2009   if (!value) {
2010     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2011     return selectJump(size_t(0));
2012   }
2013   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2014     return value == caseValue.getAsValueRange<TypeAttr>();
2015   });
2016 }
2017 
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,Optional<Location> mainRewriteLoc)2018 void ByteCodeExecutor::execute(
2019     PatternRewriter &rewriter,
2020     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2021     Optional<Location> mainRewriteLoc) {
2022   while (true) {
2023     // Print the location of the operation being executed.
2024     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2025 
2026     OpCode opCode = static_cast<OpCode>(read());
2027     switch (opCode) {
2028     case ApplyConstraint:
2029       executeApplyConstraint(rewriter);
2030       break;
2031     case ApplyRewrite:
2032       executeApplyRewrite(rewriter);
2033       break;
2034     case AreEqual:
2035       executeAreEqual();
2036       break;
2037     case AreRangesEqual:
2038       executeAreRangesEqual();
2039       break;
2040     case Branch:
2041       executeBranch();
2042       break;
2043     case CheckOperandCount:
2044       executeCheckOperandCount();
2045       break;
2046     case CheckOperationName:
2047       executeCheckOperationName();
2048       break;
2049     case CheckResultCount:
2050       executeCheckResultCount();
2051       break;
2052     case CheckTypes:
2053       executeCheckTypes();
2054       break;
2055     case Continue:
2056       executeContinue();
2057       break;
2058     case CreateOperation:
2059       executeCreateOperation(rewriter, *mainRewriteLoc);
2060       break;
2061     case CreateTypes:
2062       executeCreateTypes();
2063       break;
2064     case EraseOp:
2065       executeEraseOp(rewriter);
2066       break;
2067     case ExtractOp:
2068       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2069       break;
2070     case ExtractType:
2071       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2072       break;
2073     case ExtractValue:
2074       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2075       break;
2076     case Finalize:
2077       executeFinalize();
2078       LLVM_DEBUG(llvm::dbgs() << "\n");
2079       return;
2080     case ForEach:
2081       executeForEach();
2082       break;
2083     case GetAttribute:
2084       executeGetAttribute();
2085       break;
2086     case GetAttributeType:
2087       executeGetAttributeType();
2088       break;
2089     case GetDefiningOp:
2090       executeGetDefiningOp();
2091       break;
2092     case GetOperand0:
2093     case GetOperand1:
2094     case GetOperand2:
2095     case GetOperand3: {
2096       unsigned index = opCode - GetOperand0;
2097       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2098       executeGetOperand(index);
2099       break;
2100     }
2101     case GetOperandN:
2102       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2103       executeGetOperand(read<uint32_t>());
2104       break;
2105     case GetOperands:
2106       executeGetOperands();
2107       break;
2108     case GetResult0:
2109     case GetResult1:
2110     case GetResult2:
2111     case GetResult3: {
2112       unsigned index = opCode - GetResult0;
2113       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2114       executeGetResult(index);
2115       break;
2116     }
2117     case GetResultN:
2118       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2119       executeGetResult(read<uint32_t>());
2120       break;
2121     case GetResults:
2122       executeGetResults();
2123       break;
2124     case GetUsers:
2125       executeGetUsers();
2126       break;
2127     case GetValueType:
2128       executeGetValueType();
2129       break;
2130     case GetValueRangeTypes:
2131       executeGetValueRangeTypes();
2132       break;
2133     case IsNotNull:
2134       executeIsNotNull();
2135       break;
2136     case RecordMatch:
2137       assert(matches &&
2138              "expected matches to be provided when executing the matcher");
2139       executeRecordMatch(rewriter, *matches);
2140       break;
2141     case ReplaceOp:
2142       executeReplaceOp(rewriter);
2143       break;
2144     case SwitchAttribute:
2145       executeSwitchAttribute();
2146       break;
2147     case SwitchOperandCount:
2148       executeSwitchOperandCount();
2149       break;
2150     case SwitchOperationName:
2151       executeSwitchOperationName();
2152       break;
2153     case SwitchResultCount:
2154       executeSwitchResultCount();
2155       break;
2156     case SwitchType:
2157       executeSwitchType();
2158       break;
2159     case SwitchTypes:
2160       executeSwitchTypes();
2161       break;
2162     }
2163     LLVM_DEBUG(llvm::dbgs() << "\n");
2164   }
2165 }
2166 
2167 /// Run the pattern matcher on the given root operation, collecting the matched
2168 /// patterns in `matches`.
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const2169 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2170                         SmallVectorImpl<MatchResult> &matches,
2171                         PDLByteCodeMutableState &state) const {
2172   // The first memory slot is always the root operation.
2173   state.memory[0] = op;
2174 
2175   // The matcher function always starts at code address 0.
2176   ByteCodeExecutor executor(
2177       matcherByteCode.data(), state.memory, state.opRangeMemory,
2178       state.typeRangeMemory, state.allocatedTypeRangeMemory,
2179       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2180       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2181       constraintFunctions, rewriteFunctions);
2182   executor.execute(rewriter, &matches);
2183 
2184   // Order the found matches by benefit.
2185   std::stable_sort(matches.begin(), matches.end(),
2186                    [](const MatchResult &lhs, const MatchResult &rhs) {
2187                      return lhs.benefit > rhs.benefit;
2188                    });
2189 }
2190 
2191 /// Run the rewriter of the given pattern on the root operation `op`.
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const2192 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2193                           PDLByteCodeMutableState &state) const {
2194   // The arguments of the rewrite function are stored at the start of the
2195   // memory buffer.
2196   llvm::copy(match.values, state.memory.begin());
2197 
2198   ByteCodeExecutor executor(
2199       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2200       state.opRangeMemory, state.typeRangeMemory,
2201       state.allocatedTypeRangeMemory, state.valueRangeMemory,
2202       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2203       rewriterByteCode, state.currentPatternBenefits, patterns,
2204       constraintFunctions, rewriteFunctions);
2205   executor.execute(rewriter, /*matches=*/nullptr, match.location);
2206 }
2207