1*cf6a7c19SMahesh Ravishankar //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// 2*cf6a7c19SMahesh Ravishankar // 3*cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information. 5*cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*cf6a7c19SMahesh Ravishankar // 7*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 8*cf6a7c19SMahesh Ravishankar 9*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 10*cf6a7c19SMahesh Ravishankar 11*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 12*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/Linalg.h" 15*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Utils/Utils.h" 16*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h" 17*cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h" 18*cf6a7c19SMahesh Ravishankar 19*cf6a7c19SMahesh Ravishankar using namespace mlir; 20*cf6a7c19SMahesh Ravishankar using namespace mlir::linalg; 21*cf6a7c19SMahesh Ravishankar 22*cf6a7c19SMahesh Ravishankar namespace { 23*cf6a7c19SMahesh Ravishankar 24*cf6a7c19SMahesh Ravishankar /// External model implementation of TilingInterface for LinalgOps. An external 25*cf6a7c19SMahesh Ravishankar /// model implementation is used for now till the use of `TilingInterface` is 26*cf6a7c19SMahesh Ravishankar /// on-par with the current Linalg tiling + fusion patterns. Once it is 27*cf6a7c19SMahesh Ravishankar /// maybe possible to move this into the op-definition (though there are 28*cf6a7c19SMahesh Ravishankar /// advantages to leaving it as an external model) 29*cf6a7c19SMahesh Ravishankar template <typename LinalgOpTy> 30*cf6a7c19SMahesh Ravishankar struct LinalgOpTilingInterface 31*cf6a7c19SMahesh Ravishankar : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, 32*cf6a7c19SMahesh Ravishankar LinalgOpTy> { 33*cf6a7c19SMahesh Ravishankar 34*cf6a7c19SMahesh Ravishankar /// Return the destination operands. 35*cf6a7c19SMahesh Ravishankar SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const { 36*cf6a7c19SMahesh Ravishankar return llvm::cast<LinalgOp>(op).getOutputOperands(); 37*cf6a7c19SMahesh Ravishankar } 38*cf6a7c19SMahesh Ravishankar 39*cf6a7c19SMahesh Ravishankar /// Return the loop iterator type. 40*cf6a7c19SMahesh Ravishankar SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const { 41*cf6a7c19SMahesh Ravishankar LinalgOpTy concreteOp = cast<LinalgOpTy>(op); 42*cf6a7c19SMahesh Ravishankar return llvm::to_vector( 43*cf6a7c19SMahesh Ravishankar llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { 44*cf6a7c19SMahesh Ravishankar return strAttr.cast<StringAttr>().getValue(); 45*cf6a7c19SMahesh Ravishankar })); 46*cf6a7c19SMahesh Ravishankar } 47*cf6a7c19SMahesh Ravishankar 48*cf6a7c19SMahesh Ravishankar /// Return the iteration domain range. 49*cf6a7c19SMahesh Ravishankar SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 50*cf6a7c19SMahesh Ravishankar Location loc = op->getLoc(); 51*cf6a7c19SMahesh Ravishankar LinalgOp linalgOp = cast<LinalgOp>(op); 52*cf6a7c19SMahesh Ravishankar auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); 53*cf6a7c19SMahesh Ravishankar AffineMap map = linalgOp.getShapesToLoopsMap(); 54*cf6a7c19SMahesh Ravishankar Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 55*cf6a7c19SMahesh Ravishankar Value one = b.create<arith::ConstantIndexOp>(loc, 1); 56*cf6a7c19SMahesh Ravishankar return llvm::to_vector(llvm::map_range( 57*cf6a7c19SMahesh Ravishankar applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) { 58*cf6a7c19SMahesh Ravishankar return Range{zero, v, one}; 59*cf6a7c19SMahesh Ravishankar })); 60*cf6a7c19SMahesh Ravishankar } 61*cf6a7c19SMahesh Ravishankar 62*cf6a7c19SMahesh Ravishankar // Instantiate the tiled implementation of the operation. 63*cf6a7c19SMahesh Ravishankar SmallVector<Operation *> 64*cf6a7c19SMahesh Ravishankar getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest, 65*cf6a7c19SMahesh Ravishankar ArrayRef<OpFoldResult> offsets, 66*cf6a7c19SMahesh Ravishankar ArrayRef<OpFoldResult> sizes, 67*cf6a7c19SMahesh Ravishankar bool tileDestOperands) const { 68*cf6a7c19SMahesh Ravishankar // Leave the `sizeBounds` value empty. That is only needed when the `sizes` 69*cf6a7c19SMahesh Ravishankar // specified could lead to out of bounds accesses. 70*cf6a7c19SMahesh Ravishankar Location loc = op->getLoc(); 71*cf6a7c19SMahesh Ravishankar LinalgOp linalgOp = cast<LinalgOp>(op); 72*cf6a7c19SMahesh Ravishankar SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands(); 73*cf6a7c19SMahesh Ravishankar SmallVector<Value, 4> tiledOperands = makeTiledShapes( 74*cf6a7c19SMahesh Ravishankar b, loc, linalgOp, valuesToTile, 75*cf6a7c19SMahesh Ravishankar getValueOrCreateConstantIndexOp(b, loc, offsets), 76*cf6a7c19SMahesh Ravishankar getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); 77*cf6a7c19SMahesh Ravishankar 78*cf6a7c19SMahesh Ravishankar SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range( 79*cf6a7c19SMahesh Ravishankar linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { 80*cf6a7c19SMahesh Ravishankar return tiledOperands[opOperand->getOperandNumber()].getType(); 81*cf6a7c19SMahesh Ravishankar })); 82*cf6a7c19SMahesh Ravishankar 83*cf6a7c19SMahesh Ravishankar Operation *tiledOp = 84*cf6a7c19SMahesh Ravishankar linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); 85*cf6a7c19SMahesh Ravishankar 86*cf6a7c19SMahesh Ravishankar return {tiledOp}; 87*cf6a7c19SMahesh Ravishankar } 88*cf6a7c19SMahesh Ravishankar 89*cf6a7c19SMahesh Ravishankar // Return the details of the output tile generated by the tiled 90*cf6a7c19SMahesh Ravishankar // implementation. 91*cf6a7c19SMahesh Ravishankar LogicalResult 92*cf6a7c19SMahesh Ravishankar getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 93*cf6a7c19SMahesh Ravishankar ArrayRef<OpFoldResult> offsets, 94*cf6a7c19SMahesh Ravishankar ArrayRef<OpFoldResult> sizes, 95*cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &resultOffsets, 96*cf6a7c19SMahesh Ravishankar SmallVector<OpFoldResult> &resultSizes) const { 97*cf6a7c19SMahesh Ravishankar Location loc = op->getLoc(); 98*cf6a7c19SMahesh Ravishankar LinalgOp linalgOp = cast<LinalgOp>(op); 99*cf6a7c19SMahesh Ravishankar 100*cf6a7c19SMahesh Ravishankar AffineExpr d0; 101*cf6a7c19SMahesh Ravishankar bindDims(b.getContext(), d0); 102*cf6a7c19SMahesh Ravishankar 103*cf6a7c19SMahesh Ravishankar auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc, 104*cf6a7c19SMahesh Ravishankar AffineExpr expr, 105*cf6a7c19SMahesh Ravishankar ValueRange operands) -> Value { 106*cf6a7c19SMahesh Ravishankar AffineMap map = AffineMap::inferFromExprList({expr}).front(); 107*cf6a7c19SMahesh Ravishankar SmallVector<Value> normalizedOperands(operands.begin(), operands.end()); 108*cf6a7c19SMahesh Ravishankar mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands); 109*cf6a7c19SMahesh Ravishankar canonicalizeMapAndOperands(&map, &normalizedOperands); 110*cf6a7c19SMahesh Ravishankar return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands); 111*cf6a7c19SMahesh Ravishankar }; 112*cf6a7c19SMahesh Ravishankar 113*cf6a7c19SMahesh Ravishankar SmallVector<Value> sizeVals = 114*cf6a7c19SMahesh Ravishankar getValueOrCreateConstantIndexOp(b, loc, sizes); 115*cf6a7c19SMahesh Ravishankar SmallVector<Value> subShapeSizes = 116*cf6a7c19SMahesh Ravishankar llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) { 117*cf6a7c19SMahesh Ravishankar return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v); 118*cf6a7c19SMahesh Ravishankar })); 119*cf6a7c19SMahesh Ravishankar OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); 120*cf6a7c19SMahesh Ravishankar Value sliceOpResult = 121*cf6a7c19SMahesh Ravishankar makeTiledShape(b, loc, outOperand->get(), sizeVals, 122*cf6a7c19SMahesh Ravishankar linalgOp.getTiedIndexingMap(outOperand), 123*cf6a7c19SMahesh Ravishankar getValueOrCreateConstantIndexOp(b, loc, offsets), 124*cf6a7c19SMahesh Ravishankar /*ubs*/ {}, subShapeSizes, true); 125*cf6a7c19SMahesh Ravishankar auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>(); 126*cf6a7c19SMahesh Ravishankar if (!sliceOp) 127*cf6a7c19SMahesh Ravishankar return failure(); 128*cf6a7c19SMahesh Ravishankar resultOffsets = sliceOp.getMixedOffsets(); 129*cf6a7c19SMahesh Ravishankar resultSizes = sliceOp.getMixedSizes(); 130*cf6a7c19SMahesh Ravishankar return success(); 131*cf6a7c19SMahesh Ravishankar } 132*cf6a7c19SMahesh Ravishankar }; 133*cf6a7c19SMahesh Ravishankar 134*cf6a7c19SMahesh Ravishankar } // namespace 135*cf6a7c19SMahesh Ravishankar 136*cf6a7c19SMahesh Ravishankar template <typename OpType> static void registerOne(MLIRContext *ctx) { 137*cf6a7c19SMahesh Ravishankar OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); 138*cf6a7c19SMahesh Ravishankar } 139*cf6a7c19SMahesh Ravishankar 140*cf6a7c19SMahesh Ravishankar /// Variadic helper function. 141*cf6a7c19SMahesh Ravishankar template <typename... OpTypes> static void registerAll(MLIRContext *ctx) { 142*cf6a7c19SMahesh Ravishankar // FIXME: In c++17 this can be simplified by using 'fold expressions'. 143*cf6a7c19SMahesh Ravishankar (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...}; 144*cf6a7c19SMahesh Ravishankar } 145*cf6a7c19SMahesh Ravishankar 146*cf6a7c19SMahesh Ravishankar #define GET_OP_LIST 147*cf6a7c19SMahesh Ravishankar 148*cf6a7c19SMahesh Ravishankar void mlir::linalg::registerTilingInterfaceExternalModels( 149*cf6a7c19SMahesh Ravishankar DialectRegistry ®istry) { 150*cf6a7c19SMahesh Ravishankar registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { 151*cf6a7c19SMahesh Ravishankar registerOne<linalg::GenericOp>(ctx); 152*cf6a7c19SMahesh Ravishankar registerAll< 153*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 154*cf6a7c19SMahesh Ravishankar >(ctx); 155*cf6a7c19SMahesh Ravishankar }); 156*cf6a7c19SMahesh Ravishankar } 157