1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Dialect.h" 18 #include "mlir/IR/Location.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/IR/Types.h" 21 #include "mlir/IR/Verifier.h" 22 #include "mlir/Interfaces/InferTypeOpInterface.h" 23 #include "mlir/Parser.h" 24 25 #include "llvm/Support/Debug.h" 26 #include <cstddef> 27 28 using namespace mlir; 29 30 //===----------------------------------------------------------------------===// 31 // Context API. 32 //===----------------------------------------------------------------------===// 33 34 MlirContext mlirContextCreate() { 35 auto *context = new MLIRContext; 36 return wrap(context); 37 } 38 39 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 40 return unwrap(ctx1) == unwrap(ctx2); 41 } 42 43 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 44 45 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { 46 unwrap(context)->allowUnregisteredDialects(allow); 47 } 48 49 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { 50 return unwrap(context)->allowsUnregisteredDialects(); 51 } 52 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 53 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 54 } 55 56 // TODO: expose a cheaper way than constructing + sorting a vector only to take 57 // its size. 58 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 59 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 60 } 61 62 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 63 MlirStringRef name) { 64 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 65 } 66 67 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { 68 return unwrap(context)->isOperationRegistered(unwrap(name)); 69 } 70 71 void mlirContextEnableMultithreading(MlirContext context, bool enable) { 72 return unwrap(context)->enableMultithreading(enable); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // Dialect API. 77 //===----------------------------------------------------------------------===// 78 79 MlirContext mlirDialectGetContext(MlirDialect dialect) { 80 return wrap(unwrap(dialect)->getContext()); 81 } 82 83 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 84 return unwrap(dialect1) == unwrap(dialect2); 85 } 86 87 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 88 return wrap(unwrap(dialect)->getNamespace()); 89 } 90 91 //===----------------------------------------------------------------------===// 92 // Printing flags API. 93 //===----------------------------------------------------------------------===// 94 95 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 96 return wrap(new OpPrintingFlags()); 97 } 98 99 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 100 delete unwrap(flags); 101 } 102 103 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 104 intptr_t largeElementLimit) { 105 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 106 } 107 108 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, 109 bool prettyForm) { 110 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); 111 } 112 113 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 114 unwrap(flags)->printGenericOpForm(); 115 } 116 117 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 118 unwrap(flags)->useLocalScope(); 119 } 120 121 //===----------------------------------------------------------------------===// 122 // Location API. 123 //===----------------------------------------------------------------------===// 124 125 MlirLocation mlirLocationFileLineColGet(MlirContext context, 126 MlirStringRef filename, unsigned line, 127 unsigned col) { 128 return wrap(Location( 129 FileLineColLoc::get(unwrap(context), unwrap(filename), line, col))); 130 } 131 132 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { 133 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); 134 } 135 136 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, 137 MlirLocation const *locations, 138 MlirAttribute metadata) { 139 SmallVector<Location, 4> locs; 140 ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs); 141 return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); 142 } 143 144 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, 145 MlirLocation childLoc) { 146 if (mlirLocationIsNull(childLoc)) 147 return wrap( 148 Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); 149 return wrap(Location(NameLoc::get( 150 StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); 151 } 152 153 MlirLocation mlirLocationUnknownGet(MlirContext context) { 154 return wrap(Location(UnknownLoc::get(unwrap(context)))); 155 } 156 157 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { 158 return unwrap(l1) == unwrap(l2); 159 } 160 161 MlirContext mlirLocationGetContext(MlirLocation location) { 162 return wrap(unwrap(location).getContext()); 163 } 164 165 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 166 void *userData) { 167 detail::CallbackOstream stream(callback, userData); 168 unwrap(location).print(stream); 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Module API. 173 //===----------------------------------------------------------------------===// 174 175 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 176 return wrap(ModuleOp::create(unwrap(location))); 177 } 178 179 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { 180 OwningOpRef<ModuleOp> owning = 181 parseSourceString(unwrap(module), unwrap(context)); 182 if (!owning) 183 return MlirModule{nullptr}; 184 return MlirModule{owning.release().getOperation()}; 185 } 186 187 MlirContext mlirModuleGetContext(MlirModule module) { 188 return wrap(unwrap(module).getContext()); 189 } 190 191 MlirBlock mlirModuleGetBody(MlirModule module) { 192 return wrap(unwrap(module).getBody()); 193 } 194 195 void mlirModuleDestroy(MlirModule module) { 196 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is 197 // called. 198 OwningOpRef<ModuleOp>(unwrap(module)); 199 } 200 201 MlirOperation mlirModuleGetOperation(MlirModule module) { 202 return wrap(unwrap(module).getOperation()); 203 } 204 205 MlirModule mlirModuleFromOperation(MlirOperation op) { 206 return wrap(dyn_cast<ModuleOp>(unwrap(op))); 207 } 208 209 //===----------------------------------------------------------------------===// 210 // Operation state API. 211 //===----------------------------------------------------------------------===// 212 213 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { 214 MlirOperationState state; 215 state.name = name; 216 state.location = loc; 217 state.nResults = 0; 218 state.results = nullptr; 219 state.nOperands = 0; 220 state.operands = nullptr; 221 state.nRegions = 0; 222 state.regions = nullptr; 223 state.nSuccessors = 0; 224 state.successors = nullptr; 225 state.nAttributes = 0; 226 state.attributes = nullptr; 227 state.enableResultTypeInference = false; 228 return state; 229 } 230 231 #define APPEND_ELEMS(type, sizeName, elemName) \ 232 state->elemName = \ 233 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 234 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 235 state->sizeName += n; 236 237 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 238 MlirType const *results) { 239 APPEND_ELEMS(MlirType, nResults, results); 240 } 241 242 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 243 MlirValue const *operands) { 244 APPEND_ELEMS(MlirValue, nOperands, operands); 245 } 246 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 247 MlirRegion const *regions) { 248 APPEND_ELEMS(MlirRegion, nRegions, regions); 249 } 250 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 251 MlirBlock const *successors) { 252 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 253 } 254 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 255 MlirNamedAttribute const *attributes) { 256 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 257 } 258 259 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { 260 state->enableResultTypeInference = true; 261 } 262 263 //===----------------------------------------------------------------------===// 264 // Operation API. 265 //===----------------------------------------------------------------------===// 266 267 static LogicalResult inferOperationTypes(OperationState &state) { 268 MLIRContext *context = state.getContext(); 269 Optional<RegisteredOperationName> info = state.name.getRegisteredInfo(); 270 if (!info) { 271 emitError(state.location) 272 << "type inference was requested for the operation " << state.name 273 << ", but the operation was not registered. Ensure that the dialect " 274 "containing the operation is linked into MLIR and registered with " 275 "the context"; 276 return failure(); 277 } 278 279 // Fallback to inference via an op interface. 280 auto *inferInterface = info->getInterface<InferTypeOpInterface>(); 281 if (!inferInterface) { 282 emitError(state.location) 283 << "type inference was requested for the operation " << state.name 284 << ", but the operation does not support type inference. Result " 285 "types must be specified explicitly."; 286 return failure(); 287 } 288 289 if (succeeded(inferInterface->inferReturnTypes( 290 context, state.location, state.operands, 291 state.attributes.getDictionary(context), state.regions, state.types))) 292 return success(); 293 294 // Diagnostic emitted by interface. 295 return failure(); 296 } 297 298 MlirOperation mlirOperationCreate(MlirOperationState *state) { 299 assert(state); 300 OperationState cppState(unwrap(state->location), unwrap(state->name)); 301 SmallVector<Type, 4> resultStorage; 302 SmallVector<Value, 8> operandStorage; 303 SmallVector<Block *, 2> successorStorage; 304 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 305 cppState.addOperands( 306 unwrapList(state->nOperands, state->operands, operandStorage)); 307 cppState.addSuccessors( 308 unwrapList(state->nSuccessors, state->successors, successorStorage)); 309 310 cppState.attributes.reserve(state->nAttributes); 311 for (intptr_t i = 0; i < state->nAttributes; ++i) 312 cppState.addAttribute(unwrap(state->attributes[i].name), 313 unwrap(state->attributes[i].attribute)); 314 315 for (intptr_t i = 0; i < state->nRegions; ++i) 316 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 317 318 free(state->results); 319 free(state->operands); 320 free(state->successors); 321 free(state->regions); 322 free(state->attributes); 323 324 // Infer result types. 325 if (state->enableResultTypeInference) { 326 assert(cppState.types.empty() && 327 "result type inference enabled and result types provided"); 328 if (failed(inferOperationTypes(cppState))) 329 return {nullptr}; 330 } 331 332 MlirOperation result = wrap(Operation::create(cppState)); 333 return result; 334 } 335 336 MlirOperation mlirOperationClone(MlirOperation op) { 337 return wrap(unwrap(op)->clone()); 338 } 339 340 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 341 342 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); } 343 344 bool mlirOperationEqual(MlirOperation op, MlirOperation other) { 345 return unwrap(op) == unwrap(other); 346 } 347 348 MlirContext mlirOperationGetContext(MlirOperation op) { 349 return wrap(unwrap(op)->getContext()); 350 } 351 352 MlirLocation mlirOperationGetLocation(MlirOperation op) { 353 return wrap(unwrap(op)->getLoc()); 354 } 355 356 MlirTypeID mlirOperationGetTypeID(MlirOperation op) { 357 if (auto info = unwrap(op)->getRegisteredInfo()) 358 return wrap(info->getTypeID()); 359 return {nullptr}; 360 } 361 362 MlirIdentifier mlirOperationGetName(MlirOperation op) { 363 return wrap(unwrap(op)->getName().getIdentifier()); 364 } 365 366 MlirBlock mlirOperationGetBlock(MlirOperation op) { 367 return wrap(unwrap(op)->getBlock()); 368 } 369 370 MlirOperation mlirOperationGetParentOperation(MlirOperation op) { 371 return wrap(unwrap(op)->getParentOp()); 372 } 373 374 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 375 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 376 } 377 378 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 379 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 380 } 381 382 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { 383 Operation *cppOp = unwrap(op); 384 if (cppOp->getNumRegions() == 0) 385 return wrap(static_cast<Region *>(nullptr)); 386 return wrap(&cppOp->getRegion(0)); 387 } 388 389 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { 390 Region *cppRegion = unwrap(region); 391 Operation *parent = cppRegion->getParentOp(); 392 intptr_t next = cppRegion->getRegionNumber() + 1; 393 if (parent->getNumRegions() > next) 394 return wrap(&parent->getRegion(next)); 395 return wrap(static_cast<Region *>(nullptr)); 396 } 397 398 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 399 return wrap(unwrap(op)->getNextNode()); 400 } 401 402 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 403 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 404 } 405 406 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 407 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 408 } 409 410 void mlirOperationSetOperand(MlirOperation op, intptr_t pos, 411 MlirValue newValue) { 412 unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue)); 413 } 414 415 intptr_t mlirOperationGetNumResults(MlirOperation op) { 416 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 417 } 418 419 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 420 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 421 } 422 423 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 424 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 425 } 426 427 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 428 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 429 } 430 431 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 432 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 433 } 434 435 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 436 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 437 return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; 438 } 439 440 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 441 MlirStringRef name) { 442 return wrap(unwrap(op)->getAttr(unwrap(name))); 443 } 444 445 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, 446 MlirAttribute attr) { 447 unwrap(op)->setAttr(unwrap(name), unwrap(attr)); 448 } 449 450 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { 451 return !!unwrap(op)->removeAttr(unwrap(name)); 452 } 453 454 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 455 void *userData) { 456 detail::CallbackOstream stream(callback, userData); 457 unwrap(op)->print(stream); 458 } 459 460 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 461 MlirStringCallback callback, void *userData) { 462 detail::CallbackOstream stream(callback, userData); 463 unwrap(op)->print(stream, *unwrap(flags)); 464 } 465 466 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 467 468 bool mlirOperationVerify(MlirOperation op) { 469 return succeeded(verify(unwrap(op))); 470 } 471 472 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { 473 return unwrap(op)->moveAfter(unwrap(other)); 474 } 475 476 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { 477 return unwrap(op)->moveBefore(unwrap(other)); 478 } 479 480 //===----------------------------------------------------------------------===// 481 // Region API. 482 //===----------------------------------------------------------------------===// 483 484 MlirRegion mlirRegionCreate() { return wrap(new Region); } 485 486 bool mlirRegionEqual(MlirRegion region, MlirRegion other) { 487 return unwrap(region) == unwrap(other); 488 } 489 490 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 491 Region *cppRegion = unwrap(region); 492 if (cppRegion->empty()) 493 return wrap(static_cast<Block *>(nullptr)); 494 return wrap(&cppRegion->front()); 495 } 496 497 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 498 unwrap(region)->push_back(unwrap(block)); 499 } 500 501 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 502 MlirBlock block) { 503 auto &blockList = unwrap(region)->getBlocks(); 504 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 505 } 506 507 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 508 MlirBlock block) { 509 Region *cppRegion = unwrap(region); 510 if (mlirBlockIsNull(reference)) { 511 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 512 return; 513 } 514 515 assert(unwrap(reference)->getParent() == unwrap(region) && 516 "expected reference block to belong to the region"); 517 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 518 unwrap(block)); 519 } 520 521 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 522 MlirBlock block) { 523 if (mlirBlockIsNull(reference)) 524 return mlirRegionAppendOwnedBlock(region, block); 525 526 assert(unwrap(reference)->getParent() == unwrap(region) && 527 "expected reference block to belong to the region"); 528 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 529 unwrap(block)); 530 } 531 532 void mlirRegionDestroy(MlirRegion region) { 533 delete static_cast<Region *>(region.ptr); 534 } 535 536 //===----------------------------------------------------------------------===// 537 // Block API. 538 //===----------------------------------------------------------------------===// 539 540 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, 541 MlirLocation const *locs) { 542 Block *b = new Block; 543 for (intptr_t i = 0; i < nArgs; ++i) 544 b->addArgument(unwrap(args[i]), unwrap(locs[i])); 545 return wrap(b); 546 } 547 548 bool mlirBlockEqual(MlirBlock block, MlirBlock other) { 549 return unwrap(block) == unwrap(other); 550 } 551 552 MlirOperation mlirBlockGetParentOperation(MlirBlock block) { 553 return wrap(unwrap(block)->getParentOp()); 554 } 555 556 MlirRegion mlirBlockGetParentRegion(MlirBlock block) { 557 return wrap(unwrap(block)->getParent()); 558 } 559 560 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 561 return wrap(unwrap(block)->getNextNode()); 562 } 563 564 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 565 Block *cppBlock = unwrap(block); 566 if (cppBlock->empty()) 567 return wrap(static_cast<Operation *>(nullptr)); 568 return wrap(&cppBlock->front()); 569 } 570 571 MlirOperation mlirBlockGetTerminator(MlirBlock block) { 572 Block *cppBlock = unwrap(block); 573 if (cppBlock->empty()) 574 return wrap(static_cast<Operation *>(nullptr)); 575 Operation &back = cppBlock->back(); 576 if (!back.hasTrait<OpTrait::IsTerminator>()) 577 return wrap(static_cast<Operation *>(nullptr)); 578 return wrap(&back); 579 } 580 581 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 582 unwrap(block)->push_back(unwrap(operation)); 583 } 584 585 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 586 MlirOperation operation) { 587 auto &opList = unwrap(block)->getOperations(); 588 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 589 } 590 591 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 592 MlirOperation reference, 593 MlirOperation operation) { 594 Block *cppBlock = unwrap(block); 595 if (mlirOperationIsNull(reference)) { 596 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 597 return; 598 } 599 600 assert(unwrap(reference)->getBlock() == unwrap(block) && 601 "expected reference operation to belong to the block"); 602 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 603 unwrap(operation)); 604 } 605 606 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 607 MlirOperation reference, 608 MlirOperation operation) { 609 if (mlirOperationIsNull(reference)) 610 return mlirBlockAppendOwnedOperation(block, operation); 611 612 assert(unwrap(reference)->getBlock() == unwrap(block) && 613 "expected reference operation to belong to the block"); 614 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 615 unwrap(operation)); 616 } 617 618 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 619 620 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 621 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 622 } 623 624 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, 625 MlirLocation loc) { 626 return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc))); 627 } 628 629 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 630 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 631 } 632 633 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 634 void *userData) { 635 detail::CallbackOstream stream(callback, userData); 636 unwrap(block)->print(stream); 637 } 638 639 //===----------------------------------------------------------------------===// 640 // Value API. 641 //===----------------------------------------------------------------------===// 642 643 bool mlirValueEqual(MlirValue value1, MlirValue value2) { 644 return unwrap(value1) == unwrap(value2); 645 } 646 647 bool mlirValueIsABlockArgument(MlirValue value) { 648 return unwrap(value).isa<BlockArgument>(); 649 } 650 651 bool mlirValueIsAOpResult(MlirValue value) { 652 return unwrap(value).isa<OpResult>(); 653 } 654 655 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 656 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 657 } 658 659 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 660 return static_cast<intptr_t>( 661 unwrap(value).cast<BlockArgument>().getArgNumber()); 662 } 663 664 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 665 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 666 } 667 668 MlirOperation mlirOpResultGetOwner(MlirValue value) { 669 return wrap(unwrap(value).cast<OpResult>().getOwner()); 670 } 671 672 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 673 return static_cast<intptr_t>( 674 unwrap(value).cast<OpResult>().getResultNumber()); 675 } 676 677 MlirType mlirValueGetType(MlirValue value) { 678 return wrap(unwrap(value).getType()); 679 } 680 681 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 682 683 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 684 void *userData) { 685 detail::CallbackOstream stream(callback, userData); 686 unwrap(value).print(stream); 687 } 688 689 //===----------------------------------------------------------------------===// 690 // Type API. 691 //===----------------------------------------------------------------------===// 692 693 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { 694 return wrap(mlir::parseType(unwrap(type), unwrap(context))); 695 } 696 697 MlirContext mlirTypeGetContext(MlirType type) { 698 return wrap(unwrap(type).getContext()); 699 } 700 701 MlirTypeID mlirTypeGetTypeID(MlirType type) { 702 return wrap(unwrap(type).getTypeID()); 703 } 704 705 bool mlirTypeEqual(MlirType t1, MlirType t2) { 706 return unwrap(t1) == unwrap(t2); 707 } 708 709 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 710 detail::CallbackOstream stream(callback, userData); 711 unwrap(type).print(stream); 712 } 713 714 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 715 716 //===----------------------------------------------------------------------===// 717 // Attribute API. 718 //===----------------------------------------------------------------------===// 719 720 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { 721 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context))); 722 } 723 724 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 725 return wrap(unwrap(attribute).getContext()); 726 } 727 728 MlirType mlirAttributeGetType(MlirAttribute attribute) { 729 return wrap(unwrap(attribute).getType()); 730 } 731 732 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { 733 return wrap(unwrap(attr).getTypeID()); 734 } 735 736 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 737 return unwrap(a1) == unwrap(a2); 738 } 739 740 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 741 void *userData) { 742 detail::CallbackOstream stream(callback, userData); 743 unwrap(attr).print(stream); 744 } 745 746 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 747 748 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, 749 MlirAttribute attr) { 750 return MlirNamedAttribute{name, attr}; 751 } 752 753 //===----------------------------------------------------------------------===// 754 // Identifier API. 755 //===----------------------------------------------------------------------===// 756 757 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { 758 return wrap(StringAttr::get(unwrap(context), unwrap(str))); 759 } 760 761 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { 762 return wrap(unwrap(ident).getContext()); 763 } 764 765 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { 766 return unwrap(ident) == unwrap(other); 767 } 768 769 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { 770 return wrap(unwrap(ident).strref()); 771 } 772 773 //===----------------------------------------------------------------------===// 774 // TypeID API. 775 //===----------------------------------------------------------------------===// 776 777 bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { 778 return unwrap(typeID1) == unwrap(typeID2); 779 } 780 781 size_t mlirTypeIDHashValue(MlirTypeID typeID) { 782 return hash_value(unwrap(typeID)); 783 } 784 785 //===----------------------------------------------------------------------===// 786 // Symbol and SymbolTable API. 787 //===----------------------------------------------------------------------===// 788 789 MlirStringRef mlirSymbolTableGetSymbolAttributeName() { 790 return wrap(SymbolTable::getSymbolAttrName()); 791 } 792 793 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { 794 return wrap(SymbolTable::getVisibilityAttrName()); 795 } 796 797 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { 798 if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>()) 799 return wrap(static_cast<SymbolTable *>(nullptr)); 800 return wrap(new SymbolTable(unwrap(operation))); 801 } 802 803 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { 804 delete unwrap(symbolTable); 805 } 806 807 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, 808 MlirStringRef name) { 809 return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length))); 810 } 811 812 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, 813 MlirOperation operation) { 814 return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation))); 815 } 816 817 void mlirSymbolTableErase(MlirSymbolTable symbolTable, 818 MlirOperation operation) { 819 unwrap(symbolTable)->erase(unwrap(operation)); 820 } 821 822 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, 823 MlirStringRef newSymbol, 824 MlirOperation from) { 825 auto *cppFrom = unwrap(from); 826 auto *context = cppFrom->getContext(); 827 auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); 828 auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); 829 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, 830 unwrap(from))); 831 } 832 833 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, 834 void (*callback)(MlirOperation, bool, 835 void *userData), 836 void *userData) { 837 SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible, 838 [&](Operation *foundOpCpp, bool isVisible) { 839 callback(wrap(foundOpCpp), isVisible, 840 userData); 841 }); 842 } 843