//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/TilingInterface.h" using namespace mlir; using namespace mlir::linalg; namespace { /// External model implementation of TilingInterface for LinalgOps. An external /// model implementation is used for now till the use of `TilingInterface` is /// on-par with the current Linalg tiling + fusion patterns. Once it is /// maybe possible to move this into the op-definition (though there are /// advantages to leaving it as an external model) template struct LinalgOpTilingInterface : public TilingInterface::ExternalModel, LinalgOpTy> { /// Return the destination operands. SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { return llvm::cast(op).getOutputOperands(); } /// Return the loop iterator type. SmallVector getLoopIteratorTypes(Operation *op) const { LinalgOpTy concreteOp = cast(op); return llvm::to_vector( llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { return strAttr.cast().getValue(); })); } /// Return the iteration domain range. SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); AffineMap map = linalgOp.getShapesToLoopsMap(); Value zero = b.create(loc, 0); Value one = b.create(loc, 1); return llvm::to_vector(llvm::map_range( applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) { return Range{zero, v, one}; })); } // Instantiate the tiled implementation of the operation. SmallVector getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest, ArrayRef offsets, ArrayRef sizes, bool tileDestOperands) const { // Leave the `sizeBounds` value empty. That is only needed when the `sizes` // specified could lead to out of bounds accesses. Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); SmallVector tiledOperands = makeTiledShapes( b, loc, linalgOp, valuesToTile, getValueOrCreateConstantIndexOp(b, loc, offsets), getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); SmallVector resultTensorTypes = llvm::to_vector(llvm::map_range( linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { return tiledOperands[opOperand->getOperandNumber()].getType(); })); Operation *tiledOp = linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); return {tiledOp}; } // Return the details of the output tile generated by the tiled // implementation. LogicalResult getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes, SmallVector &resultOffsets, SmallVector &resultSizes) const { Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); AffineExpr d0; bindDims(b.getContext(), d0); auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc, AffineExpr expr, ValueRange operands) -> Value { AffineMap map = AffineMap::inferFromExprList({expr}).front(); SmallVector normalizedOperands(operands.begin(), operands.end()); mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands); canonicalizeMapAndOperands(&map, &normalizedOperands); return builder.createOrFold(loc, map, normalizedOperands); }; SmallVector sizeVals = getValueOrCreateConstantIndexOp(b, loc, sizes); SmallVector subShapeSizes = llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) { return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v); })); OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); Value sliceOpResult = makeTiledShape(b, loc, outOperand->get(), sizeVals, linalgOp.getTiedIndexingMap(outOperand), getValueOrCreateConstantIndexOp(b, loc, offsets), /*ubs*/ {}, subShapeSizes, true); auto sliceOp = sliceOpResult.getDefiningOp(); if (!sliceOp) return failure(); resultOffsets = sliceOp.getMixedOffsets(); resultSizes = sliceOp.getMixedSizes(); return success(); } }; } // namespace template static void registerOne(MLIRContext *ctx) { OpType::template attachInterface>(*ctx); } /// Variadic helper function. template static void registerAll(MLIRContext *ctx) { // FIXME: In c++17 this can be simplified by using 'fold expressions'. (void)std::initializer_list{0, (registerOne(ctx), 0)...}; } #define GET_OP_LIST void mlir::linalg::registerTilingInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { registerOne(ctx); registerAll< #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >(ctx); }); }