1 //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Interfaces/TilingInterface.h" 18 19 using namespace mlir; 20 using namespace mlir::linalg; 21 22 namespace { 23 24 /// External model implementation of TilingInterface for LinalgOps. An external 25 /// model implementation is used for now till the use of `TilingInterface` is 26 /// on-par with the current Linalg tiling + fusion patterns. Once it is 27 /// maybe possible to move this into the op-definition (though there are 28 /// advantages to leaving it as an external model) 29 template <typename LinalgOpTy> 30 struct LinalgOpTilingInterface 31 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, 32 LinalgOpTy> { 33 /// Return the destination operands. 34 SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const { 35 return llvm::cast<LinalgOp>(op).getOutputOperands(); 36 } 37 38 /// Return the loop iterator type. 39 SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const { 40 LinalgOpTy concreteOp = cast<LinalgOpTy>(op); 41 return llvm::to_vector( 42 llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { 43 return strAttr.cast<StringAttr>().getValue(); 44 })); 45 } 46 47 /// Return the iteration domain range. 48 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 49 OpBuilder::InsertionGuard g(b); 50 b.setInsertionPoint(op); 51 Location loc = op->getLoc(); 52 LinalgOp linalgOp = cast<LinalgOp>(op); 53 auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); 54 AffineMap map = linalgOp.getShapesToLoopsMap(); 55 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 56 Value one = b.create<arith::ConstantIndexOp>(loc, 1); 57 return llvm::to_vector(llvm::map_range( 58 applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) { 59 return Range{zero, v, one}; 60 })); 61 } 62 63 // Instantiate the tiled implementation of the operation. 64 SmallVector<Operation *> 65 getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest, 66 ArrayRef<OpFoldResult> offsets, 67 ArrayRef<OpFoldResult> sizes, 68 bool tileDestOperands) const { 69 // Leave the `sizeBounds` value empty. That is only needed when the `sizes` 70 // specified could lead to out of bounds accesses. 71 Location loc = op->getLoc(); 72 LinalgOp linalgOp = cast<LinalgOp>(op); 73 SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands(); 74 SmallVector<Value, 4> tiledOperands = makeTiledShapes( 75 b, loc, linalgOp, valuesToTile, 76 getValueOrCreateConstantIndexOp(b, loc, offsets), 77 getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); 78 79 SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range( 80 linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { 81 return tiledOperands[opOperand->getOperandNumber()].getType(); 82 })); 83 84 Operation *tiledOp = 85 linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); 86 87 return {tiledOp}; 88 } 89 90 // Return the details of the output tile generated by the tiled 91 // implementation. 92 LogicalResult 93 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 94 ArrayRef<OpFoldResult> offsets, 95 ArrayRef<OpFoldResult> sizes, 96 SmallVector<OpFoldResult> &resultOffsets, 97 SmallVector<OpFoldResult> &resultSizes) const { 98 Location loc = op->getLoc(); 99 LinalgOp linalgOp = cast<LinalgOp>(op); 100 101 AffineExpr d0; 102 bindDims(b.getContext(), d0); 103 104 auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc, 105 AffineExpr expr, 106 ValueRange operands) -> Value { 107 AffineMap map = AffineMap::inferFromExprList({expr}).front(); 108 SmallVector<Value> normalizedOperands(operands.begin(), operands.end()); 109 mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands); 110 canonicalizeMapAndOperands(&map, &normalizedOperands); 111 return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands); 112 }; 113 114 SmallVector<Value> sizeVals = 115 getValueOrCreateConstantIndexOp(b, loc, sizes); 116 SmallVector<Value> subShapeSizes = 117 llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) { 118 return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v); 119 })); 120 OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); 121 Value sliceOpResult = 122 makeTiledShape(b, loc, outOperand->get(), sizeVals, 123 linalgOp.getTiedIndexingMap(outOperand), 124 getValueOrCreateConstantIndexOp(b, loc, offsets), 125 /*ubs*/ {}, subShapeSizes, true); 126 auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>(); 127 if (!sliceOp) 128 return failure(); 129 resultOffsets = sliceOp.getMixedOffsets(); 130 resultSizes = sliceOp.getMixedSizes(); 131 return success(); 132 } 133 134 FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b, 135 unsigned resultNumber, 136 ValueRange dest, 137 ArrayRef<OpFoldResult> offsets, 138 ArrayRef<OpFoldResult> sizes, 139 bool tileDestOperands) const { 140 auto linalgOp = cast<LinalgOp>(op); 141 142 // Check that the indexing map used for the output is a projected 143 // permutation. This could be relaxed with a more general approach that can 144 // map the offsets and sizes from the result to iteration space tiles 145 // (filling in full extent for dimensions not used to access the result). 146 AffineMap indexingMap = 147 linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber)); 148 if (!indexingMap.isProjectedPermutation()) { 149 return op->emitOpError( 150 "unhandled tiled implementation generation when result is not " 151 "accessed using a permuted projection"); 152 } 153 154 auto numLoops = linalgOp.getNumLoops(); 155 auto tilingInterfaceOp = cast<TilingInterface>(op); 156 SmallVector<OpFoldResult> iterationTileOffsets(numLoops), 157 iterationTileSizes(numLoops); 158 if (!indexingMap.isPermutation()) { 159 SmallVector<Range> iterationDomain = 160 tilingInterfaceOp.getIterationDomain(b); 161 for (auto range : llvm::enumerate(iterationDomain)) { 162 iterationTileOffsets[range.index()] = range.value().offset; 163 iterationTileSizes[range.index()] = range.value().size; 164 } 165 } 166 for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) { 167 unsigned dimPosition = 168 resultExpr.value().cast<AffineDimExpr>().getPosition(); 169 iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; 170 iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; 171 } 172 173 SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation( 174 b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands); 175 if (tiledOp.size() != 1) 176 return op->emitOpError("failed to generate tiled implementation"); 177 178 return tiledOp[0]->getResult(resultNumber); 179 } 180 }; 181 182 } // namespace 183 184 template <typename OpType> 185 static void registerOne(MLIRContext *ctx) { 186 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); 187 } 188 189 /// Variadic helper function. 190 template <typename... OpTypes> 191 static void registerAll(MLIRContext *ctx) { 192 // FIXME: In c++17 this can be simplified by using 'fold expressions'. 193 (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...}; 194 } 195 196 #define GET_OP_LIST 197 198 void mlir::linalg::registerTilingInterfaceExternalModels( 199 DialectRegistry ®istry) { 200 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 201 registerOne<linalg::GenericOp>(ctx); 202 registerAll< 203 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 204 >(ctx); 205 }); 206 } 207