1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassRegistry.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 namespace { 23 #include "ShapeToStandard.cpp.inc" 24 } // namespace 25 26 namespace { 27 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> { 28 public: 29 using OpRewritePattern::OpRewritePattern; 30 LogicalResult matchAndRewrite(shape::CstrRequireOp op, 31 PatternRewriter &rewriter) const override { 32 rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr()); 33 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); 34 return success(); 35 } 36 }; 37 } // namespace 38 39 void mlir::populateConvertShapeConstraintsConversionPatterns( 40 OwningRewritePatternList &patterns, MLIRContext *ctx) { 41 patterns.insert<CstrBroadcastableToRequire>(ctx); 42 patterns.insert<ConvertCstrRequireOp>(ctx); 43 } 44 45 namespace { 46 // This pass eliminates shape constraints from the program, converting them to 47 // eager (side-effecting) error handling code. After eager error handling code 48 // is emitted, witnesses are satisfied, so they are replace with 49 // `shape.const_witness true`. 50 class ConvertShapeConstraints 51 : public ConvertShapeConstraintsBase<ConvertShapeConstraints> { 52 void runOnOperation() override { 53 auto func = getOperation(); 54 auto *context = &getContext(); 55 56 OwningRewritePatternList patterns; 57 populateConvertShapeConstraintsConversionPatterns(patterns, context); 58 59 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) 60 return signalPassFailure(); 61 } 62 }; 63 } // namespace 64 65 std::unique_ptr<OperationPass<FuncOp>> 66 mlir::createConvertShapeConstraintsPass() { 67 return std::make_unique<ConvertShapeConstraints>(); 68 } 69