1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Pass/PassRegistry.h" 18 19 using namespace mlir; 20 21 namespace { 22 class ConvertCstrBroadcastableOp 23 : public OpRewritePattern<shape::CstrBroadcastableOp> { 24 public: 25 using OpRewritePattern::OpRewritePattern; 26 LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, 27 PatternRewriter &rewriter) const override { 28 if (op.getType().isa<shape::ShapeType>() || 29 op.lhs().getType().isa<shape::ShapeType>() || 30 op.rhs().getType().isa<shape::ShapeType>()) { 31 return rewriter.notifyMatchFailure( 32 op, "cannot convert error-propagating shapes"); 33 } 34 35 auto loc = op.getLoc(); 36 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 37 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 38 39 // Find smaller and greater rank and extent tensor. 40 Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero); 41 Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero); 42 Value lhsSmaller = 43 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 44 Type indexTy = rewriter.getIndexType(); 45 Type extentTensorTy = op.lhs().getType(); 46 auto ifOp = rewriter.create<scf::IfOp>( 47 loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, 48 lhsSmaller, 49 [&](OpBuilder &b, Location loc) { 50 b.create<scf::YieldOp>( 51 loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()}); 52 }, 53 [&](OpBuilder &b, Location loc) { 54 b.create<scf::YieldOp>( 55 loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()}); 56 }); 57 Value lesserRank = ifOp.getResult(0); 58 Value lesserRankOperand = ifOp.getResult(1); 59 Value greaterRank = ifOp.getResult(2); 60 Value greaterRankOperand = ifOp.getResult(3); 61 62 Value rankDiff = 63 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 64 65 // Generate code to compare the shapes extent by extent, and emit errors for 66 // non-broadcast-compatible shapes. 67 // Two extents are broadcast-compatible if 68 // 1. they are both equal, or 69 // 2. at least one of them is 1. 70 71 rewriter.create<scf::ForOp>( 72 loc, rankDiff, greaterRank, one, llvm::None, 73 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 74 Value greaterRankOperandExtent = b.create<ExtractElementOp>( 75 loc, greaterRankOperand, ValueRange{iv}); 76 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); 77 Value lesserRankOperandExtent = b.create<ExtractElementOp>( 78 loc, lesserRankOperand, ValueRange{ivShifted}); 79 80 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( 81 loc, CmpIPredicate::eq, greaterRankOperandExtent, one); 82 Value lesserRankOperandExtentIsOne = b.create<CmpIOp>( 83 loc, CmpIPredicate::eq, lesserRankOperandExtent, one); 84 Value extentsAgree = 85 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent, 86 lesserRankOperandExtent); 87 auto broadcastIsValid = 88 b.create<OrOp>(loc, b.getI1Type(), extentsAgree, 89 b.create<OrOp>(loc, greaterRankOperandExtentIsOne, 90 lesserRankOperandExtentIsOne)); 91 b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast"); 92 b.create<scf::YieldOp>(loc); 93 }); 94 95 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); 96 return success(); 97 } 98 }; 99 } // namespace 100 101 namespace { 102 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> { 103 public: 104 using OpRewritePattern::OpRewritePattern; 105 LogicalResult matchAndRewrite(shape::CstrRequireOp op, 106 PatternRewriter &rewriter) const override { 107 rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr()); 108 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); 109 return success(); 110 } 111 }; 112 } // namespace 113 114 void mlir::populateConvertShapeConstraintsConversionPatterns( 115 OwningRewritePatternList &patterns, MLIRContext *ctx) { 116 patterns.insert<ConvertCstrBroadcastableOp>(ctx); 117 patterns.insert<ConvertCstrRequireOp>(ctx); 118 } 119 120 namespace { 121 // This pass eliminates shape constraints from the program, converting them to 122 // eager (side-effecting) error handling code. After eager error handling code 123 // is emitted, witnesses are satisfied, so they are replace with 124 // `shape.const_witness true`. 125 class ConvertShapeConstraints 126 : public ConvertShapeConstraintsBase<ConvertShapeConstraints> { 127 void runOnOperation() override { 128 auto func = getOperation(); 129 auto *context = &getContext(); 130 131 OwningRewritePatternList patterns; 132 populateConvertShapeConstraintsConversionPatterns(patterns, context); 133 134 if (failed(applyPatternsAndFoldGreedily(func, patterns))) 135 return signalPassFailure(); 136 } 137 }; 138 } // namespace 139 140 std::unique_ptr<OperationPass<FuncOp>> 141 mlir::createConvertShapeConstraintsPass() { 142 return std::make_unique<ConvertShapeConstraints>(); 143 } 144