1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::shape; 19 20 namespace { 21 22 /// Conversion patterns. 23 template <typename SrcOpTy, typename DstOpTy> 24 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 25 public: 26 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override { 31 typename SrcOpTy::Adaptor adaptor(operands); 32 rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(), 33 adaptor.rhs()); 34 return success(); 35 } 36 }; 37 38 class FromExtentTensorOpConversion 39 : public OpConversionPattern<FromExtentTensorOp> { 40 public: 41 using OpConversionPattern<FromExtentTensorOp>::OpConversionPattern; 42 43 LogicalResult 44 matchAndRewrite(FromExtentTensorOp op, ArrayRef<Value> operands, 45 ConversionPatternRewriter &rewriter) const override { 46 FromExtentTensorOp::Adaptor transformed(operands); 47 rewriter.replaceOp(op.getOperation(), transformed.input()); 48 return success(); 49 } 50 }; 51 52 class IndexToSizeOpConversion : public OpConversionPattern<IndexToSizeOp> { 53 public: 54 using OpConversionPattern<IndexToSizeOp>::OpConversionPattern; 55 56 LogicalResult 57 matchAndRewrite(IndexToSizeOp op, ArrayRef<Value> operands, 58 ConversionPatternRewriter &rewriter) const override { 59 IndexToSizeOp::Adaptor transformed(operands); 60 rewriter.replaceOp(op.getOperation(), transformed.arg()); 61 return success(); 62 } 63 }; 64 65 class SizeToIndexOpConversion : public OpConversionPattern<SizeToIndexOp> { 66 public: 67 using OpConversionPattern<SizeToIndexOp>::OpConversionPattern; 68 69 LogicalResult 70 matchAndRewrite(SizeToIndexOp op, ArrayRef<Value> operands, 71 ConversionPatternRewriter &rewriter) const override { 72 SizeToIndexOp::Adaptor transformed(operands); 73 rewriter.replaceOp(op.getOperation(), transformed.arg()); 74 return success(); 75 } 76 }; 77 78 class ToExtentTensorOpConversion 79 : public OpConversionPattern<ToExtentTensorOp> { 80 public: 81 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 82 83 LogicalResult 84 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 85 ConversionPatternRewriter &rewriter) const override { 86 ToExtentTensorOp::Adaptor transformed(operands); 87 rewriter.replaceOp(op.getOperation(), transformed.input()); 88 return success(); 89 } 90 }; 91 92 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> { 93 public: 94 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 95 96 LogicalResult 97 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 98 ConversionPatternRewriter &rewriter) const override { 99 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), 100 op.value().getSExtValue()); 101 return success(); 102 } 103 }; 104 105 /// Type conversions. 106 class ShapeTypeConverter : public TypeConverter { 107 public: 108 using TypeConverter::convertType; 109 110 ShapeTypeConverter(MLIRContext *ctx) { 111 // Add default pass-through conversion. 112 addConversion([&](Type type) { return type; }); 113 114 addConversion([ctx](SizeType type) { return IndexType::get(ctx); }); 115 addConversion([ctx](ShapeType type) { 116 return RankedTensorType::get({ShapedType::kDynamicSize}, 117 IndexType::get(ctx)); 118 }); 119 } 120 }; 121 122 /// Conversion pass. 123 class ConvertShapeToStandardPass 124 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 125 void runOnOperation() override { 126 // Setup type conversion. 127 MLIRContext &ctx = getContext(); 128 ShapeTypeConverter typeConverter(&ctx); 129 130 // Setup target legality. 131 ConversionTarget target(ctx); 132 target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>(); 133 target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>(); 134 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 135 return typeConverter.isSignatureLegal(op.getType()); 136 }); 137 138 // Setup conversion patterns. 139 OwningRewritePatternList patterns; 140 populateShapeToStandardConversionPatterns(patterns, &ctx); 141 populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); 142 143 // Apply conversion. 144 auto module = getOperation(); 145 if (failed(applyFullConversion(module, target, patterns, &typeConverter))) 146 signalPassFailure(); 147 } 148 }; 149 150 } // namespace 151 152 void mlir::populateShapeToStandardConversionPatterns( 153 OwningRewritePatternList &patterns, MLIRContext *ctx) { 154 // clang-format off 155 patterns.insert< 156 BinaryOpConversion<AddOp, AddIOp>, 157 BinaryOpConversion<MulOp, MulIOp>, 158 ConstSizeOpConverter, 159 FromExtentTensorOpConversion, 160 IndexToSizeOpConversion, 161 SizeToIndexOpConversion, 162 ToExtentTensorOpConversion>(ctx); 163 // clang-format on 164 } 165 166 std::unique_ptr<OperationPass<ModuleOp>> 167 mlir::createConvertShapeToStandardPass() { 168 return std::make_unique<ConvertShapeToStandardPass>(); 169 } 170