xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/Pass.cpp (revision 28e665fa)
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Pass.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 
17 using namespace mlir;
18 
19 //===----------------------------------------------------------------------===//
20 // PassManager/OpPassManager APIs.
21 //===----------------------------------------------------------------------===//
22 
23 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
24   return wrap(new PassManager(unwrap(ctx)));
25 }
26 
27 void mlirPassManagerDestroy(MlirPassManager passManager) {
28   delete unwrap(passManager);
29 }
30 
31 MlirOpPassManager
32 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
33   return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
34 }
35 
36 MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
37                                      MlirModule module) {
38   return wrap(unwrap(passManager)->run(unwrap(module)));
39 }
40 
41 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
42   return unwrap(passManager)->enableIRPrinting();
43 }
44 
45 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
46   unwrap(passManager)->enableVerifier(enable);
47 }
48 
49 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
50                                                 MlirStringRef operationName) {
51   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
52 }
53 
54 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
55                                                   MlirStringRef operationName) {
56   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
57 }
58 
59 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
60   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
61 }
62 
63 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
64                                    MlirPass pass) {
65   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
66 }
67 
68 void mlirPrintPassPipeline(MlirOpPassManager passManager,
69                            MlirStringCallback callback, void *userData) {
70   detail::CallbackOstream stream(callback, userData);
71   unwrap(passManager)->printAsTextualPipeline(stream);
72 }
73 
74 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
75                                         MlirStringRef pipeline) {
76   // TODO: errors are sent to std::errs() at the moment, we should pass in a
77   // stream and redirect to a diagnostic.
78   return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // External Pass API.
83 //===----------------------------------------------------------------------===//
84 
85 namespace mlir {
86 class ExternalPass;
87 } // namespace mlir
88 DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
89 
90 namespace mlir {
91 /// This pass class wraps external passes defined in other languages using the
92 /// MLIR C-interface
93 class ExternalPass : public Pass {
94 public:
95   ExternalPass(TypeID passID, StringRef name, StringRef argument,
96                StringRef description, Optional<StringRef> opName,
97                ArrayRef<MlirDialectHandle> dependentDialects,
98                MlirExternalPassCallbacks callbacks, void *userData)
99       : Pass(passID, opName), id(passID), name(name), argument(argument),
100         description(description), dependentDialects(dependentDialects),
101         callbacks(callbacks), userData(userData) {
102     callbacks.construct(userData);
103   }
104 
105   ~ExternalPass() override { callbacks.destruct(userData); }
106 
107   StringRef getName() const override { return name; }
108   StringRef getArgument() const override { return argument; }
109   StringRef getDescription() const override { return description; }
110 
111   void getDependentDialects(DialectRegistry &registry) const override {
112     MlirDialectRegistry cRegistry = wrap(&registry);
113     for (MlirDialectHandle dialect : dependentDialects)
114       mlirDialectHandleInsertDialect(dialect, cRegistry);
115   }
116 
117   void signalPassFailure() { Pass::signalPassFailure(); }
118 
119 protected:
120   LogicalResult initialize(MLIRContext *ctx) override {
121     if (callbacks.initialize)
122       return unwrap(callbacks.initialize(wrap(ctx), userData));
123     return success();
124   }
125 
126   bool canScheduleOn(RegisteredOperationName opName) const override {
127     if (Optional<StringRef> specifiedOpName = getOpName())
128       return opName.getStringRef() == specifiedOpName;
129     return true;
130   }
131 
132   void runOnOperation() override {
133     callbacks.run(wrap(getOperation()), wrap(this), userData);
134   }
135 
136   std::unique_ptr<Pass> clonePass() const override {
137     void *clonedUserData = callbacks.clone(userData);
138     return std::make_unique<ExternalPass>(id, name, argument, description,
139                                           getOpName(), dependentDialects,
140                                           callbacks, clonedUserData);
141   }
142 
143 private:
144   TypeID id;
145   std::string name;
146   std::string argument;
147   std::string description;
148   std::vector<MlirDialectHandle> dependentDialects;
149   MlirExternalPassCallbacks callbacks;
150   void *userData;
151 };
152 } // namespace mlir
153 
154 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
155                                 MlirStringRef argument,
156                                 MlirStringRef description, MlirStringRef opName,
157                                 intptr_t nDependentDialects,
158                                 MlirDialectHandle *dependentDialects,
159                                 MlirExternalPassCallbacks callbacks,
160                                 void *userData) {
161   return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
162       unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
163       opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None,
164       {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
165       userData)));
166 }
167 
168 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
169   unwrap(pass)->signalPassFailure();
170 }
171