1 //===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to wrap a tensor.pad op with an scf.if op
10 /// to separate the cases where we don't need padding (all pad sizes are
11 /// actually zeros) and where we indeed need padding.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19 #include "mlir/Dialect/Utils/StaticValueUtils.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/Support/Debug.h"
22
23 #define DEBUG_TYPE "mlir-tensor-split-padding"
24
25 using namespace mlir;
26
27 /// Returns true if the the given `attrOrValue` is a constant zero.
isZero(OpFoldResult attrOrValue)28 static bool isZero(OpFoldResult attrOrValue) {
29 if (Optional<int64_t> val = getConstantIntValue(attrOrValue))
30 return *val == 0;
31 return false;
32 }
33
34 /// Gets the given `attrOrValue` as a Value by creating constant ops for
35 /// attributes.
getAsValue(OpFoldResult attrOrValue,OpBuilder & builder,Location loc)36 static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder,
37 Location loc) {
38 if (Value val = attrOrValue.dyn_cast<Value>())
39 return val;
40 auto attr = attrOrValue.get<Attribute>().cast<IntegerAttr>();
41 return builder.create<arith::ConstantIndexOp>(loc, attr.getInt());
42 }
43
44 namespace {
45
46 struct SplitPadding final : public OpRewritePattern<tensor::PadOp> {
47 using OpRewritePattern::OpRewritePattern;
48
matchAndRewrite__anonffd38a090111::SplitPadding49 LogicalResult matchAndRewrite(tensor::PadOp padOp,
50 PatternRewriter &rewriter) const override {
51 // Avoid infinitely applying this pattern.
52 if (padOp->getParentOfType<scf::IfOp>())
53 return failure();
54
55 // If all padding sizes are zero, we don't need to do anything.
56 SmallVector<OpFoldResult> lowPads = padOp.getMixedLowPad();
57 SmallVector<OpFoldResult> highPads = padOp.getMixedHighPad();
58 if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero))
59 return failure();
60
61 // Build the condition for the scf.if op: all pad sizes are zero.
62 Location loc = padOp.getLoc();
63 Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
64 SmallVector<Value> eqZeroCmpVals;
65 for (OpFoldResult pad : llvm::concat<OpFoldResult>(lowPads, highPads)) {
66 if (!isZero(pad))
67 eqZeroCmpVals.push_back(rewriter.create<arith::CmpIOp>(
68 loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc),
69 cstZero));
70 }
71 Value ifCond = eqZeroCmpVals.front();
72 for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front())
73 ifCond = rewriter.create<arith::AndIOp>(loc, ifCond, cmp);
74
75 // Build the scf.if op itself. For the "then" branch, we can elide the
76 // padding. For the "else" branch, we retain the clone op.
77 auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) {
78 builder.create<scf::YieldOp>(loc, padOp.getSource());
79 };
80 auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) {
81 Operation *newOp = builder.clone(*padOp);
82 builder.create<scf::YieldOp>(loc, newOp->getResults());
83 };
84 rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond,
85 thenBuilder, elseBuilder);
86 return success();
87 }
88 };
89
90 } // namespace
91
populateSplitPaddingPatterns(RewritePatternSet & patterns,PatternBenefit baseBenefit)92 void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns,
93 PatternBenefit baseBenefit) {
94 patterns.add<SplitPadding>(patterns.getContext(), baseBenefit);
95 }
96