1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 namespace mlir { 18 namespace { 19 20 /// Conversion patterns. 21 class SizeToIndexOpConversion 22 : public OpConversionPattern<shape::SizeToIndexOp> { 23 public: 24 using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern; 25 26 LogicalResult 27 matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands, 28 ConversionPatternRewriter &rewriter) const override { 29 shape::SizeToIndexOpOperandAdaptor transformed(operands); 30 rewriter.replaceOp(op.getOperation(), transformed.arg()); 31 return success(); 32 } 33 }; 34 35 class IndexToSizeOpConversion 36 : public OpConversionPattern<shape::IndexToSizeOp> { 37 public: 38 using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern; 39 40 LogicalResult 41 matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands, 42 ConversionPatternRewriter &rewriter) const override { 43 shape::IndexToSizeOpOperandAdaptor transformed(operands); 44 rewriter.replaceOp(op.getOperation(), transformed.arg()); 45 return success(); 46 } 47 }; 48 49 /// Type conversions. 50 class ShapeTypeConverter : public TypeConverter { 51 public: 52 using TypeConverter::convertType; 53 54 ShapeTypeConverter(MLIRContext *ctx) { 55 // Add default pass-through conversion. 56 addConversion([&](Type type) { return type; }); 57 addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); }); 58 } 59 }; 60 61 /// Conversion pass. 62 class ConvertShapeToStandardPass 63 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 64 65 void runOnOperation() override { 66 67 // Setup type conversion. 68 MLIRContext &ctx = getContext(); 69 ShapeTypeConverter typeConverter(&ctx); 70 71 // Setup target legality. 72 ConversionTarget target(ctx); 73 target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>(); 74 target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>(); 75 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 76 return typeConverter.isSignatureLegal(op.getType()); 77 }); 78 79 // Setup conversion patterns. 80 OwningRewritePatternList patterns; 81 populateShapeToStandardConversionPatterns(patterns, &ctx); 82 populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); 83 84 // Apply conversion. 85 auto module = getOperation(); 86 if (failed(applyFullConversion(module, target, patterns, &typeConverter))) 87 signalPassFailure(); 88 } 89 }; 90 91 } // namespace 92 93 void populateShapeToStandardConversionPatterns( 94 OwningRewritePatternList &patterns, MLIRContext *ctx) { 95 // clang-format off 96 patterns.insert< 97 IndexToSizeOpConversion, 98 SizeToIndexOpConversion>(ctx); 99 // clang-format on 100 } 101 102 std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass() { 103 return std::make_unique<ConvertShapeToStandardPass>(); 104 } 105 106 } // namespace mlir 107