1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Tiling pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <utility> 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/AffineMap.h" 26 #include "mlir/Transforms/FoldUtils.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 29 #include "llvm/Support/CommandLine.h" 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 using namespace mlir::scf; 34 35 #define DEBUG_TYPE "linalg-tiling" 36 37 static bool isZero(Value v) { 38 if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>()) 39 return cst.value() == 0; 40 return false; 41 } 42 43 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 44 mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, 45 ValueRange allShapeSizes, 46 ValueRange allTileSizes) { 47 assert(allTileSizes.size() == map.getNumResults()); 48 // Apply `map` to get shape sizes in loop order. 49 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 50 SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end()); 51 52 // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 53 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 54 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 55 if (isZero(tileSizes[idx - zerosCount])) { 56 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 57 tileSizes.erase(tileSizes.begin() + idx - zerosCount); 58 ++zerosCount; 59 continue; 60 } 61 loopIndexToRangeIndex[idx] = idx - zerosCount; 62 } 63 64 // Create a new range with the applied tile sizes. 65 SmallVector<Range, 4> res; 66 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 67 res.push_back(Range{b.create<arith::ConstantIndexOp>(loc, 0), 68 shapeSizes[idx], tileSizes[idx]}); 69 return std::make_tuple(res, loopIndexToRangeIndex); 70 } 71 72 void mlir::linalg::transformIndexOps( 73 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 74 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 75 SmallVector<Value> allIvs(op.getNumLoops(), nullptr); 76 for (auto &en : enumerate(allIvs)) { 77 auto rangeIndex = loopIndexToRangeIndex.find(en.index()); 78 if (rangeIndex == loopIndexToRangeIndex.end()) 79 continue; 80 en.value() = ivs[rangeIndex->second]; 81 } 82 addTileLoopIvsToIndexOpResults(b, op, allIvs); 83 } 84 85 // Insert a tile `source` into the destination tensor `dest`. The position at 86 // which the tile is inserted (as well as size of tile) is taken from a given 87 // ExtractSliceOp `sliceOp`. 88 static Value insertSliceIntoTensor(RewriterBase &b, Location loc, 89 tensor::ExtractSliceOp sliceOp, Value source, 90 Value dest) { 91 return b.create<tensor::InsertSliceOp>( 92 loc, sliceOp.source().getType(), source, dest, sliceOp.offsets(), 93 sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(), 94 sliceOp.static_sizes(), sliceOp.static_strides()); 95 } 96 97 template <typename LoopTy> 98 static FailureOr<TiledLinalgOp> 99 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes, 100 const LinalgTilingOptions &options) { 101 auto nLoops = op.getNumLoops(); 102 // Initial tile sizes may be too big, only take the first nLoops. 103 tileSizes = tileSizes.take_front(nLoops); 104 105 if (llvm::all_of(tileSizes, isZero)) { 106 TiledLinalgOp tiledOp; 107 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation())); 108 tiledOp.tensorResults.assign(tiledOp.op->result_begin(), 109 tiledOp.op->result_end()); 110 return tiledOp; 111 } 112 113 // 1. Build the tiled loop ranges. 114 auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); 115 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 116 if (!shapeSizesToLoopsMap) 117 return failure(); 118 119 SmallVector<Range, 4> loopRanges; 120 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 121 std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( 122 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 123 124 SmallVector<Attribute, 4> iteratorTypes; 125 for (const auto &attr : 126 enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) { 127 if (loopIndexToRangeIndex.count(attr.index())) 128 iteratorTypes.push_back(attr.value()); 129 } 130 // If interchangeVector is empty, use the identity. Build the permutation map 131 // otherwise. 132 auto invPermutationMap = 133 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 134 if (!options.interchangeVector.empty()) { 135 // Based on the pruned iterations (due to zero tile size), recompute the 136 // interchange vector. 137 SmallVector<unsigned, 4> interchangeVector; 138 interchangeVector.reserve(options.interchangeVector.size()); 139 for (auto pos : options.interchangeVector) { 140 auto it = loopIndexToRangeIndex.find(pos); 141 if (it == loopIndexToRangeIndex.end()) 142 continue; 143 interchangeVector.push_back(it->second); 144 } 145 // Interchange vector is guaranteed to be a permutation, 146 // `inversePermutation` must succeed. 147 invPermutationMap = inversePermutation( 148 AffineMap::getPermutationMap(interchangeVector, b.getContext())); 149 assert(invPermutationMap); 150 SmallVector<int64_t> permutation(interchangeVector.begin(), 151 interchangeVector.end()); 152 applyPermutationToVector(loopRanges, permutation); 153 applyPermutationToVector(iteratorTypes, permutation); 154 } 155 156 // 2. Create the tiled loops. 157 LinalgOp res = op; 158 SmallVector<Value, 4> ivs, tensorResults; 159 auto tiledLoopBodyBuilder = 160 [&](OpBuilder &builder, Location loc, ValueRange localIvs, 161 ValueRange operandValuesToUse) -> scf::ValueVector { 162 ivs.assign(localIvs.begin(), localIvs.end()); 163 164 // When an `interchangeVector` is present, it has been applied to the 165 // loop ranges and the iterator types. Apply its inverse to the 166 // resulting loop `ivs` to match the op definition. 167 SmallVector<Value, 4> interchangedIvs; 168 if (!options.interchangeVector.empty()) 169 interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs); 170 else 171 interchangedIvs.assign(ivs.begin(), ivs.end()); 172 173 // Tile the `operandValuesToUse` that either match the `op` operands 174 // themselves or the tile loop arguments forwarding them. 175 assert(operandValuesToUse.size() == 176 static_cast<size_t>(op.getNumInputsAndOutputs()) && 177 "expect the number of operands and inputs and outputs to match"); 178 SmallVector<Value> valuesToTile = operandValuesToUse; 179 auto sizeBounds = 180 applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); 181 SmallVector<Value, 4> tiledOperands = 182 makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes, 183 sizeBounds, /*omitPartialTileCheck=*/false); 184 185 // TODO: use an interface/adaptor to avoid leaking position in 186 // `tiledOperands`. 187 SmallVector<Type, 4> resultTensorTypes; 188 for (OpOperand *opOperand : op.getOutputTensorOperands()) 189 resultTensorTypes.push_back( 190 tiledOperands[opOperand->getOperandNumber()].getType()); 191 192 res = op.clone(b, loc, resultTensorTypes, tiledOperands); 193 194 // Insert a insert_slice for each output tensor. 195 unsigned resultIdx = 0; 196 for (OpOperand *opOperand : op.getOutputTensorOperands()) { 197 // TODO: use an interface/adaptor to avoid leaking position in 198 // `tiledOperands`. 199 Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; 200 // TODO: Propagate RewriterBase everywhere. 201 IRRewriter rewriter(b); 202 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { 203 tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp, 204 res->getResult(resultIdx), 205 sliceOp.source())); 206 } else { 207 tensorResults.push_back(res->getResult(resultIdx)); 208 } 209 ++resultIdx; 210 } 211 return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 212 }; 213 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, 214 tiledLoopBodyBuilder, options.distribution, 215 options.distributionTypes); 216 217 // 3. Transform IndexOp results w.r.t. the tiling. 218 transformIndexOps(b, res, ivs, loopIndexToRangeIndex); 219 220 // 4. Gather the newly created loops and return them with the new op. 221 SmallVector<Operation *, 8> loops; 222 loops.reserve(ivs.size()); 223 for (auto iv : ivs) { 224 if (iv.isa<BlockArgument>()) { 225 loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp()); 226 assert(loops.back() && "no owner found for induction variable!"); 227 } else { 228 // TODO: Instead of doing this, try to recover the ops used instead of the 229 // loop. 230 loops.push_back(nullptr); 231 } 232 } 233 234 // 5. Get the tensor results from the outermost loop if available. Otherwise 235 // use the previously captured `tensorResults`. 236 Operation *outermostLoop = nullptr; 237 for (Operation *loop : loops) 238 if ((outermostLoop = loop)) 239 break; 240 241 return TiledLinalgOp{ 242 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 243 } 244 245 template <typename LoopTy> 246 FailureOr<TiledLinalgOp> static tileLinalgOpImpl( 247 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { 248 OpBuilder::InsertionGuard g(b); 249 b.setInsertionPoint(op); 250 251 if (!options.tileSizeComputationFunction) 252 return failure(); 253 254 // Enforce the convention that "tiling by zero" skips tiling a particular 255 // dimension. This convention is significantly simpler to handle instead of 256 // adjusting affine maps to account for missing dimensions. 257 auto nLoops = op.getNumLoops(); 258 SmallVector<Value, 4> tileSizeVector = 259 options.tileSizeComputationFunction(b, op); 260 if (tileSizeVector.size() < nLoops) { 261 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); 262 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 263 } 264 265 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 266 } 267 268 FailureOr<TiledLinalgOp> 269 mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op, 270 const LinalgTilingOptions &options) { 271 switch (options.loopType) { 272 case LinalgTilingLoopType::Loops: 273 return tileLinalgOpImpl<scf::ForOp>(b, op, options); 274 case LinalgTilingLoopType::ParallelLoops: 275 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 276 default:; 277 } 278 return failure(); 279 } 280 281 /// Generate a loop nest around a given tensor::PadOp (for tiling). `newPadOp` 282 /// and `loopNest` are output parameters that return the new (tiled) 283 /// tensor::PadOp and the loop nest. 284 static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op, 285 tensor::PadOp &newPadOp, LoopNest &loopNest, 286 const LinalgTilingOptions &options) { 287 Location loc = op.getLoc(); 288 OpBuilder::InsertionGuard g(builder); 289 builder.setInsertionPoint(op); 290 291 // Clone tensor::PadOp so that the existing op can be replaced more easily. 292 newPadOp = cast<tensor::PadOp>(builder.clone(*op.getOperation())); 293 // Get rank and tile sizes. 294 int64_t rank = op.getResultType().getRank(); 295 SmallVector<Value> tileSizes = 296 options.tileSizeComputationFunction(builder, op); 297 // Normalize untiled padding dimensions to 0. 298 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); 299 tileSizes.append(rank - tileSizes.size(), zero); 300 // Compute lower and upper bounds of the loop nest. 301 TilingInterface tilingInterface = 302 dyn_cast<TilingInterface>(op.getOperation()); 303 SmallVector<Range> ranges = tilingInterface.getIterationDomain(builder); 304 SmallVector<Value> lbs, dims, allDims, steps; 305 for (int64_t i = 0; i < rank; ++i) { 306 allDims.push_back(ranges[i].size); 307 if (!isZero(tileSizes[i])) { 308 lbs.push_back(ranges[i].offset); 309 dims.push_back(ranges[i].size); 310 steps.push_back(tileSizes[i]); 311 } 312 } 313 // Generate loop nest: One loop per dimension. 314 SmallVector<Value> destOperand = 315 tilingInterface.getDestinationOperands(builder); 316 loopNest = mlir::scf::buildLoopNest( 317 builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand), 318 [&](OpBuilder &b, Location loc, ValueRange localIvs, 319 ValueRange iterArgs) -> scf::ValueVector { 320 // Compute offsets and sizes of ExtractSliceOp. 321 SmallVector<Value> offsets = 322 computeTileOffsets(b, loc, localIvs, tileSizes); 323 SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims); 324 // Create ExtractSliceOp: Extract a tile from the tensor::PadOp. 325 // Note: The tensor::PadOp is located outside of the loop nest. It is 326 // later moved inside by ExtractSliceOfPadTensorSwapPattern. 327 auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext()); 328 Value tiledOutput = makeTiledShape( 329 b, loc, newPadOp->getResult(0), tileSizes, map, offsets, allDims, 330 sizes, /*omitPartialTileCheck=*/false); 331 auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>(); 332 assert(sliceOp && "expected ExtractSliceOp"); 333 // Insert the tile into the output tensor. 334 // TODO: Propagate RewriterBase everywhere. 335 IRRewriter rewriter(b); 336 Value yieldValue = 337 insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]); 338 return scf::ValueVector({yieldValue}); 339 }); 340 return success(); 341 } 342 343 namespace { 344 struct PadOpTilingPattern : public OpRewritePattern<tensor::PadOp> { 345 PadOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt) 346 : OpRewritePattern<tensor::PadOp>(ctx), options(std::move(opt)) {} 347 348 LogicalResult matchAndRewrite(tensor::PadOp op, 349 PatternRewriter &rewriter) const override { 350 if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker)) 351 return failure(); 352 tensor::PadOp newPadOp; 353 LoopNest loopNest; 354 if (failed(tilePadOp(rewriter, op, newPadOp, loopNest, options))) 355 return failure(); 356 newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker, 357 rewriter.getUnitAttr()); 358 // Replace all uses of the original tensor::PadOp. 359 rewriter.replaceOp(op, loopNest.getResults()[0]); 360 return success(); 361 } 362 363 LinalgTilingOptions options; 364 }; 365 } // namespace 366 367 namespace { 368 /// Helper classes for type list expansion. 369 template <typename... OpTypes> 370 class CanonicalizationPatternList; 371 372 template <> 373 class CanonicalizationPatternList<> { 374 public: 375 static void insert(RewritePatternSet &patterns) {} 376 }; 377 378 template <typename OpTy, typename... OpTypes> 379 class CanonicalizationPatternList<OpTy, OpTypes...> { 380 public: 381 static void insert(RewritePatternSet &patterns) { 382 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 383 CanonicalizationPatternList<OpTypes...>::insert(patterns); 384 } 385 }; 386 } // namespace 387 388 RewritePatternSet 389 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 390 RewritePatternSet patterns(ctx); 391 populateLinalgTilingCanonicalizationPatterns(patterns); 392 return patterns; 393 } 394 395 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 396 RewritePatternSet &patterns) { 397 auto *ctx = patterns.getContext(); 398 AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 399 AffineForOp::getCanonicalizationPatterns(patterns, ctx); 400 AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 401 AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 402 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 403 404 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 405 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 406 407 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 408 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 409 410 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 411 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); 412 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); 413 414 InitTensorOp::getCanonicalizationPatterns(patterns, ctx); 415 tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); 416 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns); 417 418 CanonicalizationPatternList< 419 #define GET_OP_LIST 420 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 421 >::insert(patterns); 422 } 423 424 /// Populate the given list with patterns that apply Linalg tiling. 425 static void insertTilingPatterns(RewritePatternSet &patterns, 426 const LinalgTilingOptions &options) { 427 auto *ctx = patterns.getContext(); 428 LinalgTransformationFilter f(ArrayRef<StringAttr>{}, 429 StringAttr::get(ctx, "tiled")); 430 TilingPatterns<GenericOp, 431 #define GET_OP_LIST 432 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 433 >::insert(patterns, options, f); 434 patterns.add<PadOpTilingPattern>(ctx, options); 435 } 436 437 void mlir::linalg::populatePadTensorTilingPatterns( 438 RewritePatternSet &patterns, const LinalgTilingOptions &options) { 439 auto *ctx = patterns.getContext(); 440 patterns.add<PadOpTilingPattern>(ctx, options); 441 } 442 443 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { 444 MLIRContext *ctx = funcOp.getContext(); 445 RewritePatternSet patterns(ctx); 446 patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext()); 447 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 448 (void)applyPatternsAndFoldGreedily( 449 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 450 } 451 452 namespace { 453 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> { 454 LinalgTilingPass() = default; 455 LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType) { 456 this->tileSizes = tileSizes; 457 this->loopType = ""; 458 this->loopTypeEnum = loopType; 459 } 460 461 void runOnOperation() override { 462 func::FuncOp funcOp = getOperation(); 463 LinalgTilingLoopType type = 464 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 465 .Case("for", LinalgTilingLoopType::Loops) 466 .Case("affine", LinalgTilingLoopType::AffineLoops) 467 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 468 .Default(loopTypeEnum); 469 auto options = 470 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(type); 471 MLIRContext *ctx = funcOp.getContext(); 472 RewritePatternSet patterns(ctx); 473 insertTilingPatterns(patterns, options); 474 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 475 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 476 (void)applyPatternsAndFoldGreedily( 477 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 478 // Drop the marker. 479 funcOp.walk([](LinalgOp op) { 480 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 481 }); 482 483 // Apply swap pattern after generating loop nest and running 484 // canonicalizations. 485 applyExtractSliceOfPadTensorSwapPattern(funcOp); 486 } 487 488 LinalgTilingLoopType loopTypeEnum; 489 }; 490 491 } // namespace 492 493 std::unique_ptr<OperationPass<func::FuncOp>> 494 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes, 495 linalg::LinalgTilingLoopType loopType) { 496 return std::make_unique<LinalgTilingPass>(tileSizes, loopType); 497 } 498