1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassRegistry.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 23 namespace { 24 class ConvertCstrBroadcastableOp 25 : public OpRewritePattern<shape::CstrBroadcastableOp> { 26 public: 27 using OpRewritePattern::OpRewritePattern; 28 LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, 29 PatternRewriter &rewriter) const override { 30 if (op.getType().isa<shape::ShapeType>() || 31 op.lhs().getType().isa<shape::ShapeType>() || 32 op.rhs().getType().isa<shape::ShapeType>()) { 33 return rewriter.notifyMatchFailure( 34 op, "cannot convert error-propagating shapes"); 35 } 36 37 auto loc = op.getLoc(); 38 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 39 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 40 41 // Find smaller and greater rank and extent tensor. 42 Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero); 43 Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero); 44 Value lhsRankULE = 45 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 46 Type indexTy = rewriter.getIndexType(); 47 Value lesserRank = 48 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); 49 Value greaterRank = 50 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); 51 Value lesserRankOperand = 52 rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs()); 53 Value greaterRankOperand = 54 rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs()); 55 56 Value rankDiff = 57 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 58 59 // Generate code to compare the shapes extent by extent, and emit errors for 60 // non-broadcast-compatible shapes. 61 // Two extents are broadcast-compatible if 62 // 1. they are both equal, or 63 // 2. at least one of them is 1. 64 65 rewriter.create<scf::ForOp>( 66 loc, rankDiff, greaterRank, one, llvm::None, 67 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 68 Value greaterRankOperandExtent = b.create<tensor::ExtractOp>( 69 loc, greaterRankOperand, ValueRange{iv}); 70 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); 71 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 72 loc, lesserRankOperand, ValueRange{ivShifted}); 73 74 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( 75 loc, CmpIPredicate::eq, greaterRankOperandExtent, one); 76 Value lesserRankOperandExtentIsOne = b.create<CmpIOp>( 77 loc, CmpIPredicate::eq, lesserRankOperandExtent, one); 78 Value extentsAgree = 79 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent, 80 lesserRankOperandExtent); 81 auto broadcastIsValid = 82 b.create<OrOp>(loc, b.getI1Type(), extentsAgree, 83 b.create<OrOp>(loc, greaterRankOperandExtentIsOne, 84 lesserRankOperandExtentIsOne)); 85 b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast"); 86 b.create<scf::YieldOp>(loc); 87 }); 88 89 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); 90 return success(); 91 } 92 }; 93 } // namespace 94 95 namespace { 96 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> { 97 public: 98 using OpRewritePattern::OpRewritePattern; 99 LogicalResult matchAndRewrite(shape::CstrRequireOp op, 100 PatternRewriter &rewriter) const override { 101 rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr()); 102 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); 103 return success(); 104 } 105 }; 106 } // namespace 107 108 void mlir::populateConvertShapeConstraintsConversionPatterns( 109 OwningRewritePatternList &patterns, MLIRContext *ctx) { 110 patterns.insert<ConvertCstrBroadcastableOp>(ctx); 111 patterns.insert<ConvertCstrRequireOp>(ctx); 112 } 113 114 namespace { 115 // This pass eliminates shape constraints from the program, converting them to 116 // eager (side-effecting) error handling code. After eager error handling code 117 // is emitted, witnesses are satisfied, so they are replace with 118 // `shape.const_witness true`. 119 class ConvertShapeConstraints 120 : public ConvertShapeConstraintsBase<ConvertShapeConstraints> { 121 void runOnOperation() override { 122 auto func = getOperation(); 123 auto *context = &getContext(); 124 125 OwningRewritePatternList patterns; 126 populateConvertShapeConstraintsConversionPatterns(patterns, context); 127 128 if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) 129 return signalPassFailure(); 130 } 131 }; 132 } // namespace 133 134 std::unique_ptr<OperationPass<FuncOp>> 135 mlir::createConvertShapeConstraintsPass() { 136 return std::make_unique<ConvertShapeConstraints>(); 137 } 138