1cf6a7c19SMahesh Ravishankar //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2cf6a7c19SMahesh Ravishankar // 3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information. 5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6cf6a7c19SMahesh Ravishankar // 7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 8cf6a7c19SMahesh Ravishankar // 9cf6a7c19SMahesh Ravishankar // This file implements the tiling using TilingInterface. 10cf6a7c19SMahesh Ravishankar // 11cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 12cf6a7c19SMahesh Ravishankar 138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14cf6a7c19SMahesh Ravishankar 15cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 16cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h" 18cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/Utils/Utils.h" 19cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h" 20cf6a7c19SMahesh Ravishankar #include "mlir/IR/Matchers.h" 21cf6a7c19SMahesh Ravishankar #include "mlir/IR/PatternMatch.h" 22cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h" 23cf6a7c19SMahesh Ravishankar #include "llvm/Support/Debug.h" 24cf6a7c19SMahesh Ravishankar 25cf6a7c19SMahesh Ravishankar #define DEBUG_TYPE "tile-using-interface" 26cf6a7c19SMahesh Ravishankar 27cf6a7c19SMahesh Ravishankar using namespace mlir; 28cf6a7c19SMahesh Ravishankar 29cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions & 30cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 31cf6a7c19SMahesh Ravishankar assert(!tileSizeComputationFunction && "tile sizes already set"); 32cf6a7c19SMahesh Ravishankar SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 33cf6a7c19SMahesh Ravishankar tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 34cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(b); 35cf6a7c19SMahesh Ravishankar b.setInsertionPointToStart( 36cf6a7c19SMahesh Ravishankar &op->getParentOfType<func::FuncOp>().getBody().front()); 37cf6a7c19SMahesh Ravishankar return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 38cf6a7c19SMahesh Ravishankar Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 39cf6a7c19SMahesh Ravishankar return v; 40cf6a7c19SMahesh Ravishankar })); 41cf6a7c19SMahesh Ravishankar }; 42cf6a7c19SMahesh Ravishankar return *this; 43cf6a7c19SMahesh Ravishankar } 44cf6a7c19SMahesh Ravishankar 452f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 462f637fe7SMahesh Ravishankar // TileUsingSCFForOp pattern implementation. 472f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 482f637fe7SMahesh Ravishankar 49cf6a7c19SMahesh Ravishankar /// Generate an empty loop nest that represents the tiled loop nest shell. 50cf6a7c19SMahesh Ravishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 51cf6a7c19SMahesh Ravishankar /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 52cf6a7c19SMahesh Ravishankar /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 53cf6a7c19SMahesh Ravishankar /// the 54cf6a7c19SMahesh Ravishankar /// tile processed within the inner most loop. 55cf6a7c19SMahesh Ravishankar static SmallVector<scf::ForOp> 56cf6a7c19SMahesh Ravishankar generateTileLoopNest(OpBuilder &builder, Location loc, 57cf6a7c19SMahesh Ravishankar ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 58cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &offsets, 59cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &sizes) { 60cf6a7c19SMahesh Ravishankar assert(!loopRanges.empty() && "expected at least one loop range"); 61cf6a7c19SMahesh Ravishankar assert(loopRanges.size() == tileSizeVals.size() && 62cf6a7c19SMahesh Ravishankar "expected as many tile sizes as loop ranges"); 63cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(builder); 64cf6a7c19SMahesh Ravishankar SmallVector<scf::ForOp> loops; 65cf6a7c19SMahesh Ravishankar offsets.resize(loopRanges.size()); 66cf6a7c19SMahesh Ravishankar sizes.resize(loopRanges.size()); 67cf6a7c19SMahesh Ravishankar 68cf6a7c19SMahesh Ravishankar // The tile size to use (to avoid out of bounds access) is minimum of 69cf6a7c19SMahesh Ravishankar // `tileSize` and `ub - iv`, where `iv` is the induction variable 70cf6a7c19SMahesh Ravishankar // of the tiled loop. 71cf6a7c19SMahesh Ravishankar AffineExpr s0, s1, d0; 72cf6a7c19SMahesh Ravishankar bindDims(builder.getContext(), d0); 73cf6a7c19SMahesh Ravishankar bindSymbols(builder.getContext(), s0, s1); 74cf6a7c19SMahesh Ravishankar AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 75cf6a7c19SMahesh Ravishankar 76cf6a7c19SMahesh Ravishankar for (auto loopRange : llvm::enumerate(loopRanges)) { 77cf6a7c19SMahesh Ravishankar // No loops if tile size is zero. Set offset and size to the loop 78cf6a7c19SMahesh Ravishankar // offset and size. 79cf6a7c19SMahesh Ravishankar if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 80cf6a7c19SMahesh Ravishankar offsets[loopRange.index()] = loopRange.value().offset; 81cf6a7c19SMahesh Ravishankar sizes[loopRange.index()] = loopRange.value().size; 82cf6a7c19SMahesh Ravishankar continue; 83cf6a7c19SMahesh Ravishankar } 84cf6a7c19SMahesh Ravishankar 85cf6a7c19SMahesh Ravishankar auto loop = builder.create<scf::ForOp>( 86cf6a7c19SMahesh Ravishankar loc, loopRange.value().offset, loopRange.value().size, 87cf6a7c19SMahesh Ravishankar tileSizeVals[loopRange.index()], ValueRange{}, 88cf6a7c19SMahesh Ravishankar [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 89cf6a7c19SMahesh Ravishankar ValueRange /*iterArgs*/) { 90cf6a7c19SMahesh Ravishankar Value boundedTileSize = builder.create<AffineMinOp>( 91cf6a7c19SMahesh Ravishankar bodyLoc, minMap, 92cf6a7c19SMahesh Ravishankar ValueRange{iv, tileSizeVals[loopRange.index()], 93cf6a7c19SMahesh Ravishankar loopRange.value().size}); 94cf6a7c19SMahesh Ravishankar sizes[loopRange.index()] = boundedTileSize; 95cf6a7c19SMahesh Ravishankar builder.create<scf::YieldOp>(loc); 96cf6a7c19SMahesh Ravishankar }); 97cf6a7c19SMahesh Ravishankar offsets[loopRange.index()] = loop.getInductionVar(); 98cf6a7c19SMahesh Ravishankar loops.push_back(loop); 99cf6a7c19SMahesh Ravishankar builder.setInsertionPoint(loop.getBody()->getTerminator()); 100cf6a7c19SMahesh Ravishankar } 101cf6a7c19SMahesh Ravishankar return loops; 102cf6a7c19SMahesh Ravishankar } 103cf6a7c19SMahesh Ravishankar 104cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, 105cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 106cf6a7c19SMahesh Ravishankar PatternBenefit benefit) 107cf6a7c19SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 108cf6a7c19SMahesh Ravishankar options(std::move(options)) {} 109cf6a7c19SMahesh Ravishankar 110cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, 111cf6a7c19SMahesh Ravishankar MLIRContext *context, 112cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 113cf6a7c19SMahesh Ravishankar PatternBenefit benefit) 114cf6a7c19SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 115cf6a7c19SMahesh Ravishankar options(std::move(options)) {} 116cf6a7c19SMahesh Ravishankar 117cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult> 118cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::returningMatchAndRewrite( 119cf6a7c19SMahesh Ravishankar TilingInterface op, PatternRewriter &rewriter) const { 120cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(rewriter); 121cf6a7c19SMahesh Ravishankar rewriter.setInsertionPointAfter(op); 122cf6a7c19SMahesh Ravishankar 123cf6a7c19SMahesh Ravishankar if (!options.tileSizeComputationFunction) { 124cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure( 125cf6a7c19SMahesh Ravishankar op, "missing tile size computation function"); 126cf6a7c19SMahesh Ravishankar } 127cf6a7c19SMahesh Ravishankar 128cf6a7c19SMahesh Ravishankar // 1. Get the range of the loops that are represented by the operation. 129cf6a7c19SMahesh Ravishankar SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 130cf6a7c19SMahesh Ravishankar size_t numLoops = iterationDomain.size(); 131cf6a7c19SMahesh Ravishankar if (numLoops == 0) { 132cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure( 133cf6a7c19SMahesh Ravishankar op, "unable to tile op with no iteration domain"); 134cf6a7c19SMahesh Ravishankar } 135cf6a7c19SMahesh Ravishankar 136cf6a7c19SMahesh Ravishankar // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 137cf6a7c19SMahesh Ravishankar // skips tiling a particular dimension. This convention is significantly 138cf6a7c19SMahesh Ravishankar // simpler to handle instead of adjusting affine maps to account for missing 139cf6a7c19SMahesh Ravishankar // dimensions. 140cf6a7c19SMahesh Ravishankar SmallVector<Value, 4> tileSizeVector = 141cf6a7c19SMahesh Ravishankar options.tileSizeComputationFunction(rewriter, op); 142cf6a7c19SMahesh Ravishankar if (tileSizeVector.size() < iterationDomain.size()) { 143cf6a7c19SMahesh Ravishankar auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 144cf6a7c19SMahesh Ravishankar tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 145cf6a7c19SMahesh Ravishankar } 146cf6a7c19SMahesh Ravishankar 147cf6a7c19SMahesh Ravishankar scf::SCFTilingResult tilingResult; 148cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> offsets, sizes; 149cf6a7c19SMahesh Ravishankar { 150cf6a7c19SMahesh Ravishankar // 3. Materialize an empty loop nest that iterates over the tiles. These 151cf6a7c19SMahesh Ravishankar // loops for now do not return any values even if the original operation has 152cf6a7c19SMahesh Ravishankar // results. 153cf6a7c19SMahesh Ravishankar tilingResult.loops = generateTileLoopNest( 154cf6a7c19SMahesh Ravishankar rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 155cf6a7c19SMahesh Ravishankar 156cf6a7c19SMahesh Ravishankar LLVM_DEBUG({ 157cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) { 158cf6a7c19SMahesh Ravishankar llvm::errs() << "LoopNest shell :\n"; 159cf6a7c19SMahesh Ravishankar tilingResult.loops.front().dump(); 160cf6a7c19SMahesh Ravishankar llvm::errs() << "\n"; 161cf6a7c19SMahesh Ravishankar } 162cf6a7c19SMahesh Ravishankar }); 163cf6a7c19SMahesh Ravishankar 164cf6a7c19SMahesh Ravishankar // 4. Generate the tiled implementation within the inner most loop. 165cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) 166cf6a7c19SMahesh Ravishankar rewriter.setInsertionPoint( 167cf6a7c19SMahesh Ravishankar tilingResult.loops.back().getBody()->getTerminator()); 168cf6a7c19SMahesh Ravishankar SmallVector<Operation *> tiledImplementation = op.getTiledImplementation( 169cf6a7c19SMahesh Ravishankar rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); 170cf6a7c19SMahesh Ravishankar if (tiledImplementation.size() != 1) { 171cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure( 172cf6a7c19SMahesh Ravishankar op, "expected tiled implementation to return a single op"); 173cf6a7c19SMahesh Ravishankar } 174cf6a7c19SMahesh Ravishankar tilingResult.tiledOp = tiledImplementation[0]; 175cf6a7c19SMahesh Ravishankar 176cf6a7c19SMahesh Ravishankar LLVM_DEBUG({ 177cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) { 178cf6a7c19SMahesh Ravishankar llvm::errs() << "After tiled implementation :\n"; 179cf6a7c19SMahesh Ravishankar tilingResult.loops.front().dump(); 180cf6a7c19SMahesh Ravishankar llvm::errs() << "\n"; 181cf6a7c19SMahesh Ravishankar } 182cf6a7c19SMahesh Ravishankar }); 183cf6a7c19SMahesh Ravishankar } 184cf6a7c19SMahesh Ravishankar 185cf6a7c19SMahesh Ravishankar if (op->getNumResults() == 0) { 186cf6a7c19SMahesh Ravishankar rewriter.eraseOp(op); 187cf6a7c19SMahesh Ravishankar return tilingResult; 188cf6a7c19SMahesh Ravishankar } 189cf6a7c19SMahesh Ravishankar 190cf6a7c19SMahesh Ravishankar // 5. If the original operations has results, modify the loop nest to yield 191cf6a7c19SMahesh Ravishankar // the replacement values. 192cf6a7c19SMahesh Ravishankar SmallVector<Value> replacements; 193cf6a7c19SMahesh Ravishankar if (tilingResult.loops.empty()) { 194cf6a7c19SMahesh Ravishankar // 5a. If there were no loops, the tiled implementation results are the 195cf6a7c19SMahesh Ravishankar // replacements. 196cf6a7c19SMahesh Ravishankar rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); 197cf6a7c19SMahesh Ravishankar return tilingResult; 198cf6a7c19SMahesh Ravishankar } 199cf6a7c19SMahesh Ravishankar 200cf6a7c19SMahesh Ravishankar // 5b. `scf.for` with tensor semantics requires the loop nest to yield the 201cf6a7c19SMahesh Ravishankar // replacement values using destructive updates. Use the `TilingInterface` 202cf6a7c19SMahesh Ravishankar // to get the position of the result tiles and use that to generate the 203cf6a7c19SMahesh Ravishankar // destructive update pattern, i.e., 204cf6a7c19SMahesh Ravishankar // 205cf6a7c19SMahesh Ravishankar // ```mlir 206cf6a7c19SMahesh Ravishankar // scf.for %iv0 = ... { 207cf6a7c19SMahesh Ravishankar // %0 = tiled_op 208cf6a7c19SMahesh Ravishankar // } 209cf6a7c19SMahesh Ravishankar // ``` 210cf6a7c19SMahesh Ravishankar // 211cf6a7c19SMahesh Ravishankar // is transformed to 212cf6a7c19SMahesh Ravishankar // 213cf6a7c19SMahesh Ravishankar // ```mlir 214cf6a7c19SMahesh Ravishankar // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { 215cf6a7c19SMahesh Ravishankar // %0 = tiled_op 216cf6a7c19SMahesh Ravishankar // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] 217cf6a7c19SMahesh Ravishankar // scf.yield %1 218cf6a7c19SMahesh Ravishankar // } 219cf6a7c19SMahesh Ravishankar // ``` 220cf6a7c19SMahesh Ravishankar NewYieldValueFn yieldValueFn = 221cf6a7c19SMahesh Ravishankar [&](OpBuilder &b, Location loc, 222cf6a7c19SMahesh Ravishankar ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 223cf6a7c19SMahesh Ravishankar SmallVector<Value> yieldedValues; 224cf6a7c19SMahesh Ravishankar Attribute one = b.getIndexAttr(1); 225cf6a7c19SMahesh Ravishankar for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) { 226cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes; 227cf6a7c19SMahesh Ravishankar if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, 228cf6a7c19SMahesh Ravishankar resultTileOffsets, 229cf6a7c19SMahesh Ravishankar resultTileSizes))) { 230cf6a7c19SMahesh Ravishankar op.emitOpError("unable to get position of result ") 231cf6a7c19SMahesh Ravishankar << resultNum << " of the tiled implementation"; 232cf6a7c19SMahesh Ravishankar return {}; 233cf6a7c19SMahesh Ravishankar } 234cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(), 235cf6a7c19SMahesh Ravishankar one); 236cf6a7c19SMahesh Ravishankar Value yieldedValue = b.create<tensor::InsertSliceOp>( 237cf6a7c19SMahesh Ravishankar op->getLoc(), tilingResult.tiledOp->getResult(resultNum), 238cf6a7c19SMahesh Ravishankar newBBArgs[resultNum], resultTileOffsets, resultTileSizes, 239cf6a7c19SMahesh Ravishankar resultTileStrides); 240cf6a7c19SMahesh Ravishankar yieldedValues.push_back(yieldedValue); 241cf6a7c19SMahesh Ravishankar } 242cf6a7c19SMahesh Ravishankar return yieldedValues; 243cf6a7c19SMahesh Ravishankar }; 244cf6a7c19SMahesh Ravishankar SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields( 245cf6a7c19SMahesh Ravishankar rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), 246cf6a7c19SMahesh Ravishankar yieldValueFn); 247*ca2933f3SAdrian Kuegel for (const auto &loop : llvm::enumerate(tilingResult.loops)) { 248cf6a7c19SMahesh Ravishankar rewriter.eraseOp(loop.value()); 249cf6a7c19SMahesh Ravishankar tilingResult.loops[loop.index()] = newLoops[loop.index()]; 250cf6a7c19SMahesh Ravishankar } 251cf6a7c19SMahesh Ravishankar rewriter.replaceOp(op, tilingResult.loops.front().getResults()); 252cf6a7c19SMahesh Ravishankar return tilingResult; 253cf6a7c19SMahesh Ravishankar } 2542f637fe7SMahesh Ravishankar 2552f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 2562f637fe7SMahesh Ravishankar // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. 2572f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 2582f637fe7SMahesh Ravishankar 2592f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp:: 2602f637fe7SMahesh Ravishankar TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, 2612f637fe7SMahesh Ravishankar scf::SCFTilingOptions options, 2622f637fe7SMahesh Ravishankar PatternBenefit benefit) 2632f637fe7SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 2642f637fe7SMahesh Ravishankar tilingPattern(context, std::move(options)) {} 2652f637fe7SMahesh Ravishankar 2662f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp:: 2672f637fe7SMahesh Ravishankar TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, 2682f637fe7SMahesh Ravishankar MLIRContext *context, 2692f637fe7SMahesh Ravishankar scf::SCFTilingOptions options, 2702f637fe7SMahesh Ravishankar PatternBenefit benefit) 2712f637fe7SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 2722f637fe7SMahesh Ravishankar tilingPattern(context, std::move(options)) {} 2732f637fe7SMahesh Ravishankar 2742f637fe7SMahesh Ravishankar /// Return the `Value` that is defined by an operation that implements 2752f637fe7SMahesh Ravishankar /// the `TilingInterface`. Looks through `iter_args` of scf.for nest 2762f637fe7SMahesh Ravishankar /// if required. 2772f637fe7SMahesh Ravishankar static Optional<OpResult> getFusableProducer(Value v) { 2782f637fe7SMahesh Ravishankar while (auto blockArg = v.dyn_cast<BlockArgument>()) { 2792f637fe7SMahesh Ravishankar auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp()); 2802f637fe7SMahesh Ravishankar if (!loopOp) 2812f637fe7SMahesh Ravishankar return llvm::None; 2822f637fe7SMahesh Ravishankar v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); 2832f637fe7SMahesh Ravishankar } 2842f637fe7SMahesh Ravishankar if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp())) 2852f637fe7SMahesh Ravishankar return llvm::None; 2862f637fe7SMahesh Ravishankar return v.cast<OpResult>(); 2872f637fe7SMahesh Ravishankar } 2882f637fe7SMahesh Ravishankar 2892f637fe7SMahesh Ravishankar FailureOr<scf::SCFTileAndFuseResult> 2902f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( 2912f637fe7SMahesh Ravishankar TilingInterface op, PatternRewriter &rewriter) const { 2922f637fe7SMahesh Ravishankar // This transformation is only valid for ops that return values (i.e. not 2932f637fe7SMahesh Ravishankar // valid to use with operations that have memref operands). 2942f637fe7SMahesh Ravishankar if (!op->getNumResults()) { 2952f637fe7SMahesh Ravishankar return rewriter.notifyMatchFailure( 2962f637fe7SMahesh Ravishankar op, "invalid pattern for op with no results"); 2972f637fe7SMahesh Ravishankar } 2982f637fe7SMahesh Ravishankar 2992f637fe7SMahesh Ravishankar // 1. First tile the consumer. 3002f637fe7SMahesh Ravishankar SCFTileAndFuseResult tileAndFuseResult; 3012f637fe7SMahesh Ravishankar { 3022f637fe7SMahesh Ravishankar FailureOr<SCFTilingResult> tilingResult = 3032f637fe7SMahesh Ravishankar tilingPattern.returningMatchAndRewrite(op, rewriter); 3042f637fe7SMahesh Ravishankar if (failed(tilingResult)) { 3052f637fe7SMahesh Ravishankar return failure(); 3062f637fe7SMahesh Ravishankar } 3072f637fe7SMahesh Ravishankar tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); 3082f637fe7SMahesh Ravishankar tileAndFuseResult.loops = std::move(tilingResult->loops); 3092f637fe7SMahesh Ravishankar } 3102f637fe7SMahesh Ravishankar 3112f637fe7SMahesh Ravishankar // 2. Typically, the operands of the tiled operation are slices of the 3122f637fe7SMahesh Ravishankar // operands of the untiled operation. These are expressed in IR using 3132f637fe7SMahesh Ravishankar // `tensor.extract_slice` operations with source being the operands of the 3142f637fe7SMahesh Ravishankar // untiled operation. Create a worklist of these `tensor.extract_slice` 3152f637fe7SMahesh Ravishankar // operations. If the producers of the source of the `tensor.extract_slice` 3162f637fe7SMahesh Ravishankar // can be tiled such that the tiled value is generated in-place, that 3172f637fe7SMahesh Ravishankar // effectively tiles + fuses the operations. 3182f637fe7SMahesh Ravishankar auto addCandidateSlices = [](Operation *fusedOp, 3192f637fe7SMahesh Ravishankar std::deque<tensor::ExtractSliceOp> &candidates) { 3202f637fe7SMahesh Ravishankar for (Value operand : fusedOp->getOperands()) 3212f637fe7SMahesh Ravishankar if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 3222f637fe7SMahesh Ravishankar candidates.push_back(sliceOp); 3232f637fe7SMahesh Ravishankar }; 3242f637fe7SMahesh Ravishankar 3252f637fe7SMahesh Ravishankar std::deque<tensor::ExtractSliceOp> candidates; 3262f637fe7SMahesh Ravishankar addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 3272f637fe7SMahesh Ravishankar OpBuilder::InsertionGuard g(rewriter); 3282f637fe7SMahesh Ravishankar while (!candidates.empty()) { 3292f637fe7SMahesh Ravishankar // 2a. Traverse the slices in BFS fashion. 3302f637fe7SMahesh Ravishankar tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 3312f637fe7SMahesh Ravishankar candidates.pop_front(); 3322f637fe7SMahesh Ravishankar 3332f637fe7SMahesh Ravishankar // 2b. Get the producer of the source (potentially walking through 3342f637fe7SMahesh Ravishankar // `iter_args` of nested `scf.for`) 3352f637fe7SMahesh Ravishankar Optional<OpResult> fusableProducer = 3362f637fe7SMahesh Ravishankar getFusableProducer(candidateSliceOp.source()); 3372f637fe7SMahesh Ravishankar if (!fusableProducer) 3382f637fe7SMahesh Ravishankar continue; 3392f637fe7SMahesh Ravishankar 3402f637fe7SMahesh Ravishankar // 2c. Generate the tiled implementation of the producer of the source 3412f637fe7SMahesh Ravishankar rewriter.setInsertionPoint(candidateSliceOp); 3422f637fe7SMahesh Ravishankar FailureOr<Value> fusedProducerValue = 3433b7c3a65SKazu Hirata tensor::replaceExtractSliceWithTiledProducer( 3443b7c3a65SKazu Hirata rewriter, candidateSliceOp, fusableProducer.getValue()); 3452f637fe7SMahesh Ravishankar if (failed(fusedProducerValue)) 3462f637fe7SMahesh Ravishankar continue; 3473b7c3a65SKazu Hirata rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue()); 3482f637fe7SMahesh Ravishankar 3492f637fe7SMahesh Ravishankar // 2d. The operands of the fused producer might themselved be slices of 3502f637fe7SMahesh Ravishankar // values produced by operations that implement the `TilingInterface`. 3512f637fe7SMahesh Ravishankar // Add these operations to the worklist. 3522f637fe7SMahesh Ravishankar Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 3532f637fe7SMahesh Ravishankar tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); 3542f637fe7SMahesh Ravishankar addCandidateSlices(fusedProducer, candidates); 3552f637fe7SMahesh Ravishankar 3562f637fe7SMahesh Ravishankar // 2e. If the operation being fused creates a value that is used as `outs` 3572f637fe7SMahesh Ravishankar // in the tiled operation, the result of the unfused operation will be 3582f637fe7SMahesh Ravishankar // used in the `iter_args` of the tiled loop generated. When the 3592f637fe7SMahesh Ravishankar // operation is fused, this use in `iter_args` needs to be modified to 3602f637fe7SMahesh Ravishankar // use the destination of the fused operation. For example, starting 3612f637fe7SMahesh Ravishankar // with 3622f637fe7SMahesh Ravishankar // 3632f637fe7SMahesh Ravishankar // ```mlir 3642f637fe7SMahesh Ravishankar // %0 = linalg.init_tensor ... 3652f637fe7SMahesh Ravishankar // %1 = linalg.fill ... outs(%0:...)... 3662f637fe7SMahesh Ravishankar // %2 = linalg.matmul ... outs(%1:...).... 3672f637fe7SMahesh Ravishankar // ``` 3682f637fe7SMahesh Ravishankar // 3692f637fe7SMahesh Ravishankar // First the `linalg.matmul` gets tiled 3702f637fe7SMahesh Ravishankar // 3712f637fe7SMahesh Ravishankar // ```mlir 3722f637fe7SMahesh Ravishankar // %0 = linalg.init_tensor 3732f637fe7SMahesh Ravishankar // %1 = linalg.fill 3742f637fe7SMahesh Ravishankar // %2 = scf.for .... iter_args(%arg0 = %1)... 3752f637fe7SMahesh Ravishankar // ... 3762f637fe7SMahesh Ravishankar // ... = linalg.matmul ... 3772f637fe7SMahesh Ravishankar // 3782f637fe7SMahesh Ravishankar // ``` 3792f637fe7SMahesh Ravishankar // 3802f637fe7SMahesh Ravishankar // When the `linalg.fill` gets fused, the `iter_args` needs to be 3812f637fe7SMahesh Ravishankar // modified 3822f637fe7SMahesh Ravishankar // 3832f637fe7SMahesh Ravishankar // ```mlir 3842f637fe7SMahesh Ravishankar // %0 = linalg.init_tensor 3852f637fe7SMahesh Ravishankar // %1 = scf.for ... iter_args(%arg0 = %0)... 3862f637fe7SMahesh Ravishankar // ... 3872f637fe7SMahesh Ravishankar // %2 = linalg.fill ... 3882f637fe7SMahesh Ravishankar // %3 = linalg.matmul ... outs(%2: ...)... 3892f637fe7SMahesh Ravishankar // ``` 3902f637fe7SMahesh Ravishankar TilingInterface unfusedProducerOp = 3912f637fe7SMahesh Ravishankar cast<TilingInterface>(fusableProducer->getOwner()); 3922f637fe7SMahesh Ravishankar scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); 3932f637fe7SMahesh Ravishankar SmallVector<Value> unfusedProducerOpDestValues = 3942f637fe7SMahesh Ravishankar unfusedProducerOp.getDestinationOperands(rewriter); 3952f637fe7SMahesh Ravishankar for (OpOperand &uses : unfusedProducerOp->getUses()) { 3962f637fe7SMahesh Ravishankar if (uses.getOwner() == outerMostTiledLoop.getOperation()) { 3972f637fe7SMahesh Ravishankar unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber(); 3982f637fe7SMahesh Ravishankar unsigned operandNumber = uses.getOperandNumber(); 3992f637fe7SMahesh Ravishankar outerMostTiledLoop->setOperand( 4002f637fe7SMahesh Ravishankar operandNumber, unfusedProducerOpDestValues[resultNumber]); 4012f637fe7SMahesh Ravishankar } 4022f637fe7SMahesh Ravishankar } 4032f637fe7SMahesh Ravishankar } 4042f637fe7SMahesh Ravishankar return tileAndFuseResult; 4052f637fe7SMahesh Ravishankar } 406