1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Dialect.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/IR/Verifier.h" 21 #include "mlir/Parser.h" 22 23 using namespace mlir; 24 25 //===----------------------------------------------------------------------===// 26 // Context API. 27 //===----------------------------------------------------------------------===// 28 29 MlirContext mlirContextCreate() { 30 auto *context = new MLIRContext; 31 return wrap(context); 32 } 33 34 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 35 return unwrap(ctx1) == unwrap(ctx2); 36 } 37 38 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 39 40 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { 41 unwrap(context)->allowUnregisteredDialects(allow); 42 } 43 44 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { 45 return unwrap(context)->allowsUnregisteredDialects(); 46 } 47 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 48 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 49 } 50 51 // TODO: expose a cheaper way than constructing + sorting a vector only to take 52 // its size. 53 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 54 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 55 } 56 57 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 58 MlirStringRef name) { 59 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // Dialect API. 64 //===----------------------------------------------------------------------===// 65 66 MlirContext mlirDialectGetContext(MlirDialect dialect) { 67 return wrap(unwrap(dialect)->getContext()); 68 } 69 70 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 71 return unwrap(dialect1) == unwrap(dialect2); 72 } 73 74 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 75 return wrap(unwrap(dialect)->getNamespace()); 76 } 77 78 //===----------------------------------------------------------------------===// 79 // Printing flags API. 80 //===----------------------------------------------------------------------===// 81 82 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 83 return wrap(new OpPrintingFlags()); 84 } 85 86 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 87 delete unwrap(flags); 88 } 89 90 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 91 intptr_t largeElementLimit) { 92 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 93 } 94 95 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, 96 bool prettyForm) { 97 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); 98 } 99 100 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 101 unwrap(flags)->printGenericOpForm(); 102 } 103 104 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 105 unwrap(flags)->useLocalScope(); 106 } 107 108 //===----------------------------------------------------------------------===// 109 // Location API. 110 //===----------------------------------------------------------------------===// 111 112 MlirLocation mlirLocationFileLineColGet(MlirContext context, 113 MlirStringRef filename, unsigned line, 114 unsigned col) { 115 return wrap( 116 FileLineColLoc::get(unwrap(filename), line, col, unwrap(context))); 117 } 118 119 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { 120 return wrap(CallSiteLoc::get(unwrap(callee), unwrap(caller))); 121 } 122 123 MlirLocation mlirLocationUnknownGet(MlirContext context) { 124 return wrap(UnknownLoc::get(unwrap(context))); 125 } 126 127 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { 128 return unwrap(l1) == unwrap(l2); 129 } 130 131 MlirContext mlirLocationGetContext(MlirLocation location) { 132 return wrap(unwrap(location).getContext()); 133 } 134 135 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 136 void *userData) { 137 detail::CallbackOstream stream(callback, userData); 138 unwrap(location).print(stream); 139 } 140 141 //===----------------------------------------------------------------------===// 142 // Module API. 143 //===----------------------------------------------------------------------===// 144 145 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 146 return wrap(ModuleOp::create(unwrap(location))); 147 } 148 149 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { 150 OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context)); 151 if (!owning) 152 return MlirModule{nullptr}; 153 return MlirModule{owning.release().getOperation()}; 154 } 155 156 MlirContext mlirModuleGetContext(MlirModule module) { 157 return wrap(unwrap(module).getContext()); 158 } 159 160 MlirBlock mlirModuleGetBody(MlirModule module) { 161 return wrap(unwrap(module).getBody()); 162 } 163 164 void mlirModuleDestroy(MlirModule module) { 165 // Transfer ownership to an OwningModuleRef so that its destructor is called. 166 OwningModuleRef(unwrap(module)); 167 } 168 169 MlirOperation mlirModuleGetOperation(MlirModule module) { 170 return wrap(unwrap(module).getOperation()); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // Operation state API. 175 //===----------------------------------------------------------------------===// 176 177 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { 178 MlirOperationState state; 179 state.name = name; 180 state.location = loc; 181 state.nResults = 0; 182 state.results = nullptr; 183 state.nOperands = 0; 184 state.operands = nullptr; 185 state.nRegions = 0; 186 state.regions = nullptr; 187 state.nSuccessors = 0; 188 state.successors = nullptr; 189 state.nAttributes = 0; 190 state.attributes = nullptr; 191 return state; 192 } 193 194 #define APPEND_ELEMS(type, sizeName, elemName) \ 195 state->elemName = \ 196 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 197 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 198 state->sizeName += n; 199 200 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 201 MlirType const *results) { 202 APPEND_ELEMS(MlirType, nResults, results); 203 } 204 205 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 206 MlirValue const *operands) { 207 APPEND_ELEMS(MlirValue, nOperands, operands); 208 } 209 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 210 MlirRegion const *regions) { 211 APPEND_ELEMS(MlirRegion, nRegions, regions); 212 } 213 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 214 MlirBlock const *successors) { 215 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 216 } 217 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 218 MlirNamedAttribute const *attributes) { 219 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 220 } 221 222 //===----------------------------------------------------------------------===// 223 // Operation API. 224 //===----------------------------------------------------------------------===// 225 226 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 227 assert(state); 228 OperationState cppState(unwrap(state->location), unwrap(state->name)); 229 SmallVector<Type, 4> resultStorage; 230 SmallVector<Value, 8> operandStorage; 231 SmallVector<Block *, 2> successorStorage; 232 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 233 cppState.addOperands( 234 unwrapList(state->nOperands, state->operands, operandStorage)); 235 cppState.addSuccessors( 236 unwrapList(state->nSuccessors, state->successors, successorStorage)); 237 238 cppState.attributes.reserve(state->nAttributes); 239 for (intptr_t i = 0; i < state->nAttributes; ++i) 240 cppState.addAttribute(unwrap(state->attributes[i].name), 241 unwrap(state->attributes[i].attribute)); 242 243 for (intptr_t i = 0; i < state->nRegions; ++i) 244 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 245 246 MlirOperation result = wrap(Operation::create(cppState)); 247 free(state->results); 248 free(state->operands); 249 free(state->successors); 250 free(state->regions); 251 free(state->attributes); 252 return result; 253 } 254 255 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 256 257 bool mlirOperationEqual(MlirOperation op, MlirOperation other) { 258 return unwrap(op) == unwrap(other); 259 } 260 261 MlirIdentifier mlirOperationGetName(MlirOperation op) { 262 return wrap(unwrap(op)->getName().getIdentifier()); 263 } 264 265 MlirBlock mlirOperationGetBlock(MlirOperation op) { 266 return wrap(unwrap(op)->getBlock()); 267 } 268 269 MlirOperation mlirOperationGetParentOperation(MlirOperation op) { 270 return wrap(unwrap(op)->getParentOp()); 271 } 272 273 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 274 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 275 } 276 277 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 278 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 279 } 280 281 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 282 return wrap(unwrap(op)->getNextNode()); 283 } 284 285 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 286 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 287 } 288 289 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 290 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 291 } 292 293 intptr_t mlirOperationGetNumResults(MlirOperation op) { 294 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 295 } 296 297 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 298 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 299 } 300 301 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 302 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 303 } 304 305 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 306 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 307 } 308 309 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 310 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 311 } 312 313 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 314 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 315 return MlirNamedAttribute{wrap(attr.first), wrap(attr.second)}; 316 } 317 318 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 319 MlirStringRef name) { 320 return wrap(unwrap(op)->getAttr(unwrap(name))); 321 } 322 323 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, 324 MlirAttribute attr) { 325 unwrap(op)->setAttr(unwrap(name), unwrap(attr)); 326 } 327 328 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { 329 auto removeResult = unwrap(op)->removeAttr(unwrap(name)); 330 return removeResult == MutableDictionaryAttr::RemoveResult::Removed; 331 } 332 333 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 334 void *userData) { 335 detail::CallbackOstream stream(callback, userData); 336 unwrap(op)->print(stream); 337 } 338 339 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 340 MlirStringCallback callback, void *userData) { 341 detail::CallbackOstream stream(callback, userData); 342 unwrap(op)->print(stream, *unwrap(flags)); 343 } 344 345 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 346 347 bool mlirOperationVerify(MlirOperation op) { 348 return succeeded(verify(unwrap(op))); 349 } 350 351 //===----------------------------------------------------------------------===// 352 // Region API. 353 //===----------------------------------------------------------------------===// 354 355 MlirRegion mlirRegionCreate() { return wrap(new Region); } 356 357 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 358 Region *cppRegion = unwrap(region); 359 if (cppRegion->empty()) 360 return wrap(static_cast<Block *>(nullptr)); 361 return wrap(&cppRegion->front()); 362 } 363 364 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 365 unwrap(region)->push_back(unwrap(block)); 366 } 367 368 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 369 MlirBlock block) { 370 auto &blockList = unwrap(region)->getBlocks(); 371 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 372 } 373 374 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 375 MlirBlock block) { 376 Region *cppRegion = unwrap(region); 377 if (mlirBlockIsNull(reference)) { 378 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 379 return; 380 } 381 382 assert(unwrap(reference)->getParent() == unwrap(region) && 383 "expected reference block to belong to the region"); 384 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 385 unwrap(block)); 386 } 387 388 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 389 MlirBlock block) { 390 if (mlirBlockIsNull(reference)) 391 return mlirRegionAppendOwnedBlock(region, block); 392 393 assert(unwrap(reference)->getParent() == unwrap(region) && 394 "expected reference block to belong to the region"); 395 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 396 unwrap(block)); 397 } 398 399 void mlirRegionDestroy(MlirRegion region) { 400 delete static_cast<Region *>(region.ptr); 401 } 402 403 //===----------------------------------------------------------------------===// 404 // Block API. 405 //===----------------------------------------------------------------------===// 406 407 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) { 408 Block *b = new Block; 409 for (intptr_t i = 0; i < nArgs; ++i) 410 b->addArgument(unwrap(args[i])); 411 return wrap(b); 412 } 413 414 bool mlirBlockEqual(MlirBlock block, MlirBlock other) { 415 return unwrap(block) == unwrap(other); 416 } 417 418 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 419 return wrap(unwrap(block)->getNextNode()); 420 } 421 422 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 423 Block *cppBlock = unwrap(block); 424 if (cppBlock->empty()) 425 return wrap(static_cast<Operation *>(nullptr)); 426 return wrap(&cppBlock->front()); 427 } 428 429 MlirOperation mlirBlockGetTerminator(MlirBlock block) { 430 Block *cppBlock = unwrap(block); 431 if (cppBlock->empty()) 432 return wrap(static_cast<Operation *>(nullptr)); 433 Operation &back = cppBlock->back(); 434 if (!back.isKnownTerminator()) 435 return wrap(static_cast<Operation *>(nullptr)); 436 return wrap(&back); 437 } 438 439 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 440 unwrap(block)->push_back(unwrap(operation)); 441 } 442 443 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 444 MlirOperation operation) { 445 auto &opList = unwrap(block)->getOperations(); 446 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 447 } 448 449 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 450 MlirOperation reference, 451 MlirOperation operation) { 452 Block *cppBlock = unwrap(block); 453 if (mlirOperationIsNull(reference)) { 454 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 455 return; 456 } 457 458 assert(unwrap(reference)->getBlock() == unwrap(block) && 459 "expected reference operation to belong to the block"); 460 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 461 unwrap(operation)); 462 } 463 464 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 465 MlirOperation reference, 466 MlirOperation operation) { 467 if (mlirOperationIsNull(reference)) 468 return mlirBlockAppendOwnedOperation(block, operation); 469 470 assert(unwrap(reference)->getBlock() == unwrap(block) && 471 "expected reference operation to belong to the block"); 472 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 473 unwrap(operation)); 474 } 475 476 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 477 478 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 479 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 480 } 481 482 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 483 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 484 } 485 486 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 487 void *userData) { 488 detail::CallbackOstream stream(callback, userData); 489 unwrap(block)->print(stream); 490 } 491 492 //===----------------------------------------------------------------------===// 493 // Value API. 494 //===----------------------------------------------------------------------===// 495 496 bool mlirValueEqual(MlirValue value1, MlirValue value2) { 497 return unwrap(value1) == unwrap(value2); 498 } 499 500 bool mlirValueIsABlockArgument(MlirValue value) { 501 return unwrap(value).isa<BlockArgument>(); 502 } 503 504 bool mlirValueIsAOpResult(MlirValue value) { 505 return unwrap(value).isa<OpResult>(); 506 } 507 508 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 509 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 510 } 511 512 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 513 return static_cast<intptr_t>( 514 unwrap(value).cast<BlockArgument>().getArgNumber()); 515 } 516 517 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 518 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 519 } 520 521 MlirOperation mlirOpResultGetOwner(MlirValue value) { 522 return wrap(unwrap(value).cast<OpResult>().getOwner()); 523 } 524 525 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 526 return static_cast<intptr_t>( 527 unwrap(value).cast<OpResult>().getResultNumber()); 528 } 529 530 MlirType mlirValueGetType(MlirValue value) { 531 return wrap(unwrap(value).getType()); 532 } 533 534 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 535 536 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 537 void *userData) { 538 detail::CallbackOstream stream(callback, userData); 539 unwrap(value).print(stream); 540 } 541 542 //===----------------------------------------------------------------------===// 543 // Type API. 544 //===----------------------------------------------------------------------===// 545 546 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { 547 return wrap(mlir::parseType(unwrap(type), unwrap(context))); 548 } 549 550 MlirContext mlirTypeGetContext(MlirType type) { 551 return wrap(unwrap(type).getContext()); 552 } 553 554 bool mlirTypeEqual(MlirType t1, MlirType t2) { 555 return unwrap(t1) == unwrap(t2); 556 } 557 558 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 559 detail::CallbackOstream stream(callback, userData); 560 unwrap(type).print(stream); 561 } 562 563 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 564 565 //===----------------------------------------------------------------------===// 566 // Attribute API. 567 //===----------------------------------------------------------------------===// 568 569 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { 570 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context))); 571 } 572 573 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 574 return wrap(unwrap(attribute).getContext()); 575 } 576 577 MlirType mlirAttributeGetType(MlirAttribute attribute) { 578 return wrap(unwrap(attribute).getType()); 579 } 580 581 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 582 return unwrap(a1) == unwrap(a2); 583 } 584 585 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 586 void *userData) { 587 detail::CallbackOstream stream(callback, userData); 588 unwrap(attr).print(stream); 589 } 590 591 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 592 593 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, 594 MlirAttribute attr) { 595 return MlirNamedAttribute{name, attr}; 596 } 597 598 //===----------------------------------------------------------------------===// 599 // Identifier API. 600 //===----------------------------------------------------------------------===// 601 602 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { 603 return wrap(Identifier::get(unwrap(str), unwrap(context))); 604 } 605 606 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { 607 return unwrap(ident) == unwrap(other); 608 } 609 610 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { 611 return wrap(unwrap(ident).strref()); 612 } 613