153a0d45dSSean Silva //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
253a0d45dSSean Silva //
353a0d45dSSean Silva // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
453a0d45dSSean Silva // See https://llvm.org/LICENSE.txt for license information.
553a0d45dSSean Silva // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
653a0d45dSSean Silva //
753a0d45dSSean Silva //===----------------------------------------------------------------------===//
853a0d45dSSean Silva
953a0d45dSSean Silva #include "mlir/Dialect/Linalg/Passes.h"
1053a0d45dSSean Silva
1153a0d45dSSean Silva #include "PassDetail.h"
12ead11072SRiver Riddle #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
13b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
14ea069aebSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15f75f391fSRob Suderman #include "mlir/Dialect/Linalg/Utils/Utils.h"
1653a0d45dSSean Silva #include "mlir/Transforms/DialectConversion.h"
1753a0d45dSSean Silva
1853a0d45dSSean Silva using namespace mlir;
1953a0d45dSSean Silva
isElementwiseMappableOpOnRankedTensors(Operation * op)2053a0d45dSSean Silva static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
21bcc9b371SFrederik Gossen if (!OpTrait::hasElementwiseMappableTraits(op))
2253a0d45dSSean Silva return false;
2353a0d45dSSean Silva
2453a0d45dSSean Silva // TODO: The conversion pattern can be made to work for `any_of` here, but
2553a0d45dSSean Silva // it's more complex as it requires tracking which operands are scalars.
2653a0d45dSSean Silva return llvm::all_of(op->getOperandTypes(),
2753a0d45dSSean Silva [](Type type) { return type.isa<RankedTensorType>(); });
2853a0d45dSSean Silva }
2953a0d45dSSean Silva
30b7ae1d3dSnicolasvasilache /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
31b7ae1d3dSnicolasvasilache /// the result types and return a list of values such that, for each result type
32b7ae1d3dSnicolasvasilache /// `t` and value `v` at the same index `idx`:
33b7ae1d3dSnicolasvasilache /// 1. `v.getType() == t`
34b7ae1d3dSnicolasvasilache /// 2. If an operand of `op` has type `t`, let `operand_first` be the first
35b7ae1d3dSnicolasvasilache /// such operand. Then`v == operand_first`.
36b7ae1d3dSnicolasvasilache /// 3. Otherwise, v is a newly created `linalg::InitTensorOp` with:
37b7ae1d3dSnicolasvasilache /// a. Static and dynamic dims extracted from the first operand of `op`.
38b7ae1d3dSnicolasvasilache /// b. Elemental type equal to the elemental type of `t`.
39b7ae1d3dSnicolasvasilache ///
40b7ae1d3dSnicolasvasilache /// This is sufficient because ElementwiseMappable guarantees that "The static
41b7ae1d3dSnicolasvasilache /// types of all vector (resp. tensor) operands and results must have the same
42b7ae1d3dSnicolasvasilache /// shape".
43b7ae1d3dSnicolasvasilache static SmallVector<Value, 4>
getOrCreateOperandsMatchingResultTypes(OpBuilder & b,Operation * op)44b7ae1d3dSnicolasvasilache getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
45b7ae1d3dSnicolasvasilache assert(isElementwiseMappableOpOnRankedTensors(op));
46b7ae1d3dSnicolasvasilache Location loc = op->getLoc();
47b7ae1d3dSnicolasvasilache ValueRange operands = op->getOperands();
48b7ae1d3dSnicolasvasilache TypeRange rankedTensorTypes = op->getResultTypes();
49b7ae1d3dSnicolasvasilache SmallVector<Value, 4> res;
50b7ae1d3dSnicolasvasilache res.reserve(rankedTensorTypes.size());
51b7ae1d3dSnicolasvasilache for (Type t : rankedTensorTypes) {
52b7ae1d3dSnicolasvasilache // Try to find an operand with type matching the result tensor.
53b7ae1d3dSnicolasvasilache bool found = false;
54b7ae1d3dSnicolasvasilache for (Value v : operands) {
55b7ae1d3dSnicolasvasilache if (v.getType() == t) {
56b7ae1d3dSnicolasvasilache found = true;
57b7ae1d3dSnicolasvasilache res.push_back(v);
58b7ae1d3dSnicolasvasilache break;
59b7ae1d3dSnicolasvasilache }
60b7ae1d3dSnicolasvasilache }
61b7ae1d3dSnicolasvasilache if (found)
62b7ae1d3dSnicolasvasilache continue;
63b7ae1d3dSnicolasvasilache
64b7ae1d3dSnicolasvasilache // Extract static / dynamic shape mix from the first operand.
65b7ae1d3dSnicolasvasilache Value firstOperand = operands.front();
66b7ae1d3dSnicolasvasilache auto rankedTensorType = t.cast<RankedTensorType>();
67f75f391fSRob Suderman auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape());
682c115eccSMatthias Springer auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b);
69f75f391fSRob Suderman
70b7ae1d3dSnicolasvasilache res.push_back(b.create<linalg::InitTensorOp>(
71b7ae1d3dSnicolasvasilache loc, dynamicShape, staticShape, rankedTensorType.getElementType()));
72b7ae1d3dSnicolasvasilache }
73b7ae1d3dSnicolasvasilache return res;
74b7ae1d3dSnicolasvasilache }
75b7ae1d3dSnicolasvasilache
7653a0d45dSSean Silva namespace {
775488a6b0SSean Silva struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
ConvertAnyElementwiseMappableOpOnRankedTensors__anon7a43419e0211::ConvertAnyElementwiseMappableOpOnRankedTensors7876f3c2f3SRiver Riddle ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
7976f3c2f3SRiver Riddle : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
matchAndRewrite__anon7a43419e0211::ConvertAnyElementwiseMappableOpOnRankedTensors8053a0d45dSSean Silva LogicalResult matchAndRewrite(Operation *op,
8153a0d45dSSean Silva PatternRewriter &rewriter) const final {
8253a0d45dSSean Silva if (!isElementwiseMappableOpOnRankedTensors(op))
8353a0d45dSSean Silva return rewriter.notifyMatchFailure(
8453a0d45dSSean Silva op, "requires elementwise op on ranked tensors");
8553a0d45dSSean Silva
8653a0d45dSSean Silva auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
8753a0d45dSSean Silva SmallVector<AffineMap, 3> indexingMaps(
8853a0d45dSSean Silva op->getNumResults() + op->getNumOperands(),
8953a0d45dSSean Silva rewriter.getMultiDimIdentityMap(rank));
9053a0d45dSSean Silva SmallVector<StringRef, 6> iteratorTypes(rank,
9153a0d45dSSean Silva getParallelIteratorTypeName());
92b7ae1d3dSnicolasvasilache auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
9353a0d45dSSean Silva rewriter.replaceOpWithNewOp<linalg::GenericOp>(
9453a0d45dSSean Silva op, /*resultTensorTypes=*/op->getResultTypes(),
9553a0d45dSSean Silva /*inputs=*/op->getOperands(),
96b7ae1d3dSnicolasvasilache /*outputs=*/outputs,
9753a0d45dSSean Silva /*indexingMaps=*/indexingMaps,
9853a0d45dSSean Silva /*iteratorTypes=*/iteratorTypes,
9953a0d45dSSean Silva /*bodyBuilder=*/
10053a0d45dSSean Silva [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
10153a0d45dSSean Silva auto resultTypes = llvm::to_vector<6>(
10253a0d45dSSean Silva llvm::map_range(op->getResultTypes(), [](Type type) {
10353a0d45dSSean Silva return type.cast<TensorType>().getElementType();
10453a0d45dSSean Silva }));
105*14ecafd0SChia-hung Duan auto *scalarOp =
106*14ecafd0SChia-hung Duan builder.create(loc, op->getName().getIdentifier(),
107*14ecafd0SChia-hung Duan regionArgs.take_front(op->getNumOperands()),
108*14ecafd0SChia-hung Duan resultTypes, op->getAttrs());
10953a0d45dSSean Silva builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
11053a0d45dSSean Silva });
11153a0d45dSSean Silva return success();
11253a0d45dSSean Silva }
11353a0d45dSSean Silva };
11453a0d45dSSean Silva } // namespace
11553a0d45dSSean Silva
populateElementwiseToLinalgConversionPatterns(RewritePatternSet & patterns)116ea069aebSMaheshRavishankar void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
117dc4e913bSChris Lattner RewritePatternSet &patterns) {
11876f3c2f3SRiver Riddle patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
11976f3c2f3SRiver Riddle patterns.getContext());
12053a0d45dSSean Silva }
12153a0d45dSSean Silva
12253a0d45dSSean Silva namespace {
12353a0d45dSSean Silva class ConvertElementwiseToLinalgPass
12453a0d45dSSean Silva : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
12553a0d45dSSean Silva
runOnOperation()126c10995a8SStella Laurenzo void runOnOperation() final {
12702b6fb21SMehdi Amini auto *func = getOperation();
12853a0d45dSSean Silva auto *context = &getContext();
12953a0d45dSSean Silva ConversionTarget target(*context);
130dc4e913bSChris Lattner RewritePatternSet patterns(context);
13153a0d45dSSean Silva
132ea069aebSMaheshRavishankar mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
13353a0d45dSSean Silva target.markUnknownOpDynamicallyLegal([](Operation *op) {
13453a0d45dSSean Silva return !isElementwiseMappableOpOnRankedTensors(op);
13553a0d45dSSean Silva });
13653a0d45dSSean Silva
13753a0d45dSSean Silva if (failed(applyPartialConversion(func, target, std::move(patterns))))
13853a0d45dSSean Silva signalPassFailure();
13953a0d45dSSean Silva }
14053a0d45dSSean Silva };
14153a0d45dSSean Silva } // namespace
14253a0d45dSSean Silva
createConvertElementwiseToLinalgPass()143c10995a8SStella Laurenzo std::unique_ptr<Pass> mlir::createConvertElementwiseToLinalgPass() {
14453a0d45dSSean Silva return std::make_unique<ConvertElementwiseToLinalgPass>();
14553a0d45dSSean Silva }
146