1e027c008SLei Zhang //===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===//
2e027c008SLei Zhang //
3e027c008SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e027c008SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
5e027c008SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e027c008SLei Zhang //
7e027c008SLei Zhang //===----------------------------------------------------------------------===//
8e027c008SLei Zhang //
9e027c008SLei Zhang // This file implements patterns to wrap a tensor.pad op with an scf.if op
10e027c008SLei Zhang /// to separate the cases where we don't need padding (all pad sizes are
11e027c008SLei Zhang /// actually zeros) and where we indeed need padding.
12e027c008SLei Zhang //
13e027c008SLei Zhang //===----------------------------------------------------------------------===//
14e027c008SLei Zhang 
15eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
17e027c008SLei Zhang #include "mlir/Dialect/Tensor/IR/Tensor.h"
18e027c008SLei Zhang #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19e027c008SLei Zhang #include "mlir/Dialect/Utils/StaticValueUtils.h"
20e027c008SLei Zhang #include "mlir/IR/PatternMatch.h"
21e027c008SLei Zhang #include "llvm/Support/Debug.h"
22e027c008SLei Zhang 
23e027c008SLei Zhang #define DEBUG_TYPE "mlir-tensor-split-padding"
24e027c008SLei Zhang 
25e027c008SLei Zhang using namespace mlir;
26e027c008SLei Zhang 
27e027c008SLei Zhang /// Returns true if the the given `attrOrValue` is a constant zero.
isZero(OpFoldResult attrOrValue)28e027c008SLei Zhang static bool isZero(OpFoldResult attrOrValue) {
29e027c008SLei Zhang   if (Optional<int64_t> val = getConstantIntValue(attrOrValue))
306d5fc1e3SKazu Hirata     return *val == 0;
31e027c008SLei Zhang   return false;
32e027c008SLei Zhang }
33e027c008SLei Zhang 
34e027c008SLei Zhang /// Gets the given `attrOrValue` as a Value by creating constant ops for
35e027c008SLei Zhang /// attributes.
getAsValue(OpFoldResult attrOrValue,OpBuilder & builder,Location loc)36e027c008SLei Zhang static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder,
37e027c008SLei Zhang                         Location loc) {
38e027c008SLei Zhang   if (Value val = attrOrValue.dyn_cast<Value>())
39e027c008SLei Zhang     return val;
40e027c008SLei Zhang   auto attr = attrOrValue.get<Attribute>().cast<IntegerAttr>();
41e027c008SLei Zhang   return builder.create<arith::ConstantIndexOp>(loc, attr.getInt());
42e027c008SLei Zhang }
43e027c008SLei Zhang 
44e027c008SLei Zhang namespace {
45e027c008SLei Zhang 
46e027c008SLei Zhang struct SplitPadding final : public OpRewritePattern<tensor::PadOp> {
47e027c008SLei Zhang   using OpRewritePattern::OpRewritePattern;
48e027c008SLei Zhang 
matchAndRewrite__anonffd38a090111::SplitPadding49e027c008SLei Zhang   LogicalResult matchAndRewrite(tensor::PadOp padOp,
50e027c008SLei Zhang                                 PatternRewriter &rewriter) const override {
51e027c008SLei Zhang     // Avoid infinitely applying this pattern.
52e027c008SLei Zhang     if (padOp->getParentOfType<scf::IfOp>())
53e027c008SLei Zhang       return failure();
54e027c008SLei Zhang 
55e027c008SLei Zhang     // If all padding sizes are zero, we don't need to do anything.
56e027c008SLei Zhang     SmallVector<OpFoldResult> lowPads = padOp.getMixedLowPad();
57e027c008SLei Zhang     SmallVector<OpFoldResult> highPads = padOp.getMixedHighPad();
58e027c008SLei Zhang     if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero))
59e027c008SLei Zhang       return failure();
60e027c008SLei Zhang 
61e027c008SLei Zhang     // Build the condition for the scf.if op: all pad sizes are zero.
62e027c008SLei Zhang     Location loc = padOp.getLoc();
63e027c008SLei Zhang     Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
64e027c008SLei Zhang     SmallVector<Value> eqZeroCmpVals;
65e027c008SLei Zhang     for (OpFoldResult pad : llvm::concat<OpFoldResult>(lowPads, highPads)) {
66e027c008SLei Zhang       if (!isZero(pad))
67e027c008SLei Zhang         eqZeroCmpVals.push_back(rewriter.create<arith::CmpIOp>(
68e027c008SLei Zhang             loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc),
69e027c008SLei Zhang             cstZero));
70e027c008SLei Zhang     }
71e027c008SLei Zhang     Value ifCond = eqZeroCmpVals.front();
72e027c008SLei Zhang     for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front())
73e027c008SLei Zhang       ifCond = rewriter.create<arith::AndIOp>(loc, ifCond, cmp);
74e027c008SLei Zhang 
75e027c008SLei Zhang     // Build the scf.if op itself. For the "then" branch, we can elide the
76e027c008SLei Zhang     // padding. For the "else" branch, we retain the clone op.
77e027c008SLei Zhang     auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) {
78*2d70eff8SJacques Pienaar       builder.create<scf::YieldOp>(loc, padOp.getSource());
79e027c008SLei Zhang     };
80e027c008SLei Zhang     auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) {
81e027c008SLei Zhang       Operation *newOp = builder.clone(*padOp);
82e027c008SLei Zhang       builder.create<scf::YieldOp>(loc, newOp->getResults());
83e027c008SLei Zhang     };
84e027c008SLei Zhang     rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond,
85e027c008SLei Zhang                                            thenBuilder, elseBuilder);
86e027c008SLei Zhang     return success();
87e027c008SLei Zhang   }
88e027c008SLei Zhang };
89e027c008SLei Zhang 
90e027c008SLei Zhang } // namespace
91e027c008SLei Zhang 
populateSplitPaddingPatterns(RewritePatternSet & patterns,PatternBenefit baseBenefit)92e027c008SLei Zhang void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns,
93e027c008SLei Zhang                                           PatternBenefit baseBenefit) {
94e027c008SLei Zhang   patterns.add<SplitPadding>(patterns.getContext(), baseBenefit);
95e027c008SLei Zhang }
96