1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
10 #include "ByteCode.h"
11 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
12 #include "mlir/Dialect/PDL/IR/PDLOps.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 
17 using namespace mlir;
18 
19 static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
20   // Skip the conversion if the module doesn't contain pdl.
21   if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
22     return success();
23 
24   // Simplify the provided PDL module. Note that we can't use the canonicalizer
25   // here because it would create a cyclic dependency.
26   auto simplifyFn = [](Operation *op) {
27     // TODO: Add folding here if ever necessary.
28     if (isOpTriviallyDead(op))
29       op->erase();
30   };
31   pdlModule.getBody()->walk(simplifyFn);
32 
33   /// Lower the PDL pattern module to the interpreter dialect.
34   PassManager pdlPipeline(pdlModule.getContext());
35 #ifdef NDEBUG
36   // We don't want to incur the hit of running the verifier when in release
37   // mode.
38   pdlPipeline.enableVerifier(false);
39 #endif
40   pdlPipeline.addPass(createPDLToPDLInterpPass());
41   if (failed(pdlPipeline.run(pdlModule)))
42     return failure();
43 
44   // Simplify again after running the lowering pipeline.
45   pdlModule.getBody()->walk(simplifyFn);
46   return success();
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // FrozenRewritePatternSet
51 //===----------------------------------------------------------------------===//
52 
53 FrozenRewritePatternSet::FrozenRewritePatternSet()
54     : impl(std::make_shared<Impl>()) {}
55 
56 FrozenRewritePatternSet::FrozenRewritePatternSet(RewritePatternSet &&patterns)
57     : impl(std::make_shared<Impl>()) {
58   impl->nativePatterns = std::move(patterns.getNativePatterns());
59 
60   // Generate the bytecode for the PDL patterns if any were provided.
61   PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
62   ModuleOp pdlModule = pdlPatterns.getModule();
63   if (!pdlModule)
64     return;
65   if (failed(convertPDLToPDLInterp(pdlModule)))
66     llvm::report_fatal_error(
67         "failed to lower PDL pattern module to the PDL Interpreter");
68 
69   // Generate the pdl bytecode.
70   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
71       pdlModule, pdlPatterns.takeConstraintFunctions(),
72       pdlPatterns.takeRewriteFunctions());
73 }
74 
75 FrozenRewritePatternSet::~FrozenRewritePatternSet() {}
76