19e39a5d9SLei Zhang //===- Generalization.cpp - linalg named ops to generic ops --------------===//
29e39a5d9SLei Zhang //
39e39a5d9SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49e39a5d9SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
59e39a5d9SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69e39a5d9SLei Zhang //
79e39a5d9SLei Zhang //===----------------------------------------------------------------------===//
89e39a5d9SLei Zhang //
99e39a5d9SLei Zhang // This file implements the Linalg generalization pass. It converts named
109e39a5d9SLei Zhang // Linalg ops to linalg.generic ops.
119e39a5d9SLei Zhang //
129e39a5d9SLei Zhang //===----------------------------------------------------------------------===//
139e39a5d9SLei Zhang
149e39a5d9SLei Zhang #include "PassDetail.h"
15b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
169e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Passes.h"
179e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
189e39a5d9SLei Zhang #include "mlir/IR/AffineMap.h"
199e39a5d9SLei Zhang #include "mlir/IR/Attributes.h"
209e39a5d9SLei Zhang #include "mlir/IR/Builders.h"
214519ca3dSNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
229e39a5d9SLei Zhang #include "mlir/IR/PatternMatch.h"
239e39a5d9SLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
249e39a5d9SLei Zhang #include "llvm/ADT/SmallVector.h"
259e39a5d9SLei Zhang #include "llvm/Support/Debug.h"
269e39a5d9SLei Zhang
279e39a5d9SLei Zhang #define DEBUG_TYPE "linalg-generalization"
289e39a5d9SLei Zhang
299e39a5d9SLei Zhang using namespace mlir;
305a451e48STobias Gysi using namespace mlir::linalg;
319e39a5d9SLei Zhang
generalizeNamedOpPrecondition(LinalgOp linalgOp)32c05db638SNicolas Vasilache static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
33e826db62STobias Gysi // Check if the operation is a LinalgOp but not a GenericOp.
34c05db638SNicolas Vasilache if (isa<GenericOp>(linalgOp))
35e826db62STobias Gysi return failure();
36e826db62STobias Gysi // Check if the operation has a region builder.
37c05db638SNicolas Vasilache if (!linalgOp.getRegionBuilder())
38e826db62STobias Gysi return failure();
39e826db62STobias Gysi return success();
40e826db62STobias Gysi }
41e826db62STobias Gysi
generalizeNamedOp(RewriterBase & rewriter,LinalgOp linalgOp)429a7d111fSNicolas Vasilache FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
43c05db638SNicolas Vasilache LinalgOp linalgOp) {
44c05db638SNicolas Vasilache if (failed(generalizeNamedOpPrecondition(linalgOp)))
45c05db638SNicolas Vasilache return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
469a7d111fSNicolas Vasilache
47c05db638SNicolas Vasilache SmallVector<Value> inputOperands = linalgOp.getInputOperands();
48c05db638SNicolas Vasilache SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
49*d2c0572bSJacques Pienaar SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
505a451e48STobias Gysi SmallVector<StringRef> iterators = llvm::to_vector<4>(
51c05db638SNicolas Vasilache linalgOp.iterator_types().getAsValueRange<StringAttr>());
52c05db638SNicolas Vasilache SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes();
535a451e48STobias Gysi SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
545a451e48STobias Gysi
55eaa52750STobias Gysi // All named ops have a region attached that can be inlined.
56c05db638SNicolas Vasilache assert(linalgOp->getNumRegions() == 1 &&
57eaa52750STobias Gysi "expect named op to have one region attached");
58ad10d965STobias Gysi GenericOp genericOp =
59c05db638SNicolas Vasilache rewriter.create<GenericOp>(linalgOp.getLoc(), types, inputOperands,
60ad10d965STobias Gysi outputOperands, indexingMaps, iterators);
61c05db638SNicolas Vasilache rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.region(),
625a451e48STobias Gysi genericOp.region().begin());
63c05db638SNicolas Vasilache rewriter.replaceOp(linalgOp, genericOp->getResults());
645a451e48STobias Gysi return genericOp;
655a451e48STobias Gysi }
665a451e48STobias Gysi
679e39a5d9SLei Zhang namespace {
689e39a5d9SLei Zhang
699e39a5d9SLei Zhang struct LinalgGeneralizationPass
709e39a5d9SLei Zhang : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
7141574554SRiver Riddle void runOnOperation() override;
729e39a5d9SLei Zhang };
739e39a5d9SLei Zhang
749e39a5d9SLei Zhang } // namespace
759e39a5d9SLei Zhang
runOnOperation()7641574554SRiver Riddle void LinalgGeneralizationPass::runOnOperation() {
7758ceae95SRiver Riddle func::FuncOp func = getOperation();
78dc4e913bSChris Lattner RewritePatternSet patterns(&getContext());
795a451e48STobias Gysi populateLinalgNamedOpsGeneralizationPatterns(patterns);
80e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
819e39a5d9SLei Zhang }
829e39a5d9SLei Zhang
populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet & patterns,const LinalgTransformationFilter & marker)839e39a5d9SLei Zhang void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
841fc096afSMehdi Amini RewritePatternSet &patterns, const LinalgTransformationFilter &marker) {
85e826db62STobias Gysi patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
869e39a5d9SLei Zhang }
879e39a5d9SLei Zhang
8858ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createLinalgGeneralizationPass()8958ceae95SRiver Riddle mlir::createLinalgGeneralizationPass() {
909e39a5d9SLei Zhang return std::make_unique<LinalgGeneralizationPass>();
919e39a5d9SLei Zhang }
92