19e3b928eSChris Lattner //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
29e3b928eSChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69e3b928eSChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
89e3b928eSChris Lattner //
99e3b928eSChris Lattner // This transformation pass converts operations into their canonical forms by
109e3b928eSChris Lattner // folding constants, applying operation identity transformations etc.
119e3b928eSChris Lattner //
129e3b928eSChris Lattner //===----------------------------------------------------------------------===//
139e3b928eSChris Lattner 
141834ad4aSRiver Riddle #include "PassDetail.h"
1548ccae24SRiver Riddle #include "mlir/Pass/Pass.h"
16b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
179e3b928eSChris Lattner #include "mlir/Transforms/Passes.h"
18ad4b4acbSUday Bondhugula 
199e3b928eSChris Lattner using namespace mlir;
209e3b928eSChris Lattner 
219e3b928eSChris Lattner namespace {
222b61b797SRiver Riddle /// Canonicalize operations in nested regions.
231834ad4aSRiver Riddle struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
24*ebad5fb3Srkayaith   Canonicalizer() = default;
Canonicalizer__anonbfd20a390111::Canonicalizer2542ac4f3dSMogball   Canonicalizer(const GreedyRewriteConfig &config,
2642ac4f3dSMogball                 ArrayRef<std::string> disabledPatterns,
27*ebad5fb3Srkayaith                 ArrayRef<std::string> enabledPatterns) {
28*ebad5fb3Srkayaith     this->topDownProcessingEnabled = config.useTopDownTraversal;
29*ebad5fb3Srkayaith     this->enableRegionSimplification = config.enableRegionSimplification;
30*ebad5fb3Srkayaith     this->maxIterations = config.maxIterations;
3142ac4f3dSMogball     this->disabledPatterns = disabledPatterns;
3242ac4f3dSMogball     this->enabledPatterns = enabledPatterns;
3342ac4f3dSMogball   }
342f23f9e6SChris Lattner 
351ba5ea67SRiver Riddle   /// Initialize the canonicalizer by building the set of patterns used during
361ba5ea67SRiver Riddle   /// execution.
initialize__anonbfd20a390111::Canonicalizer37b1aaed02SMehdi Amini   LogicalResult initialize(MLIRContext *context) override {
38dc4e913bSChris Lattner     RewritePatternSet owningPatterns(context);
39108ca7a7SMatthias Springer     for (auto *dialect : context->getLoadedDialects())
40108ca7a7SMatthias Springer       dialect->getCanonicalizationPatterns(owningPatterns);
41edc6c0ecSRiver Riddle     for (RegisteredOperationName op : context->getRegisteredOperations())
42edc6c0ecSRiver Riddle       op.getCanonicalizationPatterns(owningPatterns, context);
430289a269SRiver Riddle 
440289a269SRiver Riddle     patterns = FrozenRewritePatternSet(std::move(owningPatterns),
450289a269SRiver Riddle                                        disabledPatterns, enabledPatterns);
46b1aaed02SMehdi Amini     return success();
479e3b928eSChris Lattner   }
runOnOperation__anonbfd20a390111::Canonicalizer481ba5ea67SRiver Riddle   void runOnOperation() override {
49*ebad5fb3Srkayaith     GreedyRewriteConfig config;
50*ebad5fb3Srkayaith     config.useTopDownTraversal = topDownProcessingEnabled;
51*ebad5fb3Srkayaith     config.enableRegionSimplification = enableRegionSimplification;
52*ebad5fb3Srkayaith     config.maxIterations = maxIterations;
53*ebad5fb3Srkayaith     (void)applyPatternsAndFoldGreedily(getOperation(), patterns, config);
541ba5ea67SRiver Riddle   }
551ba5ea67SRiver Riddle 
5679d7f618SChris Lattner   FrozenRewritePatternSet patterns;
572b61b797SRiver Riddle };
58be0a7e9fSMehdi Amini } // namespace
599e3b928eSChris Lattner 
609e3b928eSChris Lattner /// Create a Canonicalizer pass.
createCanonicalizerPass()612b61b797SRiver Riddle std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
6279f53b0cSJacques Pienaar   return std::make_unique<Canonicalizer>();
63c6c53449SRiver Riddle }
642f23f9e6SChris Lattner 
652f23f9e6SChris Lattner /// Creates an instance of the Canonicalizer pass with the specified config.
662f23f9e6SChris Lattner std::unique_ptr<Pass>
createCanonicalizerPass(const GreedyRewriteConfig & config,ArrayRef<std::string> disabledPatterns,ArrayRef<std::string> enabledPatterns)67db68e6abSMogball mlir::createCanonicalizerPass(const GreedyRewriteConfig &config,
68db68e6abSMogball                               ArrayRef<std::string> disabledPatterns,
69db68e6abSMogball                               ArrayRef<std::string> enabledPatterns) {
7042ac4f3dSMogball   return std::make_unique<Canonicalizer>(config, disabledPatterns,
7142ac4f3dSMogball                                          enabledPatterns);
722f23f9e6SChris Lattner }
73