1 //===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to wrap a tensor.pad op with an scf.if op 10 /// to separate the cases where we don't need padding (all pad sizes are 11 /// actually zeros) and where we indeed need padding. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 18 #include "mlir/Dialect/Utils/StaticValueUtils.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "llvm/Support/Debug.h" 21 22 #define DEBUG_TYPE "mlir-tensor-split-padding" 23 24 using namespace mlir; 25 26 /// Returns true if the the given `attrOrValue` is a constant zero. 27 static bool isZero(OpFoldResult attrOrValue) { 28 if (Optional<int64_t> val = getConstantIntValue(attrOrValue)) 29 return val.getValue() == 0; 30 return false; 31 } 32 33 /// Gets the given `attrOrValue` as a Value by creating constant ops for 34 /// attributes. 35 static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder, 36 Location loc) { 37 if (Value val = attrOrValue.dyn_cast<Value>()) 38 return val; 39 auto attr = attrOrValue.get<Attribute>().cast<IntegerAttr>(); 40 return builder.create<arith::ConstantIndexOp>(loc, attr.getInt()); 41 } 42 43 namespace { 44 45 struct SplitPadding final : public OpRewritePattern<tensor::PadOp> { 46 using OpRewritePattern::OpRewritePattern; 47 48 LogicalResult matchAndRewrite(tensor::PadOp padOp, 49 PatternRewriter &rewriter) const override { 50 // Avoid infinitely applying this pattern. 51 if (padOp->getParentOfType<scf::IfOp>()) 52 return failure(); 53 54 // If all padding sizes are zero, we don't need to do anything. 55 SmallVector<OpFoldResult> lowPads = padOp.getMixedLowPad(); 56 SmallVector<OpFoldResult> highPads = padOp.getMixedHighPad(); 57 if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) 58 return failure(); 59 60 // Build the condition for the scf.if op: all pad sizes are zero. 61 Location loc = padOp.getLoc(); 62 Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 63 SmallVector<Value> eqZeroCmpVals; 64 for (OpFoldResult pad : llvm::concat<OpFoldResult>(lowPads, highPads)) { 65 if (!isZero(pad)) 66 eqZeroCmpVals.push_back(rewriter.create<arith::CmpIOp>( 67 loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc), 68 cstZero)); 69 } 70 Value ifCond = eqZeroCmpVals.front(); 71 for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front()) 72 ifCond = rewriter.create<arith::AndIOp>(loc, ifCond, cmp); 73 74 // Build the scf.if op itself. For the "then" branch, we can elide the 75 // padding. For the "else" branch, we retain the clone op. 76 auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) { 77 builder.create<scf::YieldOp>(loc, padOp.source()); 78 }; 79 auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) { 80 Operation *newOp = builder.clone(*padOp); 81 builder.create<scf::YieldOp>(loc, newOp->getResults()); 82 }; 83 rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond, 84 thenBuilder, elseBuilder); 85 return success(); 86 } 87 }; 88 89 } // namespace 90 91 void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns, 92 PatternBenefit baseBenefit) { 93 patterns.add<SplitPadding>(patterns.getContext(), baseBenefit); 94 } 95