19e39a5d9SLei Zhang //===- Generalization.cpp - linalg named ops to generic ops  --------------===//
29e39a5d9SLei Zhang //
39e39a5d9SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49e39a5d9SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
59e39a5d9SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69e39a5d9SLei Zhang //
79e39a5d9SLei Zhang //===----------------------------------------------------------------------===//
89e39a5d9SLei Zhang //
99e39a5d9SLei Zhang // This file implements the Linalg generalization pass. It converts named
109e39a5d9SLei Zhang // Linalg ops to linalg.generic ops.
119e39a5d9SLei Zhang //
129e39a5d9SLei Zhang //===----------------------------------------------------------------------===//
139e39a5d9SLei Zhang 
149e39a5d9SLei Zhang #include "PassDetail.h"
15b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
169e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Passes.h"
179e39a5d9SLei Zhang #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
189e39a5d9SLei Zhang #include "mlir/IR/AffineMap.h"
199e39a5d9SLei Zhang #include "mlir/IR/Attributes.h"
209e39a5d9SLei Zhang #include "mlir/IR/Builders.h"
214519ca3dSNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
229e39a5d9SLei Zhang #include "mlir/IR/PatternMatch.h"
239e39a5d9SLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
249e39a5d9SLei Zhang #include "llvm/ADT/SmallVector.h"
259e39a5d9SLei Zhang #include "llvm/Support/Debug.h"
269e39a5d9SLei Zhang 
279e39a5d9SLei Zhang #define DEBUG_TYPE "linalg-generalization"
289e39a5d9SLei Zhang 
299e39a5d9SLei Zhang using namespace mlir;
305a451e48STobias Gysi using namespace mlir::linalg;
319e39a5d9SLei Zhang 
generalizeNamedOpPrecondition(LinalgOp linalgOp)32c05db638SNicolas Vasilache static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
33e826db62STobias Gysi   // Check if the operation is a LinalgOp but not a GenericOp.
34c05db638SNicolas Vasilache   if (isa<GenericOp>(linalgOp))
35e826db62STobias Gysi     return failure();
36e826db62STobias Gysi   // Check if the operation has a region builder.
37c05db638SNicolas Vasilache   if (!linalgOp.getRegionBuilder())
38e826db62STobias Gysi     return failure();
39e826db62STobias Gysi   return success();
40e826db62STobias Gysi }
41e826db62STobias Gysi 
generalizeNamedOp(RewriterBase & rewriter,LinalgOp linalgOp)429a7d111fSNicolas Vasilache FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
43c05db638SNicolas Vasilache                                                      LinalgOp linalgOp) {
44c05db638SNicolas Vasilache   if (failed(generalizeNamedOpPrecondition(linalgOp)))
45c05db638SNicolas Vasilache     return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
469a7d111fSNicolas Vasilache 
47c05db638SNicolas Vasilache   SmallVector<Value> inputOperands = linalgOp.getInputOperands();
48c05db638SNicolas Vasilache   SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
49*d2c0572bSJacques Pienaar   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
505a451e48STobias Gysi   SmallVector<StringRef> iterators = llvm::to_vector<4>(
51c05db638SNicolas Vasilache       linalgOp.iterator_types().getAsValueRange<StringAttr>());
52c05db638SNicolas Vasilache   SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes();
535a451e48STobias Gysi   SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
545a451e48STobias Gysi 
55eaa52750STobias Gysi   // All named ops have a region attached that can be inlined.
56c05db638SNicolas Vasilache   assert(linalgOp->getNumRegions() == 1 &&
57eaa52750STobias Gysi          "expect named op to have one region attached");
58ad10d965STobias Gysi   GenericOp genericOp =
59c05db638SNicolas Vasilache       rewriter.create<GenericOp>(linalgOp.getLoc(), types, inputOperands,
60ad10d965STobias Gysi                                  outputOperands, indexingMaps, iterators);
61c05db638SNicolas Vasilache   rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.region(),
625a451e48STobias Gysi                               genericOp.region().begin());
63c05db638SNicolas Vasilache   rewriter.replaceOp(linalgOp, genericOp->getResults());
645a451e48STobias Gysi   return genericOp;
655a451e48STobias Gysi }
665a451e48STobias Gysi 
679e39a5d9SLei Zhang namespace {
689e39a5d9SLei Zhang 
699e39a5d9SLei Zhang struct LinalgGeneralizationPass
709e39a5d9SLei Zhang     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
7141574554SRiver Riddle   void runOnOperation() override;
729e39a5d9SLei Zhang };
739e39a5d9SLei Zhang 
749e39a5d9SLei Zhang } // namespace
759e39a5d9SLei Zhang 
runOnOperation()7641574554SRiver Riddle void LinalgGeneralizationPass::runOnOperation() {
7758ceae95SRiver Riddle   func::FuncOp func = getOperation();
78dc4e913bSChris Lattner   RewritePatternSet patterns(&getContext());
795a451e48STobias Gysi   populateLinalgNamedOpsGeneralizationPatterns(patterns);
80e21adfa3SRiver Riddle   (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
819e39a5d9SLei Zhang }
829e39a5d9SLei Zhang 
populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet & patterns,const LinalgTransformationFilter & marker)839e39a5d9SLei Zhang void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
841fc096afSMehdi Amini     RewritePatternSet &patterns, const LinalgTransformationFilter &marker) {
85e826db62STobias Gysi   patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);
869e39a5d9SLei Zhang }
879e39a5d9SLei Zhang 
8858ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createLinalgGeneralizationPass()8958ceae95SRiver Riddle mlir::createLinalgGeneralizationPass() {
909e39a5d9SLei Zhang   return std::make_unique<LinalgGeneralizationPass>();
919e39a5d9SLei Zhang }
92