1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/Shape/IR/Shape.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Pass/PassRegistry.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 namespace { 24 #include "ShapeToStandard.cpp.inc" 25 } // namespace 26 27 namespace { 28 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> { 29 public: 30 using OpRewritePattern::OpRewritePattern; 31 LogicalResult matchAndRewrite(shape::CstrRequireOp op, 32 PatternRewriter &rewriter) const override { 33 rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr()); 34 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); 35 return success(); 36 } 37 }; 38 } // namespace 39 40 void mlir::populateConvertShapeConstraintsConversionPatterns( 41 RewritePatternSet &patterns) { 42 patterns.add<CstrBroadcastableToRequire>(patterns.getContext()); 43 patterns.add<CstrEqToRequire>(patterns.getContext()); 44 patterns.add<ConvertCstrRequireOp>(patterns.getContext()); 45 } 46 47 namespace { 48 // This pass eliminates shape constraints from the program, converting them to 49 // eager (side-effecting) error handling code. After eager error handling code 50 // is emitted, witnesses are satisfied, so they are replace with 51 // `shape.const_witness true`. 52 class ConvertShapeConstraints 53 : public ConvertShapeConstraintsBase<ConvertShapeConstraints> { 54 void runOnOperation() override { 55 auto func = getOperation(); 56 auto *context = &getContext(); 57 58 RewritePatternSet patterns(context); 59 populateConvertShapeConstraintsConversionPatterns(patterns); 60 61 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) 62 return signalPassFailure(); 63 } 64 }; 65 } // namespace 66 67 std::unique_ptr<OperationPass<FuncOp>> 68 mlir::createConvertShapeConstraintsPass() { 69 return std::make_unique<ConvertShapeConstraints>(); 70 } 71