1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 13 #include "mlir/Dialect/SCF/IR/SCF.h" 14 #include "mlir/Dialect/Shape/IR/Shape.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassRegistry.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 namespace { 23 #include "ShapeToStandard.cpp.inc" 24 } // namespace 25 26 namespace { 27 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> { 28 public: 29 using OpRewritePattern::OpRewritePattern; 30 LogicalResult matchAndRewrite(shape::CstrRequireOp op, 31 PatternRewriter &rewriter) const override { 32 rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr()); 33 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); 34 return success(); 35 } 36 }; 37 } // namespace 38 39 void mlir::populateConvertShapeConstraintsConversionPatterns( 40 RewritePatternSet &patterns) { 41 patterns.add<CstrBroadcastableToRequire>(patterns.getContext()); 42 patterns.add<CstrEqToRequire>(patterns.getContext()); 43 patterns.add<ConvertCstrRequireOp>(patterns.getContext()); 44 } 45 46 namespace { 47 // This pass eliminates shape constraints from the program, converting them to 48 // eager (side-effecting) error handling code. After eager error handling code 49 // is emitted, witnesses are satisfied, so they are replace with 50 // `shape.const_witness true`. 51 class ConvertShapeConstraints 52 : public ConvertShapeConstraintsBase<ConvertShapeConstraints> { 53 void runOnOperation() override { 54 auto *func = getOperation(); 55 auto *context = &getContext(); 56 57 RewritePatternSet patterns(context); 58 populateConvertShapeConstraintsConversionPatterns(patterns); 59 60 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) 61 return signalPassFailure(); 62 } 63 }; 64 } // namespace 65 66 std::unique_ptr<Pass> mlir::createConvertShapeConstraintsPass() { 67 return std::make_unique<ConvertShapeConstraints>(); 68 } 69