1cf6a7c19SMahesh Ravishankar //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2cf6a7c19SMahesh Ravishankar // 3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information. 5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6cf6a7c19SMahesh Ravishankar // 7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 8cf6a7c19SMahesh Ravishankar // 9cf6a7c19SMahesh Ravishankar // This file implements the tiling using TilingInterface. 10cf6a7c19SMahesh Ravishankar // 11cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 12cf6a7c19SMahesh Ravishankar 138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14cf6a7c19SMahesh Ravishankar 15cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 16cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h" 18cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/Utils/Utils.h" 19cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h" 20cf6a7c19SMahesh Ravishankar #include "mlir/IR/Matchers.h" 21cf6a7c19SMahesh Ravishankar #include "mlir/IR/PatternMatch.h" 22cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h" 23cf6a7c19SMahesh Ravishankar #include "llvm/Support/Debug.h" 24cf6a7c19SMahesh Ravishankar 25cf6a7c19SMahesh Ravishankar #define DEBUG_TYPE "tile-using-interface" 26cf6a7c19SMahesh Ravishankar 27cf6a7c19SMahesh Ravishankar using namespace mlir; 28cf6a7c19SMahesh Ravishankar 29cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions & 30cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 31cf6a7c19SMahesh Ravishankar assert(!tileSizeComputationFunction && "tile sizes already set"); 32*b8a1f00dSMahesh Ravishankar SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 33cf6a7c19SMahesh Ravishankar tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 34cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(b); 35cf6a7c19SMahesh Ravishankar b.setInsertionPointToStart( 36cf6a7c19SMahesh Ravishankar &op->getParentOfType<func::FuncOp>().getBody().front()); 37cf6a7c19SMahesh Ravishankar return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 38cf6a7c19SMahesh Ravishankar Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 39cf6a7c19SMahesh Ravishankar return v; 40cf6a7c19SMahesh Ravishankar })); 41cf6a7c19SMahesh Ravishankar }; 42cf6a7c19SMahesh Ravishankar return *this; 43cf6a7c19SMahesh Ravishankar } 44cf6a7c19SMahesh Ravishankar 45*b8a1f00dSMahesh Ravishankar /// Helper method to adjust the interchange vector to match the iteration 46*b8a1f00dSMahesh Ravishankar /// domain. 47*b8a1f00dSMahesh Ravishankar static SmallVector<unsigned> 48*b8a1f00dSMahesh Ravishankar fillInterchangeVector(ArrayRef<unsigned> interchangeVector, 49*b8a1f00dSMahesh Ravishankar size_t iterationDomainSize) { 50*b8a1f00dSMahesh Ravishankar SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector); 51*b8a1f00dSMahesh Ravishankar if (filledVector.size() < iterationDomainSize) { 52*b8a1f00dSMahesh Ravishankar auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize); 53*b8a1f00dSMahesh Ravishankar filledVector.append(range.begin(), range.end()); 54*b8a1f00dSMahesh Ravishankar } 55*b8a1f00dSMahesh Ravishankar if (filledVector.size() > iterationDomainSize) 56*b8a1f00dSMahesh Ravishankar filledVector.resize(iterationDomainSize); 57*b8a1f00dSMahesh Ravishankar return filledVector; 58*b8a1f00dSMahesh Ravishankar } 59*b8a1f00dSMahesh Ravishankar 60*b8a1f00dSMahesh Ravishankar /// Helper method to apply permutation to a vector 61*b8a1f00dSMahesh Ravishankar template <typename T> 62*b8a1f00dSMahesh Ravishankar static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector, 63*b8a1f00dSMahesh Ravishankar ArrayRef<unsigned> interchange) { 64*b8a1f00dSMahesh Ravishankar assert(interchange.size() == vector.size()); 65*b8a1f00dSMahesh Ravishankar return llvm::to_vector( 66*b8a1f00dSMahesh Ravishankar llvm::map_range(interchange, [&](unsigned val) { return vector[val]; })); 67*b8a1f00dSMahesh Ravishankar } 68*b8a1f00dSMahesh Ravishankar /// Helper method to apply to invert a permutation. 69*b8a1f00dSMahesh Ravishankar static SmallVector<unsigned> 70*b8a1f00dSMahesh Ravishankar invertPermutationVector(ArrayRef<unsigned> interchange) { 71*b8a1f00dSMahesh Ravishankar SmallVector<unsigned> inversion(interchange.size()); 72*b8a1f00dSMahesh Ravishankar for (auto pos : llvm::enumerate(interchange)) { 73*b8a1f00dSMahesh Ravishankar inversion[pos.value()] = pos.index(); 74*b8a1f00dSMahesh Ravishankar } 75*b8a1f00dSMahesh Ravishankar return inversion; 76*b8a1f00dSMahesh Ravishankar } 77*b8a1f00dSMahesh Ravishankar /// Method to check if an interchange vector is a permutation. 78*b8a1f00dSMahesh Ravishankar static bool isPermutation(ArrayRef<unsigned> interchange) { 79*b8a1f00dSMahesh Ravishankar llvm::SmallDenseSet<unsigned, 4> seenVals; 80*b8a1f00dSMahesh Ravishankar for (auto val : interchange) { 81*b8a1f00dSMahesh Ravishankar if (seenVals.count(val)) 82*b8a1f00dSMahesh Ravishankar return false; 83*b8a1f00dSMahesh Ravishankar seenVals.insert(val); 84*b8a1f00dSMahesh Ravishankar } 85*b8a1f00dSMahesh Ravishankar return seenVals.size() == interchange.size(); 86*b8a1f00dSMahesh Ravishankar } 87*b8a1f00dSMahesh Ravishankar 882f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 892f637fe7SMahesh Ravishankar // TileUsingSCFForOp pattern implementation. 902f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 912f637fe7SMahesh Ravishankar 92cf6a7c19SMahesh Ravishankar /// Generate an empty loop nest that represents the tiled loop nest shell. 93cf6a7c19SMahesh Ravishankar /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 94cf6a7c19SMahesh Ravishankar /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 95cf6a7c19SMahesh Ravishankar /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 96cf6a7c19SMahesh Ravishankar /// the 97cf6a7c19SMahesh Ravishankar /// tile processed within the inner most loop. 98cf6a7c19SMahesh Ravishankar static SmallVector<scf::ForOp> 99cf6a7c19SMahesh Ravishankar generateTileLoopNest(OpBuilder &builder, Location loc, 100cf6a7c19SMahesh Ravishankar ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 101cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &offsets, 102cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &sizes) { 103cf6a7c19SMahesh Ravishankar assert(!loopRanges.empty() && "expected at least one loop range"); 104cf6a7c19SMahesh Ravishankar assert(loopRanges.size() == tileSizeVals.size() && 105cf6a7c19SMahesh Ravishankar "expected as many tile sizes as loop ranges"); 106cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(builder); 107cf6a7c19SMahesh Ravishankar SmallVector<scf::ForOp> loops; 108cf6a7c19SMahesh Ravishankar offsets.resize(loopRanges.size()); 109cf6a7c19SMahesh Ravishankar sizes.resize(loopRanges.size()); 110cf6a7c19SMahesh Ravishankar 111cf6a7c19SMahesh Ravishankar // The tile size to use (to avoid out of bounds access) is minimum of 112cf6a7c19SMahesh Ravishankar // `tileSize` and `ub - iv`, where `iv` is the induction variable 113cf6a7c19SMahesh Ravishankar // of the tiled loop. 114cf6a7c19SMahesh Ravishankar AffineExpr s0, s1, d0; 115cf6a7c19SMahesh Ravishankar bindDims(builder.getContext(), d0); 116cf6a7c19SMahesh Ravishankar bindSymbols(builder.getContext(), s0, s1); 117cf6a7c19SMahesh Ravishankar AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 118cf6a7c19SMahesh Ravishankar 119cf6a7c19SMahesh Ravishankar for (auto loopRange : llvm::enumerate(loopRanges)) { 120cf6a7c19SMahesh Ravishankar // No loops if tile size is zero. Set offset and size to the loop 121cf6a7c19SMahesh Ravishankar // offset and size. 122cf6a7c19SMahesh Ravishankar if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 123cf6a7c19SMahesh Ravishankar offsets[loopRange.index()] = loopRange.value().offset; 124cf6a7c19SMahesh Ravishankar sizes[loopRange.index()] = loopRange.value().size; 125cf6a7c19SMahesh Ravishankar continue; 126cf6a7c19SMahesh Ravishankar } 127cf6a7c19SMahesh Ravishankar 128cf6a7c19SMahesh Ravishankar auto loop = builder.create<scf::ForOp>( 129cf6a7c19SMahesh Ravishankar loc, loopRange.value().offset, loopRange.value().size, 130cf6a7c19SMahesh Ravishankar tileSizeVals[loopRange.index()], ValueRange{}, 131cf6a7c19SMahesh Ravishankar [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 132cf6a7c19SMahesh Ravishankar ValueRange /*iterArgs*/) { 133cf6a7c19SMahesh Ravishankar Value boundedTileSize = builder.create<AffineMinOp>( 134cf6a7c19SMahesh Ravishankar bodyLoc, minMap, 135cf6a7c19SMahesh Ravishankar ValueRange{iv, tileSizeVals[loopRange.index()], 136cf6a7c19SMahesh Ravishankar loopRange.value().size}); 137cf6a7c19SMahesh Ravishankar sizes[loopRange.index()] = boundedTileSize; 138cf6a7c19SMahesh Ravishankar builder.create<scf::YieldOp>(loc); 139cf6a7c19SMahesh Ravishankar }); 140cf6a7c19SMahesh Ravishankar offsets[loopRange.index()] = loop.getInductionVar(); 141cf6a7c19SMahesh Ravishankar loops.push_back(loop); 142cf6a7c19SMahesh Ravishankar builder.setInsertionPoint(loop.getBody()->getTerminator()); 143cf6a7c19SMahesh Ravishankar } 144cf6a7c19SMahesh Ravishankar return loops; 145cf6a7c19SMahesh Ravishankar } 146cf6a7c19SMahesh Ravishankar 147cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, 148cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 149cf6a7c19SMahesh Ravishankar PatternBenefit benefit) 150cf6a7c19SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 151cf6a7c19SMahesh Ravishankar options(std::move(options)) {} 152cf6a7c19SMahesh Ravishankar 153cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, 154cf6a7c19SMahesh Ravishankar MLIRContext *context, 155cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 156cf6a7c19SMahesh Ravishankar PatternBenefit benefit) 157cf6a7c19SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 158cf6a7c19SMahesh Ravishankar options(std::move(options)) {} 159cf6a7c19SMahesh Ravishankar 160cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult> 161cf6a7c19SMahesh Ravishankar scf::TileUsingSCFForOp::returningMatchAndRewrite( 162cf6a7c19SMahesh Ravishankar TilingInterface op, PatternRewriter &rewriter) const { 163cf6a7c19SMahesh Ravishankar OpBuilder::InsertionGuard guard(rewriter); 164cf6a7c19SMahesh Ravishankar rewriter.setInsertionPointAfter(op); 165cf6a7c19SMahesh Ravishankar 166cf6a7c19SMahesh Ravishankar if (!options.tileSizeComputationFunction) { 167cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure( 168cf6a7c19SMahesh Ravishankar op, "missing tile size computation function"); 169cf6a7c19SMahesh Ravishankar } 170cf6a7c19SMahesh Ravishankar 171cf6a7c19SMahesh Ravishankar // 1. Get the range of the loops that are represented by the operation. 172cf6a7c19SMahesh Ravishankar SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 173cf6a7c19SMahesh Ravishankar size_t numLoops = iterationDomain.size(); 174cf6a7c19SMahesh Ravishankar if (numLoops == 0) { 175cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure( 176cf6a7c19SMahesh Ravishankar op, "unable to tile op with no iteration domain"); 177cf6a7c19SMahesh Ravishankar } 178cf6a7c19SMahesh Ravishankar 179cf6a7c19SMahesh Ravishankar // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 180cf6a7c19SMahesh Ravishankar // skips tiling a particular dimension. This convention is significantly 181cf6a7c19SMahesh Ravishankar // simpler to handle instead of adjusting affine maps to account for missing 182cf6a7c19SMahesh Ravishankar // dimensions. 183*b8a1f00dSMahesh Ravishankar SmallVector<Value> tileSizeVector = 184cf6a7c19SMahesh Ravishankar options.tileSizeComputationFunction(rewriter, op); 185cf6a7c19SMahesh Ravishankar if (tileSizeVector.size() < iterationDomain.size()) { 186cf6a7c19SMahesh Ravishankar auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 187cf6a7c19SMahesh Ravishankar tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 188cf6a7c19SMahesh Ravishankar } 189cf6a7c19SMahesh Ravishankar 190cf6a7c19SMahesh Ravishankar scf::SCFTilingResult tilingResult; 191cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> offsets, sizes; 192cf6a7c19SMahesh Ravishankar { 193*b8a1f00dSMahesh Ravishankar // If there is an interchange specified, permute the iteration domain and 194*b8a1f00dSMahesh Ravishankar // the tile sizes. 195*b8a1f00dSMahesh Ravishankar SmallVector<unsigned> interchangeVector; 196*b8a1f00dSMahesh Ravishankar if (!options.interchangeVector.empty()) { 197*b8a1f00dSMahesh Ravishankar interchangeVector = fillInterchangeVector(options.interchangeVector, 198*b8a1f00dSMahesh Ravishankar iterationDomain.size()); 199*b8a1f00dSMahesh Ravishankar } 200*b8a1f00dSMahesh Ravishankar if (!interchangeVector.empty()) { 201*b8a1f00dSMahesh Ravishankar if (!isPermutation(interchangeVector)) { 202*b8a1f00dSMahesh Ravishankar return rewriter.notifyMatchFailure( 203*b8a1f00dSMahesh Ravishankar op, "invalid intechange vector, not a permutation of the entire " 204*b8a1f00dSMahesh Ravishankar "iteration space"); 205*b8a1f00dSMahesh Ravishankar } 206*b8a1f00dSMahesh Ravishankar 207*b8a1f00dSMahesh Ravishankar iterationDomain = 208*b8a1f00dSMahesh Ravishankar applyPermutationToVector(iterationDomain, interchangeVector); 209*b8a1f00dSMahesh Ravishankar tileSizeVector = 210*b8a1f00dSMahesh Ravishankar applyPermutationToVector(tileSizeVector, interchangeVector); 211*b8a1f00dSMahesh Ravishankar } 212*b8a1f00dSMahesh Ravishankar 213cf6a7c19SMahesh Ravishankar // 3. Materialize an empty loop nest that iterates over the tiles. These 214cf6a7c19SMahesh Ravishankar // loops for now do not return any values even if the original operation has 215cf6a7c19SMahesh Ravishankar // results. 216cf6a7c19SMahesh Ravishankar tilingResult.loops = generateTileLoopNest( 217cf6a7c19SMahesh Ravishankar rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 218cf6a7c19SMahesh Ravishankar 219*b8a1f00dSMahesh Ravishankar if (!interchangeVector.empty()) { 220*b8a1f00dSMahesh Ravishankar auto inversePermutation = invertPermutationVector(interchangeVector); 221*b8a1f00dSMahesh Ravishankar offsets = applyPermutationToVector(offsets, inversePermutation); 222*b8a1f00dSMahesh Ravishankar sizes = applyPermutationToVector(sizes, inversePermutation); 223*b8a1f00dSMahesh Ravishankar } 224*b8a1f00dSMahesh Ravishankar 225cf6a7c19SMahesh Ravishankar LLVM_DEBUG({ 226cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) { 227cf6a7c19SMahesh Ravishankar llvm::errs() << "LoopNest shell :\n"; 228cf6a7c19SMahesh Ravishankar tilingResult.loops.front().dump(); 229cf6a7c19SMahesh Ravishankar llvm::errs() << "\n"; 230cf6a7c19SMahesh Ravishankar } 231cf6a7c19SMahesh Ravishankar }); 232cf6a7c19SMahesh Ravishankar 233cf6a7c19SMahesh Ravishankar // 4. Generate the tiled implementation within the inner most loop. 234cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) 235cf6a7c19SMahesh Ravishankar rewriter.setInsertionPoint( 236cf6a7c19SMahesh Ravishankar tilingResult.loops.back().getBody()->getTerminator()); 237cf6a7c19SMahesh Ravishankar SmallVector<Operation *> tiledImplementation = op.getTiledImplementation( 238cf6a7c19SMahesh Ravishankar rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); 239cf6a7c19SMahesh Ravishankar if (tiledImplementation.size() != 1) { 240cf6a7c19SMahesh Ravishankar return rewriter.notifyMatchFailure( 241cf6a7c19SMahesh Ravishankar op, "expected tiled implementation to return a single op"); 242cf6a7c19SMahesh Ravishankar } 243cf6a7c19SMahesh Ravishankar tilingResult.tiledOp = tiledImplementation[0]; 244cf6a7c19SMahesh Ravishankar 245cf6a7c19SMahesh Ravishankar LLVM_DEBUG({ 246cf6a7c19SMahesh Ravishankar if (!tilingResult.loops.empty()) { 247cf6a7c19SMahesh Ravishankar llvm::errs() << "After tiled implementation :\n"; 248cf6a7c19SMahesh Ravishankar tilingResult.loops.front().dump(); 249cf6a7c19SMahesh Ravishankar llvm::errs() << "\n"; 250cf6a7c19SMahesh Ravishankar } 251cf6a7c19SMahesh Ravishankar }); 252cf6a7c19SMahesh Ravishankar } 253cf6a7c19SMahesh Ravishankar 254cf6a7c19SMahesh Ravishankar if (op->getNumResults() == 0) { 255cf6a7c19SMahesh Ravishankar rewriter.eraseOp(op); 256cf6a7c19SMahesh Ravishankar return tilingResult; 257cf6a7c19SMahesh Ravishankar } 258cf6a7c19SMahesh Ravishankar 259cf6a7c19SMahesh Ravishankar // 5. If the original operations has results, modify the loop nest to yield 260cf6a7c19SMahesh Ravishankar // the replacement values. 261cf6a7c19SMahesh Ravishankar SmallVector<Value> replacements; 262cf6a7c19SMahesh Ravishankar if (tilingResult.loops.empty()) { 263cf6a7c19SMahesh Ravishankar // 5a. If there were no loops, the tiled implementation results are the 264cf6a7c19SMahesh Ravishankar // replacements. 265cf6a7c19SMahesh Ravishankar rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); 266cf6a7c19SMahesh Ravishankar return tilingResult; 267cf6a7c19SMahesh Ravishankar } 268cf6a7c19SMahesh Ravishankar 269cf6a7c19SMahesh Ravishankar // 5b. `scf.for` with tensor semantics requires the loop nest to yield the 270cf6a7c19SMahesh Ravishankar // replacement values using destructive updates. Use the `TilingInterface` 271cf6a7c19SMahesh Ravishankar // to get the position of the result tiles and use that to generate the 272cf6a7c19SMahesh Ravishankar // destructive update pattern, i.e., 273cf6a7c19SMahesh Ravishankar // 274cf6a7c19SMahesh Ravishankar // ```mlir 275cf6a7c19SMahesh Ravishankar // scf.for %iv0 = ... { 276cf6a7c19SMahesh Ravishankar // %0 = tiled_op 277cf6a7c19SMahesh Ravishankar // } 278cf6a7c19SMahesh Ravishankar // ``` 279cf6a7c19SMahesh Ravishankar // 280cf6a7c19SMahesh Ravishankar // is transformed to 281cf6a7c19SMahesh Ravishankar // 282cf6a7c19SMahesh Ravishankar // ```mlir 283cf6a7c19SMahesh Ravishankar // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { 284cf6a7c19SMahesh Ravishankar // %0 = tiled_op 285cf6a7c19SMahesh Ravishankar // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] 286cf6a7c19SMahesh Ravishankar // scf.yield %1 287cf6a7c19SMahesh Ravishankar // } 288cf6a7c19SMahesh Ravishankar // ``` 289cf6a7c19SMahesh Ravishankar NewYieldValueFn yieldValueFn = 290cf6a7c19SMahesh Ravishankar [&](OpBuilder &b, Location loc, 291cf6a7c19SMahesh Ravishankar ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 292cf6a7c19SMahesh Ravishankar SmallVector<Value> yieldedValues; 293cf6a7c19SMahesh Ravishankar Attribute one = b.getIndexAttr(1); 294cf6a7c19SMahesh Ravishankar for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) { 295cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes; 296cf6a7c19SMahesh Ravishankar if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, 297cf6a7c19SMahesh Ravishankar resultTileOffsets, 298cf6a7c19SMahesh Ravishankar resultTileSizes))) { 299cf6a7c19SMahesh Ravishankar op.emitOpError("unable to get position of result ") 300cf6a7c19SMahesh Ravishankar << resultNum << " of the tiled implementation"; 301cf6a7c19SMahesh Ravishankar return {}; 302cf6a7c19SMahesh Ravishankar } 303cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(), 304cf6a7c19SMahesh Ravishankar one); 305cf6a7c19SMahesh Ravishankar Value yieldedValue = b.create<tensor::InsertSliceOp>( 306cf6a7c19SMahesh Ravishankar op->getLoc(), tilingResult.tiledOp->getResult(resultNum), 307cf6a7c19SMahesh Ravishankar newBBArgs[resultNum], resultTileOffsets, resultTileSizes, 308cf6a7c19SMahesh Ravishankar resultTileStrides); 309cf6a7c19SMahesh Ravishankar yieldedValues.push_back(yieldedValue); 310cf6a7c19SMahesh Ravishankar } 311cf6a7c19SMahesh Ravishankar return yieldedValues; 312cf6a7c19SMahesh Ravishankar }; 313cf6a7c19SMahesh Ravishankar SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields( 314cf6a7c19SMahesh Ravishankar rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), 315cf6a7c19SMahesh Ravishankar yieldValueFn); 316ca2933f3SAdrian Kuegel for (const auto &loop : llvm::enumerate(tilingResult.loops)) { 317cf6a7c19SMahesh Ravishankar rewriter.eraseOp(loop.value()); 318cf6a7c19SMahesh Ravishankar tilingResult.loops[loop.index()] = newLoops[loop.index()]; 319cf6a7c19SMahesh Ravishankar } 320cf6a7c19SMahesh Ravishankar rewriter.replaceOp(op, tilingResult.loops.front().getResults()); 321cf6a7c19SMahesh Ravishankar return tilingResult; 322cf6a7c19SMahesh Ravishankar } 3232f637fe7SMahesh Ravishankar 3242f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 3252f637fe7SMahesh Ravishankar // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. 3262f637fe7SMahesh Ravishankar //===----------------------------------------------------------------------===// 3272f637fe7SMahesh Ravishankar 3282f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp:: 3292f637fe7SMahesh Ravishankar TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, 3302f637fe7SMahesh Ravishankar scf::SCFTilingOptions options, 3312f637fe7SMahesh Ravishankar PatternBenefit benefit) 3322f637fe7SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 3332f637fe7SMahesh Ravishankar tilingPattern(context, std::move(options)) {} 3342f637fe7SMahesh Ravishankar 3352f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp:: 3362f637fe7SMahesh Ravishankar TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, 3372f637fe7SMahesh Ravishankar MLIRContext *context, 3382f637fe7SMahesh Ravishankar scf::SCFTilingOptions options, 3392f637fe7SMahesh Ravishankar PatternBenefit benefit) 3402f637fe7SMahesh Ravishankar : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 3412f637fe7SMahesh Ravishankar tilingPattern(context, std::move(options)) {} 3422f637fe7SMahesh Ravishankar 3432f637fe7SMahesh Ravishankar /// Return the `Value` that is defined by an operation that implements 3442f637fe7SMahesh Ravishankar /// the `TilingInterface`. Looks through `iter_args` of scf.for nest 3452f637fe7SMahesh Ravishankar /// if required. 3462f637fe7SMahesh Ravishankar static Optional<OpResult> getFusableProducer(Value v) { 3472f637fe7SMahesh Ravishankar while (auto blockArg = v.dyn_cast<BlockArgument>()) { 3482f637fe7SMahesh Ravishankar auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp()); 3492f637fe7SMahesh Ravishankar if (!loopOp) 3502f637fe7SMahesh Ravishankar return llvm::None; 3512f637fe7SMahesh Ravishankar v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); 3522f637fe7SMahesh Ravishankar } 3532f637fe7SMahesh Ravishankar if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp())) 3542f637fe7SMahesh Ravishankar return llvm::None; 3552f637fe7SMahesh Ravishankar return v.cast<OpResult>(); 3562f637fe7SMahesh Ravishankar } 3572f637fe7SMahesh Ravishankar 3582f637fe7SMahesh Ravishankar FailureOr<scf::SCFTileAndFuseResult> 3592f637fe7SMahesh Ravishankar scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( 3602f637fe7SMahesh Ravishankar TilingInterface op, PatternRewriter &rewriter) const { 3612f637fe7SMahesh Ravishankar // This transformation is only valid for ops that return values (i.e. not 3622f637fe7SMahesh Ravishankar // valid to use with operations that have memref operands). 3632f637fe7SMahesh Ravishankar if (!op->getNumResults()) { 3642f637fe7SMahesh Ravishankar return rewriter.notifyMatchFailure( 3652f637fe7SMahesh Ravishankar op, "invalid pattern for op with no results"); 3662f637fe7SMahesh Ravishankar } 3672f637fe7SMahesh Ravishankar 3682f637fe7SMahesh Ravishankar // 1. First tile the consumer. 3692f637fe7SMahesh Ravishankar SCFTileAndFuseResult tileAndFuseResult; 3702f637fe7SMahesh Ravishankar { 3712f637fe7SMahesh Ravishankar FailureOr<SCFTilingResult> tilingResult = 3722f637fe7SMahesh Ravishankar tilingPattern.returningMatchAndRewrite(op, rewriter); 3732f637fe7SMahesh Ravishankar if (failed(tilingResult)) { 3742f637fe7SMahesh Ravishankar return failure(); 3752f637fe7SMahesh Ravishankar } 3762f637fe7SMahesh Ravishankar tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); 3772f637fe7SMahesh Ravishankar tileAndFuseResult.loops = std::move(tilingResult->loops); 3782f637fe7SMahesh Ravishankar } 3792f637fe7SMahesh Ravishankar 3802f637fe7SMahesh Ravishankar // 2. Typically, the operands of the tiled operation are slices of the 3812f637fe7SMahesh Ravishankar // operands of the untiled operation. These are expressed in IR using 3822f637fe7SMahesh Ravishankar // `tensor.extract_slice` operations with source being the operands of the 3832f637fe7SMahesh Ravishankar // untiled operation. Create a worklist of these `tensor.extract_slice` 3842f637fe7SMahesh Ravishankar // operations. If the producers of the source of the `tensor.extract_slice` 3852f637fe7SMahesh Ravishankar // can be tiled such that the tiled value is generated in-place, that 3862f637fe7SMahesh Ravishankar // effectively tiles + fuses the operations. 3872f637fe7SMahesh Ravishankar auto addCandidateSlices = [](Operation *fusedOp, 3882f637fe7SMahesh Ravishankar std::deque<tensor::ExtractSliceOp> &candidates) { 3892f637fe7SMahesh Ravishankar for (Value operand : fusedOp->getOperands()) 3902f637fe7SMahesh Ravishankar if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 3912f637fe7SMahesh Ravishankar candidates.push_back(sliceOp); 3922f637fe7SMahesh Ravishankar }; 3932f637fe7SMahesh Ravishankar 3942f637fe7SMahesh Ravishankar std::deque<tensor::ExtractSliceOp> candidates; 3952f637fe7SMahesh Ravishankar addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 3962f637fe7SMahesh Ravishankar OpBuilder::InsertionGuard g(rewriter); 3972f637fe7SMahesh Ravishankar while (!candidates.empty()) { 3982f637fe7SMahesh Ravishankar // 2a. Traverse the slices in BFS fashion. 3992f637fe7SMahesh Ravishankar tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 4002f637fe7SMahesh Ravishankar candidates.pop_front(); 4012f637fe7SMahesh Ravishankar 4022f637fe7SMahesh Ravishankar // 2b. Get the producer of the source (potentially walking through 4032f637fe7SMahesh Ravishankar // `iter_args` of nested `scf.for`) 4042f637fe7SMahesh Ravishankar Optional<OpResult> fusableProducer = 40504235d07SJacques Pienaar getFusableProducer(candidateSliceOp.getSource()); 4062f637fe7SMahesh Ravishankar if (!fusableProducer) 4072f637fe7SMahesh Ravishankar continue; 4082f637fe7SMahesh Ravishankar 4092f637fe7SMahesh Ravishankar // 2c. Generate the tiled implementation of the producer of the source 4102f637fe7SMahesh Ravishankar rewriter.setInsertionPoint(candidateSliceOp); 4112f637fe7SMahesh Ravishankar FailureOr<Value> fusedProducerValue = 412c27d8152SKazu Hirata tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 413c27d8152SKazu Hirata fusableProducer.value()); 4142f637fe7SMahesh Ravishankar if (failed(fusedProducerValue)) 4152f637fe7SMahesh Ravishankar continue; 416c27d8152SKazu Hirata rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); 4172f637fe7SMahesh Ravishankar 4182f637fe7SMahesh Ravishankar // 2d. The operands of the fused producer might themselved be slices of 4192f637fe7SMahesh Ravishankar // values produced by operations that implement the `TilingInterface`. 4202f637fe7SMahesh Ravishankar // Add these operations to the worklist. 4212f637fe7SMahesh Ravishankar Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 4222f637fe7SMahesh Ravishankar tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); 4232f637fe7SMahesh Ravishankar addCandidateSlices(fusedProducer, candidates); 4242f637fe7SMahesh Ravishankar 4252f637fe7SMahesh Ravishankar // 2e. If the operation being fused creates a value that is used as `outs` 4262f637fe7SMahesh Ravishankar // in the tiled operation, the result of the unfused operation will be 4272f637fe7SMahesh Ravishankar // used in the `iter_args` of the tiled loop generated. When the 4282f637fe7SMahesh Ravishankar // operation is fused, this use in `iter_args` needs to be modified to 4292f637fe7SMahesh Ravishankar // use the destination of the fused operation. For example, starting 4302f637fe7SMahesh Ravishankar // with 4312f637fe7SMahesh Ravishankar // 4322f637fe7SMahesh Ravishankar // ```mlir 4332f637fe7SMahesh Ravishankar // %0 = linalg.init_tensor ... 4342f637fe7SMahesh Ravishankar // %1 = linalg.fill ... outs(%0:...)... 4352f637fe7SMahesh Ravishankar // %2 = linalg.matmul ... outs(%1:...).... 4362f637fe7SMahesh Ravishankar // ``` 4372f637fe7SMahesh Ravishankar // 4382f637fe7SMahesh Ravishankar // First the `linalg.matmul` gets tiled 4392f637fe7SMahesh Ravishankar // 4402f637fe7SMahesh Ravishankar // ```mlir 4412f637fe7SMahesh Ravishankar // %0 = linalg.init_tensor 4422f637fe7SMahesh Ravishankar // %1 = linalg.fill 4432f637fe7SMahesh Ravishankar // %2 = scf.for .... iter_args(%arg0 = %1)... 4442f637fe7SMahesh Ravishankar // ... 4452f637fe7SMahesh Ravishankar // ... = linalg.matmul ... 4462f637fe7SMahesh Ravishankar // 4472f637fe7SMahesh Ravishankar // ``` 4482f637fe7SMahesh Ravishankar // 4492f637fe7SMahesh Ravishankar // When the `linalg.fill` gets fused, the `iter_args` needs to be 4502f637fe7SMahesh Ravishankar // modified 4512f637fe7SMahesh Ravishankar // 4522f637fe7SMahesh Ravishankar // ```mlir 4532f637fe7SMahesh Ravishankar // %0 = linalg.init_tensor 4542f637fe7SMahesh Ravishankar // %1 = scf.for ... iter_args(%arg0 = %0)... 4552f637fe7SMahesh Ravishankar // ... 4562f637fe7SMahesh Ravishankar // %2 = linalg.fill ... 4572f637fe7SMahesh Ravishankar // %3 = linalg.matmul ... outs(%2: ...)... 4582f637fe7SMahesh Ravishankar // ``` 4592f637fe7SMahesh Ravishankar TilingInterface unfusedProducerOp = 4602f637fe7SMahesh Ravishankar cast<TilingInterface>(fusableProducer->getOwner()); 4612f637fe7SMahesh Ravishankar scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); 4622f637fe7SMahesh Ravishankar SmallVector<Value> unfusedProducerOpDestValues = 4632f637fe7SMahesh Ravishankar unfusedProducerOp.getDestinationOperands(rewriter); 4642f637fe7SMahesh Ravishankar for (OpOperand &uses : unfusedProducerOp->getUses()) { 4652f637fe7SMahesh Ravishankar if (uses.getOwner() == outerMostTiledLoop.getOperation()) { 4662f637fe7SMahesh Ravishankar unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber(); 4672f637fe7SMahesh Ravishankar unsigned operandNumber = uses.getOperandNumber(); 4682f637fe7SMahesh Ravishankar outerMostTiledLoop->setOperand( 4692f637fe7SMahesh Ravishankar operandNumber, unfusedProducerOpDestValues[resultNumber]); 4702f637fe7SMahesh Ravishankar } 4712f637fe7SMahesh Ravishankar } 4722f637fe7SMahesh Ravishankar } 4732f637fe7SMahesh Ravishankar return tileAndFuseResult; 4742f637fe7SMahesh Ravishankar } 475