1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Tiling pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <utility> 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Utils/IndexingUtils.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/AffineMap.h" 28 #include "mlir/Transforms/FoldUtils.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 31 #include "llvm/Support/CommandLine.h" 32 33 using namespace mlir; 34 using namespace mlir::linalg; 35 using namespace mlir::scf; 36 37 #define DEBUG_TYPE "linalg-tiling" 38 39 static bool isZero(Value v) { 40 if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>()) 41 return cst.value() == 0; 42 return false; 43 } 44 45 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 46 mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, 47 ValueRange allShapeSizes, 48 ValueRange allTileSizes) { 49 assert(allTileSizes.size() == map.getNumResults()); 50 // Apply `map` to get shape sizes in loop order. 51 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 52 SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end()); 53 54 // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 55 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 56 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 57 if (isZero(tileSizes[idx - zerosCount])) { 58 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 59 tileSizes.erase(tileSizes.begin() + idx - zerosCount); 60 ++zerosCount; 61 continue; 62 } 63 loopIndexToRangeIndex[idx] = idx - zerosCount; 64 } 65 66 // Create a new range with the applied tile sizes. 67 SmallVector<Range, 4> res; 68 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 69 res.push_back(Range{b.create<arith::ConstantIndexOp>(loc, 0), 70 shapeSizes[idx], tileSizes[idx]}); 71 return std::make_tuple(res, loopIndexToRangeIndex); 72 } 73 74 void mlir::linalg::transformIndexOps( 75 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 76 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 77 SmallVector<Value> allIvs(op.getNumLoops(), nullptr); 78 for (auto &en : enumerate(allIvs)) { 79 auto rangeIndex = loopIndexToRangeIndex.find(en.index()); 80 if (rangeIndex == loopIndexToRangeIndex.end()) 81 continue; 82 en.value() = ivs[rangeIndex->second]; 83 } 84 offsetIndices(b, op, allIvs); 85 } 86 87 /// Asserts that the given index-typed value is strictly positive. If the value 88 /// is an attribute, asserts at compile time, otherwise emits an assertion 89 /// checked at runtime. 90 static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, 91 OpFoldResult value) { 92 if (auto attr = value.dyn_cast<Attribute>()) { 93 assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() && 94 "expected strictly positive tile size and divisor"); 95 return; 96 } 97 98 Value zero = b.create<arith::ConstantIndexOp>(0); 99 Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, 100 value.get<Value>(), zero); 101 b.create<cf::AssertOp>( 102 condition, 103 b.getStringAttr("expected strictly positive tile size and divisor")); 104 } 105 106 FailureOr<MultiSizeSpecification> 107 mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op, 108 unsigned dimension, OpFoldResult targetSize, 109 OpFoldResult divisor, bool emitAssertions) { 110 // Bail out on dimension overflow. 111 if (dimension >= op.getNumLoops()) 112 return failure(); 113 114 // The code below works only on values. 115 ImplicitLocOpBuilder b(op.getLoc(), builder); 116 if (emitAssertions) { 117 emitIsPositiveIndexAssertion(b, targetSize); 118 emitIsPositiveIndexAssertion(b, divisor); 119 } 120 Value targetSizeValue = materializeOpFoldResult(b, targetSize); 121 Value divisorValue = materializeOpFoldResult(b, divisor); 122 123 // Find the trip count of the iteration space dimension for which the tile 124 // sizes are computed. 125 // TODO: update createFlatListOfOperandDims to return OpFoldResults and avoid 126 // littering by useless constant materialization. 127 SmallVector<Value, 4> allShapes = 128 op.createFlatListOfOperandDims(b, b.getLoc()); 129 AffineMap shapesToLoops = op.getShapesToLoopsMap(); 130 SmallVector<Value, 4> loopRanges = 131 applyMapToValues(b, op.getLoc(), shapesToLoops, allShapes); 132 Value tripCount = loopRanges[dimension]; 133 134 // Compute the tile sizes and the respective numbers of tiles. 135 AffineExpr s0 = b.getAffineSymbolExpr(0); 136 AffineExpr s1 = b.getAffineSymbolExpr(1); 137 AffineExpr s2 = b.getAffineSymbolExpr(2); 138 auto apply = [&](AffineExpr expr, ValueRange values) -> Value { 139 return makeComposedAffineApply(b, b.getLoc(), expr, values); 140 }; 141 Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue}); 142 Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue}); 143 Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t}); 144 Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue}); 145 Value v = apply(s0 % s1, {a, d}); 146 Value u = apply(s0 - s1, {d, v}); 147 148 MultiSizeSpecification spec; 149 spec.lowTileSize = s; 150 spec.highTileSize = apply(s0 + s1, {s, divisorValue}); 151 spec.lowTripCount = u; 152 spec.highTripCount = v; 153 154 // If requested, emit the check that the tile sizes are computed correctly. 155 // For example, for iteration dimension size of 15 and the target size 8 it is 156 // impossible to find two tile sizes both divisible by 8 that fully cover the 157 // original space dimension. 158 if (emitAssertions) { 159 AffineExpr s3 = builder.getAffineSymbolExpr(3); 160 Value coveredSize = 161 apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount, 162 spec.highTileSize, spec.highTripCount}); 163 Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 164 coveredSize, tripCount); 165 b.create<cf::AssertOp>( 166 equals, builder.getStringAttr( 167 "could not compute dynamic multi-size tile shapes")); 168 } 169 170 return spec; 171 } 172 173 /// Given a `subsetExtractOp`, a `source` and a `dest`, create a new 174 /// `ParallelInsertSlice` op of `source` into `dest` at the same subset location 175 /// as `subsetExtractOp`. 176 static void 177 createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc, 178 tensor::ExtractSliceOp subsetExtractOp, 179 Value source, Value dest) { 180 b.create<tensor::ParallelInsertSliceOp>( 181 loc, source, dest, subsetExtractOp.getMixedOffsets(), 182 subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides()); 183 } 184 185 /// Build an `affine_max` of all the `vals`. 186 static OpFoldResult buildMax(OpBuilder &b, Location loc, 187 ArrayRef<OpFoldResult> vals) { 188 SmallVector<Value> args = getValueOrCreateConstantIndexOp(b, loc, vals); 189 return b.createOrFold<AffineMaxOp>( 190 loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), 191 args); 192 } 193 194 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less 195 /// than `iterationSize`. 196 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, 197 OpFoldResult numThreads, 198 OpFoldResult iterationSize) { 199 Optional<int64_t> tileSizeConst = getConstantIntValue(tileSize); 200 Optional<int64_t> numThreadsConst = getConstantIntValue(numThreads); 201 Optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize); 202 if (!tileSizeConst || !numThreadsConst || !iterSizeConst) 203 return false; 204 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; 205 } 206 207 /// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The 208 /// tiling is specified by the number of tiles/threads `numThreads` and the 209 /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is 210 /// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i], 211 /// numThreads[i])`. If non-empty, the `threadDimMapping` is added as an 212 /// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate 213 /// that the dimension is not tiled, and can be thought of as tiling by the full 214 /// size of data. 215 /// It is the user's responsibility to ensure that `numThreads` is a valid 216 /// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the 217 /// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will 218 /// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. 219 static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl( 220 RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads, 221 Optional<ArrayRef<OpFoldResult>> nominalTileSizes, 222 ArrayRef<int64_t> threadDimMapping, bool omitTileOffsetBoundsCheck) { 223 Location loc = op->getLoc(); 224 OpBuilder::InsertionGuard g(b); 225 SmallVector<Range> loopRanges = op.getIterationDomain(b); 226 if (loopRanges.empty()) 227 return op->emitOpError("expected non-empty loop ranges"); 228 auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; 229 if (llvm::any_of(loopRanges, hasStrideOne)) 230 return op->emitOpError("only stride-1 supported atm"); 231 // TODO: support `getTiledImplementation` with >1 produced tiled ops. 232 auto destOperands = op.getDestinationOperands(b); 233 if (destOperands.size() != 1) 234 return op->emitOpError("only single dest operand supported atm"); 235 236 SmallVector<OpFoldResult> nonZeroNumThreads = 237 llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { 238 return !isConstantIntValue(ofr, 0); 239 })); 240 SmallVector<Value> materializedNonZeroNumThreads = 241 llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) { 242 ImplicitLocOpBuilder ilocb(loc, b); 243 return materializeOpFoldResult(ilocb, ofr); 244 })); 245 246 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 247 Operation *tiledOp = nullptr; 248 249 // Create the ForeachThreadOp. We don't use the lambda body-builder 250 // version because we require the use of RewriterBase in the body, so we 251 // manually move the insertion point to the body below. 252 scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>( 253 loc, op->getResultTypes(), ValueRange(materializedNonZeroNumThreads), 254 threadDimMapping); 255 256 // Fill out the ForeachThreadOp body. 257 b.setInsertionPointToStart(foreachThreadOp.getBody(0)); 258 ValueRange threadIds = foreachThreadOp.getThreadIndices(); 259 int64_t nLoops = loopRanges.size(); 260 SmallVector<OpFoldResult> tiledOffsets, tiledSizes; 261 tiledOffsets.reserve(nLoops); 262 tiledSizes.reserve(nLoops); 263 for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { 264 bool overflow = loopIdx >= numThreads.size(); 265 bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); 266 // Degenerate case: take the whole domain. 267 if (overflow || isZero) { 268 tiledOffsets.push_back(loopRanges[loopIdx].offset); 269 tiledSizes.push_back(loopRanges[loopIdx].size); 270 continue; 271 } 272 273 // Tiled case: compute the offset and size. 274 AffineExpr i, j, M, N, O; 275 bindDims(b.getContext(), i, j); 276 bindSymbols(b.getContext(), M, N, O); 277 Value size = loopRanges[loopIdx].size; 278 Value offset = loopRanges[loopIdx].offset; 279 Value threadId = threadIds[threadIdIdx]; 280 // Symbolic fixed max size per thread. 281 // TODO: floor + 0/1 depending on case for better load-balancing. 282 OpFoldResult tileSizePerThread = 283 nominalTileSizes.has_value() 284 ? (*nominalTileSizes)[loopIdx] 285 : makeComposedFoldedAffineApply( 286 b, loc, M.ceilDiv(N), 287 ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]}); 288 289 // Dynamic offset shifted by threadId * maxSizePerThread. 290 OpFoldResult offsetPerThread = makeComposedFoldedAffineApply( 291 b, loc, i + j * M, {offset, threadId, tileSizePerThread}); 292 // Dynamic upper-bound depending on the threadId. 293 OpFoldResult residualTileSize = makeComposedFoldedAffineApply( 294 b, loc, i + j * M - N, 295 {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); 296 if (!isConstantIntValue(residualTileSize, 0)) { 297 OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( 298 b, loc, -i + M, {offsetPerThread, size}); 299 tileSizePerThread = makeComposedFoldedAffineMin( 300 b, loc, AffineMap::getMultiDimIdentityMap(2, b.getContext()), 301 ArrayRef<OpFoldResult>{sizeMinusOffsetPerThread, tileSizePerThread}); 302 } 303 304 tiledOffsets.push_back(offsetPerThread); 305 // TODO: if tileSizePerThread <= 0 early exit. 306 if (!omitTileOffsetBoundsCheck && 307 !canOmitTileOffsetInBoundsCheck(tileSizePerThread, 308 nonZeroNumThreads[threadIdIdx], size)) 309 tileSizePerThread = buildMax(b, loc, {zero, tileSizePerThread}); 310 311 tiledSizes.push_back(tileSizePerThread); 312 ++threadIdIdx; 313 } 314 315 SmallVector<Operation *> tiledOps = 316 op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes, 317 /*tileDestOperands=*/true); 318 assert(tiledOps.size() == 1 && "expected a single produced tiled op"); 319 tiledOp = tiledOps.front(); 320 321 auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp); 322 assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface"); 323 324 auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b); 325 326 // Create terminator with parallel subset insert operations. 327 b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody()); 328 for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(), 329 destOperands)) { 330 createMatchingParallelSubsetInsertOp( 331 b, loc, cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()), 332 std::get<1>(it), std::get<2>(it)); 333 } 334 return ForeachThreadTilingResult{foreachThreadOp, tiledOp}; 335 } 336 337 FailureOr<ForeachThreadTilingResult> 338 linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op, 339 ArrayRef<OpFoldResult> numThreads, 340 ArrayRef<int64_t> threadDimMapping) { 341 return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None, 342 threadDimMapping, 343 /*omitTileOffsetBoundsCheck=*/false); 344 } 345 346 FailureOr<ForeachThreadTilingResult> 347 linalg::tileToForeachThreadOpUsingTileSizes( 348 RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> tileSizes, 349 ArrayRef<int64_t> threadDimMapping) { 350 SmallVector<Range> loopRanges = op.getIterationDomain(b); 351 unsigned nLoops = loopRanges.size(); 352 SmallVector<OpFoldResult> numThreads; 353 numThreads.reserve(nLoops); 354 AffineExpr s0, s1; 355 bindSymbols(b.getContext(), s0, s1); 356 AffineExpr divExpr = s0.ceilDiv(s1); 357 for (const auto &it : llvm::zip(tileSizes, loopRanges)) { 358 OpFoldResult numTiles = std::get<0>(it); 359 if (!isConstantIntValue(numTiles, 0)) 360 numTiles = makeComposedFoldedAffineApply( 361 b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)}); 362 numThreads.push_back(numTiles); 363 } 364 return tileToForeachThreadOpImpl(b, op, numThreads, 365 /*nominalTileSizes=*/tileSizes, 366 threadDimMapping, 367 /*omitTileOffsetBoundsCheck=*/true); 368 } 369 370 // Insert a tile `source` into the destination tensor `dest`. The position at 371 // which the tile is inserted (as well as size of tile) is taken from a given 372 // ExtractSliceOp `sliceOp`. 373 static Value insertSliceIntoTensor(RewriterBase &b, Location loc, 374 tensor::ExtractSliceOp sliceOp, Value source, 375 Value dest) { 376 return b.create<tensor::InsertSliceOp>( 377 loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(), 378 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), 379 sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); 380 } 381 382 template <typename LoopTy> 383 static FailureOr<TiledLinalgOp> 384 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes, 385 const LinalgTilingOptions &options) { 386 auto nLoops = op.getNumLoops(); 387 // Initial tile sizes may be too big, only take the first nLoops. 388 tileSizes = tileSizes.take_front(nLoops); 389 390 if (llvm::all_of(tileSizes, isZero)) { 391 TiledLinalgOp tiledOp; 392 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation())); 393 tiledOp.tensorResults.assign(tiledOp.op->result_begin(), 394 tiledOp.op->result_end()); 395 return tiledOp; 396 } 397 398 // 1. Build the tiled loop ranges. 399 auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); 400 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 401 if (!shapeSizesToLoopsMap) 402 return failure(); 403 404 SmallVector<Range, 4> loopRanges; 405 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 406 std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( 407 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 408 409 SmallVector<Attribute, 4> iteratorTypes; 410 for (const auto &attr : 411 enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) { 412 if (loopIndexToRangeIndex.count(attr.index())) 413 iteratorTypes.push_back(attr.value()); 414 } 415 // If interchangeVector is empty, use the identity. Build the permutation map 416 // otherwise. 417 auto invPermutationMap = 418 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 419 if (!options.interchangeVector.empty()) { 420 // Based on the pruned iterations (due to zero tile size), recompute the 421 // interchange vector. 422 SmallVector<unsigned, 4> interchangeVector; 423 interchangeVector.reserve(options.interchangeVector.size()); 424 for (auto pos : options.interchangeVector) { 425 auto it = loopIndexToRangeIndex.find(pos); 426 if (it == loopIndexToRangeIndex.end()) 427 continue; 428 interchangeVector.push_back(it->second); 429 } 430 // Interchange vector is guaranteed to be a permutation, 431 // `inversePermutation` must succeed. 432 invPermutationMap = inversePermutation( 433 AffineMap::getPermutationMap(interchangeVector, b.getContext())); 434 assert(invPermutationMap); 435 SmallVector<int64_t> permutation(interchangeVector.begin(), 436 interchangeVector.end()); 437 applyPermutationToVector(loopRanges, permutation); 438 applyPermutationToVector(iteratorTypes, permutation); 439 } 440 441 // 2. Create the tiled loops. 442 LinalgOp res = op; 443 SmallVector<Value, 4> ivs, tensorResults; 444 auto tiledLoopBodyBuilder = 445 [&](OpBuilder &builder, Location loc, ValueRange localIvs, 446 ValueRange operandValuesToUse) -> scf::ValueVector { 447 ivs.assign(localIvs.begin(), localIvs.end()); 448 449 // When an `interchangeVector` is present, it has been applied to the 450 // loop ranges and the iterator types. Apply its inverse to the 451 // resulting loop `ivs` to match the op definition. 452 SmallVector<Value, 4> interchangedIvs; 453 if (!options.interchangeVector.empty()) 454 interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs); 455 else 456 interchangedIvs.assign(ivs.begin(), ivs.end()); 457 458 // Tile the `operandValuesToUse` that either match the `op` operands 459 // themselves or the tile loop arguments forwarding them. 460 assert(operandValuesToUse.size() == 461 static_cast<size_t>(op.getNumInputsAndOutputs()) && 462 "expect the number of operands and inputs and outputs to match"); 463 SmallVector<Value> valuesToTile = operandValuesToUse; 464 auto sizeBounds = 465 applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); 466 SmallVector<Value, 4> tiledOperands = 467 makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes, 468 sizeBounds, /*omitPartialTileCheck=*/false); 469 470 SmallVector<Type> resultTensorTypes = 471 getTensorOutputTypes(op, tiledOperands); 472 res = op.clone(b, loc, resultTensorTypes, tiledOperands); 473 tensorResults = 474 insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); 475 return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 476 }; 477 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, 478 tiledLoopBodyBuilder, options.distribution, 479 options.distributionTypes); 480 481 // 3. Transform IndexOp results w.r.t. the tiling. 482 transformIndexOps(b, res, ivs, loopIndexToRangeIndex); 483 484 // 4. Gather the newly created loops and return them with the new op. 485 SmallVector<Operation *, 8> loops; 486 loops.reserve(ivs.size()); 487 for (auto iv : ivs) { 488 if (iv.isa<BlockArgument>()) { 489 loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp()); 490 assert(loops.back() && "no owner found for induction variable!"); 491 } else { 492 // TODO: Instead of doing this, try to recover the ops used instead of the 493 // loop. 494 loops.push_back(nullptr); 495 } 496 } 497 498 // 5. Get the tensor results from the outermost loop if available. Otherwise 499 // use the previously captured `tensorResults`. 500 Operation *outermostLoop = nullptr; 501 for (Operation *loop : loops) 502 if ((outermostLoop = loop)) 503 break; 504 505 return TiledLinalgOp{ 506 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 507 } 508 509 template <typename LoopTy> 510 FailureOr<TiledLinalgOp> static tileLinalgOpImpl( 511 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { 512 OpBuilder::InsertionGuard g(b); 513 b.setInsertionPoint(op); 514 515 if (!options.tileSizeComputationFunction) 516 return failure(); 517 518 // Enforce the convention that "tiling by zero" skips tiling a particular 519 // dimension. This convention is significantly simpler to handle instead of 520 // adjusting affine maps to account for missing dimensions. 521 auto nLoops = op.getNumLoops(); 522 SmallVector<Value, 4> tileSizeVector = 523 options.tileSizeComputationFunction(b, op); 524 if (tileSizeVector.size() < nLoops) { 525 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); 526 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 527 } 528 529 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 530 } 531 532 FailureOr<TiledLinalgOp> 533 mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op, 534 const LinalgTilingOptions &options) { 535 switch (options.loopType) { 536 case LinalgTilingLoopType::Loops: 537 return tileLinalgOpImpl<scf::ForOp>(b, op, options); 538 case LinalgTilingLoopType::ParallelLoops: 539 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 540 default:; 541 } 542 return failure(); 543 } 544 545 /// Generate a loop nest around a given tensor::PadOp (for tiling). `newPadOp` 546 /// and `loopNest` are output parameters that return the new (tiled) 547 /// tensor::PadOp and the loop nest. 548 static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op, 549 tensor::PadOp &newPadOp, LoopNest &loopNest, 550 const LinalgTilingOptions &options) { 551 Location loc = op.getLoc(); 552 OpBuilder::InsertionGuard g(builder); 553 builder.setInsertionPoint(op); 554 555 // Clone tensor::PadOp so that the existing op can be replaced more easily. 556 newPadOp = cast<tensor::PadOp>(builder.clone(*op.getOperation())); 557 // Get rank and tile sizes. 558 int64_t rank = op.getResultType().getRank(); 559 SmallVector<Value> tileSizes = 560 options.tileSizeComputationFunction(builder, op); 561 // Normalize untiled padding dimensions to 0. 562 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); 563 tileSizes.append(rank - tileSizes.size(), zero); 564 // Compute lower and upper bounds of the loop nest. 565 TilingInterface tilingInterface = 566 dyn_cast<TilingInterface>(op.getOperation()); 567 SmallVector<Range> ranges = tilingInterface.getIterationDomain(builder); 568 SmallVector<Value> lbs, dims, allDims, steps; 569 for (int64_t i = 0; i < rank; ++i) { 570 allDims.push_back(ranges[i].size); 571 if (!isZero(tileSizes[i])) { 572 lbs.push_back(ranges[i].offset); 573 dims.push_back(ranges[i].size); 574 steps.push_back(tileSizes[i]); 575 } 576 } 577 // Generate loop nest: One loop per dimension. 578 SmallVector<Value> destOperand = 579 tilingInterface.getDestinationOperands(builder); 580 loopNest = mlir::scf::buildLoopNest( 581 builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand), 582 [&](OpBuilder &b, Location loc, ValueRange localIvs, 583 ValueRange iterArgs) -> scf::ValueVector { 584 // Compute offsets and sizes of ExtractSliceOp. 585 SmallVector<Value> offsets = 586 computeTileOffsets(b, loc, localIvs, tileSizes); 587 SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims); 588 // Create ExtractSliceOp: Extract a tile from the tensor::PadOp. 589 // Note: The tensor::PadOp is located outside of the loop nest. It is 590 // later moved inside by ExtractSliceOfPadTensorSwapPattern. 591 auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext()); 592 Value tiledOutput = makeTiledShape( 593 b, loc, newPadOp->getResult(0), tileSizes, map, offsets, allDims, 594 sizes, /*omitPartialTileCheck=*/false); 595 auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>(); 596 assert(sliceOp && "expected ExtractSliceOp"); 597 // Insert the tile into the output tensor. 598 // TODO: Propagate RewriterBase everywhere. 599 IRRewriter rewriter(b); 600 Value yieldValue = 601 insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]); 602 return scf::ValueVector({yieldValue}); 603 }); 604 return success(); 605 } 606 607 namespace { 608 struct PadOpTilingPattern : public OpRewritePattern<tensor::PadOp> { 609 PadOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt) 610 : OpRewritePattern<tensor::PadOp>(ctx), options(std::move(opt)) {} 611 612 LogicalResult matchAndRewrite(tensor::PadOp op, 613 PatternRewriter &rewriter) const override { 614 if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker)) 615 return failure(); 616 tensor::PadOp newPadOp; 617 LoopNest loopNest; 618 if (failed(tilePadOp(rewriter, op, newPadOp, loopNest, options))) 619 return failure(); 620 newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker, 621 rewriter.getUnitAttr()); 622 // Replace all uses of the original tensor::PadOp. 623 rewriter.replaceOp(op, loopNest.getResults()[0]); 624 return success(); 625 } 626 627 LinalgTilingOptions options; 628 }; 629 } // namespace 630 631 namespace { 632 /// Helper classes for type list expansion. 633 template <typename... OpTypes> 634 class CanonicalizationPatternList; 635 636 template <> 637 class CanonicalizationPatternList<> { 638 public: 639 static void insert(RewritePatternSet &patterns) {} 640 }; 641 642 template <typename OpTy, typename... OpTypes> 643 class CanonicalizationPatternList<OpTy, OpTypes...> { 644 public: 645 static void insert(RewritePatternSet &patterns) { 646 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 647 CanonicalizationPatternList<OpTypes...>::insert(patterns); 648 } 649 }; 650 } // namespace 651 652 RewritePatternSet 653 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 654 RewritePatternSet patterns(ctx); 655 populateLinalgTilingCanonicalizationPatterns(patterns); 656 return patterns; 657 } 658 659 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 660 RewritePatternSet &patterns) { 661 auto *ctx = patterns.getContext(); 662 AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 663 AffineForOp::getCanonicalizationPatterns(patterns, ctx); 664 AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 665 AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 666 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 667 668 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 669 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 670 671 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 672 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 673 674 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 675 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); 676 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); 677 678 InitTensorOp::getCanonicalizationPatterns(patterns, ctx); 679 tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); 680 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns); 681 682 CanonicalizationPatternList< 683 #define GET_OP_LIST 684 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 685 >::insert(patterns); 686 } 687 688 /// Populate the given list with patterns that apply Linalg tiling. 689 static void insertTilingPatterns(RewritePatternSet &patterns, 690 const LinalgTilingOptions &options) { 691 auto *ctx = patterns.getContext(); 692 LinalgTransformationFilter f(ArrayRef<StringAttr>{}, 693 StringAttr::get(ctx, "tiled")); 694 TilingPatterns<GenericOp, 695 #define GET_OP_LIST 696 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 697 >::insert(patterns, options, f); 698 patterns.add<PadOpTilingPattern>(ctx, options); 699 } 700 701 void mlir::linalg::populatePadTensorTilingPatterns( 702 RewritePatternSet &patterns, const LinalgTilingOptions &options) { 703 auto *ctx = patterns.getContext(); 704 patterns.add<PadOpTilingPattern>(ctx, options); 705 } 706 707 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { 708 MLIRContext *ctx = funcOp.getContext(); 709 RewritePatternSet patterns(ctx); 710 patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext()); 711 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 712 (void)applyPatternsAndFoldGreedily( 713 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 714 } 715 716 namespace { 717 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> { 718 LinalgTilingPass() = default; 719 LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType) { 720 this->tileSizes = tileSizes; 721 this->loopType = ""; 722 this->loopTypeEnum = loopType; 723 } 724 725 void runOnOperation() override { 726 func::FuncOp funcOp = getOperation(); 727 LinalgTilingLoopType type = 728 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 729 .Case("for", LinalgTilingLoopType::Loops) 730 .Case("affine", LinalgTilingLoopType::AffineLoops) 731 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 732 .Default(loopTypeEnum); 733 auto options = 734 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(type); 735 MLIRContext *ctx = funcOp.getContext(); 736 RewritePatternSet patterns(ctx); 737 insertTilingPatterns(patterns, options); 738 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 739 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 740 (void)applyPatternsAndFoldGreedily( 741 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 742 // Drop the marker. 743 funcOp.walk([](LinalgOp op) { 744 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 745 }); 746 747 // Apply swap pattern after generating loop nest and running 748 // canonicalizations. 749 applyExtractSliceOfPadTensorSwapPattern(funcOp); 750 } 751 752 LinalgTilingLoopType loopTypeEnum; 753 }; 754 755 } // namespace 756 757 std::unique_ptr<OperationPass<func::FuncOp>> 758 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes, 759 linalg::LinalgTilingLoopType loopType) { 760 return std::make_unique<LinalgTilingPass>(tileSizes, loopType); 761 } 762