1 //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 14 using namespace mlir; 15 using namespace mlir::linalg; 16 17 namespace { 18 19 /// Pattern to decompose a GenericOp that has more than two statements 20 /// into one GenericOp with the first statement (i.e. peeled operation), and 21 /// a second GenericOp with the remaining statements (i.e. residual operations). 22 23 /// - The result of the first GenericOp has the same shape as the iteration 24 /// space of the GenericOp. The body of the op yields as many values as the 25 /// original op plus all the results of the peeled operation. 26 /// - The second GenericOp has as many operands as the original operation plus 27 /// all the results of the first Generic Op. It has the same number of yields as 28 /// the original op. 29 /// - If the result of the peeled operation was yielded by the original 30 /// GenericOp the uses of the corresponding results will be replaced with the 31 /// result of the first GenericOp created. 32 /// 33 /// Example 34 /// 35 /// ```mlir 36 /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) 37 /// outs(%init0, %init1 : ...) { 38 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...): 39 /// %0 = <s0> %b0, %b1 : ... 40 /// %1 = <s1> %0, %b2 : ... 41 /// linalg.yield %0, %1 : ... 42 /// } -> (..., ...) 43 /// return %result#0, %result#1 44 /// ``` 45 /// 46 /// gets split into 47 /// 48 /// ```mlir 49 /// %init = linalg.init_tensor ... 50 /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) 51 /// outs(%init0, %init1, %init : ...) 52 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): 53 /// %0 = <s0> %b0, %b1 : ... 54 /// linalg.yield %0, %..., %0 : ... 55 /// } -> (..., ..., ...) 56 /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...) 57 /// outs(%init0, %init1 : ...) { 58 /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): 59 /// %1 = <s1> %b3, %b2 : ... 60 /// linalg.yield %..., %1 : ... 61 /// } -> (..., ...) 62 /// return %op0#0, %op1#1 63 /// ``` 64 /// 65 /// After canonicalization this is expected to be 66 /// 67 /// ```mlir 68 /// %init = linalg.init_tensor ... 69 /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...) 70 /// outs(%init : ...) 71 /// ^bb0(%b0: ... , %b1: ... , %b2: ...): 72 /// %0 = <s0> %b0, %b1 : ... 73 /// linalg.yield %0 : ... 74 /// } -> ... 75 /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...) 76 /// outs(%init1 : ...) { 77 /// ^bb0(%b0: ... , %b1: ... , %b2: ...): 78 /// %1 = <s1> %b1, %b0 : ... 79 /// linalg.yield %..., %1 : ... 80 /// } -> ... 81 /// return %op0, %op1 82 /// ``` 83 struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> { 84 using OpRewritePattern<GenericOp>::OpRewritePattern; 85 86 LogicalResult matchAndRewrite(GenericOp genericOp, 87 PatternRewriter &rewriter) const override; 88 89 private: 90 /// Helper method to create a generic op for the peeled scalar operation. The 91 /// created op has an empty region. 92 GenericOp createPeeledGenericOp(GenericOp genericOp, 93 PatternRewriter &rewriter) const; 94 95 /// Helper method to create a generic op for the residual scalar operation. 96 /// The created op has the same region as the original op. 97 GenericOp createResidualGenericOp(GenericOp genericOp, 98 GenericOp peeledGenericOp, 99 PatternRewriter &rewriter) const; 100 }; 101 } // namespace 102 103 /// Helper method to compute the range of a generic op. 104 static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b, 105 GenericOp op) { 106 OpBuilder::InsertionGuard g(b); 107 b.setInsertionPoint(op); 108 Location loc = op.getLoc(); 109 auto allShapesSizes = 110 cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc); 111 AffineMap map = op.getShapesToLoopsMap(); 112 return getAsOpFoldResult(applyMapToValues(b, loc, map, allShapesSizes)); 113 } 114 115 /// Helper method to permute the list of `values` based on the `map`. 116 SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values, 117 AffineMap map) { 118 assert(map.isPermutation()); 119 SmallVector<OpFoldResult> permutedValues(values.size()); 120 for (const auto &position : 121 llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) { 122 return expr.cast<AffineDimExpr>().getPosition(); 123 }))) 124 permutedValues[position.value()] = values[position.index()]; 125 return permutedValues; 126 } 127 128 /// Get zero value for an element type. 129 static Value getZero(OpBuilder &b, Location loc, Type elementType) { 130 assert(elementType.isIntOrIndexOrFloat() && 131 "expected scalar type while computing zero value"); 132 if (elementType.isa<IntegerType>()) 133 return b.create<arith::ConstantIntOp>(loc, 0, elementType); 134 if (elementType.isIndex()) 135 return b.create<arith::ConstantIndexOp>(loc, 0); 136 // Assume float. 137 auto floatType = elementType.cast<FloatType>(); 138 return b.create<arith::ConstantFloatOp>( 139 loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); 140 } 141 142 GenericOp 143 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, 144 PatternRewriter &rewriter) const { 145 Block *body = genericOp.getBody(); 146 Operation *peeledScalarOperation = &(*body->begin()); 147 SmallVector<AffineMap> peeledGenericOpIndexingMaps = 148 genericOp.getIndexingMapsArray(); 149 150 /// Compute the loop ranges for operation. This is the shape of the result of 151 /// the generic op for the peeled operation. 152 Location loc = genericOp.getLoc(); 153 SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp); 154 SmallVector<Value> newInitValues; 155 SmallVector<Type> newResultTypes; 156 157 /// The indexing map to use for the new results is obtained by 158 /// - Check if the result is yielded. If so use the same indexing map as the 159 /// corresponding output 160 /// - Identity indexing map if the result is not yielded. 161 Operation *yieldOp = body->getTerminator(); 162 auto getResultIndexingMap = [&](OpResult scalarOpResult) -> AffineMap { 163 OpOperand *firstUseInYield = nullptr, *identityUseInYield = nullptr; 164 for (OpOperand &use : scalarOpResult.getUses()) { 165 if (use.getOwner() != yieldOp) 166 continue; 167 if (!firstUseInYield) 168 firstUseInYield = &use; 169 OpResult genericOpResult = 170 genericOp.getResult(use.getOperandNumber()).cast<OpResult>(); 171 AffineMap indexingMap = 172 genericOp.getTiedIndexingMapForResult(genericOpResult); 173 if (indexingMap.isIdentity()) 174 identityUseInYield = &use; 175 } 176 if (identityUseInYield || !firstUseInYield) 177 return rewriter.getMultiDimIdentityMap(domain.size()); 178 OpResult genericOpResult = 179 genericOp.getResult(firstUseInYield->getOperandNumber()) 180 .cast<OpResult>(); 181 return genericOp.getTiedIndexingMapForResult(genericOpResult); 182 }; 183 184 for (auto scalarResult : peeledScalarOperation->getResults()) { 185 AffineMap resultIndexingMap = getResultIndexingMap(scalarResult); 186 SmallVector<OpFoldResult> initSize = 187 permuteValues(domain, resultIndexingMap); 188 Value initTensor = rewriter.create<linalg::InitTensorOp>( 189 loc, initSize, scalarResult.getType()); 190 newInitValues.push_back(initTensor); 191 newResultTypes.push_back(initTensor.getType()); 192 peeledGenericOpIndexingMaps.push_back(resultIndexingMap); 193 } 194 195 /// Create the peeled generic op with an empty body. 196 SmallVector<Value> outsOperands = genericOp.getOutputOperands(); 197 outsOperands.append(newInitValues.begin(), newInitValues.end()); 198 SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes()); 199 resultTypes.append(newResultTypes.begin(), newResultTypes.end()); 200 auto indexingMapAttr = 201 rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); 202 return rewriter.create<GenericOp>( 203 loc, resultTypes, genericOp.inputs(), outsOperands, indexingMapAttr, 204 genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, 205 [](OpBuilder, Location, ValueRange) {}); 206 } 207 208 GenericOp 209 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, 210 GenericOp peeledGenericOp, 211 PatternRewriter &rewriter) const { 212 /// Append all results from the peeledGenericOps as `ins` operand for the 213 /// residual generic op. 214 SmallVector<Value> residualGenericOpOperands = llvm::to_vector( 215 llvm::map_range(genericOp.getInputOperands(), 216 [](OpOperand *operand) { return operand->get(); })); 217 unsigned origNumResults = genericOp.getNumResults(); 218 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults(); 219 SmallVector<Value> extraIns; 220 for (auto resultNum : 221 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) 222 extraIns.push_back(peeledGenericOp->getResult(resultNum)); 223 residualGenericOpOperands.append(extraIns); 224 225 /// Add indexing maps for the newly added operands. Use the same map 226 /// as those used for the new results of the peeledGenericOp. 227 auto indexingMaps = llvm::to_vector( 228 llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) { 229 return genericOp.getTiedIndexingMap(operand); 230 })); 231 for (auto resultNum : 232 llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) { 233 OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>(); 234 indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result)); 235 } 236 for (OpOperand *outOperand : genericOp.getOutputOperands()) 237 indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand)); 238 239 auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); 240 return rewriter.create<GenericOp>( 241 genericOp->getLoc(), genericOp->getResultTypes(), 242 residualGenericOpOperands, genericOp.outputs(), indexingMapAttr, 243 genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, 244 [](OpBuilder, Location, ValueRange) {}); 245 } 246 247 LogicalResult 248 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, 249 PatternRewriter &rewriter) const { 250 /// For now only match on operations where the iterator types are all parallel 251 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { 252 return rewriter.notifyMatchFailure(genericOp, 253 "unhandled decomposition of operation " 254 "with non-parallel iterator types"); 255 } 256 // TODO: this could be generalized to handle `linalg.generic` with buffer 257 // operands too but requires allocation for intermediates. Punt on this for 258 // now. 259 if (!genericOp.hasTensorSemantics()) { 260 return rewriter.notifyMatchFailure( 261 genericOp, "only operations with tensor semantics are handled"); 262 } 263 264 // TODO: For now only decompose operations where the `outs` operands values 265 // are not accessed within the payload. This might be relaxed in future, but 266 // needs a bit more reasoning to ensure that it is safe. 267 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { 268 return genericOp.payloadUsesValueFromOperand(outOperand); 269 })) { 270 return rewriter.notifyMatchFailure( 271 genericOp, "unhandled decomposition of generic op with use of out " 272 "operand value in payload"); 273 } 274 275 if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { 276 return !genericOp.getTiedIndexingMap(outOperand).isPermutation(); 277 })) { 278 return rewriter.notifyMatchFailure( 279 genericOp, "unhandled decomposition of generic op with out operand not " 280 "accessed using a permutation"); 281 } 282 283 /// If the op has only a single statement (apart from the yield), do nothing. 284 Block *body = genericOp.getBody(); 285 if (body->getOperations().size() <= 2) { 286 return rewriter.notifyMatchFailure(genericOp, 287 "operation has less than 3 statements"); 288 } 289 290 /// Check that the peeled statement has a scalar element type. 291 if (llvm::any_of(body->getOperations().begin()->getResultTypes(), 292 [](Type t) { return !t.isIntOrIndexOrFloat(); })) { 293 return rewriter.notifyMatchFailure( 294 &(*body->getOperations().begin()), 295 "expected return type to be only int, index or float"); 296 } 297 298 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter); 299 GenericOp residualGenericOp = 300 createResidualGenericOp(genericOp, peeledGenericOp, rewriter); 301 302 /// Move the first statement of the original operation into the body of the 303 /// generic op for the peeled operation. 304 Block *peeledGenericOpBody = peeledGenericOp.getBody(); 305 Block *residualGenericOpBody = residualGenericOp.getBody(); 306 assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() && 307 "expected split generic ops to have empty region"); 308 peeledGenericOpBody->getOperations().splice( 309 peeledGenericOpBody->begin(), body->getOperations(), body->begin()); 310 residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(), 311 body->getOperations()); 312 313 Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin()); 314 auto yieldOp = residualGenericOpBody->getTerminator(); 315 { 316 // Yield all the result of the peeled scalar operation. 317 OpBuilder::InsertionGuard g(rewriter); 318 rewriter.setInsertionPointToEnd(peeledGenericOpBody); 319 SmallVector<Value> yieldedVals; 320 for (auto origYield : yieldOp->getOperands()) { 321 if (origYield.getDefiningOp() == peeledScalarOperation) { 322 yieldedVals.push_back(origYield); 323 } else { 324 yieldedVals.push_back( 325 getZero(rewriter, genericOp.getLoc(), origYield.getType())); 326 } 327 } 328 yieldedVals.append(llvm::to_vector( 329 llvm::map_range(peeledScalarOperation->getResults(), 330 [](OpResult opr) -> Value { return opr; }))); 331 rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals); 332 } 333 334 /// In the split operations, replace block arguments uses that refer to 335 /// original operation to the block arguments of the newly created operation. 336 unsigned origNumInputs = genericOp.getNumInputs(); 337 for (const auto &inputBlockArg : 338 llvm::enumerate(genericOp.getBody()->getArguments())) { 339 Value residualOpReplacementArg = 340 residualGenericOpBody->getArgument(inputBlockArg.index()); 341 inputBlockArg.value().replaceUsesWithIf( 342 residualOpReplacementArg, [&](OpOperand &use) { 343 return use.getOwner()->getBlock() == residualGenericOpBody; 344 }); 345 346 Value peeledOpReplacementArg = 347 peeledGenericOpBody->getArgument(inputBlockArg.index()); 348 inputBlockArg.value().replaceUsesWithIf( 349 peeledOpReplacementArg, [&](OpOperand &use) { 350 return use.getOwner()->getBlock() == peeledGenericOpBody; 351 }); 352 } 353 354 /// Before fixing up the residual operation, track what values are yielded. If 355 /// any of those are from the peeled scalar operation, the uses of the 356 /// corresponding result have to be remapped to result of the generic op for 357 /// the peeled operation. 358 SmallVector<Value> replacements; 359 for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) { 360 OpResult opr = yieldValue.value().dyn_cast<OpResult>(); 361 if (!opr || opr.getOwner() != peeledScalarOperation) 362 replacements.push_back(residualGenericOp.getResult(yieldValue.index())); 363 else 364 replacements.push_back(peeledGenericOp->getResult(yieldValue.index())); 365 } 366 367 /// Update all uses of the peeled scalar operation results in the residual op 368 /// to the newly added arguments. 369 { 370 SmallVector<Value> scalarReplacements; 371 unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults(); 372 scalarReplacements.reserve(peeledScalarOpNumResults); 373 for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults)) 374 scalarReplacements.push_back( 375 residualGenericOpBody->getArgument(num + origNumInputs)); 376 bool allUsesReplaced = false; 377 rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements, 378 residualGenericOpBody, &allUsesReplaced); 379 assert(!allUsesReplaced && 380 "peeled scalar operation is erased when it wasnt expected to be"); 381 } 382 383 // Replace the original operation 384 rewriter.replaceOp(genericOp, replacements); 385 return success(); 386 } 387 388 void mlir::linalg::populateDecomposeLinalgOpsPattern( 389 RewritePatternSet &patterns) { 390 patterns.insert<DecomposeLinalgOp>(patterns.getContext()); 391 } 392