1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Dialect.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/IR/Verifier.h" 21 #include "mlir/Parser.h" 22 23 using namespace mlir; 24 25 //===----------------------------------------------------------------------===// 26 // Context API. 27 //===----------------------------------------------------------------------===// 28 29 MlirContext mlirContextCreate() { 30 auto *context = new MLIRContext; 31 return wrap(context); 32 } 33 34 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 35 return unwrap(ctx1) == unwrap(ctx2); 36 } 37 38 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 39 40 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { 41 unwrap(context)->allowUnregisteredDialects(allow); 42 } 43 44 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { 45 return unwrap(context)->allowsUnregisteredDialects(); 46 } 47 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 48 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 49 } 50 51 // TODO: expose a cheaper way than constructing + sorting a vector only to take 52 // its size. 53 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 54 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 55 } 56 57 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 58 MlirStringRef name) { 59 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // Dialect API. 64 //===----------------------------------------------------------------------===// 65 66 MlirContext mlirDialectGetContext(MlirDialect dialect) { 67 return wrap(unwrap(dialect)->getContext()); 68 } 69 70 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 71 return unwrap(dialect1) == unwrap(dialect2); 72 } 73 74 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 75 return wrap(unwrap(dialect)->getNamespace()); 76 } 77 78 //===----------------------------------------------------------------------===// 79 // Printing flags API. 80 //===----------------------------------------------------------------------===// 81 82 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 83 return wrap(new OpPrintingFlags()); 84 } 85 86 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 87 delete unwrap(flags); 88 } 89 90 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 91 intptr_t largeElementLimit) { 92 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 93 } 94 95 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, 96 bool prettyForm) { 97 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); 98 } 99 100 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 101 unwrap(flags)->printGenericOpForm(); 102 } 103 104 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 105 unwrap(flags)->useLocalScope(); 106 } 107 108 //===----------------------------------------------------------------------===// 109 // Location API. 110 //===----------------------------------------------------------------------===// 111 112 MlirLocation mlirLocationFileLineColGet(MlirContext context, 113 MlirStringRef filename, unsigned line, 114 unsigned col) { 115 return wrap( 116 FileLineColLoc::get(unwrap(filename), line, col, unwrap(context))); 117 } 118 119 MlirLocation mlirLocationUnknownGet(MlirContext context) { 120 return wrap(UnknownLoc::get(unwrap(context))); 121 } 122 123 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { 124 return unwrap(l1) == unwrap(l2); 125 } 126 127 MlirContext mlirLocationGetContext(MlirLocation location) { 128 return wrap(unwrap(location).getContext()); 129 } 130 131 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 132 void *userData) { 133 detail::CallbackOstream stream(callback, userData); 134 unwrap(location).print(stream); 135 } 136 137 //===----------------------------------------------------------------------===// 138 // Module API. 139 //===----------------------------------------------------------------------===// 140 141 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 142 return wrap(ModuleOp::create(unwrap(location))); 143 } 144 145 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { 146 OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context)); 147 if (!owning) 148 return MlirModule{nullptr}; 149 return MlirModule{owning.release().getOperation()}; 150 } 151 152 MlirContext mlirModuleGetContext(MlirModule module) { 153 return wrap(unwrap(module).getContext()); 154 } 155 156 MlirBlock mlirModuleGetBody(MlirModule module) { 157 return wrap(unwrap(module).getBody()); 158 } 159 160 void mlirModuleDestroy(MlirModule module) { 161 // Transfer ownership to an OwningModuleRef so that its destructor is called. 162 OwningModuleRef(unwrap(module)); 163 } 164 165 MlirOperation mlirModuleGetOperation(MlirModule module) { 166 return wrap(unwrap(module).getOperation()); 167 } 168 169 //===----------------------------------------------------------------------===// 170 // Operation state API. 171 //===----------------------------------------------------------------------===// 172 173 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { 174 MlirOperationState state; 175 state.name = name; 176 state.location = loc; 177 state.nResults = 0; 178 state.results = nullptr; 179 state.nOperands = 0; 180 state.operands = nullptr; 181 state.nRegions = 0; 182 state.regions = nullptr; 183 state.nSuccessors = 0; 184 state.successors = nullptr; 185 state.nAttributes = 0; 186 state.attributes = nullptr; 187 return state; 188 } 189 190 #define APPEND_ELEMS(type, sizeName, elemName) \ 191 state->elemName = \ 192 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 193 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 194 state->sizeName += n; 195 196 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 197 MlirType const *results) { 198 APPEND_ELEMS(MlirType, nResults, results); 199 } 200 201 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 202 MlirValue const *operands) { 203 APPEND_ELEMS(MlirValue, nOperands, operands); 204 } 205 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 206 MlirRegion const *regions) { 207 APPEND_ELEMS(MlirRegion, nRegions, regions); 208 } 209 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 210 MlirBlock const *successors) { 211 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 212 } 213 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 214 MlirNamedAttribute const *attributes) { 215 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // Operation API. 220 //===----------------------------------------------------------------------===// 221 222 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 223 assert(state); 224 OperationState cppState(unwrap(state->location), unwrap(state->name)); 225 SmallVector<Type, 4> resultStorage; 226 SmallVector<Value, 8> operandStorage; 227 SmallVector<Block *, 2> successorStorage; 228 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 229 cppState.addOperands( 230 unwrapList(state->nOperands, state->operands, operandStorage)); 231 cppState.addSuccessors( 232 unwrapList(state->nSuccessors, state->successors, successorStorage)); 233 234 cppState.attributes.reserve(state->nAttributes); 235 for (intptr_t i = 0; i < state->nAttributes; ++i) 236 cppState.addAttribute(unwrap(state->attributes[i].name), 237 unwrap(state->attributes[i].attribute)); 238 239 for (intptr_t i = 0; i < state->nRegions; ++i) 240 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 241 242 MlirOperation result = wrap(Operation::create(cppState)); 243 free(state->results); 244 free(state->operands); 245 free(state->successors); 246 free(state->regions); 247 free(state->attributes); 248 return result; 249 } 250 251 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 252 253 bool mlirOperationEqual(MlirOperation op, MlirOperation other) { 254 return unwrap(op) == unwrap(other); 255 } 256 257 MlirIdentifier mlirOperationGetName(MlirOperation op) { 258 return wrap(unwrap(op)->getName().getIdentifier()); 259 } 260 261 MlirBlock mlirOperationGetBlock(MlirOperation op) { 262 return wrap(unwrap(op)->getBlock()); 263 } 264 265 MlirOperation mlirOperationGetParentOperation(MlirOperation op) { 266 return wrap(unwrap(op)->getParentOp()); 267 } 268 269 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 270 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 271 } 272 273 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 274 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 275 } 276 277 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 278 return wrap(unwrap(op)->getNextNode()); 279 } 280 281 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 282 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 283 } 284 285 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 286 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 287 } 288 289 intptr_t mlirOperationGetNumResults(MlirOperation op) { 290 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 291 } 292 293 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 294 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 295 } 296 297 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 298 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 299 } 300 301 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 302 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 303 } 304 305 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 306 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 307 } 308 309 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 310 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 311 return MlirNamedAttribute{wrap(attr.first.strref()), wrap(attr.second)}; 312 } 313 314 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 315 MlirStringRef name) { 316 return wrap(unwrap(op)->getAttr(unwrap(name))); 317 } 318 319 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, 320 MlirAttribute attr) { 321 unwrap(op)->setAttr(unwrap(name), unwrap(attr)); 322 } 323 324 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { 325 auto removeResult = unwrap(op)->removeAttr(unwrap(name)); 326 return removeResult == MutableDictionaryAttr::RemoveResult::Removed; 327 } 328 329 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 330 void *userData) { 331 detail::CallbackOstream stream(callback, userData); 332 unwrap(op)->print(stream); 333 } 334 335 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 336 MlirStringCallback callback, void *userData) { 337 detail::CallbackOstream stream(callback, userData); 338 unwrap(op)->print(stream, *unwrap(flags)); 339 } 340 341 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 342 343 bool mlirOperationVerify(MlirOperation op) { 344 return succeeded(verify(unwrap(op))); 345 } 346 347 //===----------------------------------------------------------------------===// 348 // Region API. 349 //===----------------------------------------------------------------------===// 350 351 MlirRegion mlirRegionCreate() { return wrap(new Region); } 352 353 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 354 Region *cppRegion = unwrap(region); 355 if (cppRegion->empty()) 356 return wrap(static_cast<Block *>(nullptr)); 357 return wrap(&cppRegion->front()); 358 } 359 360 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 361 unwrap(region)->push_back(unwrap(block)); 362 } 363 364 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 365 MlirBlock block) { 366 auto &blockList = unwrap(region)->getBlocks(); 367 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 368 } 369 370 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 371 MlirBlock block) { 372 Region *cppRegion = unwrap(region); 373 if (mlirBlockIsNull(reference)) { 374 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 375 return; 376 } 377 378 assert(unwrap(reference)->getParent() == unwrap(region) && 379 "expected reference block to belong to the region"); 380 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 381 unwrap(block)); 382 } 383 384 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 385 MlirBlock block) { 386 if (mlirBlockIsNull(reference)) 387 return mlirRegionAppendOwnedBlock(region, block); 388 389 assert(unwrap(reference)->getParent() == unwrap(region) && 390 "expected reference block to belong to the region"); 391 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 392 unwrap(block)); 393 } 394 395 void mlirRegionDestroy(MlirRegion region) { 396 delete static_cast<Region *>(region.ptr); 397 } 398 399 //===----------------------------------------------------------------------===// 400 // Block API. 401 //===----------------------------------------------------------------------===// 402 403 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) { 404 Block *b = new Block; 405 for (intptr_t i = 0; i < nArgs; ++i) 406 b->addArgument(unwrap(args[i])); 407 return wrap(b); 408 } 409 410 bool mlirBlockEqual(MlirBlock block, MlirBlock other) { 411 return unwrap(block) == unwrap(other); 412 } 413 414 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 415 return wrap(unwrap(block)->getNextNode()); 416 } 417 418 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 419 Block *cppBlock = unwrap(block); 420 if (cppBlock->empty()) 421 return wrap(static_cast<Operation *>(nullptr)); 422 return wrap(&cppBlock->front()); 423 } 424 425 MlirOperation mlirBlockGetTerminator(MlirBlock block) { 426 Block *cppBlock = unwrap(block); 427 if (cppBlock->empty()) 428 return wrap(static_cast<Operation *>(nullptr)); 429 Operation &back = cppBlock->back(); 430 if (!back.isKnownTerminator()) 431 return wrap(static_cast<Operation *>(nullptr)); 432 return wrap(&back); 433 } 434 435 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 436 unwrap(block)->push_back(unwrap(operation)); 437 } 438 439 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 440 MlirOperation operation) { 441 auto &opList = unwrap(block)->getOperations(); 442 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 443 } 444 445 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 446 MlirOperation reference, 447 MlirOperation operation) { 448 Block *cppBlock = unwrap(block); 449 if (mlirOperationIsNull(reference)) { 450 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 451 return; 452 } 453 454 assert(unwrap(reference)->getBlock() == unwrap(block) && 455 "expected reference operation to belong to the block"); 456 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 457 unwrap(operation)); 458 } 459 460 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 461 MlirOperation reference, 462 MlirOperation operation) { 463 if (mlirOperationIsNull(reference)) 464 return mlirBlockAppendOwnedOperation(block, operation); 465 466 assert(unwrap(reference)->getBlock() == unwrap(block) && 467 "expected reference operation to belong to the block"); 468 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 469 unwrap(operation)); 470 } 471 472 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 473 474 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 475 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 476 } 477 478 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 479 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 480 } 481 482 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 483 void *userData) { 484 detail::CallbackOstream stream(callback, userData); 485 unwrap(block)->print(stream); 486 } 487 488 //===----------------------------------------------------------------------===// 489 // Value API. 490 //===----------------------------------------------------------------------===// 491 492 bool mlirValueEqual(MlirValue value1, MlirValue value2) { 493 return unwrap(value1) == unwrap(value2); 494 } 495 496 bool mlirValueIsABlockArgument(MlirValue value) { 497 return unwrap(value).isa<BlockArgument>(); 498 } 499 500 bool mlirValueIsAOpResult(MlirValue value) { 501 return unwrap(value).isa<OpResult>(); 502 } 503 504 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 505 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 506 } 507 508 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 509 return static_cast<intptr_t>( 510 unwrap(value).cast<BlockArgument>().getArgNumber()); 511 } 512 513 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 514 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 515 } 516 517 MlirOperation mlirOpResultGetOwner(MlirValue value) { 518 return wrap(unwrap(value).cast<OpResult>().getOwner()); 519 } 520 521 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 522 return static_cast<intptr_t>( 523 unwrap(value).cast<OpResult>().getResultNumber()); 524 } 525 526 MlirType mlirValueGetType(MlirValue value) { 527 return wrap(unwrap(value).getType()); 528 } 529 530 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 531 532 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 533 void *userData) { 534 detail::CallbackOstream stream(callback, userData); 535 unwrap(value).print(stream); 536 } 537 538 //===----------------------------------------------------------------------===// 539 // Type API. 540 //===----------------------------------------------------------------------===// 541 542 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { 543 return wrap(mlir::parseType(unwrap(type), unwrap(context))); 544 } 545 546 MlirContext mlirTypeGetContext(MlirType type) { 547 return wrap(unwrap(type).getContext()); 548 } 549 550 bool mlirTypeEqual(MlirType t1, MlirType t2) { 551 return unwrap(t1) == unwrap(t2); 552 } 553 554 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 555 detail::CallbackOstream stream(callback, userData); 556 unwrap(type).print(stream); 557 } 558 559 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 560 561 //===----------------------------------------------------------------------===// 562 // Attribute API. 563 //===----------------------------------------------------------------------===// 564 565 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { 566 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context))); 567 } 568 569 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 570 return wrap(unwrap(attribute).getContext()); 571 } 572 573 MlirType mlirAttributeGetType(MlirAttribute attribute) { 574 return wrap(unwrap(attribute).getType()); 575 } 576 577 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 578 return unwrap(a1) == unwrap(a2); 579 } 580 581 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 582 void *userData) { 583 detail::CallbackOstream stream(callback, userData); 584 unwrap(attr).print(stream); 585 } 586 587 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 588 589 MlirNamedAttribute mlirNamedAttributeGet(MlirStringRef name, 590 MlirAttribute attr) { 591 return MlirNamedAttribute{name, attr}; 592 } 593 594 //===----------------------------------------------------------------------===// 595 // Identifier API. 596 //===----------------------------------------------------------------------===// 597 598 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { 599 return wrap(Identifier::get(unwrap(str), unwrap(context))); 600 } 601 602 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { 603 return unwrap(ident) == unwrap(other); 604 } 605 606 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { 607 return wrap(unwrap(ident).strref()); 608 } 609