1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "PassDetail.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 19 using namespace mlir; 20 21 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { 22 if (!OpTrait::hasElementwiseMappableTraits(op)) 23 return false; 24 25 // TODO: The conversion pattern can be made to work for `any_of` here, but 26 // it's more complex as it requires tracking which operands are scalars. 27 return llvm::all_of(op->getOperandTypes(), 28 [](Type type) { return type.isa<RankedTensorType>(); }); 29 } 30 31 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over 32 /// the result types and return a list of values such that, for each result type 33 /// `t` and value `v` at the same index `idx`: 34 /// 1. `v.getType() == t` 35 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first 36 /// such operand. Then`v == operand_first`. 37 /// 3. Otherwise, v is a newly created `linalg::InitTensorOp` with: 38 /// a. Static and dynamic dims extracted from the first operand of `op`. 39 /// b. Elemental type equal to the elemental type of `t`. 40 /// 41 /// This is sufficient because ElementwiseMappable guarantees that "The static 42 /// types of all vector (resp. tensor) operands and results must have the same 43 /// shape". 44 static SmallVector<Value, 4> 45 getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { 46 assert(isElementwiseMappableOpOnRankedTensors(op)); 47 Location loc = op->getLoc(); 48 ValueRange operands = op->getOperands(); 49 TypeRange rankedTensorTypes = op->getResultTypes(); 50 SmallVector<Value, 4> res; 51 res.reserve(rankedTensorTypes.size()); 52 for (Type t : rankedTensorTypes) { 53 // Try to find an operand with type matching the result tensor. 54 bool found = false; 55 for (Value v : operands) { 56 if (v.getType() == t) { 57 found = true; 58 res.push_back(v); 59 break; 60 } 61 } 62 if (found) 63 continue; 64 65 // Extract static / dynamic shape mix from the first operand. 66 Value firstOperand = operands.front(); 67 auto rankedTensorType = t.cast<RankedTensorType>(); 68 auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); 69 auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b); 70 71 res.push_back(b.create<linalg::InitTensorOp>( 72 loc, dynamicShape, staticShape, rankedTensorType.getElementType())); 73 } 74 return res; 75 } 76 77 namespace { 78 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { 79 ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context) 80 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 81 LogicalResult matchAndRewrite(Operation *op, 82 PatternRewriter &rewriter) const final { 83 if (!isElementwiseMappableOpOnRankedTensors(op)) 84 return rewriter.notifyMatchFailure( 85 op, "requires elementwise op on ranked tensors"); 86 87 auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank(); 88 SmallVector<AffineMap, 3> indexingMaps( 89 op->getNumResults() + op->getNumOperands(), 90 rewriter.getMultiDimIdentityMap(rank)); 91 SmallVector<StringRef, 6> iteratorTypes(rank, 92 getParallelIteratorTypeName()); 93 auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); 94 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 95 op, /*resultTensorTypes=*/op->getResultTypes(), 96 /*inputs=*/op->getOperands(), 97 /*outputs=*/outputs, 98 /*indexingMaps=*/indexingMaps, 99 /*iteratorTypes=*/iteratorTypes, 100 /*bodyBuilder=*/ 101 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { 102 OperationState state(loc, op->getName()); 103 state.addAttributes(op->getAttrs()); 104 // Only take the input operands in the cloned elementwise op. 105 state.addOperands(regionArgs.take_front(op->getNumOperands())); 106 auto resultTypes = llvm::to_vector<6>( 107 llvm::map_range(op->getResultTypes(), [](Type type) { 108 return type.cast<TensorType>().getElementType(); 109 })); 110 state.addTypes(resultTypes); 111 auto *scalarOp = builder.createOperation(state); 112 builder.create<linalg::YieldOp>(loc, scalarOp->getResults()); 113 }); 114 return success(); 115 } 116 }; 117 } // namespace 118 119 void mlir::linalg::populateElementwiseToLinalgConversionPatterns( 120 RewritePatternSet &patterns) { 121 patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>( 122 patterns.getContext()); 123 } 124 125 namespace { 126 class ConvertElementwiseToLinalgPass 127 : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> { 128 129 void runOnFunction() final { 130 auto func = getOperation(); 131 auto *context = &getContext(); 132 ConversionTarget target(*context); 133 RewritePatternSet patterns(context); 134 135 mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns); 136 target.markUnknownOpDynamicallyLegal([](Operation *op) { 137 return !isElementwiseMappableOpOnRankedTensors(op); 138 }); 139 140 if (failed(applyPartialConversion(func, target, std::move(patterns)))) 141 signalPassFailure(); 142 } 143 }; 144 } // namespace 145 146 std::unique_ptr<OperationPass<FuncOp>> 147 mlir::createConvertElementwiseToLinalgPass() { 148 return std::make_unique<ConvertElementwiseToLinalgPass>(); 149 } 150