1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/AsmParser/AsmParser.h" 13 #include "mlir/CAPI/IR.h" 14 #include "mlir/CAPI/Support.h" 15 #include "mlir/CAPI/Utils.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Dialect.h" 19 #include "mlir/IR/Location.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/IR/Types.h" 22 #include "mlir/IR/Verifier.h" 23 #include "mlir/Interfaces/InferTypeOpInterface.h" 24 #include "mlir/Parser/Parser.h" 25 26 #include "llvm/Support/Debug.h" 27 #include <cstddef> 28 29 using namespace mlir; 30 31 //===----------------------------------------------------------------------===// 32 // Context API. 33 //===----------------------------------------------------------------------===// 34 35 MlirContext mlirContextCreate() { 36 auto *context = new MLIRContext; 37 return wrap(context); 38 } 39 40 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 41 return unwrap(ctx1) == unwrap(ctx2); 42 } 43 44 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 45 46 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { 47 unwrap(context)->allowUnregisteredDialects(allow); 48 } 49 50 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { 51 return unwrap(context)->allowsUnregisteredDialects(); 52 } 53 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 54 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 55 } 56 57 void mlirContextAppendDialectRegistry(MlirContext ctx, 58 MlirDialectRegistry registry) { 59 unwrap(ctx)->appendDialectRegistry(*unwrap(registry)); 60 } 61 62 // TODO: expose a cheaper way than constructing + sorting a vector only to take 63 // its size. 64 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 65 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 66 } 67 68 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 69 MlirStringRef name) { 70 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 71 } 72 73 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { 74 return unwrap(context)->isOperationRegistered(unwrap(name)); 75 } 76 77 void mlirContextEnableMultithreading(MlirContext context, bool enable) { 78 return unwrap(context)->enableMultithreading(enable); 79 } 80 81 void mlirContextLoadAllAvailableDialects(MlirContext context) { 82 unwrap(context)->loadAllAvailableDialects(); 83 } 84 85 //===----------------------------------------------------------------------===// 86 // Dialect API. 87 //===----------------------------------------------------------------------===// 88 89 MlirContext mlirDialectGetContext(MlirDialect dialect) { 90 return wrap(unwrap(dialect)->getContext()); 91 } 92 93 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 94 return unwrap(dialect1) == unwrap(dialect2); 95 } 96 97 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 98 return wrap(unwrap(dialect)->getNamespace()); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // DialectRegistry API. 103 //===----------------------------------------------------------------------===// 104 105 MlirDialectRegistry mlirDialectRegistryCreate() { 106 return wrap(new DialectRegistry()); 107 } 108 109 void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { 110 delete unwrap(registry); 111 } 112 113 //===----------------------------------------------------------------------===// 114 // Printing flags API. 115 //===----------------------------------------------------------------------===// 116 117 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 118 return wrap(new OpPrintingFlags()); 119 } 120 121 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 122 delete unwrap(flags); 123 } 124 125 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 126 intptr_t largeElementLimit) { 127 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 128 } 129 130 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, 131 bool prettyForm) { 132 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); 133 } 134 135 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 136 unwrap(flags)->printGenericOpForm(); 137 } 138 139 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 140 unwrap(flags)->useLocalScope(); 141 } 142 143 //===----------------------------------------------------------------------===// 144 // Location API. 145 //===----------------------------------------------------------------------===// 146 147 MlirLocation mlirLocationFileLineColGet(MlirContext context, 148 MlirStringRef filename, unsigned line, 149 unsigned col) { 150 return wrap(Location( 151 FileLineColLoc::get(unwrap(context), unwrap(filename), line, col))); 152 } 153 154 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { 155 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); 156 } 157 158 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, 159 MlirLocation const *locations, 160 MlirAttribute metadata) { 161 SmallVector<Location, 4> locs; 162 ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs); 163 return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); 164 } 165 166 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, 167 MlirLocation childLoc) { 168 if (mlirLocationIsNull(childLoc)) 169 return wrap( 170 Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); 171 return wrap(Location(NameLoc::get( 172 StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); 173 } 174 175 MlirLocation mlirLocationUnknownGet(MlirContext context) { 176 return wrap(Location(UnknownLoc::get(unwrap(context)))); 177 } 178 179 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { 180 return unwrap(l1) == unwrap(l2); 181 } 182 183 MlirContext mlirLocationGetContext(MlirLocation location) { 184 return wrap(unwrap(location).getContext()); 185 } 186 187 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 188 void *userData) { 189 detail::CallbackOstream stream(callback, userData); 190 unwrap(location).print(stream); 191 } 192 193 //===----------------------------------------------------------------------===// 194 // Module API. 195 //===----------------------------------------------------------------------===// 196 197 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 198 return wrap(ModuleOp::create(unwrap(location))); 199 } 200 201 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { 202 OwningOpRef<ModuleOp> owning = 203 parseSourceString<ModuleOp>(unwrap(module), unwrap(context)); 204 if (!owning) 205 return MlirModule{nullptr}; 206 return MlirModule{owning.release().getOperation()}; 207 } 208 209 MlirContext mlirModuleGetContext(MlirModule module) { 210 return wrap(unwrap(module).getContext()); 211 } 212 213 MlirBlock mlirModuleGetBody(MlirModule module) { 214 return wrap(unwrap(module).getBody()); 215 } 216 217 void mlirModuleDestroy(MlirModule module) { 218 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is 219 // called. 220 OwningOpRef<ModuleOp>(unwrap(module)); 221 } 222 223 MlirOperation mlirModuleGetOperation(MlirModule module) { 224 return wrap(unwrap(module).getOperation()); 225 } 226 227 MlirModule mlirModuleFromOperation(MlirOperation op) { 228 return wrap(dyn_cast<ModuleOp>(unwrap(op))); 229 } 230 231 //===----------------------------------------------------------------------===// 232 // Operation state API. 233 //===----------------------------------------------------------------------===// 234 235 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { 236 MlirOperationState state; 237 state.name = name; 238 state.location = loc; 239 state.nResults = 0; 240 state.results = nullptr; 241 state.nOperands = 0; 242 state.operands = nullptr; 243 state.nRegions = 0; 244 state.regions = nullptr; 245 state.nSuccessors = 0; 246 state.successors = nullptr; 247 state.nAttributes = 0; 248 state.attributes = nullptr; 249 state.enableResultTypeInference = false; 250 return state; 251 } 252 253 #define APPEND_ELEMS(type, sizeName, elemName) \ 254 state->elemName = \ 255 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 256 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 257 state->sizeName += n; 258 259 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 260 MlirType const *results) { 261 APPEND_ELEMS(MlirType, nResults, results); 262 } 263 264 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 265 MlirValue const *operands) { 266 APPEND_ELEMS(MlirValue, nOperands, operands); 267 } 268 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 269 MlirRegion const *regions) { 270 APPEND_ELEMS(MlirRegion, nRegions, regions); 271 } 272 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 273 MlirBlock const *successors) { 274 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 275 } 276 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 277 MlirNamedAttribute const *attributes) { 278 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 279 } 280 281 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { 282 state->enableResultTypeInference = true; 283 } 284 285 //===----------------------------------------------------------------------===// 286 // Operation API. 287 //===----------------------------------------------------------------------===// 288 289 static LogicalResult inferOperationTypes(OperationState &state) { 290 MLIRContext *context = state.getContext(); 291 Optional<RegisteredOperationName> info = state.name.getRegisteredInfo(); 292 if (!info) { 293 emitError(state.location) 294 << "type inference was requested for the operation " << state.name 295 << ", but the operation was not registered. Ensure that the dialect " 296 "containing the operation is linked into MLIR and registered with " 297 "the context"; 298 return failure(); 299 } 300 301 // Fallback to inference via an op interface. 302 auto *inferInterface = info->getInterface<InferTypeOpInterface>(); 303 if (!inferInterface) { 304 emitError(state.location) 305 << "type inference was requested for the operation " << state.name 306 << ", but the operation does not support type inference. Result " 307 "types must be specified explicitly."; 308 return failure(); 309 } 310 311 if (succeeded(inferInterface->inferReturnTypes( 312 context, state.location, state.operands, 313 state.attributes.getDictionary(context), state.regions, state.types))) 314 return success(); 315 316 // Diagnostic emitted by interface. 317 return failure(); 318 } 319 320 MlirOperation mlirOperationCreate(MlirOperationState *state) { 321 assert(state); 322 OperationState cppState(unwrap(state->location), unwrap(state->name)); 323 SmallVector<Type, 4> resultStorage; 324 SmallVector<Value, 8> operandStorage; 325 SmallVector<Block *, 2> successorStorage; 326 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 327 cppState.addOperands( 328 unwrapList(state->nOperands, state->operands, operandStorage)); 329 cppState.addSuccessors( 330 unwrapList(state->nSuccessors, state->successors, successorStorage)); 331 332 cppState.attributes.reserve(state->nAttributes); 333 for (intptr_t i = 0; i < state->nAttributes; ++i) 334 cppState.addAttribute(unwrap(state->attributes[i].name), 335 unwrap(state->attributes[i].attribute)); 336 337 for (intptr_t i = 0; i < state->nRegions; ++i) 338 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 339 340 free(state->results); 341 free(state->operands); 342 free(state->successors); 343 free(state->regions); 344 free(state->attributes); 345 346 // Infer result types. 347 if (state->enableResultTypeInference) { 348 assert(cppState.types.empty() && 349 "result type inference enabled and result types provided"); 350 if (failed(inferOperationTypes(cppState))) 351 return {nullptr}; 352 } 353 354 MlirOperation result = wrap(Operation::create(cppState)); 355 return result; 356 } 357 358 MlirOperation mlirOperationClone(MlirOperation op) { 359 return wrap(unwrap(op)->clone()); 360 } 361 362 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 363 364 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); } 365 366 bool mlirOperationEqual(MlirOperation op, MlirOperation other) { 367 return unwrap(op) == unwrap(other); 368 } 369 370 MlirContext mlirOperationGetContext(MlirOperation op) { 371 return wrap(unwrap(op)->getContext()); 372 } 373 374 MlirLocation mlirOperationGetLocation(MlirOperation op) { 375 return wrap(unwrap(op)->getLoc()); 376 } 377 378 MlirTypeID mlirOperationGetTypeID(MlirOperation op) { 379 if (auto info = unwrap(op)->getRegisteredInfo()) 380 return wrap(info->getTypeID()); 381 return {nullptr}; 382 } 383 384 MlirIdentifier mlirOperationGetName(MlirOperation op) { 385 return wrap(unwrap(op)->getName().getIdentifier()); 386 } 387 388 MlirBlock mlirOperationGetBlock(MlirOperation op) { 389 return wrap(unwrap(op)->getBlock()); 390 } 391 392 MlirOperation mlirOperationGetParentOperation(MlirOperation op) { 393 return wrap(unwrap(op)->getParentOp()); 394 } 395 396 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 397 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 398 } 399 400 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 401 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 402 } 403 404 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { 405 Operation *cppOp = unwrap(op); 406 if (cppOp->getNumRegions() == 0) 407 return wrap(static_cast<Region *>(nullptr)); 408 return wrap(&cppOp->getRegion(0)); 409 } 410 411 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { 412 Region *cppRegion = unwrap(region); 413 Operation *parent = cppRegion->getParentOp(); 414 intptr_t next = cppRegion->getRegionNumber() + 1; 415 if (parent->getNumRegions() > next) 416 return wrap(&parent->getRegion(next)); 417 return wrap(static_cast<Region *>(nullptr)); 418 } 419 420 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 421 return wrap(unwrap(op)->getNextNode()); 422 } 423 424 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 425 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 426 } 427 428 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 429 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 430 } 431 432 void mlirOperationSetOperand(MlirOperation op, intptr_t pos, 433 MlirValue newValue) { 434 unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue)); 435 } 436 437 intptr_t mlirOperationGetNumResults(MlirOperation op) { 438 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 439 } 440 441 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 442 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 443 } 444 445 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 446 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 447 } 448 449 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 450 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 451 } 452 453 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 454 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 455 } 456 457 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 458 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 459 return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; 460 } 461 462 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 463 MlirStringRef name) { 464 return wrap(unwrap(op)->getAttr(unwrap(name))); 465 } 466 467 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, 468 MlirAttribute attr) { 469 unwrap(op)->setAttr(unwrap(name), unwrap(attr)); 470 } 471 472 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { 473 return !!unwrap(op)->removeAttr(unwrap(name)); 474 } 475 476 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 477 void *userData) { 478 detail::CallbackOstream stream(callback, userData); 479 unwrap(op)->print(stream); 480 } 481 482 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 483 MlirStringCallback callback, void *userData) { 484 detail::CallbackOstream stream(callback, userData); 485 unwrap(op)->print(stream, *unwrap(flags)); 486 } 487 488 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 489 490 bool mlirOperationVerify(MlirOperation op) { 491 return succeeded(verify(unwrap(op))); 492 } 493 494 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { 495 return unwrap(op)->moveAfter(unwrap(other)); 496 } 497 498 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { 499 return unwrap(op)->moveBefore(unwrap(other)); 500 } 501 502 //===----------------------------------------------------------------------===// 503 // Region API. 504 //===----------------------------------------------------------------------===// 505 506 MlirRegion mlirRegionCreate() { return wrap(new Region); } 507 508 bool mlirRegionEqual(MlirRegion region, MlirRegion other) { 509 return unwrap(region) == unwrap(other); 510 } 511 512 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 513 Region *cppRegion = unwrap(region); 514 if (cppRegion->empty()) 515 return wrap(static_cast<Block *>(nullptr)); 516 return wrap(&cppRegion->front()); 517 } 518 519 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 520 unwrap(region)->push_back(unwrap(block)); 521 } 522 523 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 524 MlirBlock block) { 525 auto &blockList = unwrap(region)->getBlocks(); 526 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 527 } 528 529 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 530 MlirBlock block) { 531 Region *cppRegion = unwrap(region); 532 if (mlirBlockIsNull(reference)) { 533 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 534 return; 535 } 536 537 assert(unwrap(reference)->getParent() == unwrap(region) && 538 "expected reference block to belong to the region"); 539 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 540 unwrap(block)); 541 } 542 543 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 544 MlirBlock block) { 545 if (mlirBlockIsNull(reference)) 546 return mlirRegionAppendOwnedBlock(region, block); 547 548 assert(unwrap(reference)->getParent() == unwrap(region) && 549 "expected reference block to belong to the region"); 550 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 551 unwrap(block)); 552 } 553 554 void mlirRegionDestroy(MlirRegion region) { 555 delete static_cast<Region *>(region.ptr); 556 } 557 558 //===----------------------------------------------------------------------===// 559 // Block API. 560 //===----------------------------------------------------------------------===// 561 562 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, 563 MlirLocation const *locs) { 564 Block *b = new Block; 565 for (intptr_t i = 0; i < nArgs; ++i) 566 b->addArgument(unwrap(args[i]), unwrap(locs[i])); 567 return wrap(b); 568 } 569 570 bool mlirBlockEqual(MlirBlock block, MlirBlock other) { 571 return unwrap(block) == unwrap(other); 572 } 573 574 MlirOperation mlirBlockGetParentOperation(MlirBlock block) { 575 return wrap(unwrap(block)->getParentOp()); 576 } 577 578 MlirRegion mlirBlockGetParentRegion(MlirBlock block) { 579 return wrap(unwrap(block)->getParent()); 580 } 581 582 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 583 return wrap(unwrap(block)->getNextNode()); 584 } 585 586 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 587 Block *cppBlock = unwrap(block); 588 if (cppBlock->empty()) 589 return wrap(static_cast<Operation *>(nullptr)); 590 return wrap(&cppBlock->front()); 591 } 592 593 MlirOperation mlirBlockGetTerminator(MlirBlock block) { 594 Block *cppBlock = unwrap(block); 595 if (cppBlock->empty()) 596 return wrap(static_cast<Operation *>(nullptr)); 597 Operation &back = cppBlock->back(); 598 if (!back.hasTrait<OpTrait::IsTerminator>()) 599 return wrap(static_cast<Operation *>(nullptr)); 600 return wrap(&back); 601 } 602 603 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 604 unwrap(block)->push_back(unwrap(operation)); 605 } 606 607 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 608 MlirOperation operation) { 609 auto &opList = unwrap(block)->getOperations(); 610 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 611 } 612 613 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 614 MlirOperation reference, 615 MlirOperation operation) { 616 Block *cppBlock = unwrap(block); 617 if (mlirOperationIsNull(reference)) { 618 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 619 return; 620 } 621 622 assert(unwrap(reference)->getBlock() == unwrap(block) && 623 "expected reference operation to belong to the block"); 624 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 625 unwrap(operation)); 626 } 627 628 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 629 MlirOperation reference, 630 MlirOperation operation) { 631 if (mlirOperationIsNull(reference)) 632 return mlirBlockAppendOwnedOperation(block, operation); 633 634 assert(unwrap(reference)->getBlock() == unwrap(block) && 635 "expected reference operation to belong to the block"); 636 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 637 unwrap(operation)); 638 } 639 640 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 641 642 void mlirBlockDetach(MlirBlock block) { 643 Block *b = unwrap(block); 644 b->getParent()->getBlocks().remove(b); 645 } 646 647 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 648 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 649 } 650 651 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, 652 MlirLocation loc) { 653 return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc))); 654 } 655 656 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 657 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 658 } 659 660 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 661 void *userData) { 662 detail::CallbackOstream stream(callback, userData); 663 unwrap(block)->print(stream); 664 } 665 666 //===----------------------------------------------------------------------===// 667 // Value API. 668 //===----------------------------------------------------------------------===// 669 670 bool mlirValueEqual(MlirValue value1, MlirValue value2) { 671 return unwrap(value1) == unwrap(value2); 672 } 673 674 bool mlirValueIsABlockArgument(MlirValue value) { 675 return unwrap(value).isa<BlockArgument>(); 676 } 677 678 bool mlirValueIsAOpResult(MlirValue value) { 679 return unwrap(value).isa<OpResult>(); 680 } 681 682 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 683 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 684 } 685 686 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 687 return static_cast<intptr_t>( 688 unwrap(value).cast<BlockArgument>().getArgNumber()); 689 } 690 691 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 692 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 693 } 694 695 MlirOperation mlirOpResultGetOwner(MlirValue value) { 696 return wrap(unwrap(value).cast<OpResult>().getOwner()); 697 } 698 699 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 700 return static_cast<intptr_t>( 701 unwrap(value).cast<OpResult>().getResultNumber()); 702 } 703 704 MlirType mlirValueGetType(MlirValue value) { 705 return wrap(unwrap(value).getType()); 706 } 707 708 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 709 710 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 711 void *userData) { 712 detail::CallbackOstream stream(callback, userData); 713 unwrap(value).print(stream); 714 } 715 716 //===----------------------------------------------------------------------===// 717 // Type API. 718 //===----------------------------------------------------------------------===// 719 720 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { 721 return wrap(mlir::parseType(unwrap(type), unwrap(context))); 722 } 723 724 MlirContext mlirTypeGetContext(MlirType type) { 725 return wrap(unwrap(type).getContext()); 726 } 727 728 MlirTypeID mlirTypeGetTypeID(MlirType type) { 729 return wrap(unwrap(type).getTypeID()); 730 } 731 732 bool mlirTypeEqual(MlirType t1, MlirType t2) { 733 return unwrap(t1) == unwrap(t2); 734 } 735 736 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 737 detail::CallbackOstream stream(callback, userData); 738 unwrap(type).print(stream); 739 } 740 741 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 742 743 //===----------------------------------------------------------------------===// 744 // Attribute API. 745 //===----------------------------------------------------------------------===// 746 747 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { 748 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context))); 749 } 750 751 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 752 return wrap(unwrap(attribute).getContext()); 753 } 754 755 MlirType mlirAttributeGetType(MlirAttribute attribute) { 756 return wrap(unwrap(attribute).getType()); 757 } 758 759 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { 760 return wrap(unwrap(attr).getTypeID()); 761 } 762 763 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 764 return unwrap(a1) == unwrap(a2); 765 } 766 767 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 768 void *userData) { 769 detail::CallbackOstream stream(callback, userData); 770 unwrap(attr).print(stream); 771 } 772 773 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 774 775 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, 776 MlirAttribute attr) { 777 return MlirNamedAttribute{name, attr}; 778 } 779 780 //===----------------------------------------------------------------------===// 781 // Identifier API. 782 //===----------------------------------------------------------------------===// 783 784 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { 785 return wrap(StringAttr::get(unwrap(context), unwrap(str))); 786 } 787 788 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { 789 return wrap(unwrap(ident).getContext()); 790 } 791 792 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { 793 return unwrap(ident) == unwrap(other); 794 } 795 796 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { 797 return wrap(unwrap(ident).strref()); 798 } 799 800 //===----------------------------------------------------------------------===// 801 // Symbol and SymbolTable API. 802 //===----------------------------------------------------------------------===// 803 804 MlirStringRef mlirSymbolTableGetSymbolAttributeName() { 805 return wrap(SymbolTable::getSymbolAttrName()); 806 } 807 808 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { 809 return wrap(SymbolTable::getVisibilityAttrName()); 810 } 811 812 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { 813 if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>()) 814 return wrap(static_cast<SymbolTable *>(nullptr)); 815 return wrap(new SymbolTable(unwrap(operation))); 816 } 817 818 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { 819 delete unwrap(symbolTable); 820 } 821 822 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, 823 MlirStringRef name) { 824 return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length))); 825 } 826 827 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, 828 MlirOperation operation) { 829 return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation))); 830 } 831 832 void mlirSymbolTableErase(MlirSymbolTable symbolTable, 833 MlirOperation operation) { 834 unwrap(symbolTable)->erase(unwrap(operation)); 835 } 836 837 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, 838 MlirStringRef newSymbol, 839 MlirOperation from) { 840 auto *cppFrom = unwrap(from); 841 auto *context = cppFrom->getContext(); 842 auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); 843 auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); 844 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, 845 unwrap(from))); 846 } 847 848 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, 849 void (*callback)(MlirOperation, bool, 850 void *userData), 851 void *userData) { 852 SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible, 853 [&](Operation *foundOpCpp, bool isVisible) { 854 callback(wrap(foundOpCpp), isVisible, 855 userData); 856 }); 857 } 858