1*e027c008SLei Zhang //===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===// 2*e027c008SLei Zhang // 3*e027c008SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*e027c008SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 5*e027c008SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*e027c008SLei Zhang // 7*e027c008SLei Zhang //===----------------------------------------------------------------------===// 8*e027c008SLei Zhang // 9*e027c008SLei Zhang // This file implements patterns to wrap a tensor.pad op with an scf.if op 10*e027c008SLei Zhang /// to separate the cases where we don't need padding (all pad sizes are 11*e027c008SLei Zhang /// actually zeros) and where we indeed need padding. 12*e027c008SLei Zhang // 13*e027c008SLei Zhang //===----------------------------------------------------------------------===// 14*e027c008SLei Zhang 15*e027c008SLei Zhang #include "mlir/Dialect/SCF/SCF.h" 16*e027c008SLei Zhang #include "mlir/Dialect/Tensor/IR/Tensor.h" 17*e027c008SLei Zhang #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 18*e027c008SLei Zhang #include "mlir/Dialect/Utils/StaticValueUtils.h" 19*e027c008SLei Zhang #include "mlir/IR/PatternMatch.h" 20*e027c008SLei Zhang #include "llvm/Support/Debug.h" 21*e027c008SLei Zhang 22*e027c008SLei Zhang #define DEBUG_TYPE "mlir-tensor-split-padding" 23*e027c008SLei Zhang 24*e027c008SLei Zhang using namespace mlir; 25*e027c008SLei Zhang 26*e027c008SLei Zhang /// Returns true if the the given `attrOrValue` is a constant zero. 27*e027c008SLei Zhang static bool isZero(OpFoldResult attrOrValue) { 28*e027c008SLei Zhang if (Optional<int64_t> val = getConstantIntValue(attrOrValue)) 29*e027c008SLei Zhang return val.getValue() == 0; 30*e027c008SLei Zhang return false; 31*e027c008SLei Zhang } 32*e027c008SLei Zhang 33*e027c008SLei Zhang /// Gets the given `attrOrValue` as a Value by creating constant ops for 34*e027c008SLei Zhang /// attributes. 35*e027c008SLei Zhang static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder, 36*e027c008SLei Zhang Location loc) { 37*e027c008SLei Zhang if (Value val = attrOrValue.dyn_cast<Value>()) 38*e027c008SLei Zhang return val; 39*e027c008SLei Zhang auto attr = attrOrValue.get<Attribute>().cast<IntegerAttr>(); 40*e027c008SLei Zhang return builder.create<arith::ConstantIndexOp>(loc, attr.getInt()); 41*e027c008SLei Zhang } 42*e027c008SLei Zhang 43*e027c008SLei Zhang namespace { 44*e027c008SLei Zhang 45*e027c008SLei Zhang struct SplitPadding final : public OpRewritePattern<tensor::PadOp> { 46*e027c008SLei Zhang using OpRewritePattern::OpRewritePattern; 47*e027c008SLei Zhang 48*e027c008SLei Zhang LogicalResult matchAndRewrite(tensor::PadOp padOp, 49*e027c008SLei Zhang PatternRewriter &rewriter) const override { 50*e027c008SLei Zhang // Avoid infinitely applying this pattern. 51*e027c008SLei Zhang if (padOp->getParentOfType<scf::IfOp>()) 52*e027c008SLei Zhang return failure(); 53*e027c008SLei Zhang 54*e027c008SLei Zhang // If all padding sizes are zero, we don't need to do anything. 55*e027c008SLei Zhang SmallVector<OpFoldResult> lowPads = padOp.getMixedLowPad(); 56*e027c008SLei Zhang SmallVector<OpFoldResult> highPads = padOp.getMixedHighPad(); 57*e027c008SLei Zhang if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) 58*e027c008SLei Zhang return failure(); 59*e027c008SLei Zhang 60*e027c008SLei Zhang // Build the condition for the scf.if op: all pad sizes are zero. 61*e027c008SLei Zhang Location loc = padOp.getLoc(); 62*e027c008SLei Zhang Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 63*e027c008SLei Zhang SmallVector<Value> eqZeroCmpVals; 64*e027c008SLei Zhang for (OpFoldResult pad : llvm::concat<OpFoldResult>(lowPads, highPads)) { 65*e027c008SLei Zhang if (!isZero(pad)) 66*e027c008SLei Zhang eqZeroCmpVals.push_back(rewriter.create<arith::CmpIOp>( 67*e027c008SLei Zhang loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc), 68*e027c008SLei Zhang cstZero)); 69*e027c008SLei Zhang } 70*e027c008SLei Zhang Value ifCond = eqZeroCmpVals.front(); 71*e027c008SLei Zhang for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front()) 72*e027c008SLei Zhang ifCond = rewriter.create<arith::AndIOp>(loc, ifCond, cmp); 73*e027c008SLei Zhang 74*e027c008SLei Zhang // Build the scf.if op itself. For the "then" branch, we can elide the 75*e027c008SLei Zhang // padding. For the "else" branch, we retain the clone op. 76*e027c008SLei Zhang auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) { 77*e027c008SLei Zhang builder.create<scf::YieldOp>(loc, padOp.source()); 78*e027c008SLei Zhang }; 79*e027c008SLei Zhang auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) { 80*e027c008SLei Zhang Operation *newOp = builder.clone(*padOp); 81*e027c008SLei Zhang builder.create<scf::YieldOp>(loc, newOp->getResults()); 82*e027c008SLei Zhang }; 83*e027c008SLei Zhang rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond, 84*e027c008SLei Zhang thenBuilder, elseBuilder); 85*e027c008SLei Zhang return success(); 86*e027c008SLei Zhang } 87*e027c008SLei Zhang }; 88*e027c008SLei Zhang 89*e027c008SLei Zhang } // namespace 90*e027c008SLei Zhang 91*e027c008SLei Zhang void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns, 92*e027c008SLei Zhang PatternBenefit baseBenefit) { 93*e027c008SLei Zhang patterns.add<SplitPadding>(patterns.getContext(), baseBenefit); 94*e027c008SLei Zhang } 95