//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" using namespace mlir; using namespace mlir::linalg; namespace { /// Pattern to decompose a GenericOp that has more than two statements /// into one GenericOp with the first statement (i.e. peeled operation), and /// a second GenericOp with the remaining statements (i.e. residual operations). /// - The result of the first GenericOp has the same shape as the iteration /// space of the GenericOp. The body of the op yields as many values as the /// original op plus all the results of the peeled operation. /// - The second GenericOp has as many operands as the original operation plus /// all the results of the first Generic Op. It has the same number of yields as /// the original op. /// - If the result of the peeled operation was yielded by the original /// GenericOp the uses of the corresponding results will be replaced with the /// result of the first GenericOp created. /// /// Example /// /// ```mlir /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) /// outs(%init0, %init1 : ...) { /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...): /// %0 = %b0, %b1 : ... /// %1 = %0, %b2 : ... /// linalg.yield %0, %1 : ... /// } -> (..., ...) /// return %result#0, %result#1 /// ``` /// /// gets split into /// /// ```mlir /// %init = linalg.init_tensor ... /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) /// outs(%init0, %init1, %init : ...) /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): /// %0 = %b0, %b1 : ... /// linalg.yield %0, %..., %0 : ... /// } -> (..., ..., ...) /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...) /// outs(%init0, %init1 : ...) { /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): /// %1 = %b3, %b2 : ... /// linalg.yield %..., %1 : ... /// } -> (..., ...) /// return %op0#0, %op1#1 /// ``` /// /// After canonicalization this is expected to be /// /// ```mlir /// %init = linalg.init_tensor ... /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...) /// outs(%init : ...) /// ^bb0(%b0: ... , %b1: ... , %b2: ...): /// %0 = %b0, %b1 : ... /// linalg.yield %0 : ... /// } -> ... /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...) /// outs(%init1 : ...) { /// ^bb0(%b0: ... , %b1: ... , %b2: ...): /// %1 = %b1, %b0 : ... /// linalg.yield %..., %1 : ... /// } -> ... /// return %op0, %op1 /// ``` struct DecomposeLinalgOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override; private: /// Helper method to create a generic op for the peeled scalar operation. The /// created op has an empty region. GenericOp createPeeledGenericOp(GenericOp genericOp, PatternRewriter &rewriter) const; /// Helper method to create a generic op for the residual scalar operation. /// The created op has the same region as the original op. GenericOp createResidualGenericOp(GenericOp genericOp, GenericOp peeledGenericOp, PatternRewriter &rewriter) const; }; } // namespace /// Helper method to compute the range of a generic op. static SmallVector getGenericOpLoopRange(OpBuilder &b, GenericOp op) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); Location loc = op.getLoc(); auto allShapesSizes = cast(op.getOperation()).createFlatListOfOperandDims(b, loc); AffineMap map = op.getShapesToLoopsMap(); return getAsOpFoldResult(applyMapToValues(b, loc, map, allShapesSizes)); } /// Helper method to permute the list of `values` based on the `map`. SmallVector permuteValues(ArrayRef values, AffineMap map) { assert(map.isPermutation()); SmallVector permutedValues(values.size()); for (const auto &position : llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) { return expr.cast().getPosition(); }))) permutedValues[position.value()] = values[position.index()]; return permutedValues; } /// Get zero value for an element type. static Value getZero(OpBuilder &b, Location loc, Type elementType) { assert(elementType.isIntOrIndexOrFloat() && "expected scalar type while computing zero value"); if (elementType.isa()) return b.create(loc, 0, elementType); if (elementType.isIndex()) return b.create(loc, 0); // Assume float. auto floatType = elementType.cast(); return b.create( loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); } GenericOp DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, PatternRewriter &rewriter) const { Block *body = genericOp.getBody(); Operation *peeledScalarOperation = &(*body->begin()); SmallVector peeledGenericOpIndexingMaps = genericOp.getIndexingMapsArray(); /// Compute the loop ranges for operation. This is the shape of the result of /// the generic op for the peeled operation. Location loc = genericOp.getLoc(); SmallVector domain = getGenericOpLoopRange(rewriter, genericOp); SmallVector newInitValues; SmallVector newResultTypes; /// The indexing map to use for the new results is obtained by /// - Check if the result is yielded. If so use the same indexing map as the /// corresponding output /// - Identity indexing map if the result is not yielded. Operation *yieldOp = body->getTerminator(); auto getResultIndexingMap = [&](OpResult scalarOpResult) -> AffineMap { OpOperand *firstUseInYield = nullptr, *identityUseInYield = nullptr; for (OpOperand &use : scalarOpResult.getUses()) { if (use.getOwner() != yieldOp) continue; if (!firstUseInYield) firstUseInYield = &use; OpResult genericOpResult = genericOp.getResult(use.getOperandNumber()).cast(); AffineMap indexingMap = genericOp.getTiedIndexingMapForResult(genericOpResult); if (indexingMap.isIdentity()) identityUseInYield = &use; } if (identityUseInYield || !firstUseInYield) return rewriter.getMultiDimIdentityMap(domain.size()); OpResult genericOpResult = genericOp.getResult(firstUseInYield->getOperandNumber()) .cast(); return genericOp.getTiedIndexingMapForResult(genericOpResult); }; for (auto scalarResult : peeledScalarOperation->getResults()) { AffineMap resultIndexingMap = getResultIndexingMap(scalarResult); SmallVector initSize = permuteValues(domain, resultIndexingMap); Value initTensor = rewriter.create( loc, initSize, scalarResult.getType()); newInitValues.push_back(initTensor); newResultTypes.push_back(initTensor.getType()); peeledGenericOpIndexingMaps.push_back(resultIndexingMap); } /// Create the peeled generic op with an empty body. SmallVector outsOperands = genericOp.getOutputOperands(); outsOperands.append(newInitValues.begin(), newInitValues.end()); SmallVector resultTypes = llvm::to_vector(genericOp.getResultTypes()); resultTypes.append(newResultTypes.begin(), newResultTypes.end()); auto indexingMapAttr = rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); return rewriter.create( loc, resultTypes, genericOp.inputs(), outsOperands, indexingMapAttr, genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, [](OpBuilder, Location, ValueRange) {}); } GenericOp DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, GenericOp peeledGenericOp, PatternRewriter &rewriter) const { /// Append all results from the peeledGenericOps as `ins` operand for the /// residual generic op. SmallVector residualGenericOpOperands = llvm::to_vector( llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { return operand->get(); })); unsigned origNumResults = genericOp.getNumResults(); unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults(); SmallVector extraIns; for (auto resultNum : llvm::seq(origNumResults, peeledGenericOpNumResults)) extraIns.push_back(peeledGenericOp->getResult(resultNum)); residualGenericOpOperands.append(extraIns); /// Add indexing maps for the newly added operands. Use the same map /// as those used for the new results of the peeledGenericOp. auto indexingMaps = llvm::to_vector( llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) { return genericOp.getTiedIndexingMap(operand); })); for (auto resultNum : llvm::seq(origNumResults, peeledGenericOpNumResults)) { OpResult result = peeledGenericOp.getResult(resultNum).cast(); indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result)); } for (OpOperand *outOperand : genericOp.getOutputOperands()) indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand)); auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); return rewriter.create( genericOp->getLoc(), genericOp->getResultTypes(), residualGenericOpOperands, genericOp.outputs(), indexingMapAttr, genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, [](OpBuilder, Location, ValueRange) {}); } LogicalResult DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const { /// For now only match on operations where the iterator types are all parallel if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { return rewriter.notifyMatchFailure(genericOp, "unhandled decomposition of operation " "with non-parallel iterator types"); } // TODO: this could be generalized to handle `linalg.generic` with buffer // operands too but requires allocation for intermediates. Punt on this for // now. if (!genericOp.hasTensorSemantics()) { return rewriter.notifyMatchFailure( genericOp, "only operations with tensor semantics are handled"); } // TODO: For now only decompose operations where the `outs` operands values // are not accessed within the payload. This might be relaxed in future, but // needs a bit more reasoning to ensure that it is safe. if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { return genericOp.payloadUsesValueFromOperand(outOperand); })) { return rewriter.notifyMatchFailure( genericOp, "unhandled decomposition of generic op with use of out " "operand value in payload"); } if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { return !genericOp.getTiedIndexingMap(outOperand).isPermutation(); })) { return rewriter.notifyMatchFailure( genericOp, "unhandled decomposition of generic op with out operand not " "accessed using a permutation"); } /// If the op has only a single statement (apart from the yield), do nothing. Block *body = genericOp.getBody(); if (body->getOperations().size() <= 2) { return rewriter.notifyMatchFailure(genericOp, "operation has less than 3 statements"); } /// Check that the peeled statement has a scalar element type. if (llvm::any_of(body->getOperations().begin()->getResultTypes(), [](Type t) { return !t.isIntOrIndexOrFloat(); })) { return rewriter.notifyMatchFailure( &(*body->getOperations().begin()), "expected return type to be only int, index or float"); } GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter); GenericOp residualGenericOp = createResidualGenericOp(genericOp, peeledGenericOp, rewriter); /// Move the first statement of the original operation into the body of the /// generic op for the peeled operation. Block *peeledGenericOpBody = peeledGenericOp.getBody(); Block *residualGenericOpBody = residualGenericOp.getBody(); assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() && "expected split generic ops to have empty region"); peeledGenericOpBody->getOperations().splice( peeledGenericOpBody->begin(), body->getOperations(), body->begin()); residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(), body->getOperations()); Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin()); auto yieldOp = residualGenericOpBody->getTerminator(); { // Yield all the result of the peeled scalar operation. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointToEnd(peeledGenericOpBody); SmallVector yieldedVals; for (auto origYield : yieldOp->getOperands()) { if (origYield.getDefiningOp() == peeledScalarOperation) { yieldedVals.push_back(origYield); } else { yieldedVals.push_back( getZero(rewriter, genericOp.getLoc(), origYield.getType())); } } yieldedVals.append(llvm::to_vector( llvm::map_range(peeledScalarOperation->getResults(), [](OpResult opr) -> Value { return opr; }))); rewriter.create(genericOp.getLoc(), yieldedVals); } /// In the split operations, replace block arguments uses that refer to /// original operation to the block arguments of the newly created operation. unsigned origNumInputs = genericOp.getNumInputs(); for (const auto &inputBlockArg : llvm::enumerate(genericOp.getBody()->getArguments())) { Value residualOpReplacementArg = residualGenericOpBody->getArgument(inputBlockArg.index()); inputBlockArg.value().replaceUsesWithIf( residualOpReplacementArg, [&](OpOperand &use) { return use.getOwner()->getBlock() == residualGenericOpBody; }); Value peeledOpReplacementArg = peeledGenericOpBody->getArgument(inputBlockArg.index()); inputBlockArg.value().replaceUsesWithIf( peeledOpReplacementArg, [&](OpOperand &use) { return use.getOwner()->getBlock() == peeledGenericOpBody; }); } /// Before fixing up the residual operation, track what values are yielded. If /// any of those are from the peeled scalar operation, the uses of the /// corresponding result have to be remapped to result of the generic op for /// the peeled operation. SmallVector replacements; for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) { OpResult opr = yieldValue.value().dyn_cast(); if (!opr || opr.getOwner() != peeledScalarOperation) replacements.push_back(residualGenericOp.getResult(yieldValue.index())); else replacements.push_back(peeledGenericOp->getResult(yieldValue.index())); } /// Update all uses of the peeled scalar operation results in the residual op /// to the newly added arguments. { SmallVector scalarReplacements; unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults(); scalarReplacements.reserve(peeledScalarOpNumResults); for (auto num : llvm::seq(0, peeledScalarOpNumResults)) scalarReplacements.push_back( residualGenericOpBody->getArgument(num + origNumInputs)); bool allUsesReplaced = false; rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements, residualGenericOpBody, &allUsesReplaced); assert(!allUsesReplaced && "peeled scalar operation is erased when it wasnt expected to be"); } // Replace the original operation rewriter.replaceOp(genericOp, replacements); return success(); } void mlir::linalg::populateDecomposeLinalgOpsPattern( RewritePatternSet &patterns) { patterns.insert(patterns.getContext()); }