1 //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 #include "mlir/Dialect/Shape/Transforms/Passes.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 19 using namespace mlir; 20 using namespace mlir::shape; 21 22 namespace { 23 /// Converts `shape.num_elements` to `shape.reduce`. 24 struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> { 25 public: 26 using OpRewritePattern::OpRewritePattern; 27 28 LogicalResult matchAndRewrite(NumElementsOp op, 29 PatternRewriter &rewriter) const final; 30 }; 31 } // namespace 32 33 LogicalResult 34 NumElementsOpConverter::matchAndRewrite(NumElementsOp op, 35 PatternRewriter &rewriter) const { 36 auto loc = op.getLoc(); 37 Type valueType = op.getResult().getType(); 38 Value init = op->getDialect() 39 ->materializeConstant(rewriter, rewriter.getIndexAttr(1), 40 valueType, loc) 41 ->getResult(0); 42 ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init); 43 44 // Generate reduce operator. 45 Block *body = reduce.getBody(); 46 OpBuilder b = OpBuilder::atBlockEnd(body); 47 Value product = b.create<MulOp>(loc, valueType, body->getArgument(1), 48 body->getArgument(2)); 49 b.create<shape::YieldOp>(loc, product); 50 51 rewriter.replaceOp(op, reduce.getResult()); 52 return success(); 53 } 54 55 namespace { 56 struct ShapeToShapeLowering 57 : public ShapeToShapeLoweringBase<ShapeToShapeLowering> { 58 void runOnOperation() override; 59 }; 60 } // namespace 61 62 void ShapeToShapeLowering::runOnOperation() { 63 MLIRContext &ctx = getContext(); 64 65 RewritePatternSet patterns(&ctx); 66 populateShapeRewritePatterns(patterns); 67 68 ConversionTarget target(getContext()); 69 target.addLegalDialect<arith::ArithmeticDialect, ShapeDialect, 70 StandardOpsDialect>(); 71 target.addIllegalOp<NumElementsOp>(); 72 if (failed(mlir::applyPartialConversion(getOperation(), target, 73 std::move(patterns)))) 74 signalPassFailure(); 75 } 76 77 void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) { 78 patterns.add<NumElementsOpConverter>(patterns.getContext()); 79 } 80 81 std::unique_ptr<Pass> mlir::createShapeToShapeLowering() { 82 return std::make_unique<ShapeToShapeLowering>(); 83 } 84