1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/Pass.h"
10
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16
17 using namespace mlir;
18
19 //===----------------------------------------------------------------------===//
20 // PassManager/OpPassManager APIs.
21 //===----------------------------------------------------------------------===//
22
mlirPassManagerCreate(MlirContext ctx)23 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
24 return wrap(new PassManager(unwrap(ctx)));
25 }
26
mlirPassManagerDestroy(MlirPassManager passManager)27 void mlirPassManagerDestroy(MlirPassManager passManager) {
28 delete unwrap(passManager);
29 }
30
31 MlirOpPassManager
mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)32 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
33 return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
34 }
35
mlirPassManagerRun(MlirPassManager passManager,MlirModule module)36 MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
37 MlirModule module) {
38 return wrap(unwrap(passManager)->run(unwrap(module)));
39 }
40
mlirPassManagerEnableIRPrinting(MlirPassManager passManager)41 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
42 return unwrap(passManager)->enableIRPrinting();
43 }
44
mlirPassManagerEnableVerifier(MlirPassManager passManager,bool enable)45 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
46 unwrap(passManager)->enableVerifier(enable);
47 }
48
mlirPassManagerGetNestedUnder(MlirPassManager passManager,MlirStringRef operationName)49 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
50 MlirStringRef operationName) {
51 return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
52 }
53
mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,MlirStringRef operationName)54 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
55 MlirStringRef operationName) {
56 return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
57 }
58
mlirPassManagerAddOwnedPass(MlirPassManager passManager,MlirPass pass)59 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
60 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
61 }
62
mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,MlirPass pass)63 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
64 MlirPass pass) {
65 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
66 }
67
mlirPrintPassPipeline(MlirOpPassManager passManager,MlirStringCallback callback,void * userData)68 void mlirPrintPassPipeline(MlirOpPassManager passManager,
69 MlirStringCallback callback, void *userData) {
70 detail::CallbackOstream stream(callback, userData);
71 unwrap(passManager)->printAsTextualPipeline(stream);
72 }
73
mlirParsePassPipeline(MlirOpPassManager passManager,MlirStringRef pipeline)74 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
75 MlirStringRef pipeline) {
76 // TODO: errors are sent to std::errs() at the moment, we should pass in a
77 // stream and redirect to a diagnostic.
78 return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
79 }
80
81 //===----------------------------------------------------------------------===//
82 // External Pass API.
83 //===----------------------------------------------------------------------===//
84
85 namespace mlir {
86 class ExternalPass;
87 } // namespace mlir
88 DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
89
90 namespace mlir {
91 /// This pass class wraps external passes defined in other languages using the
92 /// MLIR C-interface
93 class ExternalPass : public Pass {
94 public:
ExternalPass(TypeID passID,StringRef name,StringRef argument,StringRef description,Optional<StringRef> opName,ArrayRef<MlirDialectHandle> dependentDialects,MlirExternalPassCallbacks callbacks,void * userData)95 ExternalPass(TypeID passID, StringRef name, StringRef argument,
96 StringRef description, Optional<StringRef> opName,
97 ArrayRef<MlirDialectHandle> dependentDialects,
98 MlirExternalPassCallbacks callbacks, void *userData)
99 : Pass(passID, opName), id(passID), name(name), argument(argument),
100 description(description), dependentDialects(dependentDialects),
101 callbacks(callbacks), userData(userData) {
102 callbacks.construct(userData);
103 }
104
~ExternalPass()105 ~ExternalPass() override { callbacks.destruct(userData); }
106
getName() const107 StringRef getName() const override { return name; }
getArgument() const108 StringRef getArgument() const override { return argument; }
getDescription() const109 StringRef getDescription() const override { return description; }
110
getDependentDialects(DialectRegistry & registry) const111 void getDependentDialects(DialectRegistry ®istry) const override {
112 MlirDialectRegistry cRegistry = wrap(®istry);
113 for (MlirDialectHandle dialect : dependentDialects)
114 mlirDialectHandleInsertDialect(dialect, cRegistry);
115 }
116
signalPassFailure()117 void signalPassFailure() { Pass::signalPassFailure(); }
118
119 protected:
initialize(MLIRContext * ctx)120 LogicalResult initialize(MLIRContext *ctx) override {
121 if (callbacks.initialize)
122 return unwrap(callbacks.initialize(wrap(ctx), userData));
123 return success();
124 }
125
canScheduleOn(RegisteredOperationName opName) const126 bool canScheduleOn(RegisteredOperationName opName) const override {
127 if (Optional<StringRef> specifiedOpName = getOpName())
128 return opName.getStringRef() == specifiedOpName;
129 return true;
130 }
131
runOnOperation()132 void runOnOperation() override {
133 callbacks.run(wrap(getOperation()), wrap(this), userData);
134 }
135
clonePass() const136 std::unique_ptr<Pass> clonePass() const override {
137 void *clonedUserData = callbacks.clone(userData);
138 return std::make_unique<ExternalPass>(id, name, argument, description,
139 getOpName(), dependentDialects,
140 callbacks, clonedUserData);
141 }
142
143 private:
144 TypeID id;
145 std::string name;
146 std::string argument;
147 std::string description;
148 std::vector<MlirDialectHandle> dependentDialects;
149 MlirExternalPassCallbacks callbacks;
150 void *userData;
151 };
152 } // namespace mlir
153
mlirCreateExternalPass(MlirTypeID passID,MlirStringRef name,MlirStringRef argument,MlirStringRef description,MlirStringRef opName,intptr_t nDependentDialects,MlirDialectHandle * dependentDialects,MlirExternalPassCallbacks callbacks,void * userData)154 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
155 MlirStringRef argument,
156 MlirStringRef description, MlirStringRef opName,
157 intptr_t nDependentDialects,
158 MlirDialectHandle *dependentDialects,
159 MlirExternalPassCallbacks callbacks,
160 void *userData) {
161 return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
162 unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
163 opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None,
164 {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
165 userData)));
166 }
167
mlirExternalPassSignalFailure(MlirExternalPass pass)168 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
169 unwrap(pass)->signalPassFailure();
170 }
171