1 //===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various operation folding utilities. These
10 // utilities are intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TRANSFORMS_FOLDUTILS_H
15 #define MLIR_TRANSFORMS_FOLDUTILS_H
16 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/DialectInterface.h"
20 #include "mlir/Interfaces/FoldInterfaces.h"
21 
22 namespace mlir {
23 class Operation;
24 class Value;
25 
26 //===--------------------------------------------------------------------===//
27 // OperationFolder
28 //===--------------------------------------------------------------------===//
29 
30 /// A utility class for folding operations, and unifying duplicated constants
31 /// generated along the way.
32 class OperationFolder {
33 public:
OperationFolder(MLIRContext * ctx)34   OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
35 
36   /// Tries to perform folding on the given `op`, including unifying
37   /// deduplicated constants. If successful, replaces `op`'s uses with
38   /// folded results, and returns success. `preReplaceAction` is invoked on `op`
39   /// before it is replaced. 'processGeneratedConstants' is invoked for any new
40   /// operations generated when folding. If the op was completely folded it is
41   /// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
42   LogicalResult
43   tryToFold(Operation *op,
44             function_ref<void(Operation *)> processGeneratedConstants = nullptr,
45             function_ref<void(Operation *)> preReplaceAction = nullptr,
46             bool *inPlaceUpdate = nullptr);
47 
48   /// Tries to fold a pre-existing constant operation. `constValue` represents
49   /// the value of the constant, and can be optionally passed if the value is
50   /// already known (e.g. if the constant was discovered by m_Constant). This is
51   /// purely an optimization opportunity for callers that already know the value
52   /// of the constant. Returns false if an existing constant for `op` already
53   /// exists in the folder, in which case `op` is replaced and erased.
54   /// Otherwise, returns true and `op` is inserted into the folder (and
55   /// hoisted if necessary).
56   bool insertKnownConstant(Operation *op, Attribute constValue = {});
57 
58   /// Notifies that the given constant `op` should be remove from this
59   /// OperationFolder's internal bookkeeping.
60   ///
61   /// Note: this method must be called if a constant op is to be deleted
62   /// externally to this OperationFolder. `op` must be a constant op.
63   void notifyRemoval(Operation *op);
64 
65   /// Create an operation of specific op type with the given builder,
66   /// and immediately try to fold it. This function populates 'results' with
67   /// the results after folding the operation.
68   template <typename OpTy, typename... Args>
create(OpBuilder & builder,SmallVectorImpl<Value> & results,Location location,Args &&...args)69   void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
70               Location location, Args &&...args) {
71     // The op needs to be inserted only if the fold (below) fails, or the number
72     // of results produced by the successful folding is zero (which is treated
73     // as an in-place fold). Using create methods of the builder will insert the
74     // op, so not using it here.
75     OperationState state(location, OpTy::getOperationName());
76     OpTy::build(builder, state, std::forward<Args>(args)...);
77     Operation *op = Operation::create(state);
78 
79     if (failed(tryToFold(builder, op, results)) || results.empty()) {
80       builder.insert(op);
81       results.assign(op->result_begin(), op->result_end());
82       return;
83     }
84     op->destroy();
85   }
86 
87   /// Overload to create or fold a single result operation.
88   template <typename OpTy, typename... Args>
89   typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
90                           Value>::type
create(OpBuilder & builder,Location location,Args &&...args)91   create(OpBuilder &builder, Location location, Args &&...args) {
92     SmallVector<Value, 1> results;
93     create<OpTy>(builder, results, location, std::forward<Args>(args)...);
94     return results.front();
95   }
96 
97   /// Overload to create or fold a zero result operation.
98   template <typename OpTy, typename... Args>
99   typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResults>(),
100                           OpTy>::type
create(OpBuilder & builder,Location location,Args &&...args)101   create(OpBuilder &builder, Location location, Args &&...args) {
102     auto op = builder.create<OpTy>(location, std::forward<Args>(args)...);
103     SmallVector<Value, 0> unused;
104     (void)tryToFold(op.getOperation(), unused);
105 
106     // Folding cannot remove a zero-result operation, so for convenience we
107     // continue to return it.
108     return op;
109   }
110 
111   /// Clear out any constants cached inside of the folder.
112   void clear();
113 
114   /// Get or create a constant using the given builder. On success this returns
115   /// the constant operation, nullptr otherwise.
116   Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
117                             Attribute value, Type type, Location loc);
118 
119 private:
120   /// This map keeps track of uniqued constants by dialect, attribute, and type.
121   /// A constant operation materializes an attribute with a type. Dialects may
122   /// generate different constants with the same input attribute and type, so we
123   /// also need to track per-dialect.
124   using ConstantMap =
125       DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
126 
127   /// Returns true if the given operation is an already folded constant that is
128   /// owned by this folder.
129   bool isFolderOwnedConstant(Operation *op) const;
130 
131   /// Tries to perform folding on the given `op`. If successful, populates
132   /// `results` with the results of the folding.
133   LogicalResult tryToFold(
134       OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
135       function_ref<void(Operation *)> processGeneratedConstants = nullptr);
136 
137   /// Try to process a set of fold results, generating constants as necessary.
138   /// Populates `results` on success, otherwise leaves it unchanged.
139   LogicalResult
140   processFoldResults(OpBuilder &builder, Operation *op,
141                      SmallVectorImpl<Value> &results,
142                      ArrayRef<OpFoldResult> foldResults,
143                      function_ref<void(Operation *)> processGeneratedConstants);
144 
145   /// Try to get or create a new constant entry. On success this returns the
146   /// constant operation, nullptr otherwise.
147   Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
148                                     Dialect *dialect, OpBuilder &builder,
149                                     Attribute value, Type type, Location loc);
150 
151   /// A mapping between an insertion region and the constants that have been
152   /// created within it.
153   DenseMap<Region *, ConstantMap> foldScopes;
154 
155   /// This map tracks all of the dialects that an operation is referenced by;
156   /// given that many dialects may generate the same constant.
157   DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects;
158 
159   /// A collection of dialect folder interfaces.
160   DialectInterfaceCollection<DialectFoldInterface> interfaces;
161 };
162 
163 } // namespace mlir
164 
165 #endif // MLIR_TRANSFORMS_FOLDUTILS_H
166