1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::shape; 19 20 namespace { 21 22 /// Generated conversion patterns. 23 #include "ShapeToStandardPatterns.inc" 24 25 /// Conversion patterns. 26 template <typename SrcOpTy, typename DstOpTy> 27 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 28 public: 29 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 30 31 LogicalResult 32 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 33 ConversionPatternRewriter &rewriter) const override { 34 typename SrcOpTy::Adaptor adaptor(operands); 35 rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(), 36 adaptor.rhs()); 37 return success(); 38 } 39 }; 40 41 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> { 42 public: 43 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 44 45 LogicalResult 46 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 47 ConversionPatternRewriter &rewriter) const override { 48 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), 49 op.value().getSExtValue()); 50 return success(); 51 } 52 }; 53 54 /// Type conversions. 55 class ShapeTypeConverter : public TypeConverter { 56 public: 57 using TypeConverter::convertType; 58 59 ShapeTypeConverter(MLIRContext *ctx) { 60 // Add default pass-through conversion. 61 addConversion([&](Type type) { return type; }); 62 63 addConversion([ctx](SizeType type) { return IndexType::get(ctx); }); 64 addConversion([ctx](ShapeType type) { 65 return RankedTensorType::get({ShapedType::kDynamicSize}, 66 IndexType::get(ctx)); 67 }); 68 } 69 }; 70 71 /// Conversion pass. 72 class ConvertShapeToStandardPass 73 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 74 75 void runOnOperation() override { 76 // Setup type conversion. 77 MLIRContext &ctx = getContext(); 78 ShapeTypeConverter typeConverter(&ctx); 79 80 // Setup target legality. 81 ConversionTarget target(ctx); 82 target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>(); 83 target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>(); 84 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 85 return typeConverter.isSignatureLegal(op.getType()) && 86 typeConverter.isLegal(&op.getBody()); 87 }); 88 89 // Setup conversion patterns. 90 OwningRewritePatternList patterns; 91 populateShapeToStandardConversionPatterns(patterns, &ctx); 92 populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); 93 94 // Apply conversion. 95 auto module = getOperation(); 96 if (failed(applyFullConversion(module, target, patterns))) 97 signalPassFailure(); 98 } 99 }; 100 101 } // namespace 102 103 void mlir::populateShapeToStandardConversionPatterns( 104 OwningRewritePatternList &patterns, MLIRContext *ctx) { 105 populateWithGenerated(ctx, &patterns); 106 // clang-format off 107 patterns.insert< 108 BinaryOpConversion<AddOp, AddIOp>, 109 BinaryOpConversion<MulOp, MulIOp>, 110 ConstSizeOpConverter>(ctx); 111 // clang-format on 112 } 113 114 std::unique_ptr<OperationPass<ModuleOp>> 115 mlir::createConvertShapeToStandardPass() { 116 return std::make_unique<ConvertShapeToStandardPass>(); 117 } 118