1 //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Interfaces/TilingInterface.h" 18 19 using namespace mlir; 20 using namespace mlir::linalg; 21 22 namespace { 23 24 /// External model implementation of TilingInterface for LinalgOps. An external 25 /// model implementation is used for now till the use of `TilingInterface` is 26 /// on-par with the current Linalg tiling + fusion patterns. Once it is 27 /// maybe possible to move this into the op-definition (though there are 28 /// advantages to leaving it as an external model) 29 template <typename LinalgOpTy> 30 struct LinalgOpTilingInterface 31 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, 32 LinalgOpTy> { 33 /// Return the destination operands. 34 SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const { 35 return llvm::cast<LinalgOp>(op).getOutputOperands(); 36 } 37 38 /// Return the loop iterator type. 39 SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const { 40 LinalgOpTy concreteOp = cast<LinalgOpTy>(op); 41 return llvm::to_vector( 42 llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { 43 return strAttr.cast<StringAttr>().getValue(); 44 })); 45 } 46 47 /// Return the iteration domain range. 48 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 49 OpBuilder::InsertionGuard g(b); 50 b.setInsertionPoint(op); 51 Location loc = op->getLoc(); 52 LinalgOp linalgOp = cast<LinalgOp>(op); 53 auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); 54 AffineMap map = linalgOp.getShapesToLoopsMap(); 55 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 56 Value one = b.create<arith::ConstantIndexOp>(loc, 1); 57 return llvm::to_vector(llvm::map_range( 58 applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) { 59 return Range{zero, v, one}; 60 })); 61 } 62 63 // Instantiate the tiled implementation of the operation. 64 SmallVector<Operation *> 65 getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest, 66 ArrayRef<OpFoldResult> offsets, 67 ArrayRef<OpFoldResult> sizes, 68 bool tileDestOperands) const { 69 // Leave the `sizeBounds` value empty. That is only needed when the `sizes` 70 // specified could lead to out of bounds accesses. 71 Location loc = op->getLoc(); 72 LinalgOp linalgOp = cast<LinalgOp>(op); 73 SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands(); 74 SmallVector<Value> offsetValues = 75 getValueOrCreateConstantIndexOp(b, loc, offsets); 76 SmallVector<Value, 4> tiledOperands = makeTiledShapes( 77 b, loc, linalgOp, valuesToTile, offsetValues, 78 getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); 79 80 SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range( 81 linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { 82 return tiledOperands[opOperand->getOperandNumber()].getType(); 83 })); 84 85 Operation *tiledOp = 86 linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); 87 offsetIndices(b, cast<LinalgOp>(tiledOp), offsetValues); 88 89 return {tiledOp}; 90 } 91 92 // Return the details of the output tile generated by the tiled 93 // implementation. 94 LogicalResult 95 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 96 ArrayRef<OpFoldResult> offsets, 97 ArrayRef<OpFoldResult> sizes, 98 SmallVector<OpFoldResult> &resultOffsets, 99 SmallVector<OpFoldResult> &resultSizes) const { 100 Location loc = op->getLoc(); 101 LinalgOp linalgOp = cast<LinalgOp>(op); 102 103 AffineExpr d0; 104 bindDims(b.getContext(), d0); 105 106 auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc, 107 AffineExpr expr, 108 ValueRange operands) -> Value { 109 AffineMap map = AffineMap::inferFromExprList({expr}).front(); 110 SmallVector<Value> normalizedOperands(operands.begin(), operands.end()); 111 mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands); 112 canonicalizeMapAndOperands(&map, &normalizedOperands); 113 return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands); 114 }; 115 116 SmallVector<Value> sizeVals = 117 getValueOrCreateConstantIndexOp(b, loc, sizes); 118 SmallVector<Value> subShapeSizes = 119 llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) { 120 return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v); 121 })); 122 OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); 123 Value sliceOpResult = 124 makeTiledShape(b, loc, outOperand->get(), sizeVals, 125 linalgOp.getTiedIndexingMap(outOperand), 126 getValueOrCreateConstantIndexOp(b, loc, offsets), 127 /*ubs*/ {}, subShapeSizes, true); 128 auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>(); 129 if (!sliceOp) 130 return failure(); 131 resultOffsets = sliceOp.getMixedOffsets(); 132 resultSizes = sliceOp.getMixedSizes(); 133 return success(); 134 } 135 136 FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b, 137 unsigned resultNumber, 138 ValueRange dest, 139 ArrayRef<OpFoldResult> offsets, 140 ArrayRef<OpFoldResult> sizes, 141 bool tileDestOperands) const { 142 auto linalgOp = cast<LinalgOp>(op); 143 144 // Check that the indexing map used for the output is a projected 145 // permutation. This could be relaxed with a more general approach that can 146 // map the offsets and sizes from the result to iteration space tiles 147 // (filling in full extent for dimensions not used to access the result). 148 AffineMap indexingMap = 149 linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber)); 150 if (!indexingMap.isProjectedPermutation()) { 151 return op->emitOpError( 152 "unhandled tiled implementation generation when result is not " 153 "accessed using a permuted projection"); 154 } 155 156 auto numLoops = linalgOp.getNumLoops(); 157 auto tilingInterfaceOp = cast<TilingInterface>(op); 158 SmallVector<OpFoldResult> iterationTileOffsets(numLoops), 159 iterationTileSizes(numLoops); 160 if (!indexingMap.isPermutation()) { 161 SmallVector<Range> iterationDomain = 162 tilingInterfaceOp.getIterationDomain(b); 163 for (const auto &range : llvm::enumerate(iterationDomain)) { 164 iterationTileOffsets[range.index()] = range.value().offset; 165 iterationTileSizes[range.index()] = range.value().size; 166 } 167 } 168 for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) { 169 unsigned dimPosition = 170 resultExpr.value().cast<AffineDimExpr>().getPosition(); 171 iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; 172 iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; 173 } 174 175 SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation( 176 b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands); 177 if (tiledOp.size() != 1) 178 return op->emitOpError("failed to generate tiled implementation"); 179 180 return tiledOp[0]->getResult(resultNumber); 181 } 182 }; 183 184 } // namespace 185 186 template <typename OpType> 187 static void registerOne(MLIRContext *ctx) { 188 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); 189 } 190 191 /// Variadic helper function. 192 template <typename... OpTypes> 193 static void registerAll(MLIRContext *ctx) { 194 // FIXME: In c++17 this can be simplified by using 'fold expressions'. 195 (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...}; 196 } 197 198 #define GET_OP_LIST 199 200 void mlir::linalg::registerTilingInterfaceExternalModels( 201 DialectRegistry ®istry) { 202 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 203 registerOne<linalg::GenericOp>(ctx); 204 registerAll< 205 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 206 >(ctx); 207 }); 208 } 209