1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::shape; 19 20 namespace { 21 22 /// Conversion patterns. 23 template <typename SrcOpTy, typename DstOpTy> 24 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 25 public: 26 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override { 31 typename SrcOpTy::OperandAdaptor adaptor(operands); 32 rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(), 33 adaptor.rhs()); 34 return success(); 35 } 36 }; 37 38 class FromExtentTensorOpConversion 39 : public OpConversionPattern<shape::FromExtentTensorOp> { 40 public: 41 using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern; 42 43 LogicalResult 44 matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands, 45 ConversionPatternRewriter &rewriter) const override { 46 shape::FromExtentTensorOpOperandAdaptor transformed(operands); 47 rewriter.replaceOp(op.getOperation(), transformed.input()); 48 return success(); 49 } 50 }; 51 52 class IndexToSizeOpConversion 53 : public OpConversionPattern<shape::IndexToSizeOp> { 54 public: 55 using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern; 56 57 LogicalResult 58 matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands, 59 ConversionPatternRewriter &rewriter) const override { 60 shape::IndexToSizeOpOperandAdaptor transformed(operands); 61 rewriter.replaceOp(op.getOperation(), transformed.arg()); 62 return success(); 63 } 64 }; 65 66 class SizeToIndexOpConversion 67 : public OpConversionPattern<shape::SizeToIndexOp> { 68 public: 69 using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands, 73 ConversionPatternRewriter &rewriter) const override { 74 shape::SizeToIndexOpOperandAdaptor transformed(operands); 75 rewriter.replaceOp(op.getOperation(), transformed.arg()); 76 return success(); 77 } 78 }; 79 80 class ToExtentTensorOpConversion 81 : public OpConversionPattern<shape::ToExtentTensorOp> { 82 public: 83 using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern; 84 85 LogicalResult 86 matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands, 87 ConversionPatternRewriter &rewriter) const override { 88 shape::ToExtentTensorOpOperandAdaptor transformed(operands); 89 rewriter.replaceOp(op.getOperation(), transformed.input()); 90 return success(); 91 } 92 }; 93 94 /// Type conversions. 95 class ShapeTypeConverter : public TypeConverter { 96 public: 97 using TypeConverter::convertType; 98 99 ShapeTypeConverter(MLIRContext *ctx) { 100 // Add default pass-through conversion. 101 addConversion([&](Type type) { return type; }); 102 103 addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); }); 104 addConversion([ctx](shape::ShapeType type) { 105 return RankedTensorType::get({ShapedType::kDynamicSize}, 106 IndexType::get(ctx)); 107 }); 108 } 109 }; 110 111 /// Conversion pass. 112 class ConvertShapeToStandardPass 113 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 114 115 void runOnOperation() override { 116 117 // Setup type conversion. 118 MLIRContext &ctx = getContext(); 119 ShapeTypeConverter typeConverter(&ctx); 120 121 // Setup target legality. 122 ConversionTarget target(ctx); 123 target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>(); 124 target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>(); 125 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 126 return typeConverter.isSignatureLegal(op.getType()); 127 }); 128 129 // Setup conversion patterns. 130 OwningRewritePatternList patterns; 131 populateShapeToStandardConversionPatterns(patterns, &ctx); 132 populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); 133 134 // Apply conversion. 135 auto module = getOperation(); 136 if (failed(applyFullConversion(module, target, patterns, &typeConverter))) 137 signalPassFailure(); 138 } 139 }; 140 141 } // namespace 142 143 void mlir::populateShapeToStandardConversionPatterns( 144 OwningRewritePatternList &patterns, MLIRContext *ctx) { 145 // clang-format off 146 patterns.insert< 147 BinaryOpConversion<AddOp, AddIOp>, 148 BinaryOpConversion<MulOp, MulIOp>, 149 FromExtentTensorOpConversion, 150 IndexToSizeOpConversion, 151 SizeToIndexOpConversion, 152 ToExtentTensorOpConversion>(ctx); 153 // clang-format on 154 } 155 156 std::unique_ptr<OperationPass<ModuleOp>> 157 mlir::createConvertShapeToStandardPass() { 158 return std::make_unique<ConvertShapeToStandardPass>(); 159 } 160