1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Module.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/Parser.h" 21 22 using namespace mlir; 23 24 /* ========================================================================== */ 25 /* Context API. */ 26 /* ========================================================================== */ 27 28 MlirContext mlirContextCreate() { 29 auto *context = new MLIRContext; 30 return wrap(context); 31 } 32 33 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 34 return unwrap(ctx1) == unwrap(ctx2); 35 } 36 37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 38 39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) { 40 unwrap(context)->allowUnregisteredDialects(allow); 41 } 42 43 int mlirContextGetAllowUnregisteredDialects(MlirContext context) { 44 return unwrap(context)->allowsUnregisteredDialects(); 45 } 46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 47 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 48 } 49 50 // TODO: expose a cheaper way than constructing + sorting a vector only to take 51 // its size. 52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 53 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 54 } 55 56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 57 MlirStringRef name) { 58 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 59 } 60 61 /* ========================================================================== */ 62 /* Dialect API. */ 63 /* ========================================================================== */ 64 65 MlirContext mlirDialectGetContext(MlirDialect dialect) { 66 return wrap(unwrap(dialect)->getContext()); 67 } 68 69 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 70 return unwrap(dialect1) == unwrap(dialect2); 71 } 72 73 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 74 return wrap(unwrap(dialect)->getNamespace()); 75 } 76 77 /* ========================================================================== */ 78 /* Printing flags API. */ 79 /* ========================================================================== */ 80 81 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 82 return wrap(new OpPrintingFlags()); 83 } 84 85 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 86 delete unwrap(flags); 87 } 88 89 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 90 intptr_t largeElementLimit) { 91 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 92 } 93 94 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, 95 int prettyForm) { 96 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); 97 } 98 99 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 100 unwrap(flags)->printGenericOpForm(); 101 } 102 103 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 104 unwrap(flags)->useLocalScope(); 105 } 106 107 /* ========================================================================== */ 108 /* Location API. */ 109 /* ========================================================================== */ 110 111 MlirLocation mlirLocationFileLineColGet(MlirContext context, 112 const char *filename, unsigned line, 113 unsigned col) { 114 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 115 } 116 117 MlirLocation mlirLocationUnknownGet(MlirContext context) { 118 return wrap(UnknownLoc::get(unwrap(context))); 119 } 120 121 MlirContext mlirLocationGetContext(MlirLocation location) { 122 return wrap(unwrap(location).getContext()); 123 } 124 125 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 126 void *userData) { 127 detail::CallbackOstream stream(callback, userData); 128 unwrap(location).print(stream); 129 stream.flush(); 130 } 131 132 /* ========================================================================== */ 133 /* Module API. */ 134 /* ========================================================================== */ 135 136 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 137 return wrap(ModuleOp::create(unwrap(location))); 138 } 139 140 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 141 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 142 if (!owning) 143 return MlirModule{nullptr}; 144 return MlirModule{owning.release().getOperation()}; 145 } 146 147 MlirContext mlirModuleGetContext(MlirModule module) { 148 return wrap(unwrap(module).getContext()); 149 } 150 151 void mlirModuleDestroy(MlirModule module) { 152 // Transfer ownership to an OwningModuleRef so that its destructor is called. 153 OwningModuleRef(unwrap(module)); 154 } 155 156 MlirOperation mlirModuleGetOperation(MlirModule module) { 157 return wrap(unwrap(module).getOperation()); 158 } 159 160 /* ========================================================================== */ 161 /* Operation state API. */ 162 /* ========================================================================== */ 163 164 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 165 MlirOperationState state; 166 state.name = name; 167 state.location = loc; 168 state.nResults = 0; 169 state.results = nullptr; 170 state.nOperands = 0; 171 state.operands = nullptr; 172 state.nRegions = 0; 173 state.regions = nullptr; 174 state.nSuccessors = 0; 175 state.successors = nullptr; 176 state.nAttributes = 0; 177 state.attributes = nullptr; 178 return state; 179 } 180 181 #define APPEND_ELEMS(type, sizeName, elemName) \ 182 state->elemName = \ 183 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 184 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 185 state->sizeName += n; 186 187 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 188 MlirType *results) { 189 APPEND_ELEMS(MlirType, nResults, results); 190 } 191 192 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 193 MlirValue *operands) { 194 APPEND_ELEMS(MlirValue, nOperands, operands); 195 } 196 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 197 MlirRegion *regions) { 198 APPEND_ELEMS(MlirRegion, nRegions, regions); 199 } 200 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 201 MlirBlock *successors) { 202 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 203 } 204 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 205 MlirNamedAttribute *attributes) { 206 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 207 } 208 209 /* ========================================================================== */ 210 /* Operation API. */ 211 /* ========================================================================== */ 212 213 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 214 assert(state); 215 OperationState cppState(unwrap(state->location), state->name); 216 SmallVector<Type, 4> resultStorage; 217 SmallVector<Value, 8> operandStorage; 218 SmallVector<Block *, 2> successorStorage; 219 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 220 cppState.addOperands( 221 unwrapList(state->nOperands, state->operands, operandStorage)); 222 cppState.addSuccessors( 223 unwrapList(state->nSuccessors, state->successors, successorStorage)); 224 225 cppState.attributes.reserve(state->nAttributes); 226 for (intptr_t i = 0; i < state->nAttributes; ++i) 227 cppState.addAttribute(state->attributes[i].name, 228 unwrap(state->attributes[i].attribute)); 229 230 for (intptr_t i = 0; i < state->nRegions; ++i) 231 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 232 233 MlirOperation result = wrap(Operation::create(cppState)); 234 free(state->results); 235 free(state->operands); 236 free(state->successors); 237 free(state->regions); 238 free(state->attributes); 239 return result; 240 } 241 242 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 243 244 int mlirOperationEqual(MlirOperation op, MlirOperation other) { 245 return unwrap(op) == unwrap(other); 246 } 247 248 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 249 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 250 } 251 252 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 253 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 254 } 255 256 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 257 return wrap(unwrap(op)->getNextNode()); 258 } 259 260 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 261 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 262 } 263 264 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 265 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 266 } 267 268 intptr_t mlirOperationGetNumResults(MlirOperation op) { 269 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 270 } 271 272 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 273 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 274 } 275 276 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 277 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 278 } 279 280 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 281 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 282 } 283 284 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 285 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 286 } 287 288 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 289 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 290 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 291 } 292 293 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 294 const char *name) { 295 return wrap(unwrap(op)->getAttr(name)); 296 } 297 298 void mlirOperationSetAttributeByName(MlirOperation op, const char *name, 299 MlirAttribute attr) { 300 unwrap(op)->setAttr(name, unwrap(attr)); 301 } 302 303 int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) { 304 auto removeResult = unwrap(op)->removeAttr(name); 305 return removeResult == MutableDictionaryAttr::RemoveResult::Removed; 306 } 307 308 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 309 void *userData) { 310 detail::CallbackOstream stream(callback, userData); 311 unwrap(op)->print(stream); 312 stream.flush(); 313 } 314 315 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 316 MlirStringCallback callback, void *userData) { 317 detail::CallbackOstream stream(callback, userData); 318 unwrap(op)->print(stream, *unwrap(flags)); 319 stream.flush(); 320 } 321 322 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 323 324 /* ========================================================================== */ 325 /* Region API. */ 326 /* ========================================================================== */ 327 328 MlirRegion mlirRegionCreate() { return wrap(new Region); } 329 330 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 331 Region *cppRegion = unwrap(region); 332 if (cppRegion->empty()) 333 return wrap(static_cast<Block *>(nullptr)); 334 return wrap(&cppRegion->front()); 335 } 336 337 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 338 unwrap(region)->push_back(unwrap(block)); 339 } 340 341 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 342 MlirBlock block) { 343 auto &blockList = unwrap(region)->getBlocks(); 344 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 345 } 346 347 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 348 MlirBlock block) { 349 Region *cppRegion = unwrap(region); 350 if (mlirBlockIsNull(reference)) { 351 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 352 return; 353 } 354 355 assert(unwrap(reference)->getParent() == unwrap(region) && 356 "expected reference block to belong to the region"); 357 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 358 unwrap(block)); 359 } 360 361 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 362 MlirBlock block) { 363 if (mlirBlockIsNull(reference)) 364 return mlirRegionAppendOwnedBlock(region, block); 365 366 assert(unwrap(reference)->getParent() == unwrap(region) && 367 "expected reference block to belong to the region"); 368 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 369 unwrap(block)); 370 } 371 372 void mlirRegionDestroy(MlirRegion region) { 373 delete static_cast<Region *>(region.ptr); 374 } 375 376 /* ========================================================================== */ 377 /* Block API. */ 378 /* ========================================================================== */ 379 380 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 381 Block *b = new Block; 382 for (intptr_t i = 0; i < nArgs; ++i) 383 b->addArgument(unwrap(args[i])); 384 return wrap(b); 385 } 386 387 int mlirBlockEqual(MlirBlock block, MlirBlock other) { 388 return unwrap(block) == unwrap(other); 389 } 390 391 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 392 return wrap(unwrap(block)->getNextNode()); 393 } 394 395 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 396 Block *cppBlock = unwrap(block); 397 if (cppBlock->empty()) 398 return wrap(static_cast<Operation *>(nullptr)); 399 return wrap(&cppBlock->front()); 400 } 401 402 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 403 unwrap(block)->push_back(unwrap(operation)); 404 } 405 406 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 407 MlirOperation operation) { 408 auto &opList = unwrap(block)->getOperations(); 409 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 410 } 411 412 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 413 MlirOperation reference, 414 MlirOperation operation) { 415 Block *cppBlock = unwrap(block); 416 if (mlirOperationIsNull(reference)) { 417 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 418 return; 419 } 420 421 assert(unwrap(reference)->getBlock() == unwrap(block) && 422 "expected reference operation to belong to the block"); 423 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 424 unwrap(operation)); 425 } 426 427 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 428 MlirOperation reference, 429 MlirOperation operation) { 430 if (mlirOperationIsNull(reference)) 431 return mlirBlockAppendOwnedOperation(block, operation); 432 433 assert(unwrap(reference)->getBlock() == unwrap(block) && 434 "expected reference operation to belong to the block"); 435 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 436 unwrap(operation)); 437 } 438 439 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 440 441 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 442 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 443 } 444 445 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 446 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 447 } 448 449 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 450 void *userData) { 451 detail::CallbackOstream stream(callback, userData); 452 unwrap(block)->print(stream); 453 stream.flush(); 454 } 455 456 /* ========================================================================== */ 457 /* Value API. */ 458 /* ========================================================================== */ 459 460 int mlirValueIsABlockArgument(MlirValue value) { 461 return unwrap(value).isa<BlockArgument>(); 462 } 463 464 int mlirValueIsAOpResult(MlirValue value) { 465 return unwrap(value).isa<OpResult>(); 466 } 467 468 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 469 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 470 } 471 472 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 473 return static_cast<intptr_t>( 474 unwrap(value).cast<BlockArgument>().getArgNumber()); 475 } 476 477 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 478 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 479 } 480 481 MlirOperation mlirOpResultGetOwner(MlirValue value) { 482 return wrap(unwrap(value).cast<OpResult>().getOwner()); 483 } 484 485 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 486 return static_cast<intptr_t>( 487 unwrap(value).cast<OpResult>().getResultNumber()); 488 } 489 490 MlirType mlirValueGetType(MlirValue value) { 491 return wrap(unwrap(value).getType()); 492 } 493 494 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 495 496 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 497 void *userData) { 498 detail::CallbackOstream stream(callback, userData); 499 unwrap(value).print(stream); 500 stream.flush(); 501 } 502 503 /* ========================================================================== */ 504 /* Type API. */ 505 /* ========================================================================== */ 506 507 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 508 return wrap(mlir::parseType(type, unwrap(context))); 509 } 510 511 MlirContext mlirTypeGetContext(MlirType type) { 512 return wrap(unwrap(type).getContext()); 513 } 514 515 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } 516 517 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 518 detail::CallbackOstream stream(callback, userData); 519 unwrap(type).print(stream); 520 stream.flush(); 521 } 522 523 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 524 525 /* ========================================================================== */ 526 /* Attribute API. */ 527 /* ========================================================================== */ 528 529 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 530 return wrap(mlir::parseAttribute(attr, unwrap(context))); 531 } 532 533 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 534 return wrap(unwrap(attribute).getContext()); 535 } 536 537 MlirType mlirAttributeGetType(MlirAttribute attribute) { 538 return wrap(unwrap(attribute).getType()); 539 } 540 541 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 542 return unwrap(a1) == unwrap(a2); 543 } 544 545 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 546 void *userData) { 547 detail::CallbackOstream stream(callback, userData); 548 unwrap(attr).print(stream); 549 stream.flush(); 550 } 551 552 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 553 554 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 555 return MlirNamedAttribute{name, attr}; 556 } 557