xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/Pass.cpp (revision 2387fade)
1f61d1028SMehdi Amini //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2f61d1028SMehdi Amini //
3f61d1028SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f61d1028SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
5f61d1028SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f61d1028SMehdi Amini //
7f61d1028SMehdi Amini //===----------------------------------------------------------------------===//
8f61d1028SMehdi Amini 
9f61d1028SMehdi Amini #include "mlir-c/Pass.h"
10f61d1028SMehdi Amini 
11f61d1028SMehdi Amini #include "mlir/CAPI/IR.h"
12f61d1028SMehdi Amini #include "mlir/CAPI/Pass.h"
13f61d1028SMehdi Amini #include "mlir/CAPI/Support.h"
14f61d1028SMehdi Amini #include "mlir/CAPI/Utils.h"
15f61d1028SMehdi Amini #include "mlir/Pass/PassManager.h"
16f61d1028SMehdi Amini 
17f61d1028SMehdi Amini using namespace mlir;
18f61d1028SMehdi Amini 
19c7994bd9SMehdi Amini //===----------------------------------------------------------------------===//
20c7994bd9SMehdi Amini // PassManager/OpPassManager APIs.
21c7994bd9SMehdi Amini //===----------------------------------------------------------------------===//
22f61d1028SMehdi Amini 
mlirPassManagerCreate(MlirContext ctx)23f61d1028SMehdi Amini MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
24f61d1028SMehdi Amini   return wrap(new PassManager(unwrap(ctx)));
25f61d1028SMehdi Amini }
26f61d1028SMehdi Amini 
mlirPassManagerDestroy(MlirPassManager passManager)27f61d1028SMehdi Amini void mlirPassManagerDestroy(MlirPassManager passManager) {
28f61d1028SMehdi Amini   delete unwrap(passManager);
29f61d1028SMehdi Amini }
30f61d1028SMehdi Amini 
31aeb4b1a9SMehdi Amini MlirOpPassManager
mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)32aeb4b1a9SMehdi Amini mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
33aeb4b1a9SMehdi Amini   return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
34aeb4b1a9SMehdi Amini }
35aeb4b1a9SMehdi Amini 
mlirPassManagerRun(MlirPassManager passManager,MlirModule module)36f61d1028SMehdi Amini MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
37f61d1028SMehdi Amini                                      MlirModule module) {
38f61d1028SMehdi Amini   return wrap(unwrap(passManager)->run(unwrap(module)));
39f61d1028SMehdi Amini }
40f61d1028SMehdi Amini 
mlirPassManagerEnableIRPrinting(MlirPassManager passManager)41caa159f0SNicolas Vasilache void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
42caa159f0SNicolas Vasilache   return unwrap(passManager)->enableIRPrinting();
43caa159f0SNicolas Vasilache }
44caa159f0SNicolas Vasilache 
mlirPassManagerEnableVerifier(MlirPassManager passManager,bool enable)45caa159f0SNicolas Vasilache void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
46caa159f0SNicolas Vasilache   unwrap(passManager)->enableVerifier(enable);
47caa159f0SNicolas Vasilache }
48caa159f0SNicolas Vasilache 
mlirPassManagerGetNestedUnder(MlirPassManager passManager,MlirStringRef operationName)49f61d1028SMehdi Amini MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
50f61d1028SMehdi Amini                                                 MlirStringRef operationName) {
51f61d1028SMehdi Amini   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
52f61d1028SMehdi Amini }
53f61d1028SMehdi Amini 
mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,MlirStringRef operationName)54f61d1028SMehdi Amini MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
55f61d1028SMehdi Amini                                                   MlirStringRef operationName) {
56f61d1028SMehdi Amini   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
57f61d1028SMehdi Amini }
58f61d1028SMehdi Amini 
mlirPassManagerAddOwnedPass(MlirPassManager passManager,MlirPass pass)59f61d1028SMehdi Amini void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
60f61d1028SMehdi Amini   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
61f61d1028SMehdi Amini }
62f61d1028SMehdi Amini 
mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,MlirPass pass)63f61d1028SMehdi Amini void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
64f61d1028SMehdi Amini                                    MlirPass pass) {
65f61d1028SMehdi Amini   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
66f61d1028SMehdi Amini }
67aeb4b1a9SMehdi Amini 
mlirPrintPassPipeline(MlirOpPassManager passManager,MlirStringCallback callback,void * userData)68aeb4b1a9SMehdi Amini void mlirPrintPassPipeline(MlirOpPassManager passManager,
69aeb4b1a9SMehdi Amini                            MlirStringCallback callback, void *userData) {
70aeb4b1a9SMehdi Amini   detail::CallbackOstream stream(callback, userData);
71aeb4b1a9SMehdi Amini   unwrap(passManager)->printAsTextualPipeline(stream);
72aeb4b1a9SMehdi Amini }
73aeb4b1a9SMehdi Amini 
mlirParsePassPipeline(MlirOpPassManager passManager,MlirStringRef pipeline)74aeb4b1a9SMehdi Amini MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
75aeb4b1a9SMehdi Amini                                         MlirStringRef pipeline) {
76aeb4b1a9SMehdi Amini   // TODO: errors are sent to std::errs() at the moment, we should pass in a
77aeb4b1a9SMehdi Amini   // stream and redirect to a diagnostic.
78aeb4b1a9SMehdi Amini   return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
79aeb4b1a9SMehdi Amini }
80*2387fadeSDaniel Resnick 
81*2387fadeSDaniel Resnick //===----------------------------------------------------------------------===//
82*2387fadeSDaniel Resnick // External Pass API.
83*2387fadeSDaniel Resnick //===----------------------------------------------------------------------===//
84*2387fadeSDaniel Resnick 
85*2387fadeSDaniel Resnick namespace mlir {
86*2387fadeSDaniel Resnick class ExternalPass;
87*2387fadeSDaniel Resnick } // namespace mlir
88*2387fadeSDaniel Resnick DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
89*2387fadeSDaniel Resnick 
90*2387fadeSDaniel Resnick namespace mlir {
91*2387fadeSDaniel Resnick /// This pass class wraps external passes defined in other languages using the
92*2387fadeSDaniel Resnick /// MLIR C-interface
93*2387fadeSDaniel Resnick class ExternalPass : public Pass {
94*2387fadeSDaniel Resnick public:
ExternalPass(TypeID passID,StringRef name,StringRef argument,StringRef description,Optional<StringRef> opName,ArrayRef<MlirDialectHandle> dependentDialects,MlirExternalPassCallbacks callbacks,void * userData)95*2387fadeSDaniel Resnick   ExternalPass(TypeID passID, StringRef name, StringRef argument,
96*2387fadeSDaniel Resnick                StringRef description, Optional<StringRef> opName,
97*2387fadeSDaniel Resnick                ArrayRef<MlirDialectHandle> dependentDialects,
98*2387fadeSDaniel Resnick                MlirExternalPassCallbacks callbacks, void *userData)
99*2387fadeSDaniel Resnick       : Pass(passID, opName), id(passID), name(name), argument(argument),
100*2387fadeSDaniel Resnick         description(description), dependentDialects(dependentDialects),
101*2387fadeSDaniel Resnick         callbacks(callbacks), userData(userData) {
102*2387fadeSDaniel Resnick     callbacks.construct(userData);
103*2387fadeSDaniel Resnick   }
104*2387fadeSDaniel Resnick 
~ExternalPass()105*2387fadeSDaniel Resnick   ~ExternalPass() override { callbacks.destruct(userData); }
106*2387fadeSDaniel Resnick 
getName() const107*2387fadeSDaniel Resnick   StringRef getName() const override { return name; }
getArgument() const108*2387fadeSDaniel Resnick   StringRef getArgument() const override { return argument; }
getDescription() const109*2387fadeSDaniel Resnick   StringRef getDescription() const override { return description; }
110*2387fadeSDaniel Resnick 
getDependentDialects(DialectRegistry & registry) const111*2387fadeSDaniel Resnick   void getDependentDialects(DialectRegistry &registry) const override {
112*2387fadeSDaniel Resnick     MlirDialectRegistry cRegistry = wrap(&registry);
113*2387fadeSDaniel Resnick     for (MlirDialectHandle dialect : dependentDialects)
114*2387fadeSDaniel Resnick       mlirDialectHandleInsertDialect(dialect, cRegistry);
115*2387fadeSDaniel Resnick   }
116*2387fadeSDaniel Resnick 
signalPassFailure()117*2387fadeSDaniel Resnick   void signalPassFailure() { Pass::signalPassFailure(); }
118*2387fadeSDaniel Resnick 
119*2387fadeSDaniel Resnick protected:
initialize(MLIRContext * ctx)120*2387fadeSDaniel Resnick   LogicalResult initialize(MLIRContext *ctx) override {
121*2387fadeSDaniel Resnick     if (callbacks.initialize)
122*2387fadeSDaniel Resnick       return unwrap(callbacks.initialize(wrap(ctx), userData));
123*2387fadeSDaniel Resnick     return success();
124*2387fadeSDaniel Resnick   }
125*2387fadeSDaniel Resnick 
canScheduleOn(RegisteredOperationName opName) const126*2387fadeSDaniel Resnick   bool canScheduleOn(RegisteredOperationName opName) const override {
127*2387fadeSDaniel Resnick     if (Optional<StringRef> specifiedOpName = getOpName())
128*2387fadeSDaniel Resnick       return opName.getStringRef() == specifiedOpName;
129*2387fadeSDaniel Resnick     return true;
130*2387fadeSDaniel Resnick   }
131*2387fadeSDaniel Resnick 
runOnOperation()132*2387fadeSDaniel Resnick   void runOnOperation() override {
133*2387fadeSDaniel Resnick     callbacks.run(wrap(getOperation()), wrap(this), userData);
134*2387fadeSDaniel Resnick   }
135*2387fadeSDaniel Resnick 
clonePass() const136*2387fadeSDaniel Resnick   std::unique_ptr<Pass> clonePass() const override {
137*2387fadeSDaniel Resnick     void *clonedUserData = callbacks.clone(userData);
138*2387fadeSDaniel Resnick     return std::make_unique<ExternalPass>(id, name, argument, description,
139*2387fadeSDaniel Resnick                                           getOpName(), dependentDialects,
140*2387fadeSDaniel Resnick                                           callbacks, clonedUserData);
141*2387fadeSDaniel Resnick   }
142*2387fadeSDaniel Resnick 
143*2387fadeSDaniel Resnick private:
144*2387fadeSDaniel Resnick   TypeID id;
145*2387fadeSDaniel Resnick   std::string name;
146*2387fadeSDaniel Resnick   std::string argument;
147*2387fadeSDaniel Resnick   std::string description;
148*2387fadeSDaniel Resnick   std::vector<MlirDialectHandle> dependentDialects;
149*2387fadeSDaniel Resnick   MlirExternalPassCallbacks callbacks;
150*2387fadeSDaniel Resnick   void *userData;
151*2387fadeSDaniel Resnick };
152*2387fadeSDaniel Resnick } // namespace mlir
153*2387fadeSDaniel Resnick 
mlirCreateExternalPass(MlirTypeID passID,MlirStringRef name,MlirStringRef argument,MlirStringRef description,MlirStringRef opName,intptr_t nDependentDialects,MlirDialectHandle * dependentDialects,MlirExternalPassCallbacks callbacks,void * userData)154*2387fadeSDaniel Resnick MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
155*2387fadeSDaniel Resnick                                 MlirStringRef argument,
156*2387fadeSDaniel Resnick                                 MlirStringRef description, MlirStringRef opName,
157*2387fadeSDaniel Resnick                                 intptr_t nDependentDialects,
158*2387fadeSDaniel Resnick                                 MlirDialectHandle *dependentDialects,
159*2387fadeSDaniel Resnick                                 MlirExternalPassCallbacks callbacks,
160*2387fadeSDaniel Resnick                                 void *userData) {
161*2387fadeSDaniel Resnick   return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
162*2387fadeSDaniel Resnick       unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
163*2387fadeSDaniel Resnick       opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None,
164*2387fadeSDaniel Resnick       {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
165*2387fadeSDaniel Resnick       userData)));
166*2387fadeSDaniel Resnick }
167*2387fadeSDaniel Resnick 
mlirExternalPassSignalFailure(MlirExternalPass pass)168*2387fadeSDaniel Resnick void mlirExternalPassSignalFailure(MlirExternalPass pass) {
169*2387fadeSDaniel Resnick   unwrap(pass)->signalPassFailure();
170*2387fadeSDaniel Resnick }
171