1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "PassDetail.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/Utils/Utils.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 20 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { 21 if (!op->hasTrait<OpTrait::ElementwiseMappable>()) 22 return false; 23 24 // TODO: The conversion pattern can be made to work for `any_of` here, but 25 // it's more complex as it requires tracking which operands are scalars. 26 return llvm::all_of(op->getOperandTypes(), 27 [](Type type) { return type.isa<RankedTensorType>(); }); 28 } 29 30 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over 31 /// the result types and return a list of values such that, for each result type 32 /// `t` and value `v` at the same index `idx`: 33 /// 1. `v.getType() == t` 34 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first 35 /// such operand. Then`v == operand_first`. 36 /// 3. Otherwise, v is a newly created `linalg::InitTensorOp` with: 37 /// a. Static and dynamic dims extracted from the first operand of `op`. 38 /// b. Elemental type equal to the elemental type of `t`. 39 /// 40 /// This is sufficient because ElementwiseMappable guarantees that "The static 41 /// types of all vector (resp. tensor) operands and results must have the same 42 /// shape". 43 static SmallVector<Value, 4> 44 getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { 45 assert(isElementwiseMappableOpOnRankedTensors(op)); 46 Location loc = op->getLoc(); 47 ValueRange operands = op->getOperands(); 48 TypeRange rankedTensorTypes = op->getResultTypes(); 49 SmallVector<Value, 4> res; 50 res.reserve(rankedTensorTypes.size()); 51 for (Type t : rankedTensorTypes) { 52 // Try to find an operand with type matching the result tensor. 53 bool found = false; 54 for (Value v : operands) { 55 if (v.getType() == t) { 56 found = true; 57 res.push_back(v); 58 break; 59 } 60 } 61 if (found) 62 continue; 63 64 // Extract static / dynamic shape mix from the first operand. 65 Value firstOperand = operands.front(); 66 auto rankedTensorType = t.cast<RankedTensorType>(); 67 auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); 68 auto dynamicShape = getDynOperands(loc, firstOperand, b); 69 70 res.push_back(b.create<linalg::InitTensorOp>( 71 loc, dynamicShape, staticShape, rankedTensorType.getElementType())); 72 } 73 return res; 74 } 75 76 namespace { 77 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { 78 ConvertAnyElementwiseMappableOpOnRankedTensors() 79 : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 80 LogicalResult matchAndRewrite(Operation *op, 81 PatternRewriter &rewriter) const final { 82 if (!isElementwiseMappableOpOnRankedTensors(op)) 83 return rewriter.notifyMatchFailure( 84 op, "requires elementwise op on ranked tensors"); 85 86 auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank(); 87 SmallVector<AffineMap, 3> indexingMaps( 88 op->getNumResults() + op->getNumOperands(), 89 rewriter.getMultiDimIdentityMap(rank)); 90 SmallVector<StringRef, 6> iteratorTypes(rank, 91 getParallelIteratorTypeName()); 92 auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); 93 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 94 op, /*resultTensorTypes=*/op->getResultTypes(), 95 /*inputs=*/op->getOperands(), 96 /*outputs=*/outputs, 97 /*indexingMaps=*/indexingMaps, 98 /*iteratorTypes=*/iteratorTypes, 99 /*bodyBuilder=*/ 100 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { 101 OperationState state(loc, op->getName()); 102 state.addAttributes(op->getAttrs()); 103 // Only take the input operands in the cloned elementwise op. 104 state.addOperands(regionArgs.take_front(op->getNumOperands())); 105 auto resultTypes = llvm::to_vector<6>( 106 llvm::map_range(op->getResultTypes(), [](Type type) { 107 return type.cast<TensorType>().getElementType(); 108 })); 109 state.addTypes(resultTypes); 110 auto *scalarOp = builder.createOperation(state); 111 builder.create<linalg::YieldOp>(loc, scalarOp->getResults()); 112 }); 113 return success(); 114 } 115 }; 116 } // namespace 117 118 void mlir::populateElementwiseToLinalgConversionPatterns( 119 OwningRewritePatternList &patterns, MLIRContext *) { 120 patterns.insert<ConvertAnyElementwiseMappableOpOnRankedTensors>(); 121 } 122 123 namespace { 124 class ConvertElementwiseToLinalgPass 125 : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> { 126 127 void runOnFunction() final { 128 auto func = getOperation(); 129 auto *context = &getContext(); 130 ConversionTarget target(*context); 131 OwningRewritePatternList patterns; 132 133 populateElementwiseToLinalgConversionPatterns(patterns, context); 134 target.markUnknownOpDynamicallyLegal([](Operation *op) { 135 return !isElementwiseMappableOpOnRankedTensors(op); 136 }); 137 138 if (failed(applyPartialConversion(func, target, std::move(patterns)))) 139 signalPassFailure(); 140 } 141 }; 142 } // namespace 143 144 std::unique_ptr<OperationPass<FuncOp>> 145 mlir::createConvertElementwiseToLinalgPass() { 146 return std::make_unique<ConvertElementwiseToLinalgPass>(); 147 } 148