1 //===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements constant folding on Linalg operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/IR/Matchers.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Support/LLVM.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 namespace { 25 /// Base class for constant folding linalg.generic ops with N inputs, 1 output, 26 /// and permutation indexing maps. 27 /// 28 /// `ConcreteType` should provide methods with signatures 29 /// 30 /// ```c++ 31 /// bool matchIndexingMaps(GenericOp genericOp) const; 32 /// RegionComputationFn getRegionComputeFn(GenericOp) const; 33 /// ``` 34 /// 35 /// The latter inspects the region and returns the computation inside as a 36 /// functor. The functor will be invoked with constant elements for all inputs 37 /// and should return the corresponding computed constant element for output. 38 template <typename ConcreteType> 39 class FoldConstantBase : public OpRewritePattern<GenericOp> { 40 public: 41 struct APIntOrFloat { 42 Optional<APInt> apInt; 43 Optional<APFloat> apFloat; 44 }; 45 struct APIntOrFloatArray { 46 SmallVector<APInt> apInts; 47 SmallVector<APFloat> apFloats; 48 }; 49 using RegionComputationFn = 50 std::function<APIntOrFloat(const APIntOrFloatArray &)>; 51 52 FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn, 53 PatternBenefit benefit = 1) 54 : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {} 55 56 LogicalResult matchAndRewrite(GenericOp genericOp, 57 PatternRewriter &rewriter) const override { 58 if (genericOp.hasBufferSemantics()) 59 return failure(); 60 61 // Only support ops generating one output for now. 62 if (genericOp.getNumOutputs() != 1) 63 return failure(); 64 65 auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>(); 66 // Require the output types to be static given that we are generating 67 // constants. 68 if (!outputType || !outputType.hasStaticShape()) 69 return failure(); 70 71 if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { 72 return operand->get().getType().isa<ShapedType>(); 73 })) 74 return failure(); 75 76 // Make sure all element types are the same. 77 auto getOperandElementType = [](OpOperand *operand) { 78 return operand->get().getType().cast<ShapedType>().getElementType(); 79 }; 80 if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), 81 getOperandElementType))) 82 return failure(); 83 84 // We can only handle the case where we have int/float elements. 85 auto elementType = outputType.getElementType(); 86 if (!elementType.isIntOrFloat()) 87 return failure(); 88 89 // Require all indexing maps to be permutations for now. This is common and 90 // it simplifies input/output access greatly: we can do the data shuffling 91 // entirely in the compiler, without needing to turn all indices into 92 // Values, and then do affine apply on them, and then match back the 93 // constant again. 94 if (!llvm::all_of(genericOp.getIndexingMaps(), 95 [](AffineMap map) { return map.isPermutation(); })) 96 return failure(); 97 98 for (OpOperand *operand : genericOp.getOutputOperands()) { 99 if (genericOp.payloadUsesValueFromOperand(operand)) 100 return failure(); 101 } 102 103 // Further check the indexing maps are okay for the ConcreteType. 104 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp)) 105 return failure(); 106 107 // Defer to the concrete type to check the region and discover the 108 // computation inside. 109 RegionComputationFn computeFn = 110 static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp); 111 if (!computeFn) 112 return failure(); 113 114 // All inputs should be constants. 115 int numInputs = genericOp.getNumInputs(); 116 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); 117 for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) { 118 if (!matchPattern(operand.value()->get(), 119 m_Constant(&inputValues[operand.index()]))) 120 return failure(); 121 } 122 123 // Identified this as a potential candidate for folding. Now check the 124 // policy to see whether we are allowed to proceed. 125 for (int i = 0; i < numInputs; ++i) { 126 OpOperand *consumer = genericOp.getInputOperand(i); 127 OpResult producer = consumer->get().cast<OpResult>(); 128 if (!controlFn(producer, *consumer)) 129 return failure(); 130 } 131 132 auto linalgOp = cast<LinalgOp>(genericOp.getOperation()); 133 SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes(); 134 int64_t numElements = outputType.getNumElements(); 135 136 // Use APInt/APFloat instead of Attribute here for constructing the output. 137 // This helps to avoid blowing up compiler memory usage: Attributes would 138 // unify the following cases but they have lifetime as the MLIRContext. 139 SmallVector<APInt> intOutputValues; 140 SmallVector<APFloat> fpOutputValues; 141 if (elementType.template isa<FloatType>()) 142 fpOutputValues.resize(numElements, APFloat(0.f)); 143 else 144 intOutputValues.resize(numElements); 145 146 // Return the constant dim positions from the given permutation map. 147 auto getDimPositions = [](AffineMap map) { 148 SmallVector<unsigned> dims; 149 dims.reserve(map.getNumResults()); 150 for (AffineExpr result : map.getResults()) { 151 dims.push_back(result.cast<AffineDimExpr>().getPosition()); 152 } 153 return dims; 154 }; 155 156 SmallVector<SmallVector<unsigned>> inputDims; 157 for (int i = 0; i < numInputs; ++i) 158 inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); 159 auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); 160 auto outputShape = outputType.getShape(); 161 162 // Allocate small vectors for index delinearization. Initial values do not 163 // matter here as they will be overwritten later. 164 SmallVector<uint64_t> indices(loopBounds.size(), 0); 165 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); 166 SmallVector<SmallVector<uint64_t>> srcIndices( 167 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); 168 SmallVector<uint64_t> srcLinearIndices(numInputs, 0); 169 uint64_t dstLinearIndex = 0; 170 171 // Allocate spaces for compute function inputs. Initial values do not matter 172 // here as they will be overwritten later. 173 APIntOrFloatArray computeFnInputs; 174 175 auto inputShapes = llvm::to_vector<4>( 176 llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { 177 return operand->get().getType().cast<ShapedType>().getShape(); 178 })); 179 180 // Given a `linearIndex`, remap it to a linear index to access linalg op 181 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, 182 // `srcLinearIndices`, `dstLinearIndex` in place. 183 auto computeRemappedLinearIndex = [&](int linearIndex) { 184 int totalCount = linearIndex; 185 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 186 indices[dim] = totalCount % loopBounds[dim]; 187 totalCount /= loopBounds[dim]; 188 } 189 190 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { 191 for (int i = 0; i < numInputs; ++i) 192 srcIndices[i][dim] = indices[inputDims[i][dim]]; 193 dstIndices[dim] = indices[outputDims[dim]]; 194 } 195 196 dstLinearIndex = dstIndices.front(); 197 for (int i = 0; i < numInputs; ++i) 198 srcLinearIndices[i] = srcIndices[i].front(); 199 200 for (int dim = 1; dim < outputType.getRank(); ++dim) { 201 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; 202 for (int i = 0; i < numInputs; ++i) 203 srcLinearIndices[i] = 204 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; 205 } 206 }; 207 208 bool isFloat = elementType.isa<FloatType>(); 209 if (isFloat) { 210 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; 211 for (int i = 0; i < numInputs; ++i) 212 inFpRanges.push_back(inputValues[i].getValues<APFloat>()); 213 214 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); 215 216 // Transpose the input constant. Because we don't know its rank in 217 // advance, we need to loop over the range [0, element count) and 218 // delinearize the index. 219 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 220 computeRemappedLinearIndex(linearIndex); 221 222 // Collect constant elements for all inputs at this loop iteration. 223 for (int i = 0; i < numInputs; ++i) 224 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; 225 226 // Invoke the computation to get the corresponding constant output 227 // element. 228 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; 229 } 230 } else { 231 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; 232 for (int i = 0; i < numInputs; ++i) 233 inIntRanges.push_back(inputValues[i].getValues<APInt>()); 234 235 computeFnInputs.apInts.resize(numInputs); 236 237 // Transpose the input constant. Because we don't know its rank in 238 // advance, we need to loop over the range [0, element count) and 239 // delinearize the index. 240 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { 241 computeRemappedLinearIndex(linearIndex); 242 243 // Collect constant elements for all inputs at this loop iteration. 244 for (int i = 0; i < numInputs; ++i) 245 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; 246 247 // Invoke the computation to get the corresponding constant output 248 // element. 249 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; 250 } 251 } 252 253 DenseElementsAttr outputAttr = 254 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) 255 : DenseElementsAttr::get(outputType, intOutputValues); 256 257 rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr); 258 return success(); 259 } 260 261 private: 262 ControlFusionFn controlFn; 263 }; 264 265 // Folds linalg.generic ops that are actually transposes on constant values. 266 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { 267 using FoldConstantBase::FoldConstantBase; 268 269 bool matchIndexingMaps(GenericOp genericOp) const { 270 // We should have one input and one output. 271 return genericOp.getIndexingMaps().size() == 2; 272 } 273 274 RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { 275 // Make sure the region only contains a yield op. 276 Block &body = genericOp.region().front(); 277 if (!llvm::hasSingleElement(body)) 278 return nullptr; 279 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 280 if (!yieldOp) 281 return nullptr; 282 283 // The yield op should return the block argument corresponds to the input. 284 for (Value yieldVal : yieldOp.values()) { 285 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 286 if (!yieldArg || yieldArg.getOwner() != &body) 287 return nullptr; 288 if (yieldArg.getArgNumber() != 0) 289 return nullptr; 290 } 291 292 // No computation; just return the orginal value. 293 return [](const APIntOrFloatArray &inputs) { 294 if (inputs.apFloats.empty()) 295 return APIntOrFloat{inputs.apInts.front(), llvm::None}; 296 return APIntOrFloat{llvm::None, inputs.apFloats.front()}; 297 }; 298 } 299 300 ControlFusionFn controlFn; 301 }; 302 } // namespace 303 304 void mlir::linalg::populateConstantFoldLinalgOperations( 305 RewritePatternSet &patterns, const ControlFusionFn &controlFn) { 306 MLIRContext *context = patterns.getContext(); 307 patterns.insert<FoldConstantTranspose>(context, controlFn); 308 } 309