1 //===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns that transforms linalg.<op> + 10 // tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce 11 // the computation for the linalg op. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 26 namespace { 27 /// Bubble up extract_slice above Linalg operation. 28 /// 29 /// A sequence of operations 30 /// 31 /// ```mlir 32 /// %0 = linalg.<op> ... arg0, arg1, ... 33 /// %1 = tensor.extract_slice %0 ... 34 /// ``` 35 /// 36 /// can be replaced with 37 /// 38 /// ```mlir 39 /// %0 = tensor.extract_slice %arg0 40 /// %1 = tensor.extract_slice %arg1 41 /// %2 = linalg.<op> ... %0, %1, ... 42 /// ``` 43 /// 44 /// This results in the reduce computation of the linalg operation. 45 /// 46 struct BubbleUpExtractSliceOpPattern 47 : OpRewritePattern<tensor::ExtractSliceOp> { 48 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 49 50 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 51 PatternRewriter &rewriter) const final { 52 Value source = sliceOp.source(); 53 auto linalgOp = source.getDefiningOp<LinalgOp>(); 54 if (!linalgOp) { 55 return rewriter.notifyMatchFailure(sliceOp, 56 "expected source to be linalg op"); 57 } 58 59 // TODO: we might relax this if we want heuristics to detect that all uses 60 // are small portion of the output. 61 if (!linalgOp->hasOneUse()) { 62 return rewriter.notifyMatchFailure(sliceOp, 63 "expected single use of linalg op"); 64 } 65 66 if (linalgOp.getNumOutputs() != 1) { 67 return rewriter.notifyMatchFailure(sliceOp, 68 "expected single output of linalg op"); 69 } 70 71 if (!linalgOp.hasTensorSemantics()) { 72 return rewriter.notifyMatchFailure(sliceOp, 73 "expected tensor of linalg op"); 74 } 75 76 if (!sliceOp.hasUnitStride()) 77 return rewriter.notifyMatchFailure(sliceOp, "expected unit stride"); 78 79 if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) { 80 return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction"); 81 } 82 83 OpOperand *outOperand = linalgOp.getOutputOperand(0); 84 AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand); 85 if (!indexingMap.isProjectedPermutation()) { 86 return rewriter.notifyMatchFailure( 87 sliceOp, "expected a projected permutation for output"); 88 } 89 90 auto linalgLoc = linalgOp.getLoc(); 91 auto allShapeSizes = 92 linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc); 93 AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap(); 94 if (!shapeSizesToLoopsMap) { 95 return rewriter.notifyMatchFailure( 96 linalgOp, "failed to get loops map from shape sizes"); 97 } 98 auto sizeBounds = applyMapToValues(rewriter, linalgLoc, 99 shapeSizesToLoopsMap, allShapeSizes); 100 101 auto sliceLoc = sliceOp.getLoc(); 102 auto offsetVals = getValueOrCreateConstantIndexOp( 103 rewriter, sliceLoc, sliceOp.getMixedOffsets()); 104 auto sizeVals = getValueOrCreateConstantIndexOp(rewriter, sliceLoc, 105 sliceOp.getMixedSizes()); 106 107 // The offsets and sizes from the slice operation only give you the tile 108 // size of the output. Use that compute the tile sizes and offsets of the 109 // loops. For loops not used to access the output, set the tile sizes to 110 // loop bounds and set the offset to 0. 111 Value zero = rewriter.create<arith::ConstantIndexOp>(linalgLoc, 0); 112 SmallVector<Value, 4> tileOffsets(sizeBounds.size(), zero); 113 SmallVector<Value, 4> tileSizes = sizeBounds; 114 for (auto const &result : enumerate(indexingMap.getResults())) { 115 unsigned position = result.value().cast<AffineDimExpr>().getPosition(); 116 tileOffsets[position] = offsetVals[result.index()]; 117 tileSizes[position] = sizeVals[result.index()]; 118 } 119 120 SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands(); 121 122 SmallVector<Value, 4> tiledOperands = makeTiledShapes( 123 rewriter, linalgLoc, linalgOp, valuesToTile, tileOffsets, tileSizes, 124 sizeBounds, /*omitPartialTileCheck=*/true); 125 126 SmallVector<Type, 4> resultTensorTypes; 127 for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) 128 resultTensorTypes.push_back( 129 tiledOperands[opOperand->getOperandNumber()].getType()); 130 131 Operation *newOp = 132 linalgOp.clone(rewriter, linalgLoc, resultTensorTypes, tiledOperands); 133 rewriter.replaceOp(sliceOp, newOp->getResults()); 134 return success(); 135 } 136 }; 137 } // namespace 138 139 void mlir::linalg::populateBubbleUpExtractSliceOpPatterns( 140 RewritePatternSet &patterns) { 141 auto *context = patterns.getContext(); 142 patterns.add<BubbleUpExtractSliceOpPattern>(context); 143 } 144