1 //===- TensorTilingInterface.cpp - Tiling Interface  models *- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/SCF/IR/SCF.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Interfaces/TilingInterface.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 namespace {
21 
22 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
23 
getDestinationOperands__anon2bfc2d770111::PadOpTiling24   SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
25     ReifiedRankedShapedTypeDims reifiedShapes;
26     ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
27         dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
28     (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes);
29 
30     auto padOp = cast<PadOp>(op);
31     SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]);
32     Value initTensor = b.create<linalg::InitTensorOp>(
33         op->getLoc(), mixedSizes, padOp.getResultType().getElementType());
34     return {initTensor};
35   }
36 
getLoopIteratorTypes__anon2bfc2d770111::PadOpTiling37   SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
38     auto padOp = cast<PadOp>(op);
39     SmallVector<StringRef> iteratorTypes(padOp.getResultType().getRank(),
40                                          getParallelIteratorTypeName());
41     return iteratorTypes;
42   }
43 
getIterationDomain__anon2bfc2d770111::PadOpTiling44   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
45     ReifiedRankedShapedTypeDims reifiedShapes;
46     ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
47         dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
48     (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes);
49 
50     Location loc = op->getLoc();
51     Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
52     Value one = b.create<arith::ConstantIndexOp>(loc, 1);
53     // Initialize all the ranges to {zero, one, one}. All the `ub`s are
54     // overwritten.
55     SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
56     for (const auto &ub : enumerate(reifiedShapes[0]))
57       loopRanges[ub.index()].size = ub.value();
58     return loopRanges;
59   }
60 
61   SmallVector<Operation *>
getTiledImplementation__anon2bfc2d770111::PadOpTiling62   getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
63                          ArrayRef<OpFoldResult> offsets,
64                          ArrayRef<OpFoldResult> sizes,
65                          bool /*tileDestOperands*/) const {
66     Operation *result =
67         tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
68     if (!result)
69       return {};
70     return {result};
71   }
72 };
73 
74 } // namespace
75 
bubbleUpPadSlice(OpBuilder & b,tensor::PadOp padOp,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,bool generateZeroSliceGuard)76 Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
77                                     ArrayRef<OpFoldResult> offsets,
78                                     ArrayRef<OpFoldResult> sizes,
79                                     bool generateZeroSliceGuard) {
80   // Only constant padding value supported.
81   Value padValue = padOp.getConstantPaddingValue();
82   if (!padValue)
83     return nullptr;
84 
85   // Helper variables and functions for various arithmetic operations. These
86   // are used extensively for computing new offset/length and padding values.
87   Location loc = padOp->getLoc();
88   AffineExpr dim0, dim1;
89   bindDims(b.getContext(), dim0, dim1);
90   // Add two integers.
91   auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
92   auto add = [&](Value v1, Value v2) {
93     return b.createOrFold<AffineApplyOp>(loc, addMap, ValueRange{v1, v2});
94   };
95   // Subtract two integers.
96   auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
97   auto sub = [&](Value v1, Value v2) {
98     return b.createOrFold<AffineApplyOp>(loc, subMap, ValueRange{v1, v2});
99   };
100   // Take the minimum of two integers.
101   auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
102   auto min = [&](Value v1, Value v2) {
103     return b.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
104   };
105   // Take the maximum of two integers.
106   auto max = [&](Value v1, Value v2) {
107     return b.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
108   };
109   // Zero index-typed integer.
110   auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
111 
112   // Helper function for filling static/dynamic low/high padding indices
113   // vectors of PadOp.
114   auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
115                          SmallVector<int64_t> &staticIndices) {
116     if (auto constInt = getConstantIntValue(val)) {
117       staticIndices.push_back(*constInt);
118     } else {
119       staticIndices.push_back(ShapedType::kDynamicSize);
120       dynIndices.push_back(val);
121     }
122   };
123 
124   // Compute new offsets, lengths, low padding, high padding.
125   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
126   SmallVector<Value> newLows, newHighs;
127   SmallVector<int64_t> staticNewLows, staticNewHighs;
128   // Set to true if the original data source is not read at all.
129   bool hasZeroLen = false;
130   // Same as hasZeroLen, but for dynamic dimension sizes. This condition
131   // is true if the original data source turns out to be unused at runtime.
132   Value dynHasZeroLenCond;
133 
134   int64_t rank = padOp.getSourceType().getRank();
135   for (unsigned dim = 0; dim < rank; ++dim) {
136     auto low =
137         getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedLowPad()[dim]);
138     bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
139     auto high =
140         getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedHighPad()[dim]);
141     bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
142     auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]);
143     auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]);
144     auto srcSize = b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim);
145 
146     // The new amount of low padding is `low - offset`. Except for the case
147     // where none of the low padding is read. In that case, the new amount of
148     // low padding is zero.
149     //
150     // Optimization: If low = 0, then newLow = 0.
151     Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
152     appendIndex(newLow, newLows, staticNewLows);
153 
154     // Start reading the data from position `offset - low`. Since the original
155     // read may have started in the low padding zone, this value could be
156     // negative. Therefore, start reading from:
157     //
158     // max(offset - low, 0)
159     //
160     // The original read could also have started in the high padding zone.
161     // In that case, set the offset to the end of source tensor. The new
162     // ExtractSliceOp length will be zero in that case. (Effectively reading
163     // no data from the source.)
164     //
165     // Optimization: If low = 0, then the formula can be simplified.
166     Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
167                                 : min(offset, srcSize);
168     newOffsets.push_back(getAsOpFoldResult(newOffset));
169 
170     // The original ExtractSliceOp was reading until position `offset +
171     // length`. Therefore, the corresponding position within the source tensor
172     // is:
173     //
174     // offset + length - low
175     //
176     // In case the original ExtractSliceOp stopped reading within the low
177     // padding zone, this value can be negative. In that case, the end
178     // position of the read should be zero. (Similar to newOffset.)
179     //
180     // The original read could also have stopped in the high padding zone.
181     // In that case, set the end positition of the read should be the end of
182     // the source tensor. (Similar to newOffset.)
183     //
184     // endLoc = min(max(offset - low + length, 0), srcSize)
185     //
186     // The new ExtractSliceOp length is `endLoc - newOffset`.
187     //
188     // Optimization: If low = 0, then the formula can be simplified.
189     Value endLoc = hasLowPad
190                        ? min(max(add(sub(offset, low), length), zero), srcSize)
191                        : min(add(offset, length), srcSize);
192     Value newLength = sub(endLoc, newOffset);
193     newLengths.push_back(getAsOpFoldResult(newLength));
194 
195     // Check if newLength is zero. In that case, no SubTensorOp should be
196     // executed.
197     if (auto newLengthInt = getConstantIntValue(newLength)) {
198       hasZeroLen |= *newLengthInt == 0;
199     } else {
200       Value check = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
201                                             newLength, zero);
202       dynHasZeroLenCond =
203           dynHasZeroLenCond
204               ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
205               : check;
206     }
207 
208     // The amount of high padding is simply the number of elements remaining,
209     // so that the result has the same length as the original ExtractSliceOp.
210     // As an optimization, if the original high padding is zero, then the new
211     // high padding must also be zero.
212     Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
213     appendIndex(newHigh, newHighs, staticNewHighs);
214 
215     // Only unit stride supported.
216     newStrides.push_back(b.getIndexAttr(1));
217   }
218 
219   // The shape of the result can be obtained from the sizes passed in.
220   SmallVector<Value> dynDims;
221   SmallVector<int64_t> shape;
222   dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize);
223   RankedTensorType resultType =
224       RankedTensorType::get(shape, padOp.getResultType().getElementType());
225 
226   // Insert cast to ensure that types match. (May be folded away.)
227   auto castResult = [&](Value val) -> Operation * {
228     return b.create<tensor::CastOp>(loc, resultType, val);
229   };
230 
231   // In cases where the original data source is unused: Emit a GenerateOp and
232   // do not generate a SliceOp. (The result shape of the SliceOp would
233   // have a dimension of size 0, the semantics of which is unclear.)
234   auto createGenerateOp = [&]() {
235     // Create GenerateOp.
236     auto generateOp = b.create<tensor::GenerateOp>(
237         loc, resultType, dynDims,
238         [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
239           builder.create<tensor::YieldOp>(gLoc, padValue);
240         });
241     return castResult(generateOp);
242   };
243 
244   // Emit a SliceOp and a PadOp. Should not be used in cases where
245   // the result shape of the new SliceOp has a zero dimension.
246   auto createPadOfExtractSlice = [&]() {
247     // Create pad(extract_slice(x)).
248     auto newSliceOp = b.create<tensor::ExtractSliceOp>(
249         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
250     auto newPadOp = b.create<PadOp>(loc, newSliceOp, staticNewLows,
251                                     staticNewHighs, newLows, newHighs);
252 
253     // Copy region to new PadOp.
254     BlockAndValueMapping bvm;
255     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
256 
257     // Cast result and return.
258     return castResult(newPadOp);
259   };
260 
261   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
262   // the original data source x is not used.
263   if (hasZeroLen)
264     return createGenerateOp();
265 
266   // If there are dynamic dimensions: Generate an scf.if check to avoid
267   // creating SliceOps with result dimensions of size 0 at runtime.
268   if (generateZeroSliceGuard && dynHasZeroLenCond) {
269     auto result = b.create<scf::IfOp>(
270         loc, resultType, dynHasZeroLenCond,
271         /*thenBuilder=*/
272         [&](OpBuilder &b, Location loc) {
273           b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
274         },
275         /*elseBuilder=*/
276         [&](OpBuilder &b, Location loc) {
277           b.create<scf::YieldOp>(loc, createPadOfExtractSlice()->getResult(0));
278         });
279     return result;
280   }
281   return createPadOfExtractSlice();
282 }
283 
registerTilingOpInterfaceExternalModels(DialectRegistry & registry)284 void mlir::tensor::registerTilingOpInterfaceExternalModels(
285     DialectRegistry &registry) {
286   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
287     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
288   });
289 }
290