1 //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares a byte-code and interpreter for pattern rewrites in MLIR. 10 // The byte-code is constructed from the PDL Interpreter dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_REWRITE_BYTECODE_H_ 15 #define MLIR_REWRITE_BYTECODE_H_ 16 17 #include "mlir/IR/PatternMatch.h" 18 19 namespace mlir { 20 namespace pdl_interp { 21 class RecordMatchOp; 22 } // namespace pdl_interp 23 24 namespace detail { 25 class PDLByteCode; 26 27 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode 28 /// entries. ByteCodeAddr refers to size of indices into the bytecode. 29 using ByteCodeField = uint16_t; 30 using ByteCodeAddr = uint32_t; 31 using OwningOpRange = llvm::OwningArrayRef<Operation *>; 32 33 //===----------------------------------------------------------------------===// 34 // PDLByteCodePattern 35 //===----------------------------------------------------------------------===// 36 37 /// All of the data pertaining to a specific pattern within the bytecode. 38 class PDLByteCodePattern : public Pattern { 39 public: 40 static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, 41 ByteCodeAddr rewriterAddr); 42 43 /// Return the bytecode address of the rewriter for this pattern. getRewriterAddr()44 ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } 45 46 private: 47 template <typename... Args> PDLByteCodePattern(ByteCodeAddr rewriterAddr,Args &&...patternArgs)48 PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs) 49 : Pattern(std::forward<Args>(patternArgs)...), 50 rewriterAddr(rewriterAddr) {} 51 52 /// The address of the rewriter for this pattern. 53 ByteCodeAddr rewriterAddr; 54 }; 55 56 //===----------------------------------------------------------------------===// 57 // PDLByteCodeMutableState 58 //===----------------------------------------------------------------------===// 59 60 /// This class contains the mutable state of a bytecode instance. This allows 61 /// for a bytecode instance to be cached and reused across various different 62 /// threads/drivers. 63 class PDLByteCodeMutableState { 64 public: 65 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 66 /// to the position of the pattern within the range returned by 67 /// `PDLByteCode::getPatterns`. 68 void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); 69 70 /// Cleanup any allocated state after a match/rewrite has been completed. This 71 /// method should be called irregardless of whether the match+rewrite was a 72 /// success or not. 73 void cleanupAfterMatchAndRewrite(); 74 75 private: 76 /// Allow access to data fields. 77 friend class PDLByteCode; 78 79 /// The mutable block of memory used during the matching and rewriting phases 80 /// of the bytecode. 81 std::vector<const void *> memory; 82 83 /// A mutable block of memory used during the matching and rewriting phase of 84 /// the bytecode to store ranges of operations. These are always stored by 85 /// owning references, because at no point in the execution of the byte code 86 /// we get an indexed range (view) of operations. 87 std::vector<OwningOpRange> opRangeMemory; 88 89 /// A mutable block of memory used during the matching and rewriting phase of 90 /// the bytecode to store ranges of types. 91 std::vector<TypeRange> typeRangeMemory; 92 /// A set of type ranges that have been allocated by the byte code interpreter 93 /// to provide a guaranteed lifetime. 94 std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory; 95 96 /// A mutable block of memory used during the matching and rewriting phase of 97 /// the bytecode to store ranges of values. 98 std::vector<ValueRange> valueRangeMemory; 99 /// A set of value ranges that have been allocated by the byte code 100 /// interpreter to provide a guaranteed lifetime. 101 std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory; 102 103 /// The current index of ranges being iterated over for each level of nesting. 104 /// These are always maintained at 0 for the loops that are not active, so we 105 /// do not need to have a separate initialization phase for each loop. 106 std::vector<unsigned> loopIndex; 107 108 /// The up-to-date benefits of the patterns held by the bytecode. The order 109 /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. 110 std::vector<PatternBenefit> currentPatternBenefits; 111 }; 112 113 //===----------------------------------------------------------------------===// 114 // PDLByteCode 115 //===----------------------------------------------------------------------===// 116 117 /// The bytecode class is also the interpreter. Contains the bytecode itself, 118 /// the static info, addresses of the rewriter functions, the interpreter 119 /// memory buffer, and the execution context. 120 class PDLByteCode { 121 public: 122 /// Each successful match returns a MatchResult, which contains information 123 /// necessary to execute the rewriter and indicates the originating pattern. 124 struct MatchResult { MatchResultMatchResult125 MatchResult(Location loc, const PDLByteCodePattern &pattern, 126 PatternBenefit benefit) 127 : location(loc), pattern(&pattern), benefit(benefit) {} 128 MatchResult(const MatchResult &) = delete; 129 MatchResult &operator=(const MatchResult &) = delete; 130 MatchResult(MatchResult &&other) = default; 131 MatchResult &operator=(MatchResult &&) = default; 132 133 /// The location of operations to be replaced. 134 Location location; 135 /// Memory values defined in the matcher that are passed to the rewriter. 136 SmallVector<const void *> values; 137 /// Memory used for the range input values. 138 SmallVector<TypeRange, 0> typeRangeValues; 139 SmallVector<ValueRange, 0> valueRangeValues; 140 141 /// The originating pattern that was matched. This is always non-null, but 142 /// represented with a pointer to allow for assignment. 143 const PDLByteCodePattern *pattern; 144 /// The current benefit of the pattern that was matched. 145 PatternBenefit benefit; 146 }; 147 148 /// Create a ByteCode instance from the given module containing operations in 149 /// the PDL interpreter dialect. 150 PDLByteCode(ModuleOp module, 151 llvm::StringMap<PDLConstraintFunction> constraintFns, 152 llvm::StringMap<PDLRewriteFunction> rewriteFns); 153 154 /// Return the patterns held by the bytecode. getPatterns()155 ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; } 156 157 /// Initialize the given state such that it can be used to execute the current 158 /// bytecode. 159 void initializeMutableState(PDLByteCodeMutableState &state) const; 160 161 /// Run the pattern matcher on the given root operation, collecting the 162 /// matched patterns in `matches`. 163 void match(Operation *op, PatternRewriter &rewriter, 164 SmallVectorImpl<MatchResult> &matches, 165 PDLByteCodeMutableState &state) const; 166 167 /// Run the rewriter of the given pattern that was previously matched in 168 /// `match`. 169 void rewrite(PatternRewriter &rewriter, const MatchResult &match, 170 PDLByteCodeMutableState &state) const; 171 172 private: 173 /// Execute the given byte code starting at the provided instruction `inst`. 174 /// `matches` is an optional field provided when this function is executed in 175 /// a matching context. 176 void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter, 177 PDLByteCodeMutableState &state, 178 SmallVectorImpl<MatchResult> *matches) const; 179 180 /// A vector containing pointers to uniqued data. The storage is intentionally 181 /// opaque such that we can store a wide range of data types. The types of 182 /// data stored here include: 183 /// * Attribute, OperationName, Type 184 std::vector<const void *> uniquedData; 185 186 /// A vector containing the generated bytecode for the matcher. 187 SmallVector<ByteCodeField, 64> matcherByteCode; 188 189 /// A vector containing the generated bytecode for all of the rewriters. 190 SmallVector<ByteCodeField, 64> rewriterByteCode; 191 192 /// The set of patterns contained within the bytecode. 193 SmallVector<PDLByteCodePattern, 32> patterns; 194 195 /// A set of user defined functions invoked via PDL. 196 std::vector<PDLConstraintFunction> constraintFunctions; 197 std::vector<PDLRewriteFunction> rewriteFunctions; 198 199 /// The maximum memory index used by a value. 200 ByteCodeField maxValueMemoryIndex = 0; 201 202 /// The maximum number of different types of ranges. 203 ByteCodeField maxOpRangeCount = 0; 204 ByteCodeField maxTypeRangeCount = 0; 205 ByteCodeField maxValueRangeCount = 0; 206 207 /// The maximum number of nested loops. 208 ByteCodeField maxLoopLevel = 0; 209 }; 210 211 } // namespace detail 212 } // namespace mlir 213 214 #endif // MLIR_REWRITE_BYTECODE_H_ 215