1250dcf61SAlexander Belyaev //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2250dcf61SAlexander Belyaev //
3250dcf61SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4250dcf61SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5250dcf61SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6250dcf61SAlexander Belyaev //
7250dcf61SAlexander Belyaev //===----------------------------------------------------------------------===//
8250dcf61SAlexander Belyaev
9250dcf61SAlexander Belyaev #include "PassDetail.h"
10a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11250dcf61SAlexander Belyaev #include "mlir/Dialect/Shape/IR/Shape.h"
12250dcf61SAlexander Belyaev #include "mlir/Dialect/Shape/Transforms/Passes.h"
13250dcf61SAlexander Belyaev #include "mlir/IR/Builders.h"
14250dcf61SAlexander Belyaev #include "mlir/IR/PatternMatch.h"
15250dcf61SAlexander Belyaev #include "mlir/Pass/Pass.h"
16250dcf61SAlexander Belyaev #include "mlir/Transforms/DialectConversion.h"
17250dcf61SAlexander Belyaev
18250dcf61SAlexander Belyaev using namespace mlir;
19250dcf61SAlexander Belyaev using namespace mlir::shape;
20250dcf61SAlexander Belyaev
21250dcf61SAlexander Belyaev namespace {
22250dcf61SAlexander Belyaev /// Converts `shape.num_elements` to `shape.reduce`.
23250dcf61SAlexander Belyaev struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
24250dcf61SAlexander Belyaev public:
25250dcf61SAlexander Belyaev using OpRewritePattern::OpRewritePattern;
26250dcf61SAlexander Belyaev
27250dcf61SAlexander Belyaev LogicalResult matchAndRewrite(NumElementsOp op,
28250dcf61SAlexander Belyaev PatternRewriter &rewriter) const final;
29250dcf61SAlexander Belyaev };
30250dcf61SAlexander Belyaev } // namespace
31250dcf61SAlexander Belyaev
32250dcf61SAlexander Belyaev LogicalResult
matchAndRewrite(NumElementsOp op,PatternRewriter & rewriter) const33250dcf61SAlexander Belyaev NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
34250dcf61SAlexander Belyaev PatternRewriter &rewriter) const {
35250dcf61SAlexander Belyaev auto loc = op.getLoc();
366d10d317SStephan Herhut Type valueType = op.getResult().getType();
370bf4a82aSChristian Sigg Value init = op->getDialect()
386d10d317SStephan Herhut ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
396d10d317SStephan Herhut valueType, loc)
406d10d317SStephan Herhut ->getResult(0);
41cfb72fd3SJacques Pienaar ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init);
42250dcf61SAlexander Belyaev
43250dcf61SAlexander Belyaev // Generate reduce operator.
44250dcf61SAlexander Belyaev Block *body = reduce.getBody();
45250dcf61SAlexander Belyaev OpBuilder b = OpBuilder::atBlockEnd(body);
466d10d317SStephan Herhut Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
476d10d317SStephan Herhut body->getArgument(2));
48136eb79aSFrederik Gossen b.create<shape::YieldOp>(loc, product);
49250dcf61SAlexander Belyaev
50cfb72fd3SJacques Pienaar rewriter.replaceOp(op, reduce.getResult());
51250dcf61SAlexander Belyaev return success();
52250dcf61SAlexander Belyaev }
53250dcf61SAlexander Belyaev
54250dcf61SAlexander Belyaev namespace {
55250dcf61SAlexander Belyaev struct ShapeToShapeLowering
56250dcf61SAlexander Belyaev : public ShapeToShapeLoweringBase<ShapeToShapeLowering> {
5741574554SRiver Riddle void runOnOperation() override;
58250dcf61SAlexander Belyaev };
59250dcf61SAlexander Belyaev } // namespace
60250dcf61SAlexander Belyaev
runOnOperation()6141574554SRiver Riddle void ShapeToShapeLowering::runOnOperation() {
627a9258e9SAlexander Belyaev MLIRContext &ctx = getContext();
637a9258e9SAlexander Belyaev
64dc4e913bSChris Lattner RewritePatternSet patterns(&ctx);
653a506b31SChris Lattner populateShapeRewritePatterns(patterns);
66250dcf61SAlexander Belyaev
67250dcf61SAlexander Belyaev ConversionTarget target(getContext());
68*1f971e23SRiver Riddle target.addLegalDialect<arith::ArithmeticDialect, ShapeDialect>();
69250dcf61SAlexander Belyaev target.addIllegalOp<NumElementsOp>();
7041574554SRiver Riddle if (failed(mlir::applyPartialConversion(getOperation(), target,
713fffffa8SRiver Riddle std::move(patterns))))
72250dcf61SAlexander Belyaev signalPassFailure();
73250dcf61SAlexander Belyaev }
74250dcf61SAlexander Belyaev
populateShapeRewritePatterns(RewritePatternSet & patterns)75dc4e913bSChris Lattner void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
76dc4e913bSChris Lattner patterns.add<NumElementsOpConverter>(patterns.getContext());
777a9258e9SAlexander Belyaev }
787a9258e9SAlexander Belyaev
createShapeToShapeLowering()79250dcf61SAlexander Belyaev std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
80250dcf61SAlexander Belyaev return std::make_unique<ShapeToShapeLowering>();
81250dcf61SAlexander Belyaev }
82