153a0d45dSSean Silva //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
253a0d45dSSean Silva //
353a0d45dSSean Silva // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
453a0d45dSSean Silva // See https://llvm.org/LICENSE.txt for license information.
553a0d45dSSean Silva // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
653a0d45dSSean Silva //
753a0d45dSSean Silva //===----------------------------------------------------------------------===//
853a0d45dSSean Silva 
953a0d45dSSean Silva #include "mlir/Dialect/Linalg/Passes.h"
1053a0d45dSSean Silva 
1153a0d45dSSean Silva #include "PassDetail.h"
12ead11072SRiver Riddle #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
13b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
14ea069aebSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15f75f391fSRob Suderman #include "mlir/Dialect/Linalg/Utils/Utils.h"
1653a0d45dSSean Silva #include "mlir/Transforms/DialectConversion.h"
1753a0d45dSSean Silva 
1853a0d45dSSean Silva using namespace mlir;
1953a0d45dSSean Silva 
isElementwiseMappableOpOnRankedTensors(Operation * op)2053a0d45dSSean Silva static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
21bcc9b371SFrederik Gossen   if (!OpTrait::hasElementwiseMappableTraits(op))
2253a0d45dSSean Silva     return false;
2353a0d45dSSean Silva 
2453a0d45dSSean Silva   // TODO: The conversion pattern can be made to work for `any_of` here, but
2553a0d45dSSean Silva   // it's more complex as it requires tracking which operands are scalars.
2653a0d45dSSean Silva   return llvm::all_of(op->getOperandTypes(),
2753a0d45dSSean Silva                       [](Type type) { return type.isa<RankedTensorType>(); });
2853a0d45dSSean Silva }
2953a0d45dSSean Silva 
30b7ae1d3dSnicolasvasilache /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
31b7ae1d3dSnicolasvasilache /// the result types and return a list of values such that, for each result type
32b7ae1d3dSnicolasvasilache /// `t` and value `v` at the same index `idx`:
33b7ae1d3dSnicolasvasilache ///   1. `v.getType() == t`
34b7ae1d3dSnicolasvasilache ///   2. If an operand of `op` has type `t`, let `operand_first` be the first
35b7ae1d3dSnicolasvasilache ///      such operand. Then`v == operand_first`.
36b7ae1d3dSnicolasvasilache ///   3. Otherwise, v is a newly created `linalg::InitTensorOp` with:
37b7ae1d3dSnicolasvasilache ///        a. Static and dynamic dims extracted from the first operand of `op`.
38b7ae1d3dSnicolasvasilache ///        b. Elemental type equal to the elemental type of `t`.
39b7ae1d3dSnicolasvasilache ///
40b7ae1d3dSnicolasvasilache /// This is sufficient because ElementwiseMappable guarantees that "The static
41b7ae1d3dSnicolasvasilache /// types of all vector (resp. tensor) operands and results must have the same
42b7ae1d3dSnicolasvasilache /// shape".
43b7ae1d3dSnicolasvasilache static SmallVector<Value, 4>
getOrCreateOperandsMatchingResultTypes(OpBuilder & b,Operation * op)44b7ae1d3dSnicolasvasilache getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
45b7ae1d3dSnicolasvasilache   assert(isElementwiseMappableOpOnRankedTensors(op));
46b7ae1d3dSnicolasvasilache   Location loc = op->getLoc();
47b7ae1d3dSnicolasvasilache   ValueRange operands = op->getOperands();
48b7ae1d3dSnicolasvasilache   TypeRange rankedTensorTypes = op->getResultTypes();
49b7ae1d3dSnicolasvasilache   SmallVector<Value, 4> res;
50b7ae1d3dSnicolasvasilache   res.reserve(rankedTensorTypes.size());
51b7ae1d3dSnicolasvasilache   for (Type t : rankedTensorTypes) {
52b7ae1d3dSnicolasvasilache     // Try to find an operand with type matching the result tensor.
53b7ae1d3dSnicolasvasilache     bool found = false;
54b7ae1d3dSnicolasvasilache     for (Value v : operands) {
55b7ae1d3dSnicolasvasilache       if (v.getType() == t) {
56b7ae1d3dSnicolasvasilache         found = true;
57b7ae1d3dSnicolasvasilache         res.push_back(v);
58b7ae1d3dSnicolasvasilache         break;
59b7ae1d3dSnicolasvasilache       }
60b7ae1d3dSnicolasvasilache     }
61b7ae1d3dSnicolasvasilache     if (found)
62b7ae1d3dSnicolasvasilache       continue;
63b7ae1d3dSnicolasvasilache 
64b7ae1d3dSnicolasvasilache     // Extract static / dynamic shape mix from the first operand.
65b7ae1d3dSnicolasvasilache     Value firstOperand = operands.front();
66b7ae1d3dSnicolasvasilache     auto rankedTensorType = t.cast<RankedTensorType>();
67f75f391fSRob Suderman     auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape());
682c115eccSMatthias Springer     auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b);
69f75f391fSRob Suderman 
70b7ae1d3dSnicolasvasilache     res.push_back(b.create<linalg::InitTensorOp>(
71b7ae1d3dSnicolasvasilache         loc, dynamicShape, staticShape, rankedTensorType.getElementType()));
72b7ae1d3dSnicolasvasilache   }
73b7ae1d3dSnicolasvasilache   return res;
74b7ae1d3dSnicolasvasilache }
75b7ae1d3dSnicolasvasilache 
7653a0d45dSSean Silva namespace {
775488a6b0SSean Silva struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
ConvertAnyElementwiseMappableOpOnRankedTensors__anon7a43419e0211::ConvertAnyElementwiseMappableOpOnRankedTensors7876f3c2f3SRiver Riddle   ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
7976f3c2f3SRiver Riddle       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
matchAndRewrite__anon7a43419e0211::ConvertAnyElementwiseMappableOpOnRankedTensors8053a0d45dSSean Silva   LogicalResult matchAndRewrite(Operation *op,
8153a0d45dSSean Silva                                 PatternRewriter &rewriter) const final {
8253a0d45dSSean Silva     if (!isElementwiseMappableOpOnRankedTensors(op))
8353a0d45dSSean Silva       return rewriter.notifyMatchFailure(
8453a0d45dSSean Silva           op, "requires elementwise op on ranked tensors");
8553a0d45dSSean Silva 
8653a0d45dSSean Silva     auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
8753a0d45dSSean Silva     SmallVector<AffineMap, 3> indexingMaps(
8853a0d45dSSean Silva         op->getNumResults() + op->getNumOperands(),
8953a0d45dSSean Silva         rewriter.getMultiDimIdentityMap(rank));
9053a0d45dSSean Silva     SmallVector<StringRef, 6> iteratorTypes(rank,
9153a0d45dSSean Silva                                             getParallelIteratorTypeName());
92b7ae1d3dSnicolasvasilache     auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
9353a0d45dSSean Silva     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
9453a0d45dSSean Silva         op, /*resultTensorTypes=*/op->getResultTypes(),
9553a0d45dSSean Silva         /*inputs=*/op->getOperands(),
96b7ae1d3dSnicolasvasilache         /*outputs=*/outputs,
9753a0d45dSSean Silva         /*indexingMaps=*/indexingMaps,
9853a0d45dSSean Silva         /*iteratorTypes=*/iteratorTypes,
9953a0d45dSSean Silva         /*bodyBuilder=*/
10053a0d45dSSean Silva         [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
10153a0d45dSSean Silva           auto resultTypes = llvm::to_vector<6>(
10253a0d45dSSean Silva               llvm::map_range(op->getResultTypes(), [](Type type) {
10353a0d45dSSean Silva                 return type.cast<TensorType>().getElementType();
10453a0d45dSSean Silva               }));
105*14ecafd0SChia-hung Duan           auto *scalarOp =
106*14ecafd0SChia-hung Duan               builder.create(loc, op->getName().getIdentifier(),
107*14ecafd0SChia-hung Duan                              regionArgs.take_front(op->getNumOperands()),
108*14ecafd0SChia-hung Duan                              resultTypes, op->getAttrs());
10953a0d45dSSean Silva           builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
11053a0d45dSSean Silva         });
11153a0d45dSSean Silva     return success();
11253a0d45dSSean Silva   }
11353a0d45dSSean Silva };
11453a0d45dSSean Silva } // namespace
11553a0d45dSSean Silva 
populateElementwiseToLinalgConversionPatterns(RewritePatternSet & patterns)116ea069aebSMaheshRavishankar void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
117dc4e913bSChris Lattner     RewritePatternSet &patterns) {
11876f3c2f3SRiver Riddle   patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
11976f3c2f3SRiver Riddle       patterns.getContext());
12053a0d45dSSean Silva }
12153a0d45dSSean Silva 
12253a0d45dSSean Silva namespace {
12353a0d45dSSean Silva class ConvertElementwiseToLinalgPass
12453a0d45dSSean Silva     : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
12553a0d45dSSean Silva 
runOnOperation()126c10995a8SStella Laurenzo   void runOnOperation() final {
12702b6fb21SMehdi Amini     auto *func = getOperation();
12853a0d45dSSean Silva     auto *context = &getContext();
12953a0d45dSSean Silva     ConversionTarget target(*context);
130dc4e913bSChris Lattner     RewritePatternSet patterns(context);
13153a0d45dSSean Silva 
132ea069aebSMaheshRavishankar     mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
13353a0d45dSSean Silva     target.markUnknownOpDynamicallyLegal([](Operation *op) {
13453a0d45dSSean Silva       return !isElementwiseMappableOpOnRankedTensors(op);
13553a0d45dSSean Silva     });
13653a0d45dSSean Silva 
13753a0d45dSSean Silva     if (failed(applyPartialConversion(func, target, std::move(patterns))))
13853a0d45dSSean Silva       signalPassFailure();
13953a0d45dSSean Silva   }
14053a0d45dSSean Silva };
14153a0d45dSSean Silva } // namespace
14253a0d45dSSean Silva 
createConvertElementwiseToLinalgPass()143c10995a8SStella Laurenzo std::unique_ptr<Pass> mlir::createConvertElementwiseToLinalgPass() {
14453a0d45dSSean Silva   return std::make_unique<ConvertElementwiseToLinalgPass>();
14553a0d45dSSean Silva }
146