1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Module.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/Parser.h" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Context API. 26 //===----------------------------------------------------------------------===// 27 28 MlirContext mlirContextCreate() { 29 auto *context = new MLIRContext; 30 return wrap(context); 31 } 32 33 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 34 return unwrap(ctx1) == unwrap(ctx2); 35 } 36 37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 38 39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) { 40 unwrap(context)->allowUnregisteredDialects(allow); 41 } 42 43 int mlirContextGetAllowUnregisteredDialects(MlirContext context) { 44 return unwrap(context)->allowsUnregisteredDialects(); 45 } 46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 47 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 48 } 49 50 // TODO: expose a cheaper way than constructing + sorting a vector only to take 51 // its size. 52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 53 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 54 } 55 56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 57 MlirStringRef name) { 58 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // Dialect API. 63 //===----------------------------------------------------------------------===// 64 65 MlirContext mlirDialectGetContext(MlirDialect dialect) { 66 return wrap(unwrap(dialect)->getContext()); 67 } 68 69 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 70 return unwrap(dialect1) == unwrap(dialect2); 71 } 72 73 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 74 return wrap(unwrap(dialect)->getNamespace()); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Printing flags API. 79 //===----------------------------------------------------------------------===// 80 81 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 82 return wrap(new OpPrintingFlags()); 83 } 84 85 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 86 delete unwrap(flags); 87 } 88 89 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 90 intptr_t largeElementLimit) { 91 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 92 } 93 94 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, 95 int prettyForm) { 96 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); 97 } 98 99 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 100 unwrap(flags)->printGenericOpForm(); 101 } 102 103 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 104 unwrap(flags)->useLocalScope(); 105 } 106 107 //===----------------------------------------------------------------------===// 108 // Location API. 109 //===----------------------------------------------------------------------===// 110 111 MlirLocation mlirLocationFileLineColGet(MlirContext context, 112 const char *filename, unsigned line, 113 unsigned col) { 114 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 115 } 116 117 MlirLocation mlirLocationUnknownGet(MlirContext context) { 118 return wrap(UnknownLoc::get(unwrap(context))); 119 } 120 121 MlirContext mlirLocationGetContext(MlirLocation location) { 122 return wrap(unwrap(location).getContext()); 123 } 124 125 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 126 void *userData) { 127 detail::CallbackOstream stream(callback, userData); 128 unwrap(location).print(stream); 129 } 130 131 //===----------------------------------------------------------------------===// 132 // Module API. 133 //===----------------------------------------------------------------------===// 134 135 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 136 return wrap(ModuleOp::create(unwrap(location))); 137 } 138 139 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 140 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 141 if (!owning) 142 return MlirModule{nullptr}; 143 return MlirModule{owning.release().getOperation()}; 144 } 145 146 MlirContext mlirModuleGetContext(MlirModule module) { 147 return wrap(unwrap(module).getContext()); 148 } 149 150 MlirBlock mlirModuleGetBody(MlirModule module) { 151 return wrap(unwrap(module).getBody()); 152 } 153 154 void mlirModuleDestroy(MlirModule module) { 155 // Transfer ownership to an OwningModuleRef so that its destructor is called. 156 OwningModuleRef(unwrap(module)); 157 } 158 159 MlirOperation mlirModuleGetOperation(MlirModule module) { 160 return wrap(unwrap(module).getOperation()); 161 } 162 163 //===----------------------------------------------------------------------===// 164 // Operation state API. 165 //===----------------------------------------------------------------------===// 166 167 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 168 MlirOperationState state; 169 state.name = name; 170 state.location = loc; 171 state.nResults = 0; 172 state.results = nullptr; 173 state.nOperands = 0; 174 state.operands = nullptr; 175 state.nRegions = 0; 176 state.regions = nullptr; 177 state.nSuccessors = 0; 178 state.successors = nullptr; 179 state.nAttributes = 0; 180 state.attributes = nullptr; 181 return state; 182 } 183 184 #define APPEND_ELEMS(type, sizeName, elemName) \ 185 state->elemName = \ 186 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 187 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 188 state->sizeName += n; 189 190 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 191 MlirType *results) { 192 APPEND_ELEMS(MlirType, nResults, results); 193 } 194 195 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 196 MlirValue *operands) { 197 APPEND_ELEMS(MlirValue, nOperands, operands); 198 } 199 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 200 MlirRegion *regions) { 201 APPEND_ELEMS(MlirRegion, nRegions, regions); 202 } 203 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 204 MlirBlock *successors) { 205 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 206 } 207 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 208 MlirNamedAttribute *attributes) { 209 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 210 } 211 212 //===----------------------------------------------------------------------===// 213 // Operation API. 214 //===----------------------------------------------------------------------===// 215 216 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 217 assert(state); 218 OperationState cppState(unwrap(state->location), state->name); 219 SmallVector<Type, 4> resultStorage; 220 SmallVector<Value, 8> operandStorage; 221 SmallVector<Block *, 2> successorStorage; 222 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 223 cppState.addOperands( 224 unwrapList(state->nOperands, state->operands, operandStorage)); 225 cppState.addSuccessors( 226 unwrapList(state->nSuccessors, state->successors, successorStorage)); 227 228 cppState.attributes.reserve(state->nAttributes); 229 for (intptr_t i = 0; i < state->nAttributes; ++i) 230 cppState.addAttribute(state->attributes[i].name, 231 unwrap(state->attributes[i].attribute)); 232 233 for (intptr_t i = 0; i < state->nRegions; ++i) 234 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 235 236 MlirOperation result = wrap(Operation::create(cppState)); 237 free(state->results); 238 free(state->operands); 239 free(state->successors); 240 free(state->regions); 241 free(state->attributes); 242 return result; 243 } 244 245 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 246 247 int mlirOperationEqual(MlirOperation op, MlirOperation other) { 248 return unwrap(op) == unwrap(other); 249 } 250 251 MlirIdentifier mlirOperationGetName(MlirOperation op) { 252 return wrap(unwrap(op)->getName().getIdentifier()); 253 } 254 255 MlirBlock mlirOperationGetBlock(MlirOperation op) { 256 return wrap(unwrap(op)->getBlock()); 257 } 258 259 MlirOperation mlirOperationGetParentOperation(MlirOperation op) { 260 return wrap(unwrap(op)->getParentOp()); 261 } 262 263 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 264 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 265 } 266 267 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 268 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 269 } 270 271 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 272 return wrap(unwrap(op)->getNextNode()); 273 } 274 275 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 276 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 277 } 278 279 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 280 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 281 } 282 283 intptr_t mlirOperationGetNumResults(MlirOperation op) { 284 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 285 } 286 287 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 288 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 289 } 290 291 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 292 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 293 } 294 295 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 296 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 297 } 298 299 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 300 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 301 } 302 303 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 304 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 305 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 306 } 307 308 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 309 const char *name) { 310 return wrap(unwrap(op)->getAttr(name)); 311 } 312 313 void mlirOperationSetAttributeByName(MlirOperation op, const char *name, 314 MlirAttribute attr) { 315 unwrap(op)->setAttr(name, unwrap(attr)); 316 } 317 318 int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) { 319 auto removeResult = unwrap(op)->removeAttr(name); 320 return removeResult == MutableDictionaryAttr::RemoveResult::Removed; 321 } 322 323 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 324 void *userData) { 325 detail::CallbackOstream stream(callback, userData); 326 unwrap(op)->print(stream); 327 } 328 329 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 330 MlirStringCallback callback, void *userData) { 331 detail::CallbackOstream stream(callback, userData); 332 unwrap(op)->print(stream, *unwrap(flags)); 333 } 334 335 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 336 337 //===----------------------------------------------------------------------===// 338 // Region API. 339 //===----------------------------------------------------------------------===// 340 341 MlirRegion mlirRegionCreate() { return wrap(new Region); } 342 343 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 344 Region *cppRegion = unwrap(region); 345 if (cppRegion->empty()) 346 return wrap(static_cast<Block *>(nullptr)); 347 return wrap(&cppRegion->front()); 348 } 349 350 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 351 unwrap(region)->push_back(unwrap(block)); 352 } 353 354 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 355 MlirBlock block) { 356 auto &blockList = unwrap(region)->getBlocks(); 357 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 358 } 359 360 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 361 MlirBlock block) { 362 Region *cppRegion = unwrap(region); 363 if (mlirBlockIsNull(reference)) { 364 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 365 return; 366 } 367 368 assert(unwrap(reference)->getParent() == unwrap(region) && 369 "expected reference block to belong to the region"); 370 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 371 unwrap(block)); 372 } 373 374 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 375 MlirBlock block) { 376 if (mlirBlockIsNull(reference)) 377 return mlirRegionAppendOwnedBlock(region, block); 378 379 assert(unwrap(reference)->getParent() == unwrap(region) && 380 "expected reference block to belong to the region"); 381 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 382 unwrap(block)); 383 } 384 385 void mlirRegionDestroy(MlirRegion region) { 386 delete static_cast<Region *>(region.ptr); 387 } 388 389 //===----------------------------------------------------------------------===// 390 // Block API. 391 //===----------------------------------------------------------------------===// 392 393 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 394 Block *b = new Block; 395 for (intptr_t i = 0; i < nArgs; ++i) 396 b->addArgument(unwrap(args[i])); 397 return wrap(b); 398 } 399 400 int mlirBlockEqual(MlirBlock block, MlirBlock other) { 401 return unwrap(block) == unwrap(other); 402 } 403 404 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 405 return wrap(unwrap(block)->getNextNode()); 406 } 407 408 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 409 Block *cppBlock = unwrap(block); 410 if (cppBlock->empty()) 411 return wrap(static_cast<Operation *>(nullptr)); 412 return wrap(&cppBlock->front()); 413 } 414 415 MlirOperation mlirBlockGetTerminator(MlirBlock block) { 416 Block *cppBlock = unwrap(block); 417 if (cppBlock->empty()) 418 return wrap(static_cast<Operation *>(nullptr)); 419 Operation &back = cppBlock->back(); 420 if (!back.isKnownTerminator()) 421 return wrap(static_cast<Operation *>(nullptr)); 422 return wrap(&back); 423 } 424 425 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 426 unwrap(block)->push_back(unwrap(operation)); 427 } 428 429 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 430 MlirOperation operation) { 431 auto &opList = unwrap(block)->getOperations(); 432 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 433 } 434 435 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 436 MlirOperation reference, 437 MlirOperation operation) { 438 Block *cppBlock = unwrap(block); 439 if (mlirOperationIsNull(reference)) { 440 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 441 return; 442 } 443 444 assert(unwrap(reference)->getBlock() == unwrap(block) && 445 "expected reference operation to belong to the block"); 446 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 447 unwrap(operation)); 448 } 449 450 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 451 MlirOperation reference, 452 MlirOperation operation) { 453 if (mlirOperationIsNull(reference)) 454 return mlirBlockAppendOwnedOperation(block, operation); 455 456 assert(unwrap(reference)->getBlock() == unwrap(block) && 457 "expected reference operation to belong to the block"); 458 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 459 unwrap(operation)); 460 } 461 462 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 463 464 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 465 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 466 } 467 468 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 469 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 470 } 471 472 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 473 void *userData) { 474 detail::CallbackOstream stream(callback, userData); 475 unwrap(block)->print(stream); 476 } 477 478 //===----------------------------------------------------------------------===// 479 // Value API. 480 //===----------------------------------------------------------------------===// 481 482 int mlirValueEqual(MlirValue value1, MlirValue value2) { 483 return unwrap(value1) == unwrap(value2); 484 } 485 486 int mlirValueIsABlockArgument(MlirValue value) { 487 return unwrap(value).isa<BlockArgument>(); 488 } 489 490 int mlirValueIsAOpResult(MlirValue value) { 491 return unwrap(value).isa<OpResult>(); 492 } 493 494 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 495 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 496 } 497 498 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 499 return static_cast<intptr_t>( 500 unwrap(value).cast<BlockArgument>().getArgNumber()); 501 } 502 503 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 504 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 505 } 506 507 MlirOperation mlirOpResultGetOwner(MlirValue value) { 508 return wrap(unwrap(value).cast<OpResult>().getOwner()); 509 } 510 511 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 512 return static_cast<intptr_t>( 513 unwrap(value).cast<OpResult>().getResultNumber()); 514 } 515 516 MlirType mlirValueGetType(MlirValue value) { 517 return wrap(unwrap(value).getType()); 518 } 519 520 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 521 522 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 523 void *userData) { 524 detail::CallbackOstream stream(callback, userData); 525 unwrap(value).print(stream); 526 } 527 528 //===----------------------------------------------------------------------===// 529 // Type API. 530 //===----------------------------------------------------------------------===// 531 532 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 533 return wrap(mlir::parseType(type, unwrap(context))); 534 } 535 536 MlirContext mlirTypeGetContext(MlirType type) { 537 return wrap(unwrap(type).getContext()); 538 } 539 540 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } 541 542 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 543 detail::CallbackOstream stream(callback, userData); 544 unwrap(type).print(stream); 545 } 546 547 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 548 549 //===----------------------------------------------------------------------===// 550 // Attribute API. 551 //===----------------------------------------------------------------------===// 552 553 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 554 return wrap(mlir::parseAttribute(attr, unwrap(context))); 555 } 556 557 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 558 return wrap(unwrap(attribute).getContext()); 559 } 560 561 MlirType mlirAttributeGetType(MlirAttribute attribute) { 562 return wrap(unwrap(attribute).getType()); 563 } 564 565 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 566 return unwrap(a1) == unwrap(a2); 567 } 568 569 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 570 void *userData) { 571 detail::CallbackOstream stream(callback, userData); 572 unwrap(attr).print(stream); 573 } 574 575 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 576 577 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 578 return MlirNamedAttribute{name, attr}; 579 } 580 581 //===----------------------------------------------------------------------===// 582 // Identifier API. 583 //===----------------------------------------------------------------------===// 584 585 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { 586 return wrap(Identifier::get(unwrap(str), unwrap(context))); 587 } 588 589 int mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { 590 return unwrap(ident) == unwrap(other); 591 } 592 593 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { 594 return wrap(unwrap(ident).strref()); 595 } 596