1250dcf61SAlexander Belyaev //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2250dcf61SAlexander Belyaev //
3250dcf61SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4250dcf61SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5250dcf61SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6250dcf61SAlexander Belyaev //
7250dcf61SAlexander Belyaev //===----------------------------------------------------------------------===//
8250dcf61SAlexander Belyaev 
9250dcf61SAlexander Belyaev #include "PassDetail.h"
10a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11250dcf61SAlexander Belyaev #include "mlir/Dialect/Shape/IR/Shape.h"
12250dcf61SAlexander Belyaev #include "mlir/Dialect/Shape/Transforms/Passes.h"
13250dcf61SAlexander Belyaev #include "mlir/IR/Builders.h"
14250dcf61SAlexander Belyaev #include "mlir/IR/PatternMatch.h"
15250dcf61SAlexander Belyaev #include "mlir/Pass/Pass.h"
16250dcf61SAlexander Belyaev #include "mlir/Transforms/DialectConversion.h"
17250dcf61SAlexander Belyaev 
18250dcf61SAlexander Belyaev using namespace mlir;
19250dcf61SAlexander Belyaev using namespace mlir::shape;
20250dcf61SAlexander Belyaev 
21250dcf61SAlexander Belyaev namespace {
22250dcf61SAlexander Belyaev /// Converts `shape.num_elements` to `shape.reduce`.
23250dcf61SAlexander Belyaev struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
24250dcf61SAlexander Belyaev public:
25250dcf61SAlexander Belyaev   using OpRewritePattern::OpRewritePattern;
26250dcf61SAlexander Belyaev 
27250dcf61SAlexander Belyaev   LogicalResult matchAndRewrite(NumElementsOp op,
28250dcf61SAlexander Belyaev                                 PatternRewriter &rewriter) const final;
29250dcf61SAlexander Belyaev };
30250dcf61SAlexander Belyaev } // namespace
31250dcf61SAlexander Belyaev 
32250dcf61SAlexander Belyaev LogicalResult
matchAndRewrite(NumElementsOp op,PatternRewriter & rewriter) const33250dcf61SAlexander Belyaev NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
34250dcf61SAlexander Belyaev                                         PatternRewriter &rewriter) const {
35250dcf61SAlexander Belyaev   auto loc = op.getLoc();
366d10d317SStephan Herhut   Type valueType = op.getResult().getType();
370bf4a82aSChristian Sigg   Value init = op->getDialect()
386d10d317SStephan Herhut                    ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
396d10d317SStephan Herhut                                          valueType, loc)
406d10d317SStephan Herhut                    ->getResult(0);
41cfb72fd3SJacques Pienaar   ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init);
42250dcf61SAlexander Belyaev 
43250dcf61SAlexander Belyaev   // Generate reduce operator.
44250dcf61SAlexander Belyaev   Block *body = reduce.getBody();
45250dcf61SAlexander Belyaev   OpBuilder b = OpBuilder::atBlockEnd(body);
466d10d317SStephan Herhut   Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
476d10d317SStephan Herhut                                   body->getArgument(2));
48136eb79aSFrederik Gossen   b.create<shape::YieldOp>(loc, product);
49250dcf61SAlexander Belyaev 
50cfb72fd3SJacques Pienaar   rewriter.replaceOp(op, reduce.getResult());
51250dcf61SAlexander Belyaev   return success();
52250dcf61SAlexander Belyaev }
53250dcf61SAlexander Belyaev 
54250dcf61SAlexander Belyaev namespace {
55250dcf61SAlexander Belyaev struct ShapeToShapeLowering
56250dcf61SAlexander Belyaev     : public ShapeToShapeLoweringBase<ShapeToShapeLowering> {
5741574554SRiver Riddle   void runOnOperation() override;
58250dcf61SAlexander Belyaev };
59250dcf61SAlexander Belyaev } // namespace
60250dcf61SAlexander Belyaev 
runOnOperation()6141574554SRiver Riddle void ShapeToShapeLowering::runOnOperation() {
627a9258e9SAlexander Belyaev   MLIRContext &ctx = getContext();
637a9258e9SAlexander Belyaev 
64dc4e913bSChris Lattner   RewritePatternSet patterns(&ctx);
653a506b31SChris Lattner   populateShapeRewritePatterns(patterns);
66250dcf61SAlexander Belyaev 
67250dcf61SAlexander Belyaev   ConversionTarget target(getContext());
68*1f971e23SRiver Riddle   target.addLegalDialect<arith::ArithmeticDialect, ShapeDialect>();
69250dcf61SAlexander Belyaev   target.addIllegalOp<NumElementsOp>();
7041574554SRiver Riddle   if (failed(mlir::applyPartialConversion(getOperation(), target,
713fffffa8SRiver Riddle                                           std::move(patterns))))
72250dcf61SAlexander Belyaev     signalPassFailure();
73250dcf61SAlexander Belyaev }
74250dcf61SAlexander Belyaev 
populateShapeRewritePatterns(RewritePatternSet & patterns)75dc4e913bSChris Lattner void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
76dc4e913bSChris Lattner   patterns.add<NumElementsOpConverter>(patterns.getContext());
777a9258e9SAlexander Belyaev }
787a9258e9SAlexander Belyaev 
createShapeToShapeLowering()79250dcf61SAlexander Belyaev std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
80250dcf61SAlexander Belyaev   return std::make_unique<ShapeToShapeLowering>();
81250dcf61SAlexander Belyaev }
82