1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Tiling pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <utility> 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 23 #include "mlir/Dialect/Tensor/IR/Tensor.h" 24 #include "mlir/Dialect/Utils/IndexingUtils.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/Transforms/FoldUtils.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 30 #include "llvm/Support/CommandLine.h" 31 32 using namespace mlir; 33 using namespace mlir::linalg; 34 using namespace mlir::scf; 35 36 #define DEBUG_TYPE "linalg-tiling" 37 38 static bool isZero(Value v) { 39 if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>()) 40 return cst.value() == 0; 41 return false; 42 } 43 44 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 45 mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, 46 ValueRange allShapeSizes, 47 ValueRange allTileSizes) { 48 assert(allTileSizes.size() == map.getNumResults()); 49 // Apply `map` to get shape sizes in loop order. 50 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 51 SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end()); 52 53 // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 54 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 55 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 56 if (isZero(tileSizes[idx - zerosCount])) { 57 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 58 tileSizes.erase(tileSizes.begin() + idx - zerosCount); 59 ++zerosCount; 60 continue; 61 } 62 loopIndexToRangeIndex[idx] = idx - zerosCount; 63 } 64 65 // Create a new range with the applied tile sizes. 66 SmallVector<Range, 4> res; 67 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 68 res.push_back(Range{b.create<arith::ConstantIndexOp>(loc, 0), 69 shapeSizes[idx], tileSizes[idx]}); 70 return std::make_tuple(res, loopIndexToRangeIndex); 71 } 72 73 void mlir::linalg::transformIndexOps( 74 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 75 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 76 SmallVector<Value> allIvs(op.getNumLoops(), nullptr); 77 for (auto &en : enumerate(allIvs)) { 78 auto rangeIndex = loopIndexToRangeIndex.find(en.index()); 79 if (rangeIndex == loopIndexToRangeIndex.end()) 80 continue; 81 en.value() = ivs[rangeIndex->second]; 82 } 83 offsetIndices(b, op, allIvs); 84 } 85 86 /// Asserts that the given index-typed value is strictly positive. If the value 87 /// is an attribute, asserts at compile time, otherwise emits an assertion 88 /// checked at runtime. 89 static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, 90 OpFoldResult value) { 91 if (auto attr = value.dyn_cast<Attribute>()) { 92 assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() && 93 "expected strictly positive tile size and divisor"); 94 return; 95 } 96 97 Value zero = b.create<arith::ConstantIndexOp>(0); 98 Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, 99 value.get<Value>(), zero); 100 b.create<cf::AssertOp>( 101 condition, 102 b.getStringAttr("expected strictly positive tile size and divisor")); 103 } 104 105 FailureOr<MultiSizeSpecification> 106 mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op, 107 unsigned dimension, OpFoldResult targetSize, 108 OpFoldResult divisor, bool emitAssertions) { 109 // Bail out on dimension overflow. 110 if (dimension >= op.getNumLoops()) 111 return failure(); 112 113 // The code below works only on values. 114 ImplicitLocOpBuilder b(op.getLoc(), builder); 115 if (emitAssertions) { 116 emitIsPositiveIndexAssertion(b, targetSize); 117 emitIsPositiveIndexAssertion(b, divisor); 118 } 119 Value targetSizeValue = materializeOpFoldResult(b, targetSize); 120 Value divisorValue = materializeOpFoldResult(b, divisor); 121 122 // Find the trip count of the iteration space dimension for which the tile 123 // sizes are computed. 124 // TODO: update createFlatListOfOperandDims to return OpFoldResults and avoid 125 // littering by useless constant materialization. 126 SmallVector<Value, 4> allShapes = 127 op.createFlatListOfOperandDims(b, b.getLoc()); 128 AffineMap shapesToLoops = op.getShapesToLoopsMap(); 129 SmallVector<Value, 4> loopRanges = 130 applyMapToValues(b, op.getLoc(), shapesToLoops, allShapes); 131 Value tripCount = loopRanges[dimension]; 132 133 // Compute the tile sizes and the respective numbers of tiles. 134 AffineExpr s0 = b.getAffineSymbolExpr(0); 135 AffineExpr s1 = b.getAffineSymbolExpr(1); 136 AffineExpr s2 = b.getAffineSymbolExpr(2); 137 auto apply = [&](AffineExpr expr, ValueRange values) -> Value { 138 return makeComposedAffineApply(b, b.getLoc(), expr, values); 139 }; 140 Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue}); 141 Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue}); 142 Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t}); 143 Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue}); 144 Value v = apply(s0 % s1, {a, d}); 145 Value u = apply(s0 - s1, {d, v}); 146 147 MultiSizeSpecification spec; 148 spec.lowTileSize = s; 149 spec.highTileSize = apply(s0 + s1, {s, divisorValue}); 150 spec.lowTripCount = u; 151 spec.highTripCount = v; 152 153 // If requested, emit the check that the tile sizes are computed correctly. 154 // For example, for iteration dimension size of 15 and the target size 8 it is 155 // impossible to find two tile sizes both divisible by 8 that fully cover the 156 // original space dimension. 157 if (emitAssertions) { 158 AffineExpr s3 = builder.getAffineSymbolExpr(3); 159 Value coveredSize = 160 apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount, 161 spec.highTileSize, spec.highTripCount}); 162 Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 163 coveredSize, tripCount); 164 b.create<cf::AssertOp>( 165 equals, builder.getStringAttr( 166 "could not compute dynamic multi-size tile shapes")); 167 } 168 169 return spec; 170 } 171 172 /// Given a `subsetExtractOp`, a `source` and a `dest`, create a new 173 /// `ParallelInsertSlice` op of `source` into `dest` at the same subset location 174 /// as `subsetExtractOp`. 175 static void 176 createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc, 177 tensor::ExtractSliceOp subsetExtractOp, 178 Value source, Value dest) { 179 b.create<tensor::ParallelInsertSliceOp>( 180 loc, source, dest, subsetExtractOp.getMixedOffsets(), 181 subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides()); 182 } 183 184 /// Build an `affine_max` of all the `vals`. 185 static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) { 186 return b.createOrFold<AffineMaxOp>( 187 loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), 188 vals); 189 } 190 191 /// Build an `affine_min` of all the `vals`. 192 static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) { 193 return b.createOrFold<AffineMinOp>( 194 loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), 195 vals); 196 } 197 198 FailureOr<ForeachThreadTilingResult> 199 linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op, 200 ArrayRef<OpFoldResult> numThreads, 201 ArrayRef<int64_t> threadDimMapping) { 202 Location loc = op->getLoc(); 203 OpBuilder::InsertionGuard g(b); 204 SmallVector<Range> loopRanges = op.getIterationDomain(b); 205 if (loopRanges.empty()) 206 return op->emitOpError("expected non-empty loop ranges"); 207 auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; 208 if (llvm::any_of(loopRanges, hasStrideOne)) 209 return op->emitOpError("only stride-1 supported atm"); 210 // TODO: support `getTiledImplementation` with >1 produced tiled ops. 211 auto destOperands = op.getDestinationOperands(b); 212 if (destOperands.size() != 1) 213 return op->emitOpError("only single dest operand supported atm"); 214 215 SmallVector<OpFoldResult> nonZeroNumThreads = 216 llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { 217 return !isConstantIntValue(ofr, 0); 218 })); 219 SmallVector<Value> materializedNonZeroNumThreads = 220 llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) { 221 ImplicitLocOpBuilder ilocb(loc, b); 222 return materializeOpFoldResult(ilocb, ofr); 223 })); 224 225 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 226 Operation *tiledOp = nullptr; 227 scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>( 228 loc, materializedNonZeroNumThreads, threadDimMapping, 229 [&](OpBuilder &b, Location loc, ValueRange threadIds) { 230 int64_t nLoops = loopRanges.size(); 231 SmallVector<OpFoldResult> tiledOffsets, tiledSizes; 232 tiledOffsets.reserve(nLoops); 233 tiledSizes.reserve(nLoops); 234 for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; 235 ++loopIdx) { 236 bool overflow = loopIdx >= numThreads.size(); 237 bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); 238 // Degenerate case: take the whole domain. 239 if (overflow || isZero) { 240 tiledOffsets.push_back(loopRanges[loopIdx].offset); 241 tiledSizes.push_back(loopRanges[loopIdx].size); 242 continue; 243 } 244 245 // Tiled case: compute the offset and size. 246 AffineExpr i, j, M, N, O; 247 bindDims(b.getContext(), i, j); 248 bindSymbols(b.getContext(), M, N, O); 249 Value size = loopRanges[loopIdx].size; 250 Value offset = loopRanges[loopIdx].offset; 251 Value threadId = threadIds[threadIdIdx]; 252 // TODO: more aggressive foldings. 253 // Symbolic fixed max size per thread. 254 // TODO: floor + 0/1 depending on case for better load-balancing. 255 Value maxSizePerThread = b.createOrFold<AffineApplyOp>( 256 loc, M.ceilDiv(N), 257 ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]}); 258 // Dynamic offset shifted by threadId * maxSizePerThread. 259 Value offsetPerThread = b.createOrFold<AffineApplyOp>( 260 loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread}); 261 // Dynamic upper-bound depending on the threadId. 262 Value sizeMinusOffsetPerThread = b.createOrFold<AffineApplyOp>( 263 loc, -i + M, ValueRange{offsetPerThread, size}); 264 Value tileSizePerThread = buildMin( 265 b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread}); 266 tiledOffsets.push_back(offsetPerThread); 267 // TODO: if tileSizePerThread <= 0 early exit. 268 tiledSizes.push_back( 269 buildMax(b, loc, ValueRange{zero, tileSizePerThread})); 270 ++threadIdIdx; 271 } 272 273 SmallVector<Operation *> tiledOps = 274 op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes, 275 /*tileDestOperands=*/true); 276 assert(tiledOps.size() == 1 && "expected a single produced tiled op"); 277 tiledOp = tiledOps.front(); 278 279 auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp); 280 assert(tilingInterfaceOp && 281 "Tiled op does not implement TilingInterface"); 282 283 auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b); 284 285 // Create terminator with parallel subset insert operations. 286 auto performConcurrentlyOp = b.create<scf::PerformConcurrentlyOp>(loc); 287 OpBuilder::InsertionGuard g(b); 288 b.setInsertionPointToStart(performConcurrentlyOp.getBody()); 289 for (auto it : 290 llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(), 291 destOperands)) { 292 createMatchingParallelSubsetInsertOp( 293 b, loc, 294 cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()), 295 std::get<1>(it), std::get<2>(it)); 296 } 297 }); 298 return ForeachThreadTilingResult{foreachThreadOp, tiledOp}; 299 } 300 301 // Insert a tile `source` into the destination tensor `dest`. The position at 302 // which the tile is inserted (as well as size of tile) is taken from a given 303 // ExtractSliceOp `sliceOp`. 304 static Value insertSliceIntoTensor(RewriterBase &b, Location loc, 305 tensor::ExtractSliceOp sliceOp, Value source, 306 Value dest) { 307 return b.create<tensor::InsertSliceOp>( 308 loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(), 309 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), 310 sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); 311 } 312 313 template <typename LoopTy> 314 static FailureOr<TiledLinalgOp> 315 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes, 316 const LinalgTilingOptions &options) { 317 auto nLoops = op.getNumLoops(); 318 // Initial tile sizes may be too big, only take the first nLoops. 319 tileSizes = tileSizes.take_front(nLoops); 320 321 if (llvm::all_of(tileSizes, isZero)) { 322 TiledLinalgOp tiledOp; 323 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation())); 324 tiledOp.tensorResults.assign(tiledOp.op->result_begin(), 325 tiledOp.op->result_end()); 326 return tiledOp; 327 } 328 329 // 1. Build the tiled loop ranges. 330 auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); 331 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 332 if (!shapeSizesToLoopsMap) 333 return failure(); 334 335 SmallVector<Range, 4> loopRanges; 336 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 337 std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( 338 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 339 340 SmallVector<Attribute, 4> iteratorTypes; 341 for (const auto &attr : 342 enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) { 343 if (loopIndexToRangeIndex.count(attr.index())) 344 iteratorTypes.push_back(attr.value()); 345 } 346 // If interchangeVector is empty, use the identity. Build the permutation map 347 // otherwise. 348 auto invPermutationMap = 349 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 350 if (!options.interchangeVector.empty()) { 351 // Based on the pruned iterations (due to zero tile size), recompute the 352 // interchange vector. 353 SmallVector<unsigned, 4> interchangeVector; 354 interchangeVector.reserve(options.interchangeVector.size()); 355 for (auto pos : options.interchangeVector) { 356 auto it = loopIndexToRangeIndex.find(pos); 357 if (it == loopIndexToRangeIndex.end()) 358 continue; 359 interchangeVector.push_back(it->second); 360 } 361 // Interchange vector is guaranteed to be a permutation, 362 // `inversePermutation` must succeed. 363 invPermutationMap = inversePermutation( 364 AffineMap::getPermutationMap(interchangeVector, b.getContext())); 365 assert(invPermutationMap); 366 SmallVector<int64_t> permutation(interchangeVector.begin(), 367 interchangeVector.end()); 368 applyPermutationToVector(loopRanges, permutation); 369 applyPermutationToVector(iteratorTypes, permutation); 370 } 371 372 // 2. Create the tiled loops. 373 LinalgOp res = op; 374 SmallVector<Value, 4> ivs, tensorResults; 375 auto tiledLoopBodyBuilder = 376 [&](OpBuilder &builder, Location loc, ValueRange localIvs, 377 ValueRange operandValuesToUse) -> scf::ValueVector { 378 ivs.assign(localIvs.begin(), localIvs.end()); 379 380 // When an `interchangeVector` is present, it has been applied to the 381 // loop ranges and the iterator types. Apply its inverse to the 382 // resulting loop `ivs` to match the op definition. 383 SmallVector<Value, 4> interchangedIvs; 384 if (!options.interchangeVector.empty()) 385 interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs); 386 else 387 interchangedIvs.assign(ivs.begin(), ivs.end()); 388 389 // Tile the `operandValuesToUse` that either match the `op` operands 390 // themselves or the tile loop arguments forwarding them. 391 assert(operandValuesToUse.size() == 392 static_cast<size_t>(op.getNumInputsAndOutputs()) && 393 "expect the number of operands and inputs and outputs to match"); 394 SmallVector<Value> valuesToTile = operandValuesToUse; 395 auto sizeBounds = 396 applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); 397 SmallVector<Value, 4> tiledOperands = 398 makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes, 399 sizeBounds, /*omitPartialTileCheck=*/false); 400 401 SmallVector<Type> resultTensorTypes = 402 getTensorOutputTypes(op, tiledOperands); 403 res = op.clone(b, loc, resultTensorTypes, tiledOperands); 404 tensorResults = 405 insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); 406 return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 407 }; 408 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, 409 tiledLoopBodyBuilder, options.distribution, 410 options.distributionTypes); 411 412 // 3. Transform IndexOp results w.r.t. the tiling. 413 transformIndexOps(b, res, ivs, loopIndexToRangeIndex); 414 415 // 4. Gather the newly created loops and return them with the new op. 416 SmallVector<Operation *, 8> loops; 417 loops.reserve(ivs.size()); 418 for (auto iv : ivs) { 419 if (iv.isa<BlockArgument>()) { 420 loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp()); 421 assert(loops.back() && "no owner found for induction variable!"); 422 } else { 423 // TODO: Instead of doing this, try to recover the ops used instead of the 424 // loop. 425 loops.push_back(nullptr); 426 } 427 } 428 429 // 5. Get the tensor results from the outermost loop if available. Otherwise 430 // use the previously captured `tensorResults`. 431 Operation *outermostLoop = nullptr; 432 for (Operation *loop : loops) 433 if ((outermostLoop = loop)) 434 break; 435 436 return TiledLinalgOp{ 437 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 438 } 439 440 template <typename LoopTy> 441 FailureOr<TiledLinalgOp> static tileLinalgOpImpl( 442 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { 443 OpBuilder::InsertionGuard g(b); 444 b.setInsertionPoint(op); 445 446 if (!options.tileSizeComputationFunction) 447 return failure(); 448 449 // Enforce the convention that "tiling by zero" skips tiling a particular 450 // dimension. This convention is significantly simpler to handle instead of 451 // adjusting affine maps to account for missing dimensions. 452 auto nLoops = op.getNumLoops(); 453 SmallVector<Value, 4> tileSizeVector = 454 options.tileSizeComputationFunction(b, op); 455 if (tileSizeVector.size() < nLoops) { 456 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); 457 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 458 } 459 460 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 461 } 462 463 FailureOr<TiledLinalgOp> 464 mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op, 465 const LinalgTilingOptions &options) { 466 switch (options.loopType) { 467 case LinalgTilingLoopType::Loops: 468 return tileLinalgOpImpl<scf::ForOp>(b, op, options); 469 case LinalgTilingLoopType::ParallelLoops: 470 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 471 default:; 472 } 473 return failure(); 474 } 475 476 /// Generate a loop nest around a given tensor::PadOp (for tiling). `newPadOp` 477 /// and `loopNest` are output parameters that return the new (tiled) 478 /// tensor::PadOp and the loop nest. 479 static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op, 480 tensor::PadOp &newPadOp, LoopNest &loopNest, 481 const LinalgTilingOptions &options) { 482 Location loc = op.getLoc(); 483 OpBuilder::InsertionGuard g(builder); 484 builder.setInsertionPoint(op); 485 486 // Clone tensor::PadOp so that the existing op can be replaced more easily. 487 newPadOp = cast<tensor::PadOp>(builder.clone(*op.getOperation())); 488 // Get rank and tile sizes. 489 int64_t rank = op.getResultType().getRank(); 490 SmallVector<Value> tileSizes = 491 options.tileSizeComputationFunction(builder, op); 492 // Normalize untiled padding dimensions to 0. 493 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); 494 tileSizes.append(rank - tileSizes.size(), zero); 495 // Compute lower and upper bounds of the loop nest. 496 TilingInterface tilingInterface = 497 dyn_cast<TilingInterface>(op.getOperation()); 498 SmallVector<Range> ranges = tilingInterface.getIterationDomain(builder); 499 SmallVector<Value> lbs, dims, allDims, steps; 500 for (int64_t i = 0; i < rank; ++i) { 501 allDims.push_back(ranges[i].size); 502 if (!isZero(tileSizes[i])) { 503 lbs.push_back(ranges[i].offset); 504 dims.push_back(ranges[i].size); 505 steps.push_back(tileSizes[i]); 506 } 507 } 508 // Generate loop nest: One loop per dimension. 509 SmallVector<Value> destOperand = 510 tilingInterface.getDestinationOperands(builder); 511 loopNest = mlir::scf::buildLoopNest( 512 builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand), 513 [&](OpBuilder &b, Location loc, ValueRange localIvs, 514 ValueRange iterArgs) -> scf::ValueVector { 515 // Compute offsets and sizes of ExtractSliceOp. 516 SmallVector<Value> offsets = 517 computeTileOffsets(b, loc, localIvs, tileSizes); 518 SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims); 519 // Create ExtractSliceOp: Extract a tile from the tensor::PadOp. 520 // Note: The tensor::PadOp is located outside of the loop nest. It is 521 // later moved inside by ExtractSliceOfPadTensorSwapPattern. 522 auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext()); 523 Value tiledOutput = makeTiledShape( 524 b, loc, newPadOp->getResult(0), tileSizes, map, offsets, allDims, 525 sizes, /*omitPartialTileCheck=*/false); 526 auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>(); 527 assert(sliceOp && "expected ExtractSliceOp"); 528 // Insert the tile into the output tensor. 529 // TODO: Propagate RewriterBase everywhere. 530 IRRewriter rewriter(b); 531 Value yieldValue = 532 insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]); 533 return scf::ValueVector({yieldValue}); 534 }); 535 return success(); 536 } 537 538 namespace { 539 struct PadOpTilingPattern : public OpRewritePattern<tensor::PadOp> { 540 PadOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt) 541 : OpRewritePattern<tensor::PadOp>(ctx), options(std::move(opt)) {} 542 543 LogicalResult matchAndRewrite(tensor::PadOp op, 544 PatternRewriter &rewriter) const override { 545 if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker)) 546 return failure(); 547 tensor::PadOp newPadOp; 548 LoopNest loopNest; 549 if (failed(tilePadOp(rewriter, op, newPadOp, loopNest, options))) 550 return failure(); 551 newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker, 552 rewriter.getUnitAttr()); 553 // Replace all uses of the original tensor::PadOp. 554 rewriter.replaceOp(op, loopNest.getResults()[0]); 555 return success(); 556 } 557 558 LinalgTilingOptions options; 559 }; 560 } // namespace 561 562 namespace { 563 /// Helper classes for type list expansion. 564 template <typename... OpTypes> 565 class CanonicalizationPatternList; 566 567 template <> 568 class CanonicalizationPatternList<> { 569 public: 570 static void insert(RewritePatternSet &patterns) {} 571 }; 572 573 template <typename OpTy, typename... OpTypes> 574 class CanonicalizationPatternList<OpTy, OpTypes...> { 575 public: 576 static void insert(RewritePatternSet &patterns) { 577 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 578 CanonicalizationPatternList<OpTypes...>::insert(patterns); 579 } 580 }; 581 } // namespace 582 583 RewritePatternSet 584 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 585 RewritePatternSet patterns(ctx); 586 populateLinalgTilingCanonicalizationPatterns(patterns); 587 return patterns; 588 } 589 590 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 591 RewritePatternSet &patterns) { 592 auto *ctx = patterns.getContext(); 593 AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 594 AffineForOp::getCanonicalizationPatterns(patterns, ctx); 595 AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 596 AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 597 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 598 599 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 600 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 601 602 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 603 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 604 605 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 606 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); 607 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); 608 609 InitTensorOp::getCanonicalizationPatterns(patterns, ctx); 610 tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); 611 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns); 612 613 CanonicalizationPatternList< 614 #define GET_OP_LIST 615 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 616 >::insert(patterns); 617 } 618 619 /// Populate the given list with patterns that apply Linalg tiling. 620 static void insertTilingPatterns(RewritePatternSet &patterns, 621 const LinalgTilingOptions &options) { 622 auto *ctx = patterns.getContext(); 623 LinalgTransformationFilter f(ArrayRef<StringAttr>{}, 624 StringAttr::get(ctx, "tiled")); 625 TilingPatterns<GenericOp, 626 #define GET_OP_LIST 627 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 628 >::insert(patterns, options, f); 629 patterns.add<PadOpTilingPattern>(ctx, options); 630 } 631 632 void mlir::linalg::populatePadTensorTilingPatterns( 633 RewritePatternSet &patterns, const LinalgTilingOptions &options) { 634 auto *ctx = patterns.getContext(); 635 patterns.add<PadOpTilingPattern>(ctx, options); 636 } 637 638 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { 639 MLIRContext *ctx = funcOp.getContext(); 640 RewritePatternSet patterns(ctx); 641 patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext()); 642 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 643 (void)applyPatternsAndFoldGreedily( 644 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 645 } 646 647 namespace { 648 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> { 649 LinalgTilingPass() = default; 650 LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType) { 651 this->tileSizes = tileSizes; 652 this->loopType = ""; 653 this->loopTypeEnum = loopType; 654 } 655 656 void runOnOperation() override { 657 func::FuncOp funcOp = getOperation(); 658 LinalgTilingLoopType type = 659 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 660 .Case("for", LinalgTilingLoopType::Loops) 661 .Case("affine", LinalgTilingLoopType::AffineLoops) 662 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 663 .Default(loopTypeEnum); 664 auto options = 665 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(type); 666 MLIRContext *ctx = funcOp.getContext(); 667 RewritePatternSet patterns(ctx); 668 insertTilingPatterns(patterns, options); 669 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 670 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 671 (void)applyPatternsAndFoldGreedily( 672 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 673 // Drop the marker. 674 funcOp.walk([](LinalgOp op) { 675 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 676 }); 677 678 // Apply swap pattern after generating loop nest and running 679 // canonicalizations. 680 applyExtractSliceOfPadTensorSwapPattern(funcOp); 681 } 682 683 LinalgTilingLoopType loopTypeEnum; 684 }; 685 686 } // namespace 687 688 std::unique_ptr<OperationPass<func::FuncOp>> 689 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes, 690 linalg::LinalgTilingLoopType loopType) { 691 return std::make_unique<LinalgTilingPass>(tileSizes, loopType); 692 } 693