//===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "ByteCode.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

using namespace mlir;

static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
  // Skip the conversion if the module doesn't contain pdl.
  if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
    return success();

  // Simplify the provided PDL module. Note that we can't use the canonicalizer
  // here because it would create a cyclic dependency.
  auto simplifyFn = [](Operation *op) {
    // TODO: Add folding here if ever necessary.
    if (isOpTriviallyDead(op))
      op->erase();
  };
  pdlModule.getBody()->walk(simplifyFn);

  /// Lower the PDL pattern module to the interpreter dialect.
  PassManager pdlPipeline(pdlModule.getContext());
#ifdef NDEBUG
  // We don't want to incur the hit of running the verifier when in release
  // mode.
  pdlPipeline.enableVerifier(false);
#endif
  pdlPipeline.addPass(createPDLToPDLInterpPass());
  if (failed(pdlPipeline.run(pdlModule)))
    return failure();

  // Simplify again after running the lowering pipeline.
  pdlModule.getBody()->walk(simplifyFn);
  return success();
}

//===----------------------------------------------------------------------===//
// FrozenRewritePatternSet
//===----------------------------------------------------------------------===//

FrozenRewritePatternSet::FrozenRewritePatternSet()
    : impl(std::make_shared<Impl>()) {}

FrozenRewritePatternSet::FrozenRewritePatternSet(
    RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
    ArrayRef<std::string> enabledPatternLabels)
    : impl(std::make_shared<Impl>()) {
  DenseSet<StringRef> disabledPatterns, enabledPatterns;
  disabledPatterns.insert(disabledPatternLabels.begin(),
                          disabledPatternLabels.end());
  enabledPatterns.insert(enabledPatternLabels.begin(),
                         enabledPatternLabels.end());

  // Functor used to walk all of the operations registered in the context. This
  // is useful for patterns that get applied to multiple operations, such as
  // interface and trait based patterns.
  std::vector<RegisteredOperationName> opInfos;
  auto addToOpsWhen =
      [&](std::unique_ptr<RewritePattern> &pattern,
          function_ref<bool(RegisteredOperationName)> callbackFn) {
        if (opInfos.empty())
          opInfos = pattern->getContext()->getRegisteredOperations();
        for (RegisteredOperationName info : opInfos)
          if (callbackFn(info))
            impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
        impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
      };

  for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
    // Don't add patterns that haven't been enabled by the user.
    if (!enabledPatterns.empty()) {
      auto isEnabledFn = [&](StringRef label) {
        return enabledPatterns.count(label);
      };
      if (!isEnabledFn(pat->getDebugName()) &&
          llvm::none_of(pat->getDebugLabels(), isEnabledFn))
        continue;
    }
    // Don't add patterns that have been disabled by the user.
    if (!disabledPatterns.empty()) {
      auto isDisabledFn = [&](StringRef label) {
        return disabledPatterns.count(label);
      };
      if (isDisabledFn(pat->getDebugName()) ||
          llvm::any_of(pat->getDebugLabels(), isDisabledFn))
        continue;
    }

    if (Optional<OperationName> rootName = pat->getRootKind()) {
      impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
      impl->nativeOpSpecificPatternList.push_back(std::move(pat));
      continue;
    }
    if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
      addToOpsWhen(pat, [&](RegisteredOperationName info) {
        return info.hasInterface(*interfaceID);
      });
      continue;
    }
    if (Optional<TypeID> traitID = pat->getRootTraitID()) {
      addToOpsWhen(pat, [&](RegisteredOperationName info) {
        return info.hasTrait(*traitID);
      });
      continue;
    }
    impl->nativeAnyOpPatterns.push_back(std::move(pat));
  }

  // Generate the bytecode for the PDL patterns if any were provided.
  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
  ModuleOp pdlModule = pdlPatterns.getModule();
  if (!pdlModule)
    return;
  if (failed(convertPDLToPDLInterp(pdlModule)))
    llvm::report_fatal_error(
        "failed to lower PDL pattern module to the PDL Interpreter");

  // Generate the pdl bytecode.
  impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
      pdlModule, pdlPatterns.takeConstraintFunctions(),
      pdlPatterns.takeRewriteFunctions());
}

FrozenRewritePatternSet::~FrozenRewritePatternSet() {}
