1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "PassDetail.h" 12 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 20 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { 21 if (!OpTrait::hasElementwiseMappableTraits(op)) 22 return false; 23 24 // TODO: The conversion pattern can be made to work for `any_of` here, but 25 // it's more complex as it requires tracking which operands are scalars. 26 return llvm::all_of(op->getOperandTypes(), 27 [](Type type) { return type.isa<RankedTensorType>(); }); 28 } 29 30 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over 31 /// the result types and return a list of values such that, for each result type 32 /// `t` and value `v` at the same index `idx`: 33 /// 1. `v.getType() == t` 34 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first 35 /// such operand. Then`v == operand_first`. 36 /// 3. Otherwise, v is a newly created `linalg::InitTensorOp` with: 37 /// a. Static and dynamic dims extracted from the first operand of `op`. 38 /// b. Elemental type equal to the elemental type of `t`. 39 /// 40 /// This is sufficient because ElementwiseMappable guarantees that "The static 41 /// types of all vector (resp. tensor) operands and results must have the same 42 /// shape". 43 static SmallVector<Value, 4> 44 getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { 45 assert(isElementwiseMappableOpOnRankedTensors(op)); 46 Location loc = op->getLoc(); 47 ValueRange operands = op->getOperands(); 48 TypeRange rankedTensorTypes = op->getResultTypes(); 49 SmallVector<Value, 4> res; 50 res.reserve(rankedTensorTypes.size()); 51 for (Type t : rankedTensorTypes) { 52 // Try to find an operand with type matching the result tensor. 53 bool found = false; 54 for (Value v : operands) { 55 if (v.getType() == t) { 56 found = true; 57 res.push_back(v); 58 break; 59 } 60 } 61 if (found) 62 continue; 63 64 // Extract static / dynamic shape mix from the first operand. 65 Value firstOperand = operands.front(); 66 auto rankedTensorType = t.cast<RankedTensorType>(); 67 auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); 68 auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b); 69 70 res.push_back(b.create<linalg::InitTensorOp>( 71 loc, dynamicShape, staticShape, rankedTensorType.getElementType())); 72 } 73 return res; 74 } 75 76 namespace { 77 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { 78 ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context) 79 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 80 LogicalResult matchAndRewrite(Operation *op, 81 PatternRewriter &rewriter) const final { 82 if (!isElementwiseMappableOpOnRankedTensors(op)) 83 return rewriter.notifyMatchFailure( 84 op, "requires elementwise op on ranked tensors"); 85 86 auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank(); 87 SmallVector<AffineMap, 3> indexingMaps( 88 op->getNumResults() + op->getNumOperands(), 89 rewriter.getMultiDimIdentityMap(rank)); 90 SmallVector<StringRef, 6> iteratorTypes(rank, 91 getParallelIteratorTypeName()); 92 auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); 93 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 94 op, /*resultTensorTypes=*/op->getResultTypes(), 95 /*inputs=*/op->getOperands(), 96 /*outputs=*/outputs, 97 /*indexingMaps=*/indexingMaps, 98 /*iteratorTypes=*/iteratorTypes, 99 /*bodyBuilder=*/ 100 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { 101 auto resultTypes = llvm::to_vector<6>( 102 llvm::map_range(op->getResultTypes(), [](Type type) { 103 return type.cast<TensorType>().getElementType(); 104 })); 105 auto *scalarOp = 106 builder.create(loc, op->getName().getIdentifier(), 107 regionArgs.take_front(op->getNumOperands()), 108 resultTypes, op->getAttrs()); 109 builder.create<linalg::YieldOp>(loc, scalarOp->getResults()); 110 }); 111 return success(); 112 } 113 }; 114 } // namespace 115 116 void mlir::linalg::populateElementwiseToLinalgConversionPatterns( 117 RewritePatternSet &patterns) { 118 patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>( 119 patterns.getContext()); 120 } 121 122 namespace { 123 class ConvertElementwiseToLinalgPass 124 : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> { 125 126 void runOnOperation() final { 127 auto *func = getOperation(); 128 auto *context = &getContext(); 129 ConversionTarget target(*context); 130 RewritePatternSet patterns(context); 131 132 mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns); 133 target.markUnknownOpDynamicallyLegal([](Operation *op) { 134 return !isElementwiseMappableOpOnRankedTensors(op); 135 }); 136 137 if (failed(applyPartialConversion(func, target, std::move(patterns)))) 138 signalPassFailure(); 139 } 140 }; 141 } // namespace 142 143 std::unique_ptr<Pass> mlir::createConvertElementwiseToLinalgPass() { 144 return std::make_unique<ConvertElementwiseToLinalgPass>(); 145 } 146