1 //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a byte-code and interpreter for pattern rewrites in MLIR.
10 // The byte-code is constructed from the PDL Interpreter dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_REWRITE_BYTECODE_H_
15 #define MLIR_REWRITE_BYTECODE_H_
16 
17 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 namespace pdl_interp {
21 class RecordMatchOp;
22 } // namespace pdl_interp
23 
24 namespace detail {
25 class PDLByteCode;
26 
27 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode
28 /// entries. ByteCodeAddr refers to size of indices into the bytecode.
29 using ByteCodeField = uint16_t;
30 using ByteCodeAddr = uint32_t;
31 using OwningOpRange = llvm::OwningArrayRef<Operation *>;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
37 /// All of the data pertaining to a specific pattern within the bytecode.
38 class PDLByteCodePattern : public Pattern {
39 public:
40   static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
41                                    ByteCodeAddr rewriterAddr);
42 
43   /// Return the bytecode address of the rewriter for this pattern.
getRewriterAddr()44   ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
45 
46 private:
47   template <typename... Args>
PDLByteCodePattern(ByteCodeAddr rewriterAddr,Args &&...patternArgs)48   PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs)
49       : Pattern(std::forward<Args>(patternArgs)...),
50         rewriterAddr(rewriterAddr) {}
51 
52   /// The address of the rewriter for this pattern.
53   ByteCodeAddr rewriterAddr;
54 };
55 
56 //===----------------------------------------------------------------------===//
57 // PDLByteCodeMutableState
58 //===----------------------------------------------------------------------===//
59 
60 /// This class contains the mutable state of a bytecode instance. This allows
61 /// for a bytecode instance to be cached and reused across various different
62 /// threads/drivers.
63 class PDLByteCodeMutableState {
64 public:
65   /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
66   /// to the position of the pattern within the range returned by
67   /// `PDLByteCode::getPatterns`.
68   void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
69 
70   /// Cleanup any allocated state after a match/rewrite has been completed. This
71   /// method should be called irregardless of whether the match+rewrite was a
72   /// success or not.
73   void cleanupAfterMatchAndRewrite();
74 
75 private:
76   /// Allow access to data fields.
77   friend class PDLByteCode;
78 
79   /// The mutable block of memory used during the matching and rewriting phases
80   /// of the bytecode.
81   std::vector<const void *> memory;
82 
83   /// A mutable block of memory used during the matching and rewriting phase of
84   /// the bytecode to store ranges of operations. These are always stored by
85   /// owning references, because at no point in the execution of the byte code
86   /// we get an indexed range (view) of operations.
87   std::vector<OwningOpRange> opRangeMemory;
88 
89   /// A mutable block of memory used during the matching and rewriting phase of
90   /// the bytecode to store ranges of types.
91   std::vector<TypeRange> typeRangeMemory;
92   /// A set of type ranges that have been allocated by the byte code interpreter
93   /// to provide a guaranteed lifetime.
94   std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
95 
96   /// A mutable block of memory used during the matching and rewriting phase of
97   /// the bytecode to store ranges of values.
98   std::vector<ValueRange> valueRangeMemory;
99   /// A set of value ranges that have been allocated by the byte code
100   /// interpreter to provide a guaranteed lifetime.
101   std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
102 
103   /// The current index of ranges being iterated over for each level of nesting.
104   /// These are always maintained at 0 for the loops that are not active, so we
105   /// do not need to have a separate initialization phase for each loop.
106   std::vector<unsigned> loopIndex;
107 
108   /// The up-to-date benefits of the patterns held by the bytecode. The order
109   /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
110   std::vector<PatternBenefit> currentPatternBenefits;
111 };
112 
113 //===----------------------------------------------------------------------===//
114 // PDLByteCode
115 //===----------------------------------------------------------------------===//
116 
117 /// The bytecode class is also the interpreter. Contains the bytecode itself,
118 /// the static info, addresses of the rewriter functions, the interpreter
119 /// memory buffer, and the execution context.
120 class PDLByteCode {
121 public:
122   /// Each successful match returns a MatchResult, which contains information
123   /// necessary to execute the rewriter and indicates the originating pattern.
124   struct MatchResult {
MatchResultMatchResult125     MatchResult(Location loc, const PDLByteCodePattern &pattern,
126                 PatternBenefit benefit)
127         : location(loc), pattern(&pattern), benefit(benefit) {}
128     MatchResult(const MatchResult &) = delete;
129     MatchResult &operator=(const MatchResult &) = delete;
130     MatchResult(MatchResult &&other) = default;
131     MatchResult &operator=(MatchResult &&) = default;
132 
133     /// The location of operations to be replaced.
134     Location location;
135     /// Memory values defined in the matcher that are passed to the rewriter.
136     SmallVector<const void *> values;
137     /// Memory used for the range input values.
138     SmallVector<TypeRange, 0> typeRangeValues;
139     SmallVector<ValueRange, 0> valueRangeValues;
140 
141     /// The originating pattern that was matched. This is always non-null, but
142     /// represented with a pointer to allow for assignment.
143     const PDLByteCodePattern *pattern;
144     /// The current benefit of the pattern that was matched.
145     PatternBenefit benefit;
146   };
147 
148   /// Create a ByteCode instance from the given module containing operations in
149   /// the PDL interpreter dialect.
150   PDLByteCode(ModuleOp module,
151               llvm::StringMap<PDLConstraintFunction> constraintFns,
152               llvm::StringMap<PDLRewriteFunction> rewriteFns);
153 
154   /// Return the patterns held by the bytecode.
getPatterns()155   ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
156 
157   /// Initialize the given state such that it can be used to execute the current
158   /// bytecode.
159   void initializeMutableState(PDLByteCodeMutableState &state) const;
160 
161   /// Run the pattern matcher on the given root operation, collecting the
162   /// matched patterns in `matches`.
163   void match(Operation *op, PatternRewriter &rewriter,
164              SmallVectorImpl<MatchResult> &matches,
165              PDLByteCodeMutableState &state) const;
166 
167   /// Run the rewriter of the given pattern that was previously matched in
168   /// `match`.
169   void rewrite(PatternRewriter &rewriter, const MatchResult &match,
170                PDLByteCodeMutableState &state) const;
171 
172 private:
173   /// Execute the given byte code starting at the provided instruction `inst`.
174   /// `matches` is an optional field provided when this function is executed in
175   /// a matching context.
176   void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
177                        PDLByteCodeMutableState &state,
178                        SmallVectorImpl<MatchResult> *matches) const;
179 
180   /// A vector containing pointers to uniqued data. The storage is intentionally
181   /// opaque such that we can store a wide range of data types. The types of
182   /// data stored here include:
183   ///  * Attribute, OperationName, Type
184   std::vector<const void *> uniquedData;
185 
186   /// A vector containing the generated bytecode for the matcher.
187   SmallVector<ByteCodeField, 64> matcherByteCode;
188 
189   /// A vector containing the generated bytecode for all of the rewriters.
190   SmallVector<ByteCodeField, 64> rewriterByteCode;
191 
192   /// The set of patterns contained within the bytecode.
193   SmallVector<PDLByteCodePattern, 32> patterns;
194 
195   /// A set of user defined functions invoked via PDL.
196   std::vector<PDLConstraintFunction> constraintFunctions;
197   std::vector<PDLRewriteFunction> rewriteFunctions;
198 
199   /// The maximum memory index used by a value.
200   ByteCodeField maxValueMemoryIndex = 0;
201 
202   /// The maximum number of different types of ranges.
203   ByteCodeField maxOpRangeCount = 0;
204   ByteCodeField maxTypeRangeCount = 0;
205   ByteCodeField maxValueRangeCount = 0;
206 
207   /// The maximum number of nested loops.
208   ByteCodeField maxLoopLevel = 0;
209 };
210 
211 } // namespace detail
212 } // namespace mlir
213 
214 #endif // MLIR_REWRITE_BYTECODE_H_
215