1fd0c6f53SAlexander Belyaev //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===//
2fd0c6f53SAlexander Belyaev //
3fd0c6f53SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fd0c6f53SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5fd0c6f53SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fd0c6f53SAlexander Belyaev //
7fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
8fd0c6f53SAlexander Belyaev
9fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Affine/IR/AffineOps.h"
11ead11072SRiver Riddle #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Linalg/IR/Linalg.h"
138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
14fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Tensor/IR/Tensor.h"
15fd0c6f53SAlexander Belyaev #include "mlir/Interfaces/TilingInterface.h"
16fd0c6f53SAlexander Belyaev
17fd0c6f53SAlexander Belyaev using namespace mlir;
18fd0c6f53SAlexander Belyaev using namespace mlir::tensor;
19fd0c6f53SAlexander Belyaev
20fd0c6f53SAlexander Belyaev namespace {
21fd0c6f53SAlexander Belyaev
22fd0c6f53SAlexander Belyaev struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
23fd0c6f53SAlexander Belyaev
getDestinationOperands__anon2bfc2d770111::PadOpTiling24fd0c6f53SAlexander Belyaev SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
25fd0c6f53SAlexander Belyaev ReifiedRankedShapedTypeDims reifiedShapes;
26fd0c6f53SAlexander Belyaev ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
27fd0c6f53SAlexander Belyaev dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
28fd0c6f53SAlexander Belyaev (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes);
29fd0c6f53SAlexander Belyaev
30fd0c6f53SAlexander Belyaev auto padOp = cast<PadOp>(op);
31fd0c6f53SAlexander Belyaev SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]);
32fd0c6f53SAlexander Belyaev Value initTensor = b.create<linalg::InitTensorOp>(
33fd0c6f53SAlexander Belyaev op->getLoc(), mixedSizes, padOp.getResultType().getElementType());
34fd0c6f53SAlexander Belyaev return {initTensor};
35fd0c6f53SAlexander Belyaev }
36fd0c6f53SAlexander Belyaev
getLoopIteratorTypes__anon2bfc2d770111::PadOpTiling37fd0c6f53SAlexander Belyaev SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
38fd0c6f53SAlexander Belyaev auto padOp = cast<PadOp>(op);
39fd0c6f53SAlexander Belyaev SmallVector<StringRef> iteratorTypes(padOp.getResultType().getRank(),
40fd0c6f53SAlexander Belyaev getParallelIteratorTypeName());
41fd0c6f53SAlexander Belyaev return iteratorTypes;
42fd0c6f53SAlexander Belyaev }
43fd0c6f53SAlexander Belyaev
getIterationDomain__anon2bfc2d770111::PadOpTiling44fd0c6f53SAlexander Belyaev SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
45fd0c6f53SAlexander Belyaev ReifiedRankedShapedTypeDims reifiedShapes;
46fd0c6f53SAlexander Belyaev ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
47fd0c6f53SAlexander Belyaev dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
48fd0c6f53SAlexander Belyaev (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes);
49fd0c6f53SAlexander Belyaev
50fd0c6f53SAlexander Belyaev Location loc = op->getLoc();
51fd0c6f53SAlexander Belyaev Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
52fd0c6f53SAlexander Belyaev Value one = b.create<arith::ConstantIndexOp>(loc, 1);
53fd0c6f53SAlexander Belyaev // Initialize all the ranges to {zero, one, one}. All the `ub`s are
54fd0c6f53SAlexander Belyaev // overwritten.
55fd0c6f53SAlexander Belyaev SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
56fd0c6f53SAlexander Belyaev for (const auto &ub : enumerate(reifiedShapes[0]))
57fd0c6f53SAlexander Belyaev loopRanges[ub.index()].size = ub.value();
58fd0c6f53SAlexander Belyaev return loopRanges;
59fd0c6f53SAlexander Belyaev }
60fd0c6f53SAlexander Belyaev
61fd0c6f53SAlexander Belyaev SmallVector<Operation *>
getTiledImplementation__anon2bfc2d770111::PadOpTiling62fd0c6f53SAlexander Belyaev getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
63fd0c6f53SAlexander Belyaev ArrayRef<OpFoldResult> offsets,
64fd0c6f53SAlexander Belyaev ArrayRef<OpFoldResult> sizes,
65fd0c6f53SAlexander Belyaev bool /*tileDestOperands*/) const {
660edb4127SLei Zhang Operation *result =
670edb4127SLei Zhang tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
680edb4127SLei Zhang if (!result)
690edb4127SLei Zhang return {};
700edb4127SLei Zhang return {result};
710edb4127SLei Zhang }
720edb4127SLei Zhang };
730edb4127SLei Zhang
740edb4127SLei Zhang } // namespace
750edb4127SLei Zhang
bubbleUpPadSlice(OpBuilder & b,tensor::PadOp padOp,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,bool generateZeroSliceGuard)760edb4127SLei Zhang Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
770edb4127SLei Zhang ArrayRef<OpFoldResult> offsets,
780edb4127SLei Zhang ArrayRef<OpFoldResult> sizes,
790edb4127SLei Zhang bool generateZeroSliceGuard) {
80fd0c6f53SAlexander Belyaev // Only constant padding value supported.
81fd0c6f53SAlexander Belyaev Value padValue = padOp.getConstantPaddingValue();
82fd0c6f53SAlexander Belyaev if (!padValue)
830edb4127SLei Zhang return nullptr;
84fd0c6f53SAlexander Belyaev
85fd0c6f53SAlexander Belyaev // Helper variables and functions for various arithmetic operations. These
86fd0c6f53SAlexander Belyaev // are used extensively for computing new offset/length and padding values.
870edb4127SLei Zhang Location loc = padOp->getLoc();
88fd0c6f53SAlexander Belyaev AffineExpr dim0, dim1;
89fd0c6f53SAlexander Belyaev bindDims(b.getContext(), dim0, dim1);
90fd0c6f53SAlexander Belyaev // Add two integers.
91fd0c6f53SAlexander Belyaev auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
92fd0c6f53SAlexander Belyaev auto add = [&](Value v1, Value v2) {
93fd0c6f53SAlexander Belyaev return b.createOrFold<AffineApplyOp>(loc, addMap, ValueRange{v1, v2});
94fd0c6f53SAlexander Belyaev };
95fd0c6f53SAlexander Belyaev // Subtract two integers.
96fd0c6f53SAlexander Belyaev auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
97fd0c6f53SAlexander Belyaev auto sub = [&](Value v1, Value v2) {
98fd0c6f53SAlexander Belyaev return b.createOrFold<AffineApplyOp>(loc, subMap, ValueRange{v1, v2});
99fd0c6f53SAlexander Belyaev };
100fd0c6f53SAlexander Belyaev // Take the minimum of two integers.
101fd0c6f53SAlexander Belyaev auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
102fd0c6f53SAlexander Belyaev auto min = [&](Value v1, Value v2) {
103fd0c6f53SAlexander Belyaev return b.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
104fd0c6f53SAlexander Belyaev };
105fd0c6f53SAlexander Belyaev // Take the maximum of two integers.
106fd0c6f53SAlexander Belyaev auto max = [&](Value v1, Value v2) {
107fd0c6f53SAlexander Belyaev return b.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
108fd0c6f53SAlexander Belyaev };
109fd0c6f53SAlexander Belyaev // Zero index-typed integer.
110fd0c6f53SAlexander Belyaev auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
111fd0c6f53SAlexander Belyaev
112fd0c6f53SAlexander Belyaev // Helper function for filling static/dynamic low/high padding indices
113fd0c6f53SAlexander Belyaev // vectors of PadOp.
114fd0c6f53SAlexander Belyaev auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
115fd0c6f53SAlexander Belyaev SmallVector<int64_t> &staticIndices) {
116fd0c6f53SAlexander Belyaev if (auto constInt = getConstantIntValue(val)) {
117fd0c6f53SAlexander Belyaev staticIndices.push_back(*constInt);
118fd0c6f53SAlexander Belyaev } else {
119fd0c6f53SAlexander Belyaev staticIndices.push_back(ShapedType::kDynamicSize);
120fd0c6f53SAlexander Belyaev dynIndices.push_back(val);
121fd0c6f53SAlexander Belyaev }
122fd0c6f53SAlexander Belyaev };
123fd0c6f53SAlexander Belyaev
124fd0c6f53SAlexander Belyaev // Compute new offsets, lengths, low padding, high padding.
125fd0c6f53SAlexander Belyaev SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
126fd0c6f53SAlexander Belyaev SmallVector<Value> newLows, newHighs;
127fd0c6f53SAlexander Belyaev SmallVector<int64_t> staticNewLows, staticNewHighs;
128fd0c6f53SAlexander Belyaev // Set to true if the original data source is not read at all.
129fd0c6f53SAlexander Belyaev bool hasZeroLen = false;
130fd0c6f53SAlexander Belyaev // Same as hasZeroLen, but for dynamic dimension sizes. This condition
131fd0c6f53SAlexander Belyaev // is true if the original data source turns out to be unused at runtime.
132fd0c6f53SAlexander Belyaev Value dynHasZeroLenCond;
133fd0c6f53SAlexander Belyaev
134fd0c6f53SAlexander Belyaev int64_t rank = padOp.getSourceType().getRank();
135fd0c6f53SAlexander Belyaev for (unsigned dim = 0; dim < rank; ++dim) {
136fd0c6f53SAlexander Belyaev auto low =
137fd0c6f53SAlexander Belyaev getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedLowPad()[dim]);
138fd0c6f53SAlexander Belyaev bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
139fd0c6f53SAlexander Belyaev auto high =
140fd0c6f53SAlexander Belyaev getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedHighPad()[dim]);
141fd0c6f53SAlexander Belyaev bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
142fd0c6f53SAlexander Belyaev auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]);
143fd0c6f53SAlexander Belyaev auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]);
144*04235d07SJacques Pienaar auto srcSize = b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim);
145fd0c6f53SAlexander Belyaev
146fd0c6f53SAlexander Belyaev // The new amount of low padding is `low - offset`. Except for the case
147fd0c6f53SAlexander Belyaev // where none of the low padding is read. In that case, the new amount of
148fd0c6f53SAlexander Belyaev // low padding is zero.
149fd0c6f53SAlexander Belyaev //
150fd0c6f53SAlexander Belyaev // Optimization: If low = 0, then newLow = 0.
151fd0c6f53SAlexander Belyaev Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
152fd0c6f53SAlexander Belyaev appendIndex(newLow, newLows, staticNewLows);
153fd0c6f53SAlexander Belyaev
154fd0c6f53SAlexander Belyaev // Start reading the data from position `offset - low`. Since the original
155fd0c6f53SAlexander Belyaev // read may have started in the low padding zone, this value could be
156fd0c6f53SAlexander Belyaev // negative. Therefore, start reading from:
157fd0c6f53SAlexander Belyaev //
158fd0c6f53SAlexander Belyaev // max(offset - low, 0)
159fd0c6f53SAlexander Belyaev //
160fd0c6f53SAlexander Belyaev // The original read could also have started in the high padding zone.
161fd0c6f53SAlexander Belyaev // In that case, set the offset to the end of source tensor. The new
162fd0c6f53SAlexander Belyaev // ExtractSliceOp length will be zero in that case. (Effectively reading
163fd0c6f53SAlexander Belyaev // no data from the source.)
164fd0c6f53SAlexander Belyaev //
165fd0c6f53SAlexander Belyaev // Optimization: If low = 0, then the formula can be simplified.
166fd0c6f53SAlexander Belyaev Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
167fd0c6f53SAlexander Belyaev : min(offset, srcSize);
168fd0c6f53SAlexander Belyaev newOffsets.push_back(getAsOpFoldResult(newOffset));
169fd0c6f53SAlexander Belyaev
170fd0c6f53SAlexander Belyaev // The original ExtractSliceOp was reading until position `offset +
171fd0c6f53SAlexander Belyaev // length`. Therefore, the corresponding position within the source tensor
172fd0c6f53SAlexander Belyaev // is:
173fd0c6f53SAlexander Belyaev //
174fd0c6f53SAlexander Belyaev // offset + length - low
175fd0c6f53SAlexander Belyaev //
176fd0c6f53SAlexander Belyaev // In case the original ExtractSliceOp stopped reading within the low
177fd0c6f53SAlexander Belyaev // padding zone, this value can be negative. In that case, the end
178fd0c6f53SAlexander Belyaev // position of the read should be zero. (Similar to newOffset.)
179fd0c6f53SAlexander Belyaev //
180fd0c6f53SAlexander Belyaev // The original read could also have stopped in the high padding zone.
181fd0c6f53SAlexander Belyaev // In that case, set the end positition of the read should be the end of
182fd0c6f53SAlexander Belyaev // the source tensor. (Similar to newOffset.)
183fd0c6f53SAlexander Belyaev //
184fd0c6f53SAlexander Belyaev // endLoc = min(max(offset - low + length, 0), srcSize)
185fd0c6f53SAlexander Belyaev //
186fd0c6f53SAlexander Belyaev // The new ExtractSliceOp length is `endLoc - newOffset`.
187fd0c6f53SAlexander Belyaev //
188fd0c6f53SAlexander Belyaev // Optimization: If low = 0, then the formula can be simplified.
1890edb4127SLei Zhang Value endLoc = hasLowPad
1900edb4127SLei Zhang ? min(max(add(sub(offset, low), length), zero), srcSize)
191fd0c6f53SAlexander Belyaev : min(add(offset, length), srcSize);
192fd0c6f53SAlexander Belyaev Value newLength = sub(endLoc, newOffset);
193fd0c6f53SAlexander Belyaev newLengths.push_back(getAsOpFoldResult(newLength));
194fd0c6f53SAlexander Belyaev
195fd0c6f53SAlexander Belyaev // Check if newLength is zero. In that case, no SubTensorOp should be
196fd0c6f53SAlexander Belyaev // executed.
197fd0c6f53SAlexander Belyaev if (auto newLengthInt = getConstantIntValue(newLength)) {
198fd0c6f53SAlexander Belyaev hasZeroLen |= *newLengthInt == 0;
199fd0c6f53SAlexander Belyaev } else {
200fd0c6f53SAlexander Belyaev Value check = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
201fd0c6f53SAlexander Belyaev newLength, zero);
202fd0c6f53SAlexander Belyaev dynHasZeroLenCond =
203fd0c6f53SAlexander Belyaev dynHasZeroLenCond
204fd0c6f53SAlexander Belyaev ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
205fd0c6f53SAlexander Belyaev : check;
206fd0c6f53SAlexander Belyaev }
207fd0c6f53SAlexander Belyaev
208fd0c6f53SAlexander Belyaev // The amount of high padding is simply the number of elements remaining,
209fd0c6f53SAlexander Belyaev // so that the result has the same length as the original ExtractSliceOp.
210fd0c6f53SAlexander Belyaev // As an optimization, if the original high padding is zero, then the new
211fd0c6f53SAlexander Belyaev // high padding must also be zero.
212fd0c6f53SAlexander Belyaev Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
213fd0c6f53SAlexander Belyaev appendIndex(newHigh, newHighs, staticNewHighs);
214fd0c6f53SAlexander Belyaev
215fd0c6f53SAlexander Belyaev // Only unit stride supported.
216fd0c6f53SAlexander Belyaev newStrides.push_back(b.getIndexAttr(1));
217fd0c6f53SAlexander Belyaev }
218fd0c6f53SAlexander Belyaev
219fd0c6f53SAlexander Belyaev // The shape of the result can be obtained from the sizes passed in.
220fd0c6f53SAlexander Belyaev SmallVector<Value> dynDims;
221fd0c6f53SAlexander Belyaev SmallVector<int64_t> shape;
222fd0c6f53SAlexander Belyaev dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize);
223fd0c6f53SAlexander Belyaev RankedTensorType resultType =
224fd0c6f53SAlexander Belyaev RankedTensorType::get(shape, padOp.getResultType().getElementType());
225fd0c6f53SAlexander Belyaev
226fd0c6f53SAlexander Belyaev // Insert cast to ensure that types match. (May be folded away.)
227fd0c6f53SAlexander Belyaev auto castResult = [&](Value val) -> Operation * {
2280edb4127SLei Zhang return b.create<tensor::CastOp>(loc, resultType, val);
229fd0c6f53SAlexander Belyaev };
230fd0c6f53SAlexander Belyaev
231fd0c6f53SAlexander Belyaev // In cases where the original data source is unused: Emit a GenerateOp and
232fd0c6f53SAlexander Belyaev // do not generate a SliceOp. (The result shape of the SliceOp would
233fd0c6f53SAlexander Belyaev // have a dimension of size 0, the semantics of which is unclear.)
234fd0c6f53SAlexander Belyaev auto createGenerateOp = [&]() {
235fd0c6f53SAlexander Belyaev // Create GenerateOp.
236fd0c6f53SAlexander Belyaev auto generateOp = b.create<tensor::GenerateOp>(
237fd0c6f53SAlexander Belyaev loc, resultType, dynDims,
238fd0c6f53SAlexander Belyaev [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
239fd0c6f53SAlexander Belyaev builder.create<tensor::YieldOp>(gLoc, padValue);
240fd0c6f53SAlexander Belyaev });
241fd0c6f53SAlexander Belyaev return castResult(generateOp);
242fd0c6f53SAlexander Belyaev };
243fd0c6f53SAlexander Belyaev
244fd0c6f53SAlexander Belyaev // Emit a SliceOp and a PadOp. Should not be used in cases where
245fd0c6f53SAlexander Belyaev // the result shape of the new SliceOp has a zero dimension.
2460edb4127SLei Zhang auto createPadOfExtractSlice = [&]() {
2470edb4127SLei Zhang // Create pad(extract_slice(x)).
248fd0c6f53SAlexander Belyaev auto newSliceOp = b.create<tensor::ExtractSliceOp>(
249*04235d07SJacques Pienaar loc, padOp.getSource(), newOffsets, newLengths, newStrides);
250fd0c6f53SAlexander Belyaev auto newPadOp = b.create<PadOp>(loc, newSliceOp, staticNewLows,
251fd0c6f53SAlexander Belyaev staticNewHighs, newLows, newHighs);
252fd0c6f53SAlexander Belyaev
253fd0c6f53SAlexander Belyaev // Copy region to new PadOp.
254fd0c6f53SAlexander Belyaev BlockAndValueMapping bvm;
255*04235d07SJacques Pienaar padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
256fd0c6f53SAlexander Belyaev
257fd0c6f53SAlexander Belyaev // Cast result and return.
258fd0c6f53SAlexander Belyaev return castResult(newPadOp);
259fd0c6f53SAlexander Belyaev };
260fd0c6f53SAlexander Belyaev
2610edb4127SLei Zhang // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
2620edb4127SLei Zhang // the original data source x is not used.
263fd0c6f53SAlexander Belyaev if (hasZeroLen)
2640edb4127SLei Zhang return createGenerateOp();
265fd0c6f53SAlexander Belyaev
266fd0c6f53SAlexander Belyaev // If there are dynamic dimensions: Generate an scf.if check to avoid
267fd0c6f53SAlexander Belyaev // creating SliceOps with result dimensions of size 0 at runtime.
2680edb4127SLei Zhang if (generateZeroSliceGuard && dynHasZeroLenCond) {
269fd0c6f53SAlexander Belyaev auto result = b.create<scf::IfOp>(
270fd0c6f53SAlexander Belyaev loc, resultType, dynHasZeroLenCond,
271fd0c6f53SAlexander Belyaev /*thenBuilder=*/
272fd0c6f53SAlexander Belyaev [&](OpBuilder &b, Location loc) {
273fd0c6f53SAlexander Belyaev b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
274fd0c6f53SAlexander Belyaev },
275fd0c6f53SAlexander Belyaev /*elseBuilder=*/
276fd0c6f53SAlexander Belyaev [&](OpBuilder &b, Location loc) {
2770edb4127SLei Zhang b.create<scf::YieldOp>(loc, createPadOfExtractSlice()->getResult(0));
278fd0c6f53SAlexander Belyaev });
2790edb4127SLei Zhang return result;
280fd0c6f53SAlexander Belyaev }
2810edb4127SLei Zhang return createPadOfExtractSlice();
282fd0c6f53SAlexander Belyaev }
283fd0c6f53SAlexander Belyaev
registerTilingOpInterfaceExternalModels(DialectRegistry & registry)284fd0c6f53SAlexander Belyaev void mlir::tensor::registerTilingOpInterfaceExternalModels(
285fd0c6f53SAlexander Belyaev DialectRegistry ®istry) {
28677eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
28777eee579SRiver Riddle tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
28877eee579SRiver Riddle });
289fd0c6f53SAlexander Belyaev }
290