1e027c008SLei Zhang //===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===// 2e027c008SLei Zhang // 3e027c008SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e027c008SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 5e027c008SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e027c008SLei Zhang // 7e027c008SLei Zhang //===----------------------------------------------------------------------===// 8e027c008SLei Zhang // 9e027c008SLei Zhang // This file implements patterns to wrap a tensor.pad op with an scf.if op 10e027c008SLei Zhang /// to separate the cases where we don't need padding (all pad sizes are 11e027c008SLei Zhang /// actually zeros) and where we indeed need padding. 12e027c008SLei Zhang // 13e027c008SLei Zhang //===----------------------------------------------------------------------===// 14e027c008SLei Zhang 15eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 17e027c008SLei Zhang #include "mlir/Dialect/Tensor/IR/Tensor.h" 18e027c008SLei Zhang #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 19e027c008SLei Zhang #include "mlir/Dialect/Utils/StaticValueUtils.h" 20e027c008SLei Zhang #include "mlir/IR/PatternMatch.h" 21e027c008SLei Zhang #include "llvm/Support/Debug.h" 22e027c008SLei Zhang 23e027c008SLei Zhang #define DEBUG_TYPE "mlir-tensor-split-padding" 24e027c008SLei Zhang 25e027c008SLei Zhang using namespace mlir; 26e027c008SLei Zhang 27e027c008SLei Zhang /// Returns true if the the given `attrOrValue` is a constant zero. 28e027c008SLei Zhang static bool isZero(OpFoldResult attrOrValue) { 29e027c008SLei Zhang if (Optional<int64_t> val = getConstantIntValue(attrOrValue)) 30e027c008SLei Zhang return val.getValue() == 0; 31e027c008SLei Zhang return false; 32e027c008SLei Zhang } 33e027c008SLei Zhang 34e027c008SLei Zhang /// Gets the given `attrOrValue` as a Value by creating constant ops for 35e027c008SLei Zhang /// attributes. 36e027c008SLei Zhang static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder, 37e027c008SLei Zhang Location loc) { 38e027c008SLei Zhang if (Value val = attrOrValue.dyn_cast<Value>()) 39e027c008SLei Zhang return val; 40e027c008SLei Zhang auto attr = attrOrValue.get<Attribute>().cast<IntegerAttr>(); 41e027c008SLei Zhang return builder.create<arith::ConstantIndexOp>(loc, attr.getInt()); 42e027c008SLei Zhang } 43e027c008SLei Zhang 44e027c008SLei Zhang namespace { 45e027c008SLei Zhang 46e027c008SLei Zhang struct SplitPadding final : public OpRewritePattern<tensor::PadOp> { 47e027c008SLei Zhang using OpRewritePattern::OpRewritePattern; 48e027c008SLei Zhang 49e027c008SLei Zhang LogicalResult matchAndRewrite(tensor::PadOp padOp, 50e027c008SLei Zhang PatternRewriter &rewriter) const override { 51e027c008SLei Zhang // Avoid infinitely applying this pattern. 52e027c008SLei Zhang if (padOp->getParentOfType<scf::IfOp>()) 53e027c008SLei Zhang return failure(); 54e027c008SLei Zhang 55e027c008SLei Zhang // If all padding sizes are zero, we don't need to do anything. 56e027c008SLei Zhang SmallVector<OpFoldResult> lowPads = padOp.getMixedLowPad(); 57e027c008SLei Zhang SmallVector<OpFoldResult> highPads = padOp.getMixedHighPad(); 58e027c008SLei Zhang if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) 59e027c008SLei Zhang return failure(); 60e027c008SLei Zhang 61e027c008SLei Zhang // Build the condition for the scf.if op: all pad sizes are zero. 62e027c008SLei Zhang Location loc = padOp.getLoc(); 63e027c008SLei Zhang Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 64e027c008SLei Zhang SmallVector<Value> eqZeroCmpVals; 65e027c008SLei Zhang for (OpFoldResult pad : llvm::concat<OpFoldResult>(lowPads, highPads)) { 66e027c008SLei Zhang if (!isZero(pad)) 67e027c008SLei Zhang eqZeroCmpVals.push_back(rewriter.create<arith::CmpIOp>( 68e027c008SLei Zhang loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc), 69e027c008SLei Zhang cstZero)); 70e027c008SLei Zhang } 71e027c008SLei Zhang Value ifCond = eqZeroCmpVals.front(); 72e027c008SLei Zhang for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front()) 73e027c008SLei Zhang ifCond = rewriter.create<arith::AndIOp>(loc, ifCond, cmp); 74e027c008SLei Zhang 75e027c008SLei Zhang // Build the scf.if op itself. For the "then" branch, we can elide the 76e027c008SLei Zhang // padding. For the "else" branch, we retain the clone op. 77e027c008SLei Zhang auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) { 78e027c008SLei Zhang builder.create<scf::YieldOp>(loc, padOp.source()); 79e027c008SLei Zhang }; 80e027c008SLei Zhang auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) { 81e027c008SLei Zhang Operation *newOp = builder.clone(*padOp); 82e027c008SLei Zhang builder.create<scf::YieldOp>(loc, newOp->getResults()); 83e027c008SLei Zhang }; 84e027c008SLei Zhang rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond, 85e027c008SLei Zhang thenBuilder, elseBuilder); 86e027c008SLei Zhang return success(); 87e027c008SLei Zhang } 88e027c008SLei Zhang }; 89e027c008SLei Zhang 90e027c008SLei Zhang } // namespace 91e027c008SLei Zhang 92e027c008SLei Zhang void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns, 93e027c008SLei Zhang PatternBenefit baseBenefit) { 94e027c008SLei Zhang patterns.add<SplitPadding>(patterns.getContext(), baseBenefit); 95e027c008SLei Zhang } 96