1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "PassDetail.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 16 using namespace mlir; 17 18 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { 19 if (!op->hasTrait<OpTrait::ElementwiseMappable>()) 20 return false; 21 22 // TODO: The conversion pattern can be made to work for `any_of` here, but 23 // it's more complex as it requires tracking which operands are scalars. 24 return llvm::all_of(op->getOperandTypes(), 25 [](Type type) { return type.isa<RankedTensorType>(); }); 26 } 27 28 namespace { 29 struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern { 30 ConvertStdElementwiseOpOnRankedTensors() 31 : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 32 LogicalResult matchAndRewrite(Operation *op, 33 PatternRewriter &rewriter) const final { 34 if (!isElementwiseMappableOpOnRankedTensors(op)) 35 return rewriter.notifyMatchFailure( 36 op, "requires elementwise op on ranked tensors"); 37 38 auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank(); 39 SmallVector<AffineMap, 3> indexingMaps( 40 op->getNumResults() + op->getNumOperands(), 41 rewriter.getMultiDimIdentityMap(rank)); 42 SmallVector<StringRef, 6> iteratorTypes(rank, 43 getParallelIteratorTypeName()); 44 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 45 op, /*resultTensorTypes=*/op->getResultTypes(), 46 /*inputs=*/op->getOperands(), 47 /*outputBuffers=*/ValueRange(), 48 /*initTensors=*/ValueRange(), 49 /*indexingMaps=*/indexingMaps, 50 /*iteratorTypes=*/iteratorTypes, 51 /*bodyBuilder=*/ 52 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { 53 OperationState state(loc, op->getName()); 54 state.addAttributes(op->getAttrs()); 55 state.addOperands(regionArgs); 56 auto resultTypes = llvm::to_vector<6>( 57 llvm::map_range(op->getResultTypes(), [](Type type) { 58 return type.cast<TensorType>().getElementType(); 59 })); 60 state.addTypes(resultTypes); 61 auto *scalarOp = builder.createOperation(state); 62 builder.create<linalg::YieldOp>(loc, scalarOp->getResults()); 63 }); 64 return success(); 65 } 66 }; 67 } // namespace 68 69 void mlir::populateElementwiseToLinalgConversionPatterns( 70 OwningRewritePatternList &patterns, MLIRContext *) { 71 patterns.insert<ConvertStdElementwiseOpOnRankedTensors>(); 72 } 73 74 namespace { 75 class ConvertElementwiseToLinalgPass 76 : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> { 77 78 void runOnFunction() final { 79 auto func = getOperation(); 80 auto *context = &getContext(); 81 ConversionTarget target(*context); 82 OwningRewritePatternList patterns; 83 84 populateElementwiseToLinalgConversionPatterns(patterns, context); 85 target.markUnknownOpDynamicallyLegal([](Operation *op) { 86 return !isElementwiseMappableOpOnRankedTensors(op); 87 }); 88 89 if (failed(applyPartialConversion(func, target, std::move(patterns)))) 90 signalPassFailure(); 91 } 92 }; 93 } // namespace 94 95 std::unique_ptr<OperationPass<FuncOp>> 96 mlir::createConvertElementwiseToLinalgPass() { 97 return std::make_unique<ConvertElementwiseToLinalgPass>(); 98 } 99