1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Passes.h"
10 
11 #include "PassDetail.h"
12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 using namespace mlir;
17 
18 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
19   if (!op->hasTrait<OpTrait::ElementwiseMappable>())
20     return false;
21 
22   // TODO: The conversion pattern can be made to work for `any_of` here, but
23   // it's more complex as it requires tracking which operands are scalars.
24   return llvm::all_of(op->getOperandTypes(),
25                       [](Type type) { return type.isa<RankedTensorType>(); });
26 }
27 
28 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
29 /// the result types and return a list of values such that, for each result type
30 /// `t` and value `v` at the same index `idx`:
31 ///   1. `v.getType() == t`
32 ///   2. If an operand of `op` has type `t`, let `operand_first` be the first
33 ///      such operand. Then`v == operand_first`.
34 ///   3. Otherwise, v is a newly created `linalg::InitTensorOp` with:
35 ///        a. Static and dynamic dims extracted from the first operand of `op`.
36 ///        b. Elemental type equal to the elemental type of `t`.
37 ///
38 /// This is sufficient because ElementwiseMappable guarantees that "The static
39 /// types of all vector (resp. tensor) operands and results must have the same
40 /// shape".
41 static SmallVector<Value, 4>
42 getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
43   assert(isElementwiseMappableOpOnRankedTensors(op));
44   Location loc = op->getLoc();
45   ValueRange operands = op->getOperands();
46   TypeRange rankedTensorTypes = op->getResultTypes();
47   SmallVector<Value, 4> res;
48   res.reserve(rankedTensorTypes.size());
49   for (Type t : rankedTensorTypes) {
50     // Try to find an operand with type matching the result tensor.
51     bool found = false;
52     for (Value v : operands) {
53       if (v.getType() == t) {
54         found = true;
55         res.push_back(v);
56         break;
57       }
58     }
59     if (found)
60       continue;
61 
62     // Extract static / dynamic shape mix from the first operand.
63     Value firstOperand = operands.front();
64     auto rankedTensorType = t.cast<RankedTensorType>();
65     SmallVector<Value, 8> dynamicShape;
66     SmallVector<int64_t, 8> staticShape;
67     dynamicShape.reserve(rankedTensorType.getRank());
68     staticShape.reserve(rankedTensorType.getRank());
69     unsigned idx = 0;
70     for (auto shape : rankedTensorType.getShape()) {
71       staticShape.push_back(shape);
72       if (rankedTensorType.isDynamicDim(idx))
73         dynamicShape.push_back(b.create<DimOp>(loc, firstOperand, idx));
74       ++idx;
75     }
76     // Create init tensor.
77     res.push_back(b.create<linalg::InitTensorOp>(
78         loc, dynamicShape, staticShape, rankedTensorType.getElementType()));
79   }
80   return res;
81 }
82 
83 namespace {
84 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
85   ConvertAnyElementwiseMappableOpOnRankedTensors()
86       : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
87   LogicalResult matchAndRewrite(Operation *op,
88                                 PatternRewriter &rewriter) const final {
89     if (!isElementwiseMappableOpOnRankedTensors(op))
90       return rewriter.notifyMatchFailure(
91           op, "requires elementwise op on ranked tensors");
92 
93     auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
94     SmallVector<AffineMap, 3> indexingMaps(
95         op->getNumResults() + op->getNumOperands(),
96         rewriter.getMultiDimIdentityMap(rank));
97     SmallVector<StringRef, 6> iteratorTypes(rank,
98                                             getParallelIteratorTypeName());
99     auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
100     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
101         op, /*resultTensorTypes=*/op->getResultTypes(),
102         /*inputs=*/op->getOperands(),
103         /*outputs=*/outputs,
104         /*indexingMaps=*/indexingMaps,
105         /*iteratorTypes=*/iteratorTypes,
106         /*bodyBuilder=*/
107         [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
108           OperationState state(loc, op->getName());
109           state.addAttributes(op->getAttrs());
110           // Only take the input operands in the cloned elementwise op.
111           state.addOperands(regionArgs.take_front(op->getNumOperands()));
112           auto resultTypes = llvm::to_vector<6>(
113               llvm::map_range(op->getResultTypes(), [](Type type) {
114                 return type.cast<TensorType>().getElementType();
115               }));
116           state.addTypes(resultTypes);
117           auto *scalarOp = builder.createOperation(state);
118           builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
119         });
120     return success();
121   }
122 };
123 } // namespace
124 
125 void mlir::populateElementwiseToLinalgConversionPatterns(
126     OwningRewritePatternList &patterns, MLIRContext *) {
127   patterns.insert<ConvertAnyElementwiseMappableOpOnRankedTensors>();
128 }
129 
130 namespace {
131 class ConvertElementwiseToLinalgPass
132     : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
133 
134   void runOnFunction() final {
135     auto func = getOperation();
136     auto *context = &getContext();
137     ConversionTarget target(*context);
138     OwningRewritePatternList patterns;
139 
140     populateElementwiseToLinalgConversionPatterns(patterns, context);
141     target.markUnknownOpDynamicallyLegal([](Operation *op) {
142       return !isElementwiseMappableOpOnRankedTensors(op);
143     });
144 
145     if (failed(applyPartialConversion(func, target, std::move(patterns))))
146       signalPassFailure();
147   }
148 };
149 } // namespace
150 
151 std::unique_ptr<OperationPass<FuncOp>>
152 mlir::createConvertElementwiseToLinalgPass() {
153   return std::make_unique<ConvertElementwiseToLinalgPass>();
154 }
155