1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Dialect.h" 18 #include "mlir/IR/Location.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/IR/Types.h" 21 #include "mlir/IR/Verifier.h" 22 #include "mlir/Interfaces/InferTypeOpInterface.h" 23 #include "mlir/Parser.h" 24 25 #include "llvm/Support/Debug.h" 26 #include <cstddef> 27 28 using namespace mlir; 29 30 //===----------------------------------------------------------------------===// 31 // Context API. 32 //===----------------------------------------------------------------------===// 33 34 MlirContext mlirContextCreate() { 35 auto *context = new MLIRContext; 36 return wrap(context); 37 } 38 39 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 40 return unwrap(ctx1) == unwrap(ctx2); 41 } 42 43 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 44 45 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { 46 unwrap(context)->allowUnregisteredDialects(allow); 47 } 48 49 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { 50 return unwrap(context)->allowsUnregisteredDialects(); 51 } 52 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 53 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 54 } 55 56 // TODO: expose a cheaper way than constructing + sorting a vector only to take 57 // its size. 58 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 59 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 60 } 61 62 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 63 MlirStringRef name) { 64 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 65 } 66 67 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { 68 return unwrap(context)->isOperationRegistered(unwrap(name)); 69 } 70 71 void mlirContextEnableMultithreading(MlirContext context, bool enable) { 72 return unwrap(context)->enableMultithreading(enable); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // Dialect API. 77 //===----------------------------------------------------------------------===// 78 79 MlirContext mlirDialectGetContext(MlirDialect dialect) { 80 return wrap(unwrap(dialect)->getContext()); 81 } 82 83 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 84 return unwrap(dialect1) == unwrap(dialect2); 85 } 86 87 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 88 return wrap(unwrap(dialect)->getNamespace()); 89 } 90 91 //===----------------------------------------------------------------------===// 92 // Printing flags API. 93 //===----------------------------------------------------------------------===// 94 95 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 96 return wrap(new OpPrintingFlags()); 97 } 98 99 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 100 delete unwrap(flags); 101 } 102 103 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 104 intptr_t largeElementLimit) { 105 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 106 } 107 108 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, 109 bool prettyForm) { 110 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); 111 } 112 113 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 114 unwrap(flags)->printGenericOpForm(); 115 } 116 117 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 118 unwrap(flags)->useLocalScope(); 119 } 120 121 //===----------------------------------------------------------------------===// 122 // Location API. 123 //===----------------------------------------------------------------------===// 124 125 MlirLocation mlirLocationFileLineColGet(MlirContext context, 126 MlirStringRef filename, unsigned line, 127 unsigned col) { 128 return wrap(Location( 129 FileLineColLoc::get(unwrap(context), unwrap(filename), line, col))); 130 } 131 132 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { 133 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); 134 } 135 136 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, 137 MlirLocation const *locations, 138 MlirAttribute metadata) { 139 SmallVector<Location, 4> locs; 140 ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs); 141 return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); 142 } 143 144 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, 145 MlirLocation childLoc) { 146 if (mlirLocationIsNull(childLoc)) 147 return wrap( 148 Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); 149 return wrap(Location(NameLoc::get( 150 StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); 151 } 152 153 MlirLocation mlirLocationUnknownGet(MlirContext context) { 154 return wrap(Location(UnknownLoc::get(unwrap(context)))); 155 } 156 157 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { 158 return unwrap(l1) == unwrap(l2); 159 } 160 161 MlirContext mlirLocationGetContext(MlirLocation location) { 162 return wrap(unwrap(location).getContext()); 163 } 164 165 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 166 void *userData) { 167 detail::CallbackOstream stream(callback, userData); 168 unwrap(location).print(stream); 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Module API. 173 //===----------------------------------------------------------------------===// 174 175 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 176 return wrap(ModuleOp::create(unwrap(location))); 177 } 178 179 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { 180 OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context)); 181 if (!owning) 182 return MlirModule{nullptr}; 183 return MlirModule{owning.release().getOperation()}; 184 } 185 186 MlirContext mlirModuleGetContext(MlirModule module) { 187 return wrap(unwrap(module).getContext()); 188 } 189 190 MlirBlock mlirModuleGetBody(MlirModule module) { 191 return wrap(unwrap(module).getBody()); 192 } 193 194 void mlirModuleDestroy(MlirModule module) { 195 // Transfer ownership to an OwningModuleRef so that its destructor is called. 196 OwningModuleRef(unwrap(module)); 197 } 198 199 MlirOperation mlirModuleGetOperation(MlirModule module) { 200 return wrap(unwrap(module).getOperation()); 201 } 202 203 MlirModule mlirModuleFromOperation(MlirOperation op) { 204 return wrap(dyn_cast<ModuleOp>(unwrap(op))); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // Operation state API. 209 //===----------------------------------------------------------------------===// 210 211 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { 212 MlirOperationState state; 213 state.name = name; 214 state.location = loc; 215 state.nResults = 0; 216 state.results = nullptr; 217 state.nOperands = 0; 218 state.operands = nullptr; 219 state.nRegions = 0; 220 state.regions = nullptr; 221 state.nSuccessors = 0; 222 state.successors = nullptr; 223 state.nAttributes = 0; 224 state.attributes = nullptr; 225 state.enableResultTypeInference = false; 226 return state; 227 } 228 229 #define APPEND_ELEMS(type, sizeName, elemName) \ 230 state->elemName = \ 231 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 232 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 233 state->sizeName += n; 234 235 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 236 MlirType const *results) { 237 APPEND_ELEMS(MlirType, nResults, results); 238 } 239 240 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 241 MlirValue const *operands) { 242 APPEND_ELEMS(MlirValue, nOperands, operands); 243 } 244 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 245 MlirRegion const *regions) { 246 APPEND_ELEMS(MlirRegion, nRegions, regions); 247 } 248 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 249 MlirBlock const *successors) { 250 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 251 } 252 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 253 MlirNamedAttribute const *attributes) { 254 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 255 } 256 257 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { 258 state->enableResultTypeInference = true; 259 } 260 261 //===----------------------------------------------------------------------===// 262 // Operation API. 263 //===----------------------------------------------------------------------===// 264 265 static LogicalResult inferOperationTypes(OperationState &state) { 266 MLIRContext *context = state.getContext(); 267 Optional<RegisteredOperationName> info = state.name.getRegisteredInfo(); 268 if (!info) { 269 emitError(state.location) 270 << "type inference was requested for the operation " << state.name 271 << ", but the operation was not registered. Ensure that the dialect " 272 "containing the operation is linked into MLIR and registered with " 273 "the context"; 274 return failure(); 275 } 276 277 // Fallback to inference via an op interface. 278 auto *inferInterface = info->getInterface<InferTypeOpInterface>(); 279 if (!inferInterface) { 280 emitError(state.location) 281 << "type inference was requested for the operation " << state.name 282 << ", but the operation does not support type inference. Result " 283 "types must be specified explicitly."; 284 return failure(); 285 } 286 287 if (succeeded(inferInterface->inferReturnTypes( 288 context, state.location, state.operands, 289 state.attributes.getDictionary(context), state.regions, state.types))) 290 return success(); 291 292 // Diagnostic emitted by interface. 293 return failure(); 294 } 295 296 MlirOperation mlirOperationCreate(MlirOperationState *state) { 297 assert(state); 298 OperationState cppState(unwrap(state->location), unwrap(state->name)); 299 SmallVector<Type, 4> resultStorage; 300 SmallVector<Value, 8> operandStorage; 301 SmallVector<Block *, 2> successorStorage; 302 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 303 cppState.addOperands( 304 unwrapList(state->nOperands, state->operands, operandStorage)); 305 cppState.addSuccessors( 306 unwrapList(state->nSuccessors, state->successors, successorStorage)); 307 308 cppState.attributes.reserve(state->nAttributes); 309 for (intptr_t i = 0; i < state->nAttributes; ++i) 310 cppState.addAttribute(unwrap(state->attributes[i].name), 311 unwrap(state->attributes[i].attribute)); 312 313 for (intptr_t i = 0; i < state->nRegions; ++i) 314 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 315 316 free(state->results); 317 free(state->operands); 318 free(state->successors); 319 free(state->regions); 320 free(state->attributes); 321 322 // Infer result types. 323 if (state->enableResultTypeInference) { 324 assert(cppState.types.empty() && 325 "result type inference enabled and result types provided"); 326 if (failed(inferOperationTypes(cppState))) 327 return {nullptr}; 328 } 329 330 MlirOperation result = wrap(Operation::create(cppState)); 331 return result; 332 } 333 334 MlirOperation mlirOperationClone(MlirOperation op) { 335 return wrap(unwrap(op)->clone()); 336 } 337 338 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 339 340 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); } 341 342 bool mlirOperationEqual(MlirOperation op, MlirOperation other) { 343 return unwrap(op) == unwrap(other); 344 } 345 346 MlirContext mlirOperationGetContext(MlirOperation op) { 347 return wrap(unwrap(op)->getContext()); 348 } 349 350 MlirLocation mlirOperationGetLocation(MlirOperation op) { 351 return wrap(unwrap(op)->getLoc()); 352 } 353 354 MlirTypeID mlirOperationGetTypeID(MlirOperation op) { 355 if (auto info = unwrap(op)->getRegisteredInfo()) 356 return wrap(info->getTypeID()); 357 return {nullptr}; 358 } 359 360 MlirIdentifier mlirOperationGetName(MlirOperation op) { 361 return wrap(unwrap(op)->getName().getIdentifier()); 362 } 363 364 MlirBlock mlirOperationGetBlock(MlirOperation op) { 365 return wrap(unwrap(op)->getBlock()); 366 } 367 368 MlirOperation mlirOperationGetParentOperation(MlirOperation op) { 369 return wrap(unwrap(op)->getParentOp()); 370 } 371 372 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 373 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 374 } 375 376 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 377 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 378 } 379 380 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { 381 Operation *cppOp = unwrap(op); 382 if (cppOp->getNumRegions() == 0) 383 return wrap(static_cast<Region *>(nullptr)); 384 return wrap(&cppOp->getRegion(0)); 385 } 386 387 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { 388 Region *cppRegion = unwrap(region); 389 Operation *parent = cppRegion->getParentOp(); 390 intptr_t next = cppRegion->getRegionNumber() + 1; 391 if (parent->getNumRegions() > next) 392 return wrap(&parent->getRegion(next)); 393 return wrap(static_cast<Region *>(nullptr)); 394 } 395 396 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 397 return wrap(unwrap(op)->getNextNode()); 398 } 399 400 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 401 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 402 } 403 404 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 405 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 406 } 407 408 void mlirOperationSetOperand(MlirOperation op, intptr_t pos, 409 MlirValue newValue) { 410 unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue)); 411 } 412 413 intptr_t mlirOperationGetNumResults(MlirOperation op) { 414 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 415 } 416 417 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 418 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 419 } 420 421 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 422 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 423 } 424 425 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 426 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 427 } 428 429 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 430 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 431 } 432 433 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 434 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 435 return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; 436 } 437 438 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 439 MlirStringRef name) { 440 return wrap(unwrap(op)->getAttr(unwrap(name))); 441 } 442 443 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, 444 MlirAttribute attr) { 445 unwrap(op)->setAttr(unwrap(name), unwrap(attr)); 446 } 447 448 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { 449 return !!unwrap(op)->removeAttr(unwrap(name)); 450 } 451 452 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 453 void *userData) { 454 detail::CallbackOstream stream(callback, userData); 455 unwrap(op)->print(stream); 456 } 457 458 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 459 MlirStringCallback callback, void *userData) { 460 detail::CallbackOstream stream(callback, userData); 461 unwrap(op)->print(stream, *unwrap(flags)); 462 } 463 464 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 465 466 bool mlirOperationVerify(MlirOperation op) { 467 return succeeded(verify(unwrap(op))); 468 } 469 470 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { 471 return unwrap(op)->moveAfter(unwrap(other)); 472 } 473 474 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { 475 return unwrap(op)->moveBefore(unwrap(other)); 476 } 477 478 //===----------------------------------------------------------------------===// 479 // Region API. 480 //===----------------------------------------------------------------------===// 481 482 MlirRegion mlirRegionCreate() { return wrap(new Region); } 483 484 bool mlirRegionEqual(MlirRegion region, MlirRegion other) { 485 return unwrap(region) == unwrap(other); 486 } 487 488 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 489 Region *cppRegion = unwrap(region); 490 if (cppRegion->empty()) 491 return wrap(static_cast<Block *>(nullptr)); 492 return wrap(&cppRegion->front()); 493 } 494 495 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 496 unwrap(region)->push_back(unwrap(block)); 497 } 498 499 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 500 MlirBlock block) { 501 auto &blockList = unwrap(region)->getBlocks(); 502 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 503 } 504 505 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 506 MlirBlock block) { 507 Region *cppRegion = unwrap(region); 508 if (mlirBlockIsNull(reference)) { 509 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 510 return; 511 } 512 513 assert(unwrap(reference)->getParent() == unwrap(region) && 514 "expected reference block to belong to the region"); 515 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 516 unwrap(block)); 517 } 518 519 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 520 MlirBlock block) { 521 if (mlirBlockIsNull(reference)) 522 return mlirRegionAppendOwnedBlock(region, block); 523 524 assert(unwrap(reference)->getParent() == unwrap(region) && 525 "expected reference block to belong to the region"); 526 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 527 unwrap(block)); 528 } 529 530 void mlirRegionDestroy(MlirRegion region) { 531 delete static_cast<Region *>(region.ptr); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // Block API. 536 //===----------------------------------------------------------------------===// 537 538 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, 539 MlirLocation const *locs) { 540 Block *b = new Block; 541 for (intptr_t i = 0; i < nArgs; ++i) 542 b->addArgument(unwrap(args[i]), unwrap(locs[i])); 543 return wrap(b); 544 } 545 546 bool mlirBlockEqual(MlirBlock block, MlirBlock other) { 547 return unwrap(block) == unwrap(other); 548 } 549 550 MlirOperation mlirBlockGetParentOperation(MlirBlock block) { 551 return wrap(unwrap(block)->getParentOp()); 552 } 553 554 MlirRegion mlirBlockGetParentRegion(MlirBlock block) { 555 return wrap(unwrap(block)->getParent()); 556 } 557 558 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 559 return wrap(unwrap(block)->getNextNode()); 560 } 561 562 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 563 Block *cppBlock = unwrap(block); 564 if (cppBlock->empty()) 565 return wrap(static_cast<Operation *>(nullptr)); 566 return wrap(&cppBlock->front()); 567 } 568 569 MlirOperation mlirBlockGetTerminator(MlirBlock block) { 570 Block *cppBlock = unwrap(block); 571 if (cppBlock->empty()) 572 return wrap(static_cast<Operation *>(nullptr)); 573 Operation &back = cppBlock->back(); 574 if (!back.hasTrait<OpTrait::IsTerminator>()) 575 return wrap(static_cast<Operation *>(nullptr)); 576 return wrap(&back); 577 } 578 579 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 580 unwrap(block)->push_back(unwrap(operation)); 581 } 582 583 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 584 MlirOperation operation) { 585 auto &opList = unwrap(block)->getOperations(); 586 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 587 } 588 589 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 590 MlirOperation reference, 591 MlirOperation operation) { 592 Block *cppBlock = unwrap(block); 593 if (mlirOperationIsNull(reference)) { 594 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 595 return; 596 } 597 598 assert(unwrap(reference)->getBlock() == unwrap(block) && 599 "expected reference operation to belong to the block"); 600 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 601 unwrap(operation)); 602 } 603 604 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 605 MlirOperation reference, 606 MlirOperation operation) { 607 if (mlirOperationIsNull(reference)) 608 return mlirBlockAppendOwnedOperation(block, operation); 609 610 assert(unwrap(reference)->getBlock() == unwrap(block) && 611 "expected reference operation to belong to the block"); 612 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 613 unwrap(operation)); 614 } 615 616 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 617 618 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 619 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 620 } 621 622 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, 623 MlirLocation loc) { 624 return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc))); 625 } 626 627 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 628 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 629 } 630 631 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 632 void *userData) { 633 detail::CallbackOstream stream(callback, userData); 634 unwrap(block)->print(stream); 635 } 636 637 //===----------------------------------------------------------------------===// 638 // Value API. 639 //===----------------------------------------------------------------------===// 640 641 bool mlirValueEqual(MlirValue value1, MlirValue value2) { 642 return unwrap(value1) == unwrap(value2); 643 } 644 645 bool mlirValueIsABlockArgument(MlirValue value) { 646 return unwrap(value).isa<BlockArgument>(); 647 } 648 649 bool mlirValueIsAOpResult(MlirValue value) { 650 return unwrap(value).isa<OpResult>(); 651 } 652 653 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 654 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 655 } 656 657 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 658 return static_cast<intptr_t>( 659 unwrap(value).cast<BlockArgument>().getArgNumber()); 660 } 661 662 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 663 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 664 } 665 666 MlirOperation mlirOpResultGetOwner(MlirValue value) { 667 return wrap(unwrap(value).cast<OpResult>().getOwner()); 668 } 669 670 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 671 return static_cast<intptr_t>( 672 unwrap(value).cast<OpResult>().getResultNumber()); 673 } 674 675 MlirType mlirValueGetType(MlirValue value) { 676 return wrap(unwrap(value).getType()); 677 } 678 679 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 680 681 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 682 void *userData) { 683 detail::CallbackOstream stream(callback, userData); 684 unwrap(value).print(stream); 685 } 686 687 //===----------------------------------------------------------------------===// 688 // Type API. 689 //===----------------------------------------------------------------------===// 690 691 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { 692 return wrap(mlir::parseType(unwrap(type), unwrap(context))); 693 } 694 695 MlirContext mlirTypeGetContext(MlirType type) { 696 return wrap(unwrap(type).getContext()); 697 } 698 699 MlirTypeID mlirTypeGetTypeID(MlirType type) { 700 return wrap(unwrap(type).getTypeID()); 701 } 702 703 bool mlirTypeEqual(MlirType t1, MlirType t2) { 704 return unwrap(t1) == unwrap(t2); 705 } 706 707 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 708 detail::CallbackOstream stream(callback, userData); 709 unwrap(type).print(stream); 710 } 711 712 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 713 714 //===----------------------------------------------------------------------===// 715 // Attribute API. 716 //===----------------------------------------------------------------------===// 717 718 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { 719 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context))); 720 } 721 722 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 723 return wrap(unwrap(attribute).getContext()); 724 } 725 726 MlirType mlirAttributeGetType(MlirAttribute attribute) { 727 return wrap(unwrap(attribute).getType()); 728 } 729 730 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { 731 return wrap(unwrap(attr).getTypeID()); 732 } 733 734 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 735 return unwrap(a1) == unwrap(a2); 736 } 737 738 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 739 void *userData) { 740 detail::CallbackOstream stream(callback, userData); 741 unwrap(attr).print(stream); 742 } 743 744 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 745 746 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, 747 MlirAttribute attr) { 748 return MlirNamedAttribute{name, attr}; 749 } 750 751 //===----------------------------------------------------------------------===// 752 // Identifier API. 753 //===----------------------------------------------------------------------===// 754 755 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { 756 return wrap(StringAttr::get(unwrap(context), unwrap(str))); 757 } 758 759 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { 760 return wrap(unwrap(ident).getContext()); 761 } 762 763 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { 764 return unwrap(ident) == unwrap(other); 765 } 766 767 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { 768 return wrap(unwrap(ident).strref()); 769 } 770 771 //===----------------------------------------------------------------------===// 772 // TypeID API. 773 //===----------------------------------------------------------------------===// 774 775 bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { 776 return unwrap(typeID1) == unwrap(typeID2); 777 } 778 779 size_t mlirTypeIDHashValue(MlirTypeID typeID) { 780 return hash_value(unwrap(typeID)); 781 } 782 783 //===----------------------------------------------------------------------===// 784 // Symbol and SymbolTable API. 785 //===----------------------------------------------------------------------===// 786 787 MlirStringRef mlirSymbolTableGetSymbolAttributeName() { 788 return wrap(SymbolTable::getSymbolAttrName()); 789 } 790 791 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { 792 return wrap(SymbolTable::getVisibilityAttrName()); 793 } 794 795 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { 796 if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>()) 797 return wrap(static_cast<SymbolTable *>(nullptr)); 798 return wrap(new SymbolTable(unwrap(operation))); 799 } 800 801 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { 802 delete unwrap(symbolTable); 803 } 804 805 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, 806 MlirStringRef name) { 807 return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length))); 808 } 809 810 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, 811 MlirOperation operation) { 812 return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation))); 813 } 814 815 void mlirSymbolTableErase(MlirSymbolTable symbolTable, 816 MlirOperation operation) { 817 unwrap(symbolTable)->erase(unwrap(operation)); 818 } 819 820 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, 821 MlirStringRef newSymbol, 822 MlirOperation from) { 823 auto *cppFrom = unwrap(from); 824 auto *context = cppFrom->getContext(); 825 auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); 826 auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); 827 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, 828 unwrap(from))); 829 } 830 831 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, 832 void (*callback)(MlirOperation, bool, 833 void *userData), 834 void *userData) { 835 SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible, 836 [&](Operation *foundOpCpp, bool isVisible) { 837 callback(wrap(foundOpCpp), isVisible, 838 userData); 839 }); 840 } 841