1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Pass.h" 10 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/CAPI/Pass.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/Pass/PassManager.h" 16 17 using namespace mlir; 18 19 //===----------------------------------------------------------------------===// 20 // PassManager/OpPassManager APIs. 21 //===----------------------------------------------------------------------===// 22 23 MlirPassManager mlirPassManagerCreate(MlirContext ctx) { 24 return wrap(new PassManager(unwrap(ctx))); 25 } 26 27 void mlirPassManagerDestroy(MlirPassManager passManager) { 28 delete unwrap(passManager); 29 } 30 31 MlirOpPassManager 32 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) { 33 return wrap(static_cast<OpPassManager *>(unwrap(passManager))); 34 } 35 36 MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, 37 MlirModule module) { 38 return wrap(unwrap(passManager)->run(unwrap(module))); 39 } 40 41 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { 42 return unwrap(passManager)->enableIRPrinting(); 43 } 44 45 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { 46 unwrap(passManager)->enableVerifier(enable); 47 } 48 49 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, 50 MlirStringRef operationName) { 51 return wrap(&unwrap(passManager)->nest(unwrap(operationName))); 52 } 53 54 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, 55 MlirStringRef operationName) { 56 return wrap(&unwrap(passManager)->nest(unwrap(operationName))); 57 } 58 59 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) { 60 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass))); 61 } 62 63 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, 64 MlirPass pass) { 65 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass))); 66 } 67 68 void mlirPrintPassPipeline(MlirOpPassManager passManager, 69 MlirStringCallback callback, void *userData) { 70 detail::CallbackOstream stream(callback, userData); 71 unwrap(passManager)->printAsTextualPipeline(stream); 72 } 73 74 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, 75 MlirStringRef pipeline) { 76 // TODO: errors are sent to std::errs() at the moment, we should pass in a 77 // stream and redirect to a diagnostic. 78 return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager))); 79 } 80 81 //===----------------------------------------------------------------------===// 82 // External Pass API. 83 //===----------------------------------------------------------------------===// 84 85 namespace mlir { 86 class ExternalPass; 87 } // namespace mlir 88 DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass) 89 90 namespace mlir { 91 /// This pass class wraps external passes defined in other languages using the 92 /// MLIR C-interface 93 class ExternalPass : public Pass { 94 public: 95 ExternalPass(TypeID passID, StringRef name, StringRef argument, 96 StringRef description, Optional<StringRef> opName, 97 ArrayRef<MlirDialectHandle> dependentDialects, 98 MlirExternalPassCallbacks callbacks, void *userData) 99 : Pass(passID, opName), id(passID), name(name), argument(argument), 100 description(description), dependentDialects(dependentDialects), 101 callbacks(callbacks), userData(userData) { 102 callbacks.construct(userData); 103 } 104 105 ~ExternalPass() override { callbacks.destruct(userData); } 106 107 StringRef getName() const override { return name; } 108 StringRef getArgument() const override { return argument; } 109 StringRef getDescription() const override { return description; } 110 111 void getDependentDialects(DialectRegistry ®istry) const override { 112 MlirDialectRegistry cRegistry = wrap(®istry); 113 for (MlirDialectHandle dialect : dependentDialects) 114 mlirDialectHandleInsertDialect(dialect, cRegistry); 115 } 116 117 void signalPassFailure() { Pass::signalPassFailure(); } 118 119 protected: 120 LogicalResult initialize(MLIRContext *ctx) override { 121 if (callbacks.initialize) 122 return unwrap(callbacks.initialize(wrap(ctx), userData)); 123 return success(); 124 } 125 126 bool canScheduleOn(RegisteredOperationName opName) const override { 127 if (Optional<StringRef> specifiedOpName = getOpName()) 128 return opName.getStringRef() == specifiedOpName; 129 return true; 130 } 131 132 void runOnOperation() override { 133 callbacks.run(wrap(getOperation()), wrap(this), userData); 134 } 135 136 std::unique_ptr<Pass> clonePass() const override { 137 void *clonedUserData = callbacks.clone(userData); 138 return std::make_unique<ExternalPass>(id, name, argument, description, 139 getOpName(), dependentDialects, 140 callbacks, clonedUserData); 141 } 142 143 private: 144 TypeID id; 145 std::string name; 146 std::string argument; 147 std::string description; 148 std::vector<MlirDialectHandle> dependentDialects; 149 MlirExternalPassCallbacks callbacks; 150 void *userData; 151 }; 152 } // namespace mlir 153 154 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, 155 MlirStringRef argument, 156 MlirStringRef description, MlirStringRef opName, 157 intptr_t nDependentDialects, 158 MlirDialectHandle *dependentDialects, 159 MlirExternalPassCallbacks callbacks, 160 void *userData) { 161 return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass( 162 unwrap(passID), unwrap(name), unwrap(argument), unwrap(description), 163 opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None, 164 {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks, 165 userData))); 166 } 167 168 void mlirExternalPassSignalFailure(MlirExternalPass pass) { 169 unwrap(pass)->signalPassFailure(); 170 } 171