1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Dialect.h" 18 #include "mlir/IR/Location.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/IR/Types.h" 21 #include "mlir/IR/Verifier.h" 22 #include "mlir/Interfaces/InferTypeOpInterface.h" 23 #include "mlir/Parser/Parser.h" 24 25 #include "llvm/Support/Debug.h" 26 #include <cstddef> 27 28 using namespace mlir; 29 30 //===----------------------------------------------------------------------===// 31 // Context API. 32 //===----------------------------------------------------------------------===// 33 34 MlirContext mlirContextCreate() { 35 auto *context = new MLIRContext; 36 return wrap(context); 37 } 38 39 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 40 return unwrap(ctx1) == unwrap(ctx2); 41 } 42 43 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 44 45 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { 46 unwrap(context)->allowUnregisteredDialects(allow); 47 } 48 49 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { 50 return unwrap(context)->allowsUnregisteredDialects(); 51 } 52 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 53 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 54 } 55 56 void mlirContextAppendDialectRegistry(MlirContext ctx, 57 MlirDialectRegistry registry) { 58 unwrap(ctx)->appendDialectRegistry(*unwrap(registry)); 59 } 60 61 // TODO: expose a cheaper way than constructing + sorting a vector only to take 62 // its size. 63 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 64 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 65 } 66 67 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 68 MlirStringRef name) { 69 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 70 } 71 72 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { 73 return unwrap(context)->isOperationRegistered(unwrap(name)); 74 } 75 76 void mlirContextEnableMultithreading(MlirContext context, bool enable) { 77 return unwrap(context)->enableMultithreading(enable); 78 } 79 80 void mlirContextLoadAllAvailableDialects(MlirContext context) { 81 unwrap(context)->loadAllAvailableDialects(); 82 } 83 84 //===----------------------------------------------------------------------===// 85 // Dialect API. 86 //===----------------------------------------------------------------------===// 87 88 MlirContext mlirDialectGetContext(MlirDialect dialect) { 89 return wrap(unwrap(dialect)->getContext()); 90 } 91 92 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 93 return unwrap(dialect1) == unwrap(dialect2); 94 } 95 96 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 97 return wrap(unwrap(dialect)->getNamespace()); 98 } 99 100 //===----------------------------------------------------------------------===// 101 // DialectRegistry API. 102 //===----------------------------------------------------------------------===// 103 104 MlirDialectRegistry mlirDialectRegistryCreate() { 105 return wrap(new DialectRegistry()); 106 } 107 108 void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { 109 delete unwrap(registry); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // Printing flags API. 114 //===----------------------------------------------------------------------===// 115 116 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 117 return wrap(new OpPrintingFlags()); 118 } 119 120 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 121 delete unwrap(flags); 122 } 123 124 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 125 intptr_t largeElementLimit) { 126 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 127 } 128 129 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, 130 bool prettyForm) { 131 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); 132 } 133 134 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 135 unwrap(flags)->printGenericOpForm(); 136 } 137 138 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 139 unwrap(flags)->useLocalScope(); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // Location API. 144 //===----------------------------------------------------------------------===// 145 146 MlirLocation mlirLocationFileLineColGet(MlirContext context, 147 MlirStringRef filename, unsigned line, 148 unsigned col) { 149 return wrap(Location( 150 FileLineColLoc::get(unwrap(context), unwrap(filename), line, col))); 151 } 152 153 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { 154 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); 155 } 156 157 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, 158 MlirLocation const *locations, 159 MlirAttribute metadata) { 160 SmallVector<Location, 4> locs; 161 ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs); 162 return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); 163 } 164 165 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, 166 MlirLocation childLoc) { 167 if (mlirLocationIsNull(childLoc)) 168 return wrap( 169 Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); 170 return wrap(Location(NameLoc::get( 171 StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); 172 } 173 174 MlirLocation mlirLocationUnknownGet(MlirContext context) { 175 return wrap(Location(UnknownLoc::get(unwrap(context)))); 176 } 177 178 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { 179 return unwrap(l1) == unwrap(l2); 180 } 181 182 MlirContext mlirLocationGetContext(MlirLocation location) { 183 return wrap(unwrap(location).getContext()); 184 } 185 186 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 187 void *userData) { 188 detail::CallbackOstream stream(callback, userData); 189 unwrap(location).print(stream); 190 } 191 192 //===----------------------------------------------------------------------===// 193 // Module API. 194 //===----------------------------------------------------------------------===// 195 196 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 197 return wrap(ModuleOp::create(unwrap(location))); 198 } 199 200 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { 201 OwningOpRef<ModuleOp> owning = 202 parseSourceString<ModuleOp>(unwrap(module), unwrap(context)); 203 if (!owning) 204 return MlirModule{nullptr}; 205 return MlirModule{owning.release().getOperation()}; 206 } 207 208 MlirContext mlirModuleGetContext(MlirModule module) { 209 return wrap(unwrap(module).getContext()); 210 } 211 212 MlirBlock mlirModuleGetBody(MlirModule module) { 213 return wrap(unwrap(module).getBody()); 214 } 215 216 void mlirModuleDestroy(MlirModule module) { 217 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is 218 // called. 219 OwningOpRef<ModuleOp>(unwrap(module)); 220 } 221 222 MlirOperation mlirModuleGetOperation(MlirModule module) { 223 return wrap(unwrap(module).getOperation()); 224 } 225 226 MlirModule mlirModuleFromOperation(MlirOperation op) { 227 return wrap(dyn_cast<ModuleOp>(unwrap(op))); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // Operation state API. 232 //===----------------------------------------------------------------------===// 233 234 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { 235 MlirOperationState state; 236 state.name = name; 237 state.location = loc; 238 state.nResults = 0; 239 state.results = nullptr; 240 state.nOperands = 0; 241 state.operands = nullptr; 242 state.nRegions = 0; 243 state.regions = nullptr; 244 state.nSuccessors = 0; 245 state.successors = nullptr; 246 state.nAttributes = 0; 247 state.attributes = nullptr; 248 state.enableResultTypeInference = false; 249 return state; 250 } 251 252 #define APPEND_ELEMS(type, sizeName, elemName) \ 253 state->elemName = \ 254 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 255 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 256 state->sizeName += n; 257 258 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 259 MlirType const *results) { 260 APPEND_ELEMS(MlirType, nResults, results); 261 } 262 263 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 264 MlirValue const *operands) { 265 APPEND_ELEMS(MlirValue, nOperands, operands); 266 } 267 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 268 MlirRegion const *regions) { 269 APPEND_ELEMS(MlirRegion, nRegions, regions); 270 } 271 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 272 MlirBlock const *successors) { 273 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 274 } 275 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 276 MlirNamedAttribute const *attributes) { 277 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 278 } 279 280 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { 281 state->enableResultTypeInference = true; 282 } 283 284 //===----------------------------------------------------------------------===// 285 // Operation API. 286 //===----------------------------------------------------------------------===// 287 288 static LogicalResult inferOperationTypes(OperationState &state) { 289 MLIRContext *context = state.getContext(); 290 Optional<RegisteredOperationName> info = state.name.getRegisteredInfo(); 291 if (!info) { 292 emitError(state.location) 293 << "type inference was requested for the operation " << state.name 294 << ", but the operation was not registered. Ensure that the dialect " 295 "containing the operation is linked into MLIR and registered with " 296 "the context"; 297 return failure(); 298 } 299 300 // Fallback to inference via an op interface. 301 auto *inferInterface = info->getInterface<InferTypeOpInterface>(); 302 if (!inferInterface) { 303 emitError(state.location) 304 << "type inference was requested for the operation " << state.name 305 << ", but the operation does not support type inference. Result " 306 "types must be specified explicitly."; 307 return failure(); 308 } 309 310 if (succeeded(inferInterface->inferReturnTypes( 311 context, state.location, state.operands, 312 state.attributes.getDictionary(context), state.regions, state.types))) 313 return success(); 314 315 // Diagnostic emitted by interface. 316 return failure(); 317 } 318 319 MlirOperation mlirOperationCreate(MlirOperationState *state) { 320 assert(state); 321 OperationState cppState(unwrap(state->location), unwrap(state->name)); 322 SmallVector<Type, 4> resultStorage; 323 SmallVector<Value, 8> operandStorage; 324 SmallVector<Block *, 2> successorStorage; 325 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 326 cppState.addOperands( 327 unwrapList(state->nOperands, state->operands, operandStorage)); 328 cppState.addSuccessors( 329 unwrapList(state->nSuccessors, state->successors, successorStorage)); 330 331 cppState.attributes.reserve(state->nAttributes); 332 for (intptr_t i = 0; i < state->nAttributes; ++i) 333 cppState.addAttribute(unwrap(state->attributes[i].name), 334 unwrap(state->attributes[i].attribute)); 335 336 for (intptr_t i = 0; i < state->nRegions; ++i) 337 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 338 339 free(state->results); 340 free(state->operands); 341 free(state->successors); 342 free(state->regions); 343 free(state->attributes); 344 345 // Infer result types. 346 if (state->enableResultTypeInference) { 347 assert(cppState.types.empty() && 348 "result type inference enabled and result types provided"); 349 if (failed(inferOperationTypes(cppState))) 350 return {nullptr}; 351 } 352 353 MlirOperation result = wrap(Operation::create(cppState)); 354 return result; 355 } 356 357 MlirOperation mlirOperationClone(MlirOperation op) { 358 return wrap(unwrap(op)->clone()); 359 } 360 361 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 362 363 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); } 364 365 bool mlirOperationEqual(MlirOperation op, MlirOperation other) { 366 return unwrap(op) == unwrap(other); 367 } 368 369 MlirContext mlirOperationGetContext(MlirOperation op) { 370 return wrap(unwrap(op)->getContext()); 371 } 372 373 MlirLocation mlirOperationGetLocation(MlirOperation op) { 374 return wrap(unwrap(op)->getLoc()); 375 } 376 377 MlirTypeID mlirOperationGetTypeID(MlirOperation op) { 378 if (auto info = unwrap(op)->getRegisteredInfo()) 379 return wrap(info->getTypeID()); 380 return {nullptr}; 381 } 382 383 MlirIdentifier mlirOperationGetName(MlirOperation op) { 384 return wrap(unwrap(op)->getName().getIdentifier()); 385 } 386 387 MlirBlock mlirOperationGetBlock(MlirOperation op) { 388 return wrap(unwrap(op)->getBlock()); 389 } 390 391 MlirOperation mlirOperationGetParentOperation(MlirOperation op) { 392 return wrap(unwrap(op)->getParentOp()); 393 } 394 395 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 396 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 397 } 398 399 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 400 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 401 } 402 403 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { 404 Operation *cppOp = unwrap(op); 405 if (cppOp->getNumRegions() == 0) 406 return wrap(static_cast<Region *>(nullptr)); 407 return wrap(&cppOp->getRegion(0)); 408 } 409 410 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { 411 Region *cppRegion = unwrap(region); 412 Operation *parent = cppRegion->getParentOp(); 413 intptr_t next = cppRegion->getRegionNumber() + 1; 414 if (parent->getNumRegions() > next) 415 return wrap(&parent->getRegion(next)); 416 return wrap(static_cast<Region *>(nullptr)); 417 } 418 419 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 420 return wrap(unwrap(op)->getNextNode()); 421 } 422 423 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 424 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 425 } 426 427 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 428 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 429 } 430 431 void mlirOperationSetOperand(MlirOperation op, intptr_t pos, 432 MlirValue newValue) { 433 unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue)); 434 } 435 436 intptr_t mlirOperationGetNumResults(MlirOperation op) { 437 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 438 } 439 440 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 441 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 442 } 443 444 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 445 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 446 } 447 448 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 449 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 450 } 451 452 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 453 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 454 } 455 456 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 457 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 458 return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; 459 } 460 461 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 462 MlirStringRef name) { 463 return wrap(unwrap(op)->getAttr(unwrap(name))); 464 } 465 466 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, 467 MlirAttribute attr) { 468 unwrap(op)->setAttr(unwrap(name), unwrap(attr)); 469 } 470 471 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { 472 return !!unwrap(op)->removeAttr(unwrap(name)); 473 } 474 475 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 476 void *userData) { 477 detail::CallbackOstream stream(callback, userData); 478 unwrap(op)->print(stream); 479 } 480 481 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 482 MlirStringCallback callback, void *userData) { 483 detail::CallbackOstream stream(callback, userData); 484 unwrap(op)->print(stream, *unwrap(flags)); 485 } 486 487 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 488 489 bool mlirOperationVerify(MlirOperation op) { 490 return succeeded(verify(unwrap(op))); 491 } 492 493 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { 494 return unwrap(op)->moveAfter(unwrap(other)); 495 } 496 497 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { 498 return unwrap(op)->moveBefore(unwrap(other)); 499 } 500 501 //===----------------------------------------------------------------------===// 502 // Region API. 503 //===----------------------------------------------------------------------===// 504 505 MlirRegion mlirRegionCreate() { return wrap(new Region); } 506 507 bool mlirRegionEqual(MlirRegion region, MlirRegion other) { 508 return unwrap(region) == unwrap(other); 509 } 510 511 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 512 Region *cppRegion = unwrap(region); 513 if (cppRegion->empty()) 514 return wrap(static_cast<Block *>(nullptr)); 515 return wrap(&cppRegion->front()); 516 } 517 518 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 519 unwrap(region)->push_back(unwrap(block)); 520 } 521 522 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 523 MlirBlock block) { 524 auto &blockList = unwrap(region)->getBlocks(); 525 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 526 } 527 528 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 529 MlirBlock block) { 530 Region *cppRegion = unwrap(region); 531 if (mlirBlockIsNull(reference)) { 532 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 533 return; 534 } 535 536 assert(unwrap(reference)->getParent() == unwrap(region) && 537 "expected reference block to belong to the region"); 538 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 539 unwrap(block)); 540 } 541 542 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 543 MlirBlock block) { 544 if (mlirBlockIsNull(reference)) 545 return mlirRegionAppendOwnedBlock(region, block); 546 547 assert(unwrap(reference)->getParent() == unwrap(region) && 548 "expected reference block to belong to the region"); 549 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 550 unwrap(block)); 551 } 552 553 void mlirRegionDestroy(MlirRegion region) { 554 delete static_cast<Region *>(region.ptr); 555 } 556 557 //===----------------------------------------------------------------------===// 558 // Block API. 559 //===----------------------------------------------------------------------===// 560 561 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, 562 MlirLocation const *locs) { 563 Block *b = new Block; 564 for (intptr_t i = 0; i < nArgs; ++i) 565 b->addArgument(unwrap(args[i]), unwrap(locs[i])); 566 return wrap(b); 567 } 568 569 bool mlirBlockEqual(MlirBlock block, MlirBlock other) { 570 return unwrap(block) == unwrap(other); 571 } 572 573 MlirOperation mlirBlockGetParentOperation(MlirBlock block) { 574 return wrap(unwrap(block)->getParentOp()); 575 } 576 577 MlirRegion mlirBlockGetParentRegion(MlirBlock block) { 578 return wrap(unwrap(block)->getParent()); 579 } 580 581 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 582 return wrap(unwrap(block)->getNextNode()); 583 } 584 585 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 586 Block *cppBlock = unwrap(block); 587 if (cppBlock->empty()) 588 return wrap(static_cast<Operation *>(nullptr)); 589 return wrap(&cppBlock->front()); 590 } 591 592 MlirOperation mlirBlockGetTerminator(MlirBlock block) { 593 Block *cppBlock = unwrap(block); 594 if (cppBlock->empty()) 595 return wrap(static_cast<Operation *>(nullptr)); 596 Operation &back = cppBlock->back(); 597 if (!back.hasTrait<OpTrait::IsTerminator>()) 598 return wrap(static_cast<Operation *>(nullptr)); 599 return wrap(&back); 600 } 601 602 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 603 unwrap(block)->push_back(unwrap(operation)); 604 } 605 606 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 607 MlirOperation operation) { 608 auto &opList = unwrap(block)->getOperations(); 609 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 610 } 611 612 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 613 MlirOperation reference, 614 MlirOperation operation) { 615 Block *cppBlock = unwrap(block); 616 if (mlirOperationIsNull(reference)) { 617 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 618 return; 619 } 620 621 assert(unwrap(reference)->getBlock() == unwrap(block) && 622 "expected reference operation to belong to the block"); 623 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 624 unwrap(operation)); 625 } 626 627 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 628 MlirOperation reference, 629 MlirOperation operation) { 630 if (mlirOperationIsNull(reference)) 631 return mlirBlockAppendOwnedOperation(block, operation); 632 633 assert(unwrap(reference)->getBlock() == unwrap(block) && 634 "expected reference operation to belong to the block"); 635 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 636 unwrap(operation)); 637 } 638 639 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 640 641 void mlirBlockDetach(MlirBlock block) { 642 Block *b = unwrap(block); 643 b->getParent()->getBlocks().remove(b); 644 } 645 646 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 647 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 648 } 649 650 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, 651 MlirLocation loc) { 652 return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc))); 653 } 654 655 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 656 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 657 } 658 659 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 660 void *userData) { 661 detail::CallbackOstream stream(callback, userData); 662 unwrap(block)->print(stream); 663 } 664 665 //===----------------------------------------------------------------------===// 666 // Value API. 667 //===----------------------------------------------------------------------===// 668 669 bool mlirValueEqual(MlirValue value1, MlirValue value2) { 670 return unwrap(value1) == unwrap(value2); 671 } 672 673 bool mlirValueIsABlockArgument(MlirValue value) { 674 return unwrap(value).isa<BlockArgument>(); 675 } 676 677 bool mlirValueIsAOpResult(MlirValue value) { 678 return unwrap(value).isa<OpResult>(); 679 } 680 681 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 682 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 683 } 684 685 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 686 return static_cast<intptr_t>( 687 unwrap(value).cast<BlockArgument>().getArgNumber()); 688 } 689 690 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 691 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 692 } 693 694 MlirOperation mlirOpResultGetOwner(MlirValue value) { 695 return wrap(unwrap(value).cast<OpResult>().getOwner()); 696 } 697 698 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 699 return static_cast<intptr_t>( 700 unwrap(value).cast<OpResult>().getResultNumber()); 701 } 702 703 MlirType mlirValueGetType(MlirValue value) { 704 return wrap(unwrap(value).getType()); 705 } 706 707 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 708 709 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 710 void *userData) { 711 detail::CallbackOstream stream(callback, userData); 712 unwrap(value).print(stream); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // Type API. 717 //===----------------------------------------------------------------------===// 718 719 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { 720 return wrap(mlir::parseType(unwrap(type), unwrap(context))); 721 } 722 723 MlirContext mlirTypeGetContext(MlirType type) { 724 return wrap(unwrap(type).getContext()); 725 } 726 727 MlirTypeID mlirTypeGetTypeID(MlirType type) { 728 return wrap(unwrap(type).getTypeID()); 729 } 730 731 bool mlirTypeEqual(MlirType t1, MlirType t2) { 732 return unwrap(t1) == unwrap(t2); 733 } 734 735 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 736 detail::CallbackOstream stream(callback, userData); 737 unwrap(type).print(stream); 738 } 739 740 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 741 742 //===----------------------------------------------------------------------===// 743 // Attribute API. 744 //===----------------------------------------------------------------------===// 745 746 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { 747 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context))); 748 } 749 750 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 751 return wrap(unwrap(attribute).getContext()); 752 } 753 754 MlirType mlirAttributeGetType(MlirAttribute attribute) { 755 return wrap(unwrap(attribute).getType()); 756 } 757 758 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { 759 return wrap(unwrap(attr).getTypeID()); 760 } 761 762 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 763 return unwrap(a1) == unwrap(a2); 764 } 765 766 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 767 void *userData) { 768 detail::CallbackOstream stream(callback, userData); 769 unwrap(attr).print(stream); 770 } 771 772 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 773 774 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, 775 MlirAttribute attr) { 776 return MlirNamedAttribute{name, attr}; 777 } 778 779 //===----------------------------------------------------------------------===// 780 // Identifier API. 781 //===----------------------------------------------------------------------===// 782 783 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { 784 return wrap(StringAttr::get(unwrap(context), unwrap(str))); 785 } 786 787 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { 788 return wrap(unwrap(ident).getContext()); 789 } 790 791 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { 792 return unwrap(ident) == unwrap(other); 793 } 794 795 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { 796 return wrap(unwrap(ident).strref()); 797 } 798 799 //===----------------------------------------------------------------------===// 800 // Symbol and SymbolTable API. 801 //===----------------------------------------------------------------------===// 802 803 MlirStringRef mlirSymbolTableGetSymbolAttributeName() { 804 return wrap(SymbolTable::getSymbolAttrName()); 805 } 806 807 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { 808 return wrap(SymbolTable::getVisibilityAttrName()); 809 } 810 811 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { 812 if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>()) 813 return wrap(static_cast<SymbolTable *>(nullptr)); 814 return wrap(new SymbolTable(unwrap(operation))); 815 } 816 817 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { 818 delete unwrap(symbolTable); 819 } 820 821 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, 822 MlirStringRef name) { 823 return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length))); 824 } 825 826 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, 827 MlirOperation operation) { 828 return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation))); 829 } 830 831 void mlirSymbolTableErase(MlirSymbolTable symbolTable, 832 MlirOperation operation) { 833 unwrap(symbolTable)->erase(unwrap(operation)); 834 } 835 836 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, 837 MlirStringRef newSymbol, 838 MlirOperation from) { 839 auto *cppFrom = unwrap(from); 840 auto *context = cppFrom->getContext(); 841 auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); 842 auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); 843 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, 844 unwrap(from))); 845 } 846 847 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, 848 void (*callback)(MlirOperation, bool, 849 void *userData), 850 void *userData) { 851 SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible, 852 [&](Operation *foundOpCpp, bool isVisible) { 853 callback(wrap(foundOpCpp), isVisible, 854 userData); 855 }); 856 } 857