1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "PassDetail.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 16 using namespace mlir; 17 18 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { 19 if (!op->hasTrait<OpTrait::ElementwiseMappable>()) 20 return false; 21 22 // TODO: The conversion pattern can be made to work for `any_of` here, but 23 // it's more complex as it requires tracking which operands are scalars. 24 return llvm::all_of(op->getOperandTypes(), 25 [](Type type) { return type.isa<RankedTensorType>(); }); 26 } 27 28 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over 29 /// the result types and return a list of values such that, for each result type 30 /// `t` and value `v` at the same index `idx`: 31 /// 1. `v.getType() == t` 32 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first 33 /// such operand. Then`v == operand_first`. 34 /// 3. Otherwise, v is a newly created `linalg::InitTensorOp` with: 35 /// a. Static and dynamic dims extracted from the first operand of `op`. 36 /// b. Elemental type equal to the elemental type of `t`. 37 /// 38 /// This is sufficient because ElementwiseMappable guarantees that "The static 39 /// types of all vector (resp. tensor) operands and results must have the same 40 /// shape". 41 static SmallVector<Value, 4> 42 getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { 43 assert(isElementwiseMappableOpOnRankedTensors(op)); 44 Location loc = op->getLoc(); 45 ValueRange operands = op->getOperands(); 46 TypeRange rankedTensorTypes = op->getResultTypes(); 47 SmallVector<Value, 4> res; 48 res.reserve(rankedTensorTypes.size()); 49 for (Type t : rankedTensorTypes) { 50 // Try to find an operand with type matching the result tensor. 51 bool found = false; 52 for (Value v : operands) { 53 if (v.getType() == t) { 54 found = true; 55 res.push_back(v); 56 break; 57 } 58 } 59 if (found) 60 continue; 61 62 // Extract static / dynamic shape mix from the first operand. 63 Value firstOperand = operands.front(); 64 auto rankedTensorType = t.cast<RankedTensorType>(); 65 SmallVector<Value, 8> dynamicShape; 66 SmallVector<int64_t, 8> staticShape; 67 dynamicShape.reserve(rankedTensorType.getRank()); 68 staticShape.reserve(rankedTensorType.getRank()); 69 unsigned idx = 0; 70 for (auto shape : rankedTensorType.getShape()) { 71 staticShape.push_back(shape); 72 if (rankedTensorType.isDynamicDim(idx)) 73 dynamicShape.push_back(b.create<DimOp>(loc, firstOperand, idx)); 74 ++idx; 75 } 76 // Create init tensor. 77 res.push_back(b.create<linalg::InitTensorOp>( 78 loc, dynamicShape, staticShape, rankedTensorType.getElementType())); 79 } 80 return res; 81 } 82 83 namespace { 84 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { 85 ConvertAnyElementwiseMappableOpOnRankedTensors() 86 : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} 87 LogicalResult matchAndRewrite(Operation *op, 88 PatternRewriter &rewriter) const final { 89 if (!isElementwiseMappableOpOnRankedTensors(op)) 90 return rewriter.notifyMatchFailure( 91 op, "requires elementwise op on ranked tensors"); 92 93 auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank(); 94 SmallVector<AffineMap, 3> indexingMaps( 95 op->getNumResults() + op->getNumOperands(), 96 rewriter.getMultiDimIdentityMap(rank)); 97 SmallVector<StringRef, 6> iteratorTypes(rank, 98 getParallelIteratorTypeName()); 99 auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); 100 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 101 op, /*resultTensorTypes=*/op->getResultTypes(), 102 /*inputs=*/op->getOperands(), 103 /*outputs=*/outputs, 104 /*indexingMaps=*/indexingMaps, 105 /*iteratorTypes=*/iteratorTypes, 106 /*bodyBuilder=*/ 107 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { 108 OperationState state(loc, op->getName()); 109 state.addAttributes(op->getAttrs()); 110 // Only take the input operands in the cloned elementwise op. 111 state.addOperands(regionArgs.take_front(op->getNumOperands())); 112 auto resultTypes = llvm::to_vector<6>( 113 llvm::map_range(op->getResultTypes(), [](Type type) { 114 return type.cast<TensorType>().getElementType(); 115 })); 116 state.addTypes(resultTypes); 117 auto *scalarOp = builder.createOperation(state); 118 builder.create<linalg::YieldOp>(loc, scalarOp->getResults()); 119 }); 120 return success(); 121 } 122 }; 123 } // namespace 124 125 void mlir::populateElementwiseToLinalgConversionPatterns( 126 OwningRewritePatternList &patterns, MLIRContext *) { 127 patterns.insert<ConvertAnyElementwiseMappableOpOnRankedTensors>(); 128 } 129 130 namespace { 131 class ConvertElementwiseToLinalgPass 132 : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> { 133 134 void runOnFunction() final { 135 auto func = getOperation(); 136 auto *context = &getContext(); 137 ConversionTarget target(*context); 138 OwningRewritePatternList patterns; 139 140 populateElementwiseToLinalgConversionPatterns(patterns, context); 141 target.markUnknownOpDynamicallyLegal([](Operation *op) { 142 return !isElementwiseMappableOpOnRankedTensors(op); 143 }); 144 145 if (failed(applyPartialConversion(func, target, std::move(patterns)))) 146 signalPassFailure(); 147 } 148 }; 149 } // namespace 150 151 std::unique_ptr<OperationPass<FuncOp>> 152 mlir::createConvertElementwiseToLinalgPass() { 153 return std::make_unique<ConvertElementwiseToLinalgPass>(); 154 } 155