1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 19 namespace { 20 21 /// Conversion patterns. 22 class FromExtentTensorOpConversion 23 : public OpConversionPattern<shape::FromExtentTensorOp> { 24 public: 25 using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern; 26 27 LogicalResult 28 matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands, 29 ConversionPatternRewriter &rewriter) const override { 30 shape::FromExtentTensorOpOperandAdaptor transformed(operands); 31 rewriter.replaceOp(op.getOperation(), transformed.input()); 32 return success(); 33 } 34 }; 35 36 class IndexToSizeOpConversion 37 : public OpConversionPattern<shape::IndexToSizeOp> { 38 public: 39 using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern; 40 41 LogicalResult 42 matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands, 43 ConversionPatternRewriter &rewriter) const override { 44 shape::IndexToSizeOpOperandAdaptor transformed(operands); 45 rewriter.replaceOp(op.getOperation(), transformed.arg()); 46 return success(); 47 } 48 }; 49 50 class SizeToIndexOpConversion 51 : public OpConversionPattern<shape::SizeToIndexOp> { 52 public: 53 using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern; 54 55 LogicalResult 56 matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands, 57 ConversionPatternRewriter &rewriter) const override { 58 shape::SizeToIndexOpOperandAdaptor transformed(operands); 59 rewriter.replaceOp(op.getOperation(), transformed.arg()); 60 return success(); 61 } 62 }; 63 64 class ToExtentTensorOpConversion 65 : public OpConversionPattern<shape::ToExtentTensorOp> { 66 public: 67 using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern; 68 69 LogicalResult 70 matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands, 71 ConversionPatternRewriter &rewriter) const override { 72 shape::ToExtentTensorOpOperandAdaptor transformed(operands); 73 rewriter.replaceOp(op.getOperation(), transformed.input()); 74 return success(); 75 } 76 }; 77 78 /// Type conversions. 79 class ShapeTypeConverter : public TypeConverter { 80 public: 81 using TypeConverter::convertType; 82 83 ShapeTypeConverter(MLIRContext *ctx) { 84 // Add default pass-through conversion. 85 addConversion([&](Type type) { return type; }); 86 87 addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); }); 88 addConversion([ctx](shape::ShapeType type) { 89 return RankedTensorType::get({ShapedType::kDynamicSize}, 90 IndexType::get(ctx)); 91 }); 92 } 93 }; 94 95 /// Conversion pass. 96 class ConvertShapeToStandardPass 97 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 98 99 void runOnOperation() override { 100 101 // Setup type conversion. 102 MLIRContext &ctx = getContext(); 103 ShapeTypeConverter typeConverter(&ctx); 104 105 // Setup target legality. 106 ConversionTarget target(ctx); 107 target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>(); 108 target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>(); 109 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 110 return typeConverter.isSignatureLegal(op.getType()); 111 }); 112 113 // Setup conversion patterns. 114 OwningRewritePatternList patterns; 115 populateShapeToStandardConversionPatterns(patterns, &ctx); 116 populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); 117 118 // Apply conversion. 119 auto module = getOperation(); 120 if (failed(applyFullConversion(module, target, patterns, &typeConverter))) 121 signalPassFailure(); 122 } 123 }; 124 125 } // namespace 126 127 void mlir::populateShapeToStandardConversionPatterns( 128 OwningRewritePatternList &patterns, MLIRContext *ctx) { 129 // clang-format off 130 patterns.insert< 131 FromExtentTensorOpConversion, 132 IndexToSizeOpConversion, 133 SizeToIndexOpConversion, 134 ToExtentTensorOpConversion>(ctx); 135 // clang-format on 136 } 137 138 std::unique_ptr<OperationPass<ModuleOp>> 139 mlir::createConvertShapeToStandardPass() { 140 return std::make_unique<ConvertShapeToStandardPass>(); 141 } 142