1 //===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file declares various operation folding utilities. These 10 // utilities are intended to be used by passes to unify and simply their logic. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_TRANSFORMS_FOLDUTILS_H 15 #define MLIR_TRANSFORMS_FOLDUTILS_H 16 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/Dialect.h" 19 #include "mlir/IR/DialectInterface.h" 20 #include "mlir/Interfaces/FoldInterfaces.h" 21 22 namespace mlir { 23 class Operation; 24 class Value; 25 26 //===--------------------------------------------------------------------===// 27 // OperationFolder 28 //===--------------------------------------------------------------------===// 29 30 /// A utility class for folding operations, and unifying duplicated constants 31 /// generated along the way. 32 class OperationFolder { 33 public: OperationFolder(MLIRContext * ctx)34 OperationFolder(MLIRContext *ctx) : interfaces(ctx) {} 35 36 /// Tries to perform folding on the given `op`, including unifying 37 /// deduplicated constants. If successful, replaces `op`'s uses with 38 /// folded results, and returns success. `preReplaceAction` is invoked on `op` 39 /// before it is replaced. 'processGeneratedConstants' is invoked for any new 40 /// operations generated when folding. If the op was completely folded it is 41 /// erased. If it is just updated in place, `inPlaceUpdate` is set to true. 42 LogicalResult 43 tryToFold(Operation *op, 44 function_ref<void(Operation *)> processGeneratedConstants = nullptr, 45 function_ref<void(Operation *)> preReplaceAction = nullptr, 46 bool *inPlaceUpdate = nullptr); 47 48 /// Tries to fold a pre-existing constant operation. `constValue` represents 49 /// the value of the constant, and can be optionally passed if the value is 50 /// already known (e.g. if the constant was discovered by m_Constant). This is 51 /// purely an optimization opportunity for callers that already know the value 52 /// of the constant. Returns false if an existing constant for `op` already 53 /// exists in the folder, in which case `op` is replaced and erased. 54 /// Otherwise, returns true and `op` is inserted into the folder (and 55 /// hoisted if necessary). 56 bool insertKnownConstant(Operation *op, Attribute constValue = {}); 57 58 /// Notifies that the given constant `op` should be remove from this 59 /// OperationFolder's internal bookkeeping. 60 /// 61 /// Note: this method must be called if a constant op is to be deleted 62 /// externally to this OperationFolder. `op` must be a constant op. 63 void notifyRemoval(Operation *op); 64 65 /// Create an operation of specific op type with the given builder, 66 /// and immediately try to fold it. This function populates 'results' with 67 /// the results after folding the operation. 68 template <typename OpTy, typename... Args> create(OpBuilder & builder,SmallVectorImpl<Value> & results,Location location,Args &&...args)69 void create(OpBuilder &builder, SmallVectorImpl<Value> &results, 70 Location location, Args &&...args) { 71 // The op needs to be inserted only if the fold (below) fails, or the number 72 // of results produced by the successful folding is zero (which is treated 73 // as an in-place fold). Using create methods of the builder will insert the 74 // op, so not using it here. 75 OperationState state(location, OpTy::getOperationName()); 76 OpTy::build(builder, state, std::forward<Args>(args)...); 77 Operation *op = Operation::create(state); 78 79 if (failed(tryToFold(builder, op, results)) || results.empty()) { 80 builder.insert(op); 81 results.assign(op->result_begin(), op->result_end()); 82 return; 83 } 84 op->destroy(); 85 } 86 87 /// Overload to create or fold a single result operation. 88 template <typename OpTy, typename... Args> 89 typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(), 90 Value>::type create(OpBuilder & builder,Location location,Args &&...args)91 create(OpBuilder &builder, Location location, Args &&...args) { 92 SmallVector<Value, 1> results; 93 create<OpTy>(builder, results, location, std::forward<Args>(args)...); 94 return results.front(); 95 } 96 97 /// Overload to create or fold a zero result operation. 98 template <typename OpTy, typename... Args> 99 typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResults>(), 100 OpTy>::type create(OpBuilder & builder,Location location,Args &&...args)101 create(OpBuilder &builder, Location location, Args &&...args) { 102 auto op = builder.create<OpTy>(location, std::forward<Args>(args)...); 103 SmallVector<Value, 0> unused; 104 (void)tryToFold(op.getOperation(), unused); 105 106 // Folding cannot remove a zero-result operation, so for convenience we 107 // continue to return it. 108 return op; 109 } 110 111 /// Clear out any constants cached inside of the folder. 112 void clear(); 113 114 /// Get or create a constant using the given builder. On success this returns 115 /// the constant operation, nullptr otherwise. 116 Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect, 117 Attribute value, Type type, Location loc); 118 119 private: 120 /// This map keeps track of uniqued constants by dialect, attribute, and type. 121 /// A constant operation materializes an attribute with a type. Dialects may 122 /// generate different constants with the same input attribute and type, so we 123 /// also need to track per-dialect. 124 using ConstantMap = 125 DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>; 126 127 /// Returns true if the given operation is an already folded constant that is 128 /// owned by this folder. 129 bool isFolderOwnedConstant(Operation *op) const; 130 131 /// Tries to perform folding on the given `op`. If successful, populates 132 /// `results` with the results of the folding. 133 LogicalResult tryToFold( 134 OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results, 135 function_ref<void(Operation *)> processGeneratedConstants = nullptr); 136 137 /// Try to process a set of fold results, generating constants as necessary. 138 /// Populates `results` on success, otherwise leaves it unchanged. 139 LogicalResult 140 processFoldResults(OpBuilder &builder, Operation *op, 141 SmallVectorImpl<Value> &results, 142 ArrayRef<OpFoldResult> foldResults, 143 function_ref<void(Operation *)> processGeneratedConstants); 144 145 /// Try to get or create a new constant entry. On success this returns the 146 /// constant operation, nullptr otherwise. 147 Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants, 148 Dialect *dialect, OpBuilder &builder, 149 Attribute value, Type type, Location loc); 150 151 /// A mapping between an insertion region and the constants that have been 152 /// created within it. 153 DenseMap<Region *, ConstantMap> foldScopes; 154 155 /// This map tracks all of the dialects that an operation is referenced by; 156 /// given that many dialects may generate the same constant. 157 DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects; 158 159 /// A collection of dialect folder interfaces. 160 DialectInterfaceCollection<DialectFoldInterface> interfaces; 161 }; 162 163 } // namespace mlir 164 165 #endif // MLIR_TRANSFORMS_FOLDUTILS_H 166