1*cf6a7c19SMahesh Ravishankar //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2*cf6a7c19SMahesh Ravishankar // 3*cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information. 5*cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*cf6a7c19SMahesh Ravishankar // 7*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 8*cf6a7c19SMahesh Ravishankar // 9*cf6a7c19SMahesh Ravishankar // This file implements the tiling using TilingInterface. 10*cf6a7c19SMahesh Ravishankar // 11*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 12*cf6a7c19SMahesh Ravishankar 13*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/TileUsingInterface.h" 14*cf6a7c19SMahesh Ravishankar 15*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 16*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h" 18*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/Utils/Utils.h" 19*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h" 20*cf6a7c19SMahesh Ravishankar #include "mlir/IR/Matchers.h" 21*cf6a7c19SMahesh Ravishankar #include "mlir/IR/PatternMatch.h" 22*cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h" 23*cf6a7c19SMahesh Ravishankar #include "llvm/Support/Debug.h" 24*cf6a7c19SMahesh Ravishankar 25*cf6a7c19SMahesh Ravishankar #define DEBUG_TYPE "tile-using-interface" 26*cf6a7c19SMahesh Ravishankar 27*cf6a7c19SMahesh Ravishankar using namespace mlir; 28*cf6a7c19SMahesh Ravishankar 29*cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions & 30*cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 31*cf6a7c19SMahesh Ravishankar assert(!tileSizeComputationFunction && "tile sizes already set"); 32*cf6a7c19SMahesh Ravishankar SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 33*cf6a7c19SMahesh Ravishankar tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 34*cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(b); 35*cf6a7c19SMahesh Ravishankar b.setInsertionPointToStart( 36*cf6a7c19SMahesh Ravishankar &op->getParentOfType<func::FuncOp>().getBody().front()); 37*cf6a7c19SMahesh Ravishankar return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 38*cf6a7c19SMahesh Ravishankar Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 39*cf6a7c19SMahesh Ravishankar return v; 40*cf6a7c19SMahesh Ravishankar })); 41*cf6a7c19SMahesh Ravishankar }; 42*cf6a7c19SMahesh Ravishankar return *this; 43*cf6a7c19SMahesh Ravishankar } 44*cf6a7c19SMahesh Ravishankar 45*cf6a7c19SMahesh Ravishankar /// Generate an empty loop nest that represents the tiled loop nest shell. 46*cf6a7c19SMahesh Ravishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 47*cf6a7c19SMahesh Ravishankar /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 48*cf6a7c19SMahesh Ravishankar /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 49*cf6a7c19SMahesh Ravishankar /// the 50*cf6a7c19SMahesh Ravishankar /// tile processed within the inner most loop. 51*cf6a7c19SMahesh Ravishankar static SmallVector<scf::ForOp> 52*cf6a7c19SMahesh Ravishankar generateTileLoopNest(OpBuilder &builder, Location loc, 53*cf6a7c19SMahesh Ravishankar ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 54*cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &offsets, 55*cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &sizes) { 56*cf6a7c19SMahesh Ravishankar assert(!loopRanges.empty() && "expected at least one loop range"); 57*cf6a7c19SMahesh Ravishankar assert(loopRanges.size() == tileSizeVals.size() && 58*cf6a7c19SMahesh Ravishankar "expected as many tile sizes as loop ranges"); 59*cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(builder); 60*cf6a7c19SMahesh Ravishankar SmallVector<scf::ForOp> loops; 61*cf6a7c19SMahesh Ravishankar offsets.resize(loopRanges.size()); 62*cf6a7c19SMahesh Ravishankar sizes.resize(loopRanges.size()); 63*cf6a7c19SMahesh Ravishankar 64*cf6a7c19SMahesh Ravishankar // The tile size to use (to avoid out of bounds access) is minimum of 65*cf6a7c19SMahesh Ravishankar // `tileSize` and `ub - iv`, where `iv` is the induction variable 66*cf6a7c19SMahesh Ravishankar // of the tiled loop. 67*cf6a7c19SMahesh Ravishankar AffineExpr s0, s1, d0; 68*cf6a7c19SMahesh Ravishankar bindDims(builder.getContext(), d0); 69*cf6a7c19SMahesh Ravishankar bindSymbols(builder.getContext(), s0, s1); 70*cf6a7c19SMahesh Ravishankar AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 71*cf6a7c19SMahesh Ravishankar 72*cf6a7c19SMahesh Ravishankar for (auto loopRange : llvm::enumerate(loopRanges)) { 73*cf6a7c19SMahesh Ravishankar // No loops if tile size is zero. Set offset and size to the loop 74*cf6a7c19SMahesh Ravishankar // offset and size. 75*cf6a7c19SMahesh Ravishankar if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 76*cf6a7c19SMahesh Ravishankar offsets[loopRange.index()] = loopRange.value().offset; 77*cf6a7c19SMahesh Ravishankar sizes[loopRange.index()] = loopRange.value().size; 78*cf6a7c19SMahesh Ravishankar continue; 79*cf6a7c19SMahesh Ravishankar } 80*cf6a7c19SMahesh Ravishankar 81*cf6a7c19SMahesh Ravishankar auto loop = builder.create<scf::ForOp>( 82*cf6a7c19SMahesh Ravishankar loc, loopRange.value().offset, loopRange.value().size, 83*cf6a7c19SMahesh Ravishankar tileSizeVals[loopRange.index()], ValueRange{}, 84*cf6a7c19SMahesh Ravishankar [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 85*cf6a7c19SMahesh Ravishankar ValueRange /*iterArgs*/) { 86*cf6a7c19SMahesh Ravishankar Value boundedTileSize = builder.create<AffineMinOp>( 87*cf6a7c19SMahesh Ravishankar bodyLoc, minMap, 88*cf6a7c19SMahesh Ravishankar ValueRange{iv, tileSizeVals[loopRange.index()], 89*cf6a7c19SMahesh Ravishankar loopRange.value().size}); 90*cf6a7c19SMahesh Ravishankar sizes[loopRange.index()] = boundedTileSize; 91*cf6a7c19SMahesh Ravishankar builder.create<scf::YieldOp>(loc); 92*cf6a7c19SMahesh Ravishankar }); 93*cf6a7c19SMahesh Ravishankar offsets[loopRange.index()] = loop.getInductionVar(); 94*cf6a7c19SMahesh Ravishankar loops.push_back(loop); 95*cf6a7c19SMahesh Ravishankar builder.setInsertionPoint(loop.getBody()->getTerminator()); 96*cf6a7c19SMahesh Ravishankar } 97*cf6a7c19SMahesh Ravishankar return loops; 98*cf6a7c19SMahesh Ravishankar } 99*cf6a7c19SMahesh Ravishankar 100*cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, 101*cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 102*cf6a7c19SMahesh Ravishankar PatternBenefit benefit) 103*cf6a7c19SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 104*cf6a7c19SMahesh Ravishankar options(std::move(options)) {} 105*cf6a7c19SMahesh Ravishankar 106*cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, 107*cf6a7c19SMahesh Ravishankar MLIRContext *context, 108*cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 109*cf6a7c19SMahesh Ravishankar PatternBenefit benefit) 110*cf6a7c19SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 111*cf6a7c19SMahesh Ravishankar options(std::move(options)) {} 112*cf6a7c19SMahesh Ravishankar 113*cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult> 114*cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::returningMatchAndRewrite( 115*cf6a7c19SMahesh Ravishankar TilingInterface op, PatternRewriter &rewriter) const { 116*cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(rewriter); 117*cf6a7c19SMahesh Ravishankar rewriter.setInsertionPointAfter(op); 118*cf6a7c19SMahesh Ravishankar 119*cf6a7c19SMahesh Ravishankar if (!options.tileSizeComputationFunction) { 120*cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure( 121*cf6a7c19SMahesh Ravishankar op, "missing tile size computation function"); 122*cf6a7c19SMahesh Ravishankar } 123*cf6a7c19SMahesh Ravishankar 124*cf6a7c19SMahesh Ravishankar // 1. Get the range of the loops that are represented by the operation. 125*cf6a7c19SMahesh Ravishankar SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 126*cf6a7c19SMahesh Ravishankar size_t numLoops = iterationDomain.size(); 127*cf6a7c19SMahesh Ravishankar if (numLoops == 0) { 128*cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure( 129*cf6a7c19SMahesh Ravishankar op, "unable to tile op with no iteration domain"); 130*cf6a7c19SMahesh Ravishankar } 131*cf6a7c19SMahesh Ravishankar 132*cf6a7c19SMahesh Ravishankar // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 133*cf6a7c19SMahesh Ravishankar // skips tiling a particular dimension. This convention is significantly 134*cf6a7c19SMahesh Ravishankar // simpler to handle instead of adjusting affine maps to account for missing 135*cf6a7c19SMahesh Ravishankar // dimensions. 136*cf6a7c19SMahesh Ravishankar SmallVector<Value, 4> tileSizeVector = 137*cf6a7c19SMahesh Ravishankar options.tileSizeComputationFunction(rewriter, op); 138*cf6a7c19SMahesh Ravishankar if (tileSizeVector.size() < iterationDomain.size()) { 139*cf6a7c19SMahesh Ravishankar auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 140*cf6a7c19SMahesh Ravishankar tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 141*cf6a7c19SMahesh Ravishankar } 142*cf6a7c19SMahesh Ravishankar 143*cf6a7c19SMahesh Ravishankar scf::SCFTilingResult tilingResult; 144*cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> offsets, sizes; 145*cf6a7c19SMahesh Ravishankar { 146*cf6a7c19SMahesh Ravishankar // 3. Materialize an empty loop nest that iterates over the tiles. These 147*cf6a7c19SMahesh Ravishankar // loops for now do not return any values even if the original operation has 148*cf6a7c19SMahesh Ravishankar // results. 149*cf6a7c19SMahesh Ravishankar tilingResult.loops = generateTileLoopNest( 150*cf6a7c19SMahesh Ravishankar rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 151*cf6a7c19SMahesh Ravishankar 152*cf6a7c19SMahesh Ravishankar LLVM_DEBUG({ 153*cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) { 154*cf6a7c19SMahesh Ravishankar llvm::errs() << "LoopNest shell :\n"; 155*cf6a7c19SMahesh Ravishankar tilingResult.loops.front().dump(); 156*cf6a7c19SMahesh Ravishankar llvm::errs() << "\n"; 157*cf6a7c19SMahesh Ravishankar } 158*cf6a7c19SMahesh Ravishankar }); 159*cf6a7c19SMahesh Ravishankar 160*cf6a7c19SMahesh Ravishankar // 4. Generate the tiled implementation within the inner most loop. 161*cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) 162*cf6a7c19SMahesh Ravishankar rewriter.setInsertionPoint( 163*cf6a7c19SMahesh Ravishankar tilingResult.loops.back().getBody()->getTerminator()); 164*cf6a7c19SMahesh Ravishankar SmallVector<Operation *> tiledImplementation = op.getTiledImplementation( 165*cf6a7c19SMahesh Ravishankar rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); 166*cf6a7c19SMahesh Ravishankar if (tiledImplementation.size() != 1) { 167*cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure( 168*cf6a7c19SMahesh Ravishankar op, "expected tiled implementation to return a single op"); 169*cf6a7c19SMahesh Ravishankar } 170*cf6a7c19SMahesh Ravishankar tilingResult.tiledOp = tiledImplementation[0]; 171*cf6a7c19SMahesh Ravishankar 172*cf6a7c19SMahesh Ravishankar LLVM_DEBUG({ 173*cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) { 174*cf6a7c19SMahesh Ravishankar llvm::errs() << "After tiled implementation :\n"; 175*cf6a7c19SMahesh Ravishankar tilingResult.loops.front().dump(); 176*cf6a7c19SMahesh Ravishankar llvm::errs() << "\n"; 177*cf6a7c19SMahesh Ravishankar } 178*cf6a7c19SMahesh Ravishankar }); 179*cf6a7c19SMahesh Ravishankar } 180*cf6a7c19SMahesh Ravishankar 181*cf6a7c19SMahesh Ravishankar if (op->getNumResults() == 0) { 182*cf6a7c19SMahesh Ravishankar rewriter.eraseOp(op); 183*cf6a7c19SMahesh Ravishankar return tilingResult; 184*cf6a7c19SMahesh Ravishankar } 185*cf6a7c19SMahesh Ravishankar 186*cf6a7c19SMahesh Ravishankar // 5. If the original operations has results, modify the loop nest to yield 187*cf6a7c19SMahesh Ravishankar // the replacement values. 188*cf6a7c19SMahesh Ravishankar SmallVector<Value> replacements; 189*cf6a7c19SMahesh Ravishankar if (tilingResult.loops.empty()) { 190*cf6a7c19SMahesh Ravishankar // 5a. If there were no loops, the tiled implementation results are the 191*cf6a7c19SMahesh Ravishankar // replacements. 192*cf6a7c19SMahesh Ravishankar rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); 193*cf6a7c19SMahesh Ravishankar return tilingResult; 194*cf6a7c19SMahesh Ravishankar } 195*cf6a7c19SMahesh Ravishankar 196*cf6a7c19SMahesh Ravishankar // 5b. `scf.for` with tensor semantics requires the loop nest to yield the 197*cf6a7c19SMahesh Ravishankar // replacement values using destructive updates. Use the `TilingInterface` 198*cf6a7c19SMahesh Ravishankar // to get the position of the result tiles and use that to generate the 199*cf6a7c19SMahesh Ravishankar // destructive update pattern, i.e., 200*cf6a7c19SMahesh Ravishankar // 201*cf6a7c19SMahesh Ravishankar // ```mlir 202*cf6a7c19SMahesh Ravishankar // scf.for %iv0 = ... { 203*cf6a7c19SMahesh Ravishankar // %0 = tiled_op 204*cf6a7c19SMahesh Ravishankar // } 205*cf6a7c19SMahesh Ravishankar // ``` 206*cf6a7c19SMahesh Ravishankar // 207*cf6a7c19SMahesh Ravishankar // is transformed to 208*cf6a7c19SMahesh Ravishankar // 209*cf6a7c19SMahesh Ravishankar // ```mlir 210*cf6a7c19SMahesh Ravishankar // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { 211*cf6a7c19SMahesh Ravishankar // %0 = tiled_op 212*cf6a7c19SMahesh Ravishankar // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] 213*cf6a7c19SMahesh Ravishankar // scf.yield %1 214*cf6a7c19SMahesh Ravishankar // } 215*cf6a7c19SMahesh Ravishankar // ``` 216*cf6a7c19SMahesh Ravishankar NewYieldValueFn yieldValueFn = 217*cf6a7c19SMahesh Ravishankar [&](OpBuilder &b, Location loc, 218*cf6a7c19SMahesh Ravishankar ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 219*cf6a7c19SMahesh Ravishankar SmallVector<Value> yieldedValues; 220*cf6a7c19SMahesh Ravishankar Attribute one = b.getIndexAttr(1); 221*cf6a7c19SMahesh Ravishankar for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) { 222*cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes; 223*cf6a7c19SMahesh Ravishankar if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, 224*cf6a7c19SMahesh Ravishankar resultTileOffsets, 225*cf6a7c19SMahesh Ravishankar resultTileSizes))) { 226*cf6a7c19SMahesh Ravishankar op.emitOpError("unable to get position of result ") 227*cf6a7c19SMahesh Ravishankar << resultNum << " of the tiled implementation"; 228*cf6a7c19SMahesh Ravishankar return {}; 229*cf6a7c19SMahesh Ravishankar } 230*cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(), 231*cf6a7c19SMahesh Ravishankar one); 232*cf6a7c19SMahesh Ravishankar Value yieldedValue = b.create<tensor::InsertSliceOp>( 233*cf6a7c19SMahesh Ravishankar op->getLoc(), tilingResult.tiledOp->getResult(resultNum), 234*cf6a7c19SMahesh Ravishankar newBBArgs[resultNum], resultTileOffsets, resultTileSizes, 235*cf6a7c19SMahesh Ravishankar resultTileStrides); 236*cf6a7c19SMahesh Ravishankar yieldedValues.push_back(yieldedValue); 237*cf6a7c19SMahesh Ravishankar } 238*cf6a7c19SMahesh Ravishankar return yieldedValues; 239*cf6a7c19SMahesh Ravishankar }; 240*cf6a7c19SMahesh Ravishankar SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields( 241*cf6a7c19SMahesh Ravishankar rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), 242*cf6a7c19SMahesh Ravishankar yieldValueFn); 243*cf6a7c19SMahesh Ravishankar for (auto loop : llvm::enumerate(tilingResult.loops)) { 244*cf6a7c19SMahesh Ravishankar rewriter.eraseOp(loop.value()); 245*cf6a7c19SMahesh Ravishankar tilingResult.loops[loop.index()] = newLoops[loop.index()]; 246*cf6a7c19SMahesh Ravishankar } 247*cf6a7c19SMahesh Ravishankar rewriter.replaceOp(op, tilingResult.loops.front().getResults()); 248*cf6a7c19SMahesh Ravishankar return tilingResult; 249*cf6a7c19SMahesh Ravishankar } 250