1 //===- TensorTilingInterface.cpp - Tiling Interface  models *- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Interfaces/TilingInterface.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 namespace {
21 
22 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
23 
24   SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
25     ReifiedRankedShapedTypeDims reifiedShapes;
26     ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
27         dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
28     (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes);
29 
30     auto padOp = cast<PadOp>(op);
31     SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]);
32     Value initTensor = b.create<linalg::InitTensorOp>(
33         op->getLoc(), mixedSizes, padOp.getResultType().getElementType());
34     return {initTensor};
35   }
36 
37   SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
38     auto padOp = cast<PadOp>(op);
39     SmallVector<StringRef> iteratorTypes(padOp.getResultType().getRank(),
40                                          getParallelIteratorTypeName());
41     return iteratorTypes;
42   }
43 
44   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
45     ReifiedRankedShapedTypeDims reifiedShapes;
46     ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
47         dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
48     (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes);
49 
50     Location loc = op->getLoc();
51     Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
52     Value one = b.create<arith::ConstantIndexOp>(loc, 1);
53     // Initialize all the ranges to {zero, one, one}. All the `ub`s are
54     // overwritten.
55     SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
56     for (const auto &ub : enumerate(reifiedShapes[0]))
57       loopRanges[ub.index()].size = ub.value();
58     return loopRanges;
59   }
60 
61   SmallVector<Operation *>
62   getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
63                          ArrayRef<OpFoldResult> offsets,
64                          ArrayRef<OpFoldResult> sizes,
65                          bool /*tileDestOperands*/) const {
66     auto padOp = cast<PadOp>(op);
67     // Only constant padding value supported.
68     Value padValue = padOp.getConstantPaddingValue();
69     if (!padValue)
70       return {};
71 
72     // Helper variables and functions for various arithmetic operations. These
73     // are used extensively for computing new offset/length and padding values.
74     Location loc = op->getLoc();
75     AffineExpr dim0, dim1;
76     bindDims(b.getContext(), dim0, dim1);
77     // Add two integers.
78     auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
79     auto add = [&](Value v1, Value v2) {
80       return b.createOrFold<AffineApplyOp>(loc, addMap, ValueRange{v1, v2});
81     };
82     // Subtract two integers.
83     auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
84     auto sub = [&](Value v1, Value v2) {
85       return b.createOrFold<AffineApplyOp>(loc, subMap, ValueRange{v1, v2});
86     };
87     // Take the minimum of two integers.
88     auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
89     auto min = [&](Value v1, Value v2) {
90       return b.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
91     };
92     // Take the maximum of two integers.
93     auto max = [&](Value v1, Value v2) {
94       return b.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
95     };
96     // Zero index-typed integer.
97     auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
98 
99     // Helper function for filling static/dynamic low/high padding indices
100     // vectors of PadOp.
101     auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
102                            SmallVector<int64_t> &staticIndices) {
103       if (auto constInt = getConstantIntValue(val)) {
104         staticIndices.push_back(*constInt);
105       } else {
106         staticIndices.push_back(ShapedType::kDynamicSize);
107         dynIndices.push_back(val);
108       }
109     };
110 
111     // Compute new offsets, lengths, low padding, high padding.
112     SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
113     SmallVector<Value> newLows, newHighs;
114     SmallVector<int64_t> staticNewLows, staticNewHighs;
115     // Set to true if the original data source is not read at all.
116     bool hasZeroLen = false;
117     // Same as hasZeroLen, but for dynamic dimension sizes. This condition
118     // is true if the original data source turns out to be unused at runtime.
119     Value dynHasZeroLenCond;
120 
121     int64_t rank = padOp.getSourceType().getRank();
122     for (unsigned dim = 0; dim < rank; ++dim) {
123       auto low =
124           getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedLowPad()[dim]);
125       bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
126       auto high =
127           getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedHighPad()[dim]);
128       bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
129       auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]);
130       auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]);
131       auto srcSize = b.createOrFold<tensor::DimOp>(loc, padOp.source(), dim);
132 
133       // The new amount of low padding is `low - offset`. Except for the case
134       // where none of the low padding is read. In that case, the new amount of
135       // low padding is zero.
136       //
137       // Optimization: If low = 0, then newLow = 0.
138       Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
139       appendIndex(newLow, newLows, staticNewLows);
140 
141       // Start reading the data from position `offset - low`. Since the original
142       // read may have started in the low padding zone, this value could be
143       // negative. Therefore, start reading from:
144       //
145       // max(offset - low, 0)
146       //
147       // The original read could also have started in the high padding zone.
148       // In that case, set the offset to the end of source tensor. The new
149       // ExtractSliceOp length will be zero in that case. (Effectively reading
150       // no data from the source.)
151       //
152       // Optimization: If low = 0, then the formula can be simplified.
153       Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
154                                   : min(offset, srcSize);
155       newOffsets.push_back(getAsOpFoldResult(newOffset));
156 
157       // The original ExtractSliceOp was reading until position `offset +
158       // length`. Therefore, the corresponding position within the source tensor
159       // is:
160       //
161       // offset + length - low
162       //
163       // In case the original ExtractSliceOp stopped reading within the low
164       // padding zone, this value can be negative. In that case, the end
165       // position of the read should be zero. (Similar to newOffset.)
166       //
167       // The original read could also have stopped in the high padding zone.
168       // In that case, set the end positition of the read should be the end of
169       // the source tensor. (Similar to newOffset.)
170       //
171       // endLoc = min(max(offset - low + length, 0), srcSize)
172       //
173       // The new ExtractSliceOp length is `endLoc - newOffset`.
174       //
175       // Optimization: If low = 0, then the formula can be simplified.
176       Value endLoc =
177           hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize)
178                     : min(add(offset, length), srcSize);
179       Value newLength = sub(endLoc, newOffset);
180       newLengths.push_back(getAsOpFoldResult(newLength));
181 
182       // Check if newLength is zero. In that case, no SubTensorOp should be
183       // executed.
184       if (auto newLengthInt = getConstantIntValue(newLength)) {
185         hasZeroLen |= *newLengthInt == 0;
186       } else {
187         Value check = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
188                                               newLength, zero);
189         dynHasZeroLenCond =
190             dynHasZeroLenCond
191                 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
192                 : check;
193       }
194 
195       // The amount of high padding is simply the number of elements remaining,
196       // so that the result has the same length as the original ExtractSliceOp.
197       // As an optimization, if the original high padding is zero, then the new
198       // high padding must also be zero.
199       Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
200       appendIndex(newHigh, newHighs, staticNewHighs);
201 
202       // Only unit stride supported.
203       newStrides.push_back(b.getIndexAttr(1));
204     }
205 
206     // The shape of the result can be obtained from the sizes passed in.
207     SmallVector<Value> dynDims;
208     SmallVector<int64_t> shape;
209     dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize);
210     RankedTensorType resultType =
211         RankedTensorType::get(shape, padOp.getResultType().getElementType());
212 
213     // Insert cast to ensure that types match. (May be folded away.)
214     auto castResult = [&](Value val) -> Operation * {
215       auto castOp = b.create<tensor::CastOp>(loc, resultType, val);
216       return castOp;
217     };
218 
219     // In cases where the original data source is unused: Emit a GenerateOp and
220     // do not generate a SliceOp. (The result shape of the SliceOp would
221     // have a dimension of size 0, the semantics of which is unclear.)
222     auto createGenerateOp = [&]() {
223       // Create GenerateOp.
224       auto generateOp = b.create<tensor::GenerateOp>(
225           loc, resultType, dynDims,
226           [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
227             builder.create<tensor::YieldOp>(gLoc, padValue);
228           });
229       return castResult(generateOp);
230     };
231 
232     // Emit a SliceOp and a PadOp. Should not be used in cases where
233     // the result shape of the new SliceOp has a zero dimension.
234     auto createPadTensorOfSubTensor = [&]() {
235       // Create pad_tensor(subtensor(x)).
236       auto newSliceOp = b.create<tensor::ExtractSliceOp>(
237           loc, padOp.source(), newOffsets, newLengths, newStrides);
238       auto newPadOp = b.create<PadOp>(loc, newSliceOp, staticNewLows,
239                                       staticNewHighs, newLows, newHighs);
240 
241       // Copy region to new PadOp.
242       BlockAndValueMapping bvm;
243       padOp.region().cloneInto(&newPadOp.getRegion(), bvm);
244 
245       // Cast result and return.
246       return castResult(newPadOp);
247     };
248 
249     // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known
250     // that the original data source x is not used.
251     if (hasZeroLen)
252       return {createGenerateOp()};
253 
254     // If there are dynamic dimensions: Generate an scf.if check to avoid
255     // creating SliceOps with result dimensions of size 0 at runtime.
256     if (dynHasZeroLenCond) {
257       auto result = b.create<scf::IfOp>(
258           loc, resultType, dynHasZeroLenCond,
259           /*thenBuilder=*/
260           [&](OpBuilder &b, Location loc) {
261             b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
262           },
263           /*elseBuilder=*/
264           [&](OpBuilder &b, Location loc) {
265             b.create<scf::YieldOp>(loc,
266                                    createPadTensorOfSubTensor()->getResult(0));
267           });
268       return {result};
269     }
270     return {createPadTensorOfSubTensor()};
271   }
272 };
273 
274 } // namespace
275 
276 void mlir::tensor::registerTilingOpInterfaceExternalModels(
277     DialectRegistry &registry) {
278   registry.addOpInterface<tensor::PadOp, PadOpTiling>();
279 }
280