1 //===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to wrap a tensor.pad op with an scf.if op 10 /// to separate the cases where we don't need padding (all pad sizes are 11 /// actually zeros) and where we indeed need padding. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 19 #include "mlir/Dialect/Utils/StaticValueUtils.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "llvm/Support/Debug.h" 22 23 #define DEBUG_TYPE "mlir-tensor-split-padding" 24 25 using namespace mlir; 26 27 /// Returns true if the the given `attrOrValue` is a constant zero. 28 static bool isZero(OpFoldResult attrOrValue) { 29 if (Optional<int64_t> val = getConstantIntValue(attrOrValue)) 30 return val.getValue() == 0; 31 return false; 32 } 33 34 /// Gets the given `attrOrValue` as a Value by creating constant ops for 35 /// attributes. 36 static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder, 37 Location loc) { 38 if (Value val = attrOrValue.dyn_cast<Value>()) 39 return val; 40 auto attr = attrOrValue.get<Attribute>().cast<IntegerAttr>(); 41 return builder.create<arith::ConstantIndexOp>(loc, attr.getInt()); 42 } 43 44 namespace { 45 46 struct SplitPadding final : public OpRewritePattern<tensor::PadOp> { 47 using OpRewritePattern::OpRewritePattern; 48 49 LogicalResult matchAndRewrite(tensor::PadOp padOp, 50 PatternRewriter &rewriter) const override { 51 // Avoid infinitely applying this pattern. 52 if (padOp->getParentOfType<scf::IfOp>()) 53 return failure(); 54 55 // If all padding sizes are zero, we don't need to do anything. 56 SmallVector<OpFoldResult> lowPads = padOp.getMixedLowPad(); 57 SmallVector<OpFoldResult> highPads = padOp.getMixedHighPad(); 58 if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) 59 return failure(); 60 61 // Build the condition for the scf.if op: all pad sizes are zero. 62 Location loc = padOp.getLoc(); 63 Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 64 SmallVector<Value> eqZeroCmpVals; 65 for (OpFoldResult pad : llvm::concat<OpFoldResult>(lowPads, highPads)) { 66 if (!isZero(pad)) 67 eqZeroCmpVals.push_back(rewriter.create<arith::CmpIOp>( 68 loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc), 69 cstZero)); 70 } 71 Value ifCond = eqZeroCmpVals.front(); 72 for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front()) 73 ifCond = rewriter.create<arith::AndIOp>(loc, ifCond, cmp); 74 75 // Build the scf.if op itself. For the "then" branch, we can elide the 76 // padding. For the "else" branch, we retain the clone op. 77 auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) { 78 builder.create<scf::YieldOp>(loc, padOp.source()); 79 }; 80 auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) { 81 Operation *newOp = builder.clone(*padOp); 82 builder.create<scf::YieldOp>(loc, newOp->getResults()); 83 }; 84 rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond, 85 thenBuilder, elseBuilder); 86 return success(); 87 } 88 }; 89 90 } // namespace 91 92 void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns, 93 PatternBenefit baseBenefit) { 94 patterns.add<SplitPadding>(patterns.getContext(), baseBenefit); 95 } 96