xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/Pass.cpp (revision 2387fade)
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Pass.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 
17 using namespace mlir;
18 
19 //===----------------------------------------------------------------------===//
20 // PassManager/OpPassManager APIs.
21 //===----------------------------------------------------------------------===//
22 
mlirPassManagerCreate(MlirContext ctx)23 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
24   return wrap(new PassManager(unwrap(ctx)));
25 }
26 
mlirPassManagerDestroy(MlirPassManager passManager)27 void mlirPassManagerDestroy(MlirPassManager passManager) {
28   delete unwrap(passManager);
29 }
30 
31 MlirOpPassManager
mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)32 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
33   return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
34 }
35 
mlirPassManagerRun(MlirPassManager passManager,MlirModule module)36 MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
37                                      MlirModule module) {
38   return wrap(unwrap(passManager)->run(unwrap(module)));
39 }
40 
mlirPassManagerEnableIRPrinting(MlirPassManager passManager)41 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
42   return unwrap(passManager)->enableIRPrinting();
43 }
44 
mlirPassManagerEnableVerifier(MlirPassManager passManager,bool enable)45 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
46   unwrap(passManager)->enableVerifier(enable);
47 }
48 
mlirPassManagerGetNestedUnder(MlirPassManager passManager,MlirStringRef operationName)49 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
50                                                 MlirStringRef operationName) {
51   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
52 }
53 
mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,MlirStringRef operationName)54 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
55                                                   MlirStringRef operationName) {
56   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
57 }
58 
mlirPassManagerAddOwnedPass(MlirPassManager passManager,MlirPass pass)59 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
60   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
61 }
62 
mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,MlirPass pass)63 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
64                                    MlirPass pass) {
65   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
66 }
67 
mlirPrintPassPipeline(MlirOpPassManager passManager,MlirStringCallback callback,void * userData)68 void mlirPrintPassPipeline(MlirOpPassManager passManager,
69                            MlirStringCallback callback, void *userData) {
70   detail::CallbackOstream stream(callback, userData);
71   unwrap(passManager)->printAsTextualPipeline(stream);
72 }
73 
mlirParsePassPipeline(MlirOpPassManager passManager,MlirStringRef pipeline)74 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
75                                         MlirStringRef pipeline) {
76   // TODO: errors are sent to std::errs() at the moment, we should pass in a
77   // stream and redirect to a diagnostic.
78   return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // External Pass API.
83 //===----------------------------------------------------------------------===//
84 
85 namespace mlir {
86 class ExternalPass;
87 } // namespace mlir
88 DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
89 
90 namespace mlir {
91 /// This pass class wraps external passes defined in other languages using the
92 /// MLIR C-interface
93 class ExternalPass : public Pass {
94 public:
ExternalPass(TypeID passID,StringRef name,StringRef argument,StringRef description,Optional<StringRef> opName,ArrayRef<MlirDialectHandle> dependentDialects,MlirExternalPassCallbacks callbacks,void * userData)95   ExternalPass(TypeID passID, StringRef name, StringRef argument,
96                StringRef description, Optional<StringRef> opName,
97                ArrayRef<MlirDialectHandle> dependentDialects,
98                MlirExternalPassCallbacks callbacks, void *userData)
99       : Pass(passID, opName), id(passID), name(name), argument(argument),
100         description(description), dependentDialects(dependentDialects),
101         callbacks(callbacks), userData(userData) {
102     callbacks.construct(userData);
103   }
104 
~ExternalPass()105   ~ExternalPass() override { callbacks.destruct(userData); }
106 
getName() const107   StringRef getName() const override { return name; }
getArgument() const108   StringRef getArgument() const override { return argument; }
getDescription() const109   StringRef getDescription() const override { return description; }
110 
getDependentDialects(DialectRegistry & registry) const111   void getDependentDialects(DialectRegistry &registry) const override {
112     MlirDialectRegistry cRegistry = wrap(&registry);
113     for (MlirDialectHandle dialect : dependentDialects)
114       mlirDialectHandleInsertDialect(dialect, cRegistry);
115   }
116 
signalPassFailure()117   void signalPassFailure() { Pass::signalPassFailure(); }
118 
119 protected:
initialize(MLIRContext * ctx)120   LogicalResult initialize(MLIRContext *ctx) override {
121     if (callbacks.initialize)
122       return unwrap(callbacks.initialize(wrap(ctx), userData));
123     return success();
124   }
125 
canScheduleOn(RegisteredOperationName opName) const126   bool canScheduleOn(RegisteredOperationName opName) const override {
127     if (Optional<StringRef> specifiedOpName = getOpName())
128       return opName.getStringRef() == specifiedOpName;
129     return true;
130   }
131 
runOnOperation()132   void runOnOperation() override {
133     callbacks.run(wrap(getOperation()), wrap(this), userData);
134   }
135 
clonePass() const136   std::unique_ptr<Pass> clonePass() const override {
137     void *clonedUserData = callbacks.clone(userData);
138     return std::make_unique<ExternalPass>(id, name, argument, description,
139                                           getOpName(), dependentDialects,
140                                           callbacks, clonedUserData);
141   }
142 
143 private:
144   TypeID id;
145   std::string name;
146   std::string argument;
147   std::string description;
148   std::vector<MlirDialectHandle> dependentDialects;
149   MlirExternalPassCallbacks callbacks;
150   void *userData;
151 };
152 } // namespace mlir
153 
mlirCreateExternalPass(MlirTypeID passID,MlirStringRef name,MlirStringRef argument,MlirStringRef description,MlirStringRef opName,intptr_t nDependentDialects,MlirDialectHandle * dependentDialects,MlirExternalPassCallbacks callbacks,void * userData)154 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
155                                 MlirStringRef argument,
156                                 MlirStringRef description, MlirStringRef opName,
157                                 intptr_t nDependentDialects,
158                                 MlirDialectHandle *dependentDialects,
159                                 MlirExternalPassCallbacks callbacks,
160                                 void *userData) {
161   return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
162       unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
163       opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None,
164       {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
165       userData)));
166 }
167 
mlirExternalPassSignalFailure(MlirExternalPass pass)168 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
169   unwrap(pass)->signalPassFailure();
170 }
171