1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 11 #include "mlir/IR/Attributes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/Module.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/IR/Types.h" 16 #include "mlir/InitAllDialects.h" 17 #include "mlir/Parser.h" 18 #include "llvm/Support/raw_ostream.h" 19 20 using namespace mlir; 21 22 /* ========================================================================== */ 23 /* Definitions of methods for non-owning structures used in C API. */ 24 /* ========================================================================== */ 25 26 #define DEFINE_C_API_PTR_METHODS(name, cpptype) \ 27 static name wrap(cpptype *cpp) { return name{cpp}; } \ 28 static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); } 29 30 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext) 31 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation) 32 DEFINE_C_API_PTR_METHODS(MlirBlock, Block) 33 DEFINE_C_API_PTR_METHODS(MlirRegion, Region) 34 35 #define DEFINE_C_API_METHODS(name, cpptype) \ 36 static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; } \ 37 static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); } 38 39 DEFINE_C_API_METHODS(MlirAttribute, Attribute) 40 DEFINE_C_API_METHODS(MlirLocation, Location); 41 DEFINE_C_API_METHODS(MlirType, Type) 42 DEFINE_C_API_METHODS(MlirValue, Value) 43 DEFINE_C_API_METHODS(MlirModule, ModuleOp) 44 45 template <typename CppTy, typename CTy> 46 static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first, 47 SmallVectorImpl<CppTy> &storage) { 48 static_assert( 49 std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value, 50 "incompatible C and C++ types"); 51 52 if (size == 0) 53 return llvm::None; 54 55 assert(storage.empty() && "expected to populate storage"); 56 storage.reserve(size); 57 for (intptr_t i = 0; i < size; ++i) 58 storage.push_back(unwrap(*(first + i))); 59 return storage; 60 } 61 62 /* ========================================================================== */ 63 /* Printing helper. */ 64 /* ========================================================================== */ 65 66 namespace { 67 /// A simple raw ostream subclass that forwards write_impl calls to the 68 /// user-supplied callback together with opaque user-supplied data. 69 class CallbackOstream : public llvm::raw_ostream { 70 public: 71 CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback, 72 void *opaqueData) 73 : callback(callback), opaqueData(opaqueData), pos(0u) {} 74 75 void write_impl(const char *ptr, size_t size) override { 76 callback(ptr, size, opaqueData); 77 pos += size; 78 } 79 80 uint64_t current_pos() const override { return pos; } 81 82 private: 83 std::function<void(const char *, intptr_t, void *)> callback; 84 void *opaqueData; 85 uint64_t pos; 86 }; 87 } // end namespace 88 89 /* ========================================================================== */ 90 /* Context API. */ 91 /* ========================================================================== */ 92 93 MlirContext mlirContextCreate() { 94 auto *context = new MLIRContext(false); 95 return wrap(context); 96 } 97 98 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 99 100 void mlirContextLoadAllDialects(MlirContext context) { 101 registerAllDialects(unwrap(context)); 102 getGlobalDialectRegistry().loadAll(unwrap(context)); 103 } 104 105 /* ========================================================================== */ 106 /* Location API. */ 107 /* ========================================================================== */ 108 109 MlirLocation mlirLocationFileLineColGet(MlirContext context, 110 const char *filename, unsigned line, 111 unsigned col) { 112 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 113 } 114 115 MlirLocation mlirLocationUnknownGet(MlirContext context) { 116 return wrap(UnknownLoc::get(unwrap(context))); 117 } 118 119 void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, 120 void *userData) { 121 CallbackOstream stream(callback, userData); 122 unwrap(location).print(stream); 123 stream.flush(); 124 } 125 126 /* ========================================================================== */ 127 /* Module API. */ 128 /* ========================================================================== */ 129 130 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 131 return wrap(ModuleOp::create(unwrap(location))); 132 } 133 134 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 135 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 136 return MlirModule{owning.release().getOperation()}; 137 } 138 139 void mlirModuleDestroy(MlirModule module) { 140 // Transfer ownership to an OwningModuleRef so that its destructor is called. 141 OwningModuleRef(unwrap(module)); 142 } 143 144 MlirOperation mlirModuleGetOperation(MlirModule module) { 145 return wrap(unwrap(module).getOperation()); 146 } 147 148 /* ========================================================================== */ 149 /* Operation state API. */ 150 /* ========================================================================== */ 151 152 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 153 MlirOperationState state; 154 state.name = name; 155 state.location = loc; 156 state.nResults = 0; 157 state.results = nullptr; 158 state.nOperands = 0; 159 state.operands = nullptr; 160 state.nRegions = 0; 161 state.regions = nullptr; 162 state.nSuccessors = 0; 163 state.successors = nullptr; 164 state.nAttributes = 0; 165 state.attributes = nullptr; 166 return state; 167 } 168 169 #define APPEND_ELEMS(type, sizeName, elemName) \ 170 state->elemName = \ 171 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 172 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 173 state->sizeName += n; 174 175 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 176 MlirType *results) { 177 APPEND_ELEMS(MlirType, nResults, results); 178 } 179 180 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 181 MlirValue *operands) { 182 APPEND_ELEMS(MlirValue, nOperands, operands); 183 } 184 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 185 MlirRegion *regions) { 186 APPEND_ELEMS(MlirRegion, nRegions, regions); 187 } 188 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 189 MlirBlock *successors) { 190 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 191 } 192 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 193 MlirNamedAttribute *attributes) { 194 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 195 } 196 197 /* ========================================================================== */ 198 /* Operation API. */ 199 /* ========================================================================== */ 200 201 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 202 assert(state); 203 OperationState cppState(unwrap(state->location), state->name); 204 SmallVector<Type, 4> resultStorage; 205 SmallVector<Value, 8> operandStorage; 206 SmallVector<Block *, 2> successorStorage; 207 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 208 cppState.addOperands( 209 unwrapList(state->nOperands, state->operands, operandStorage)); 210 cppState.addSuccessors( 211 unwrapList(state->nSuccessors, state->successors, successorStorage)); 212 213 cppState.attributes.reserve(state->nAttributes); 214 for (intptr_t i = 0; i < state->nAttributes; ++i) 215 cppState.addAttribute(state->attributes[i].name, 216 unwrap(state->attributes[i].attribute)); 217 218 for (intptr_t i = 0; i < state->nRegions; ++i) 219 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 220 221 MlirOperation result = wrap(Operation::create(cppState)); 222 free(state->results); 223 free(state->operands); 224 free(state->successors); 225 free(state->regions); 226 free(state->attributes); 227 return result; 228 } 229 230 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 231 232 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } 233 234 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 235 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 236 } 237 238 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 239 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 240 } 241 242 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 243 return wrap(unwrap(op)->getNextNode()); 244 } 245 246 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 247 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 248 } 249 250 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 251 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 252 } 253 254 intptr_t mlirOperationGetNumResults(MlirOperation op) { 255 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 256 } 257 258 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 259 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 260 } 261 262 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 263 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 264 } 265 266 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 267 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 268 } 269 270 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 271 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 272 } 273 274 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 275 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 276 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 277 } 278 279 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 280 const char *name) { 281 return wrap(unwrap(op)->getAttr(name)); 282 } 283 284 void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, 285 void *userData) { 286 CallbackOstream stream(callback, userData); 287 unwrap(op)->print(stream); 288 stream.flush(); 289 } 290 291 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 292 293 /* ========================================================================== */ 294 /* Region API. */ 295 /* ========================================================================== */ 296 297 MlirRegion mlirRegionCreate() { return wrap(new Region); } 298 299 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 300 Region *cppRegion = unwrap(region); 301 if (cppRegion->empty()) 302 return wrap(static_cast<Block *>(nullptr)); 303 return wrap(&cppRegion->front()); 304 } 305 306 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 307 unwrap(region)->push_back(unwrap(block)); 308 } 309 310 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 311 MlirBlock block) { 312 auto &blockList = unwrap(region)->getBlocks(); 313 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 314 } 315 316 void mlirRegionDestroy(MlirRegion region) { 317 delete static_cast<Region *>(region.ptr); 318 } 319 320 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } 321 322 /* ========================================================================== */ 323 /* Block API. */ 324 /* ========================================================================== */ 325 326 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 327 Block *b = new Block; 328 for (intptr_t i = 0; i < nArgs; ++i) 329 b->addArgument(unwrap(args[i])); 330 return wrap(b); 331 } 332 333 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 334 return wrap(unwrap(block)->getNextNode()); 335 } 336 337 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 338 Block *cppBlock = unwrap(block); 339 if (cppBlock->empty()) 340 return wrap(static_cast<Operation *>(nullptr)); 341 return wrap(&cppBlock->front()); 342 } 343 344 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 345 unwrap(block)->push_back(unwrap(operation)); 346 } 347 348 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 349 MlirOperation operation) { 350 auto &opList = unwrap(block)->getOperations(); 351 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 352 } 353 354 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 355 356 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } 357 358 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 359 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 360 } 361 362 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 363 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 364 } 365 366 void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, 367 void *userData) { 368 CallbackOstream stream(callback, userData); 369 unwrap(block)->print(stream); 370 stream.flush(); 371 } 372 373 /* ========================================================================== */ 374 /* Value API. */ 375 /* ========================================================================== */ 376 377 MlirType mlirValueGetType(MlirValue value) { 378 return wrap(unwrap(value).getType()); 379 } 380 381 void mlirValuePrint(MlirValue value, MlirPrintCallback callback, 382 void *userData) { 383 CallbackOstream stream(callback, userData); 384 unwrap(value).print(stream); 385 stream.flush(); 386 } 387 388 /* ========================================================================== */ 389 /* Type API. */ 390 /* ========================================================================== */ 391 392 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 393 return wrap(mlir::parseType(type, unwrap(context))); 394 } 395 396 void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) { 397 CallbackOstream stream(callback, userData); 398 unwrap(type).print(stream); 399 stream.flush(); 400 } 401 402 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 403 404 /* ========================================================================== */ 405 /* Attribute API. */ 406 /* ========================================================================== */ 407 408 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 409 return wrap(mlir::parseAttribute(attr, unwrap(context))); 410 } 411 412 void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, 413 void *userData) { 414 CallbackOstream stream(callback, userData); 415 unwrap(attr).print(stream); 416 stream.flush(); 417 } 418 419 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 420 421 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 422 return MlirNamedAttribute{name, attr}; 423 } 424