1 //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 #include "mlir/Dialect/Shape/Transforms/Passes.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 21 namespace { 22 /// Converts `shape.num_elements` to `shape.reduce`. 23 struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> { 24 public: 25 using OpRewritePattern::OpRewritePattern; 26 27 LogicalResult matchAndRewrite(NumElementsOp op, 28 PatternRewriter &rewriter) const final; 29 }; 30 } // namespace 31 32 LogicalResult 33 NumElementsOpConverter::matchAndRewrite(NumElementsOp op, 34 PatternRewriter &rewriter) const { 35 auto loc = op.getLoc(); 36 Type valueType = op.getResult().getType(); 37 Value init = op->getDialect() 38 ->materializeConstant(rewriter, rewriter.getIndexAttr(1), 39 valueType, loc) 40 ->getResult(0); 41 ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init); 42 43 // Generate reduce operator. 44 Block *body = reduce.getBody(); 45 OpBuilder b = OpBuilder::atBlockEnd(body); 46 Value product = b.create<MulOp>(loc, valueType, body->getArgument(1), 47 body->getArgument(2)); 48 b.create<shape::YieldOp>(loc, product); 49 50 rewriter.replaceOp(op, reduce.getResult()); 51 return success(); 52 } 53 54 namespace { 55 struct ShapeToShapeLowering 56 : public ShapeToShapeLoweringBase<ShapeToShapeLowering> { 57 void runOnOperation() override; 58 }; 59 } // namespace 60 61 void ShapeToShapeLowering::runOnOperation() { 62 MLIRContext &ctx = getContext(); 63 64 RewritePatternSet patterns(&ctx); 65 populateShapeRewritePatterns(patterns); 66 67 ConversionTarget target(getContext()); 68 target.addLegalDialect<arith::ArithmeticDialect, ShapeDialect>(); 69 target.addIllegalOp<NumElementsOp>(); 70 if (failed(mlir::applyPartialConversion(getOperation(), target, 71 std::move(patterns)))) 72 signalPassFailure(); 73 } 74 75 void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) { 76 patterns.add<NumElementsOpConverter>(patterns.getContext()); 77 } 78 79 std::unique_ptr<Pass> mlir::createShapeToShapeLowering() { 80 return std::make_unique<ShapeToShapeLowering>(); 81 } 82