1 //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Interfaces/TilingInterface.h" 18 19 using namespace mlir; 20 using namespace mlir::linalg; 21 22 namespace { 23 24 /// External model implementation of TilingInterface for LinalgOps. An external 25 /// model implementation is used for now till the use of `TilingInterface` is 26 /// on-par with the current Linalg tiling + fusion patterns. Once it is 27 /// maybe possible to move this into the op-definition (though there are 28 /// advantages to leaving it as an external model) 29 template <typename LinalgOpTy> 30 struct LinalgOpTilingInterface 31 : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, 32 LinalgOpTy> { 33 34 /// Return the destination operands. 35 SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const { 36 return llvm::cast<LinalgOp>(op).getOutputOperands(); 37 } 38 39 /// Return the loop iterator type. 40 SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const { 41 LinalgOpTy concreteOp = cast<LinalgOpTy>(op); 42 return llvm::to_vector( 43 llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { 44 return strAttr.cast<StringAttr>().getValue(); 45 })); 46 } 47 48 /// Return the iteration domain range. 49 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 50 Location loc = op->getLoc(); 51 LinalgOp linalgOp = cast<LinalgOp>(op); 52 auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); 53 AffineMap map = linalgOp.getShapesToLoopsMap(); 54 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 55 Value one = b.create<arith::ConstantIndexOp>(loc, 1); 56 return llvm::to_vector(llvm::map_range( 57 applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) { 58 return Range{zero, v, one}; 59 })); 60 } 61 62 // Instantiate the tiled implementation of the operation. 63 SmallVector<Operation *> 64 getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest, 65 ArrayRef<OpFoldResult> offsets, 66 ArrayRef<OpFoldResult> sizes, 67 bool tileDestOperands) const { 68 // Leave the `sizeBounds` value empty. That is only needed when the `sizes` 69 // specified could lead to out of bounds accesses. 70 Location loc = op->getLoc(); 71 LinalgOp linalgOp = cast<LinalgOp>(op); 72 SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands(); 73 SmallVector<Value, 4> tiledOperands = makeTiledShapes( 74 b, loc, linalgOp, valuesToTile, 75 getValueOrCreateConstantIndexOp(b, loc, offsets), 76 getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); 77 78 SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range( 79 linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { 80 return tiledOperands[opOperand->getOperandNumber()].getType(); 81 })); 82 83 Operation *tiledOp = 84 linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); 85 86 return {tiledOp}; 87 } 88 89 // Return the details of the output tile generated by the tiled 90 // implementation. 91 LogicalResult 92 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 93 ArrayRef<OpFoldResult> offsets, 94 ArrayRef<OpFoldResult> sizes, 95 SmallVector<OpFoldResult> &resultOffsets, 96 SmallVector<OpFoldResult> &resultSizes) const { 97 Location loc = op->getLoc(); 98 LinalgOp linalgOp = cast<LinalgOp>(op); 99 100 AffineExpr d0; 101 bindDims(b.getContext(), d0); 102 103 auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc, 104 AffineExpr expr, 105 ValueRange operands) -> Value { 106 AffineMap map = AffineMap::inferFromExprList({expr}).front(); 107 SmallVector<Value> normalizedOperands(operands.begin(), operands.end()); 108 mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands); 109 canonicalizeMapAndOperands(&map, &normalizedOperands); 110 return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands); 111 }; 112 113 SmallVector<Value> sizeVals = 114 getValueOrCreateConstantIndexOp(b, loc, sizes); 115 SmallVector<Value> subShapeSizes = 116 llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) { 117 return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v); 118 })); 119 OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); 120 Value sliceOpResult = 121 makeTiledShape(b, loc, outOperand->get(), sizeVals, 122 linalgOp.getTiedIndexingMap(outOperand), 123 getValueOrCreateConstantIndexOp(b, loc, offsets), 124 /*ubs*/ {}, subShapeSizes, true); 125 auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>(); 126 if (!sliceOp) 127 return failure(); 128 resultOffsets = sliceOp.getMixedOffsets(); 129 resultSizes = sliceOp.getMixedSizes(); 130 return success(); 131 } 132 }; 133 134 } // namespace 135 136 template <typename OpType> static void registerOne(MLIRContext *ctx) { 137 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); 138 } 139 140 /// Variadic helper function. 141 template <typename... OpTypes> static void registerAll(MLIRContext *ctx) { 142 // FIXME: In c++17 this can be simplified by using 'fold expressions'. 143 (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...}; 144 } 145 146 #define GET_OP_LIST 147 148 void mlir::linalg::registerTilingInterfaceExternalModels( 149 DialectRegistry ®istry) { 150 registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 151 registerOne<linalg::GenericOp>(ctx); 152 registerAll< 153 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 154 >(ctx); 155 }); 156 } 157