1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/Dialect.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/Parser.h" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Context API. 26 //===----------------------------------------------------------------------===// 27 28 MlirContext mlirContextCreate() { 29 auto *context = new MLIRContext; 30 return wrap(context); 31 } 32 33 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 34 return unwrap(ctx1) == unwrap(ctx2); 35 } 36 37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 38 39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { 40 unwrap(context)->allowUnregisteredDialects(allow); 41 } 42 43 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { 44 return unwrap(context)->allowsUnregisteredDialects(); 45 } 46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 47 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 48 } 49 50 // TODO: expose a cheaper way than constructing + sorting a vector only to take 51 // its size. 52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 53 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 54 } 55 56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 57 MlirStringRef name) { 58 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // Dialect API. 63 //===----------------------------------------------------------------------===// 64 65 MlirContext mlirDialectGetContext(MlirDialect dialect) { 66 return wrap(unwrap(dialect)->getContext()); 67 } 68 69 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 70 return unwrap(dialect1) == unwrap(dialect2); 71 } 72 73 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 74 return wrap(unwrap(dialect)->getNamespace()); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Printing flags API. 79 //===----------------------------------------------------------------------===// 80 81 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 82 return wrap(new OpPrintingFlags()); 83 } 84 85 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 86 delete unwrap(flags); 87 } 88 89 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 90 intptr_t largeElementLimit) { 91 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 92 } 93 94 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, 95 bool prettyForm) { 96 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); 97 } 98 99 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 100 unwrap(flags)->printGenericOpForm(); 101 } 102 103 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 104 unwrap(flags)->useLocalScope(); 105 } 106 107 //===----------------------------------------------------------------------===// 108 // Location API. 109 //===----------------------------------------------------------------------===// 110 111 MlirLocation mlirLocationFileLineColGet(MlirContext context, 112 MlirStringRef filename, unsigned line, 113 unsigned col) { 114 return wrap( 115 FileLineColLoc::get(unwrap(filename), line, col, unwrap(context))); 116 } 117 118 MlirLocation mlirLocationUnknownGet(MlirContext context) { 119 return wrap(UnknownLoc::get(unwrap(context))); 120 } 121 122 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { 123 return unwrap(l1) == unwrap(l2); 124 } 125 126 MlirContext mlirLocationGetContext(MlirLocation location) { 127 return wrap(unwrap(location).getContext()); 128 } 129 130 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 131 void *userData) { 132 detail::CallbackOstream stream(callback, userData); 133 unwrap(location).print(stream); 134 } 135 136 //===----------------------------------------------------------------------===// 137 // Module API. 138 //===----------------------------------------------------------------------===// 139 140 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 141 return wrap(ModuleOp::create(unwrap(location))); 142 } 143 144 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { 145 OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context)); 146 if (!owning) 147 return MlirModule{nullptr}; 148 return MlirModule{owning.release().getOperation()}; 149 } 150 151 MlirContext mlirModuleGetContext(MlirModule module) { 152 return wrap(unwrap(module).getContext()); 153 } 154 155 MlirBlock mlirModuleGetBody(MlirModule module) { 156 return wrap(unwrap(module).getBody()); 157 } 158 159 void mlirModuleDestroy(MlirModule module) { 160 // Transfer ownership to an OwningModuleRef so that its destructor is called. 161 OwningModuleRef(unwrap(module)); 162 } 163 164 MlirOperation mlirModuleGetOperation(MlirModule module) { 165 return wrap(unwrap(module).getOperation()); 166 } 167 168 //===----------------------------------------------------------------------===// 169 // Operation state API. 170 //===----------------------------------------------------------------------===// 171 172 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { 173 MlirOperationState state; 174 state.name = name; 175 state.location = loc; 176 state.nResults = 0; 177 state.results = nullptr; 178 state.nOperands = 0; 179 state.operands = nullptr; 180 state.nRegions = 0; 181 state.regions = nullptr; 182 state.nSuccessors = 0; 183 state.successors = nullptr; 184 state.nAttributes = 0; 185 state.attributes = nullptr; 186 return state; 187 } 188 189 #define APPEND_ELEMS(type, sizeName, elemName) \ 190 state->elemName = \ 191 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 192 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 193 state->sizeName += n; 194 195 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 196 MlirType const *results) { 197 APPEND_ELEMS(MlirType, nResults, results); 198 } 199 200 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 201 MlirValue const *operands) { 202 APPEND_ELEMS(MlirValue, nOperands, operands); 203 } 204 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 205 MlirRegion const *regions) { 206 APPEND_ELEMS(MlirRegion, nRegions, regions); 207 } 208 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 209 MlirBlock const *successors) { 210 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 211 } 212 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 213 MlirNamedAttribute const *attributes) { 214 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // Operation API. 219 //===----------------------------------------------------------------------===// 220 221 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 222 assert(state); 223 OperationState cppState(unwrap(state->location), unwrap(state->name)); 224 SmallVector<Type, 4> resultStorage; 225 SmallVector<Value, 8> operandStorage; 226 SmallVector<Block *, 2> successorStorage; 227 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 228 cppState.addOperands( 229 unwrapList(state->nOperands, state->operands, operandStorage)); 230 cppState.addSuccessors( 231 unwrapList(state->nSuccessors, state->successors, successorStorage)); 232 233 cppState.attributes.reserve(state->nAttributes); 234 for (intptr_t i = 0; i < state->nAttributes; ++i) 235 cppState.addAttribute(unwrap(state->attributes[i].name), 236 unwrap(state->attributes[i].attribute)); 237 238 for (intptr_t i = 0; i < state->nRegions; ++i) 239 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 240 241 MlirOperation result = wrap(Operation::create(cppState)); 242 free(state->results); 243 free(state->operands); 244 free(state->successors); 245 free(state->regions); 246 free(state->attributes); 247 return result; 248 } 249 250 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 251 252 bool mlirOperationEqual(MlirOperation op, MlirOperation other) { 253 return unwrap(op) == unwrap(other); 254 } 255 256 MlirIdentifier mlirOperationGetName(MlirOperation op) { 257 return wrap(unwrap(op)->getName().getIdentifier()); 258 } 259 260 MlirBlock mlirOperationGetBlock(MlirOperation op) { 261 return wrap(unwrap(op)->getBlock()); 262 } 263 264 MlirOperation mlirOperationGetParentOperation(MlirOperation op) { 265 return wrap(unwrap(op)->getParentOp()); 266 } 267 268 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 269 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 270 } 271 272 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 273 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 274 } 275 276 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 277 return wrap(unwrap(op)->getNextNode()); 278 } 279 280 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 281 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 282 } 283 284 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 285 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 286 } 287 288 intptr_t mlirOperationGetNumResults(MlirOperation op) { 289 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 290 } 291 292 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 293 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 294 } 295 296 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 297 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 298 } 299 300 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 301 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 302 } 303 304 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 305 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 306 } 307 308 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 309 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 310 return MlirNamedAttribute{wrap(attr.first.strref()), wrap(attr.second)}; 311 } 312 313 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 314 MlirStringRef name) { 315 return wrap(unwrap(op)->getAttr(unwrap(name))); 316 } 317 318 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, 319 MlirAttribute attr) { 320 unwrap(op)->setAttr(unwrap(name), unwrap(attr)); 321 } 322 323 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { 324 auto removeResult = unwrap(op)->removeAttr(unwrap(name)); 325 return removeResult == MutableDictionaryAttr::RemoveResult::Removed; 326 } 327 328 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 329 void *userData) { 330 detail::CallbackOstream stream(callback, userData); 331 unwrap(op)->print(stream); 332 } 333 334 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 335 MlirStringCallback callback, void *userData) { 336 detail::CallbackOstream stream(callback, userData); 337 unwrap(op)->print(stream, *unwrap(flags)); 338 } 339 340 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 341 342 //===----------------------------------------------------------------------===// 343 // Region API. 344 //===----------------------------------------------------------------------===// 345 346 MlirRegion mlirRegionCreate() { return wrap(new Region); } 347 348 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 349 Region *cppRegion = unwrap(region); 350 if (cppRegion->empty()) 351 return wrap(static_cast<Block *>(nullptr)); 352 return wrap(&cppRegion->front()); 353 } 354 355 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 356 unwrap(region)->push_back(unwrap(block)); 357 } 358 359 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 360 MlirBlock block) { 361 auto &blockList = unwrap(region)->getBlocks(); 362 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 363 } 364 365 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 366 MlirBlock block) { 367 Region *cppRegion = unwrap(region); 368 if (mlirBlockIsNull(reference)) { 369 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 370 return; 371 } 372 373 assert(unwrap(reference)->getParent() == unwrap(region) && 374 "expected reference block to belong to the region"); 375 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 376 unwrap(block)); 377 } 378 379 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 380 MlirBlock block) { 381 if (mlirBlockIsNull(reference)) 382 return mlirRegionAppendOwnedBlock(region, block); 383 384 assert(unwrap(reference)->getParent() == unwrap(region) && 385 "expected reference block to belong to the region"); 386 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 387 unwrap(block)); 388 } 389 390 void mlirRegionDestroy(MlirRegion region) { 391 delete static_cast<Region *>(region.ptr); 392 } 393 394 //===----------------------------------------------------------------------===// 395 // Block API. 396 //===----------------------------------------------------------------------===// 397 398 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) { 399 Block *b = new Block; 400 for (intptr_t i = 0; i < nArgs; ++i) 401 b->addArgument(unwrap(args[i])); 402 return wrap(b); 403 } 404 405 bool mlirBlockEqual(MlirBlock block, MlirBlock other) { 406 return unwrap(block) == unwrap(other); 407 } 408 409 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 410 return wrap(unwrap(block)->getNextNode()); 411 } 412 413 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 414 Block *cppBlock = unwrap(block); 415 if (cppBlock->empty()) 416 return wrap(static_cast<Operation *>(nullptr)); 417 return wrap(&cppBlock->front()); 418 } 419 420 MlirOperation mlirBlockGetTerminator(MlirBlock block) { 421 Block *cppBlock = unwrap(block); 422 if (cppBlock->empty()) 423 return wrap(static_cast<Operation *>(nullptr)); 424 Operation &back = cppBlock->back(); 425 if (!back.isKnownTerminator()) 426 return wrap(static_cast<Operation *>(nullptr)); 427 return wrap(&back); 428 } 429 430 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 431 unwrap(block)->push_back(unwrap(operation)); 432 } 433 434 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 435 MlirOperation operation) { 436 auto &opList = unwrap(block)->getOperations(); 437 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 438 } 439 440 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 441 MlirOperation reference, 442 MlirOperation operation) { 443 Block *cppBlock = unwrap(block); 444 if (mlirOperationIsNull(reference)) { 445 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 446 return; 447 } 448 449 assert(unwrap(reference)->getBlock() == unwrap(block) && 450 "expected reference operation to belong to the block"); 451 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 452 unwrap(operation)); 453 } 454 455 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 456 MlirOperation reference, 457 MlirOperation operation) { 458 if (mlirOperationIsNull(reference)) 459 return mlirBlockAppendOwnedOperation(block, operation); 460 461 assert(unwrap(reference)->getBlock() == unwrap(block) && 462 "expected reference operation to belong to the block"); 463 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 464 unwrap(operation)); 465 } 466 467 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 468 469 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 470 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 471 } 472 473 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 474 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 475 } 476 477 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 478 void *userData) { 479 detail::CallbackOstream stream(callback, userData); 480 unwrap(block)->print(stream); 481 } 482 483 //===----------------------------------------------------------------------===// 484 // Value API. 485 //===----------------------------------------------------------------------===// 486 487 bool mlirValueEqual(MlirValue value1, MlirValue value2) { 488 return unwrap(value1) == unwrap(value2); 489 } 490 491 bool mlirValueIsABlockArgument(MlirValue value) { 492 return unwrap(value).isa<BlockArgument>(); 493 } 494 495 bool mlirValueIsAOpResult(MlirValue value) { 496 return unwrap(value).isa<OpResult>(); 497 } 498 499 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 500 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 501 } 502 503 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 504 return static_cast<intptr_t>( 505 unwrap(value).cast<BlockArgument>().getArgNumber()); 506 } 507 508 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 509 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 510 } 511 512 MlirOperation mlirOpResultGetOwner(MlirValue value) { 513 return wrap(unwrap(value).cast<OpResult>().getOwner()); 514 } 515 516 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 517 return static_cast<intptr_t>( 518 unwrap(value).cast<OpResult>().getResultNumber()); 519 } 520 521 MlirType mlirValueGetType(MlirValue value) { 522 return wrap(unwrap(value).getType()); 523 } 524 525 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 526 527 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 528 void *userData) { 529 detail::CallbackOstream stream(callback, userData); 530 unwrap(value).print(stream); 531 } 532 533 //===----------------------------------------------------------------------===// 534 // Type API. 535 //===----------------------------------------------------------------------===// 536 537 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { 538 return wrap(mlir::parseType(unwrap(type), unwrap(context))); 539 } 540 541 MlirContext mlirTypeGetContext(MlirType type) { 542 return wrap(unwrap(type).getContext()); 543 } 544 545 bool mlirTypeEqual(MlirType t1, MlirType t2) { 546 return unwrap(t1) == unwrap(t2); 547 } 548 549 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 550 detail::CallbackOstream stream(callback, userData); 551 unwrap(type).print(stream); 552 } 553 554 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 555 556 //===----------------------------------------------------------------------===// 557 // Attribute API. 558 //===----------------------------------------------------------------------===// 559 560 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { 561 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context))); 562 } 563 564 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 565 return wrap(unwrap(attribute).getContext()); 566 } 567 568 MlirType mlirAttributeGetType(MlirAttribute attribute) { 569 return wrap(unwrap(attribute).getType()); 570 } 571 572 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 573 return unwrap(a1) == unwrap(a2); 574 } 575 576 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 577 void *userData) { 578 detail::CallbackOstream stream(callback, userData); 579 unwrap(attr).print(stream); 580 } 581 582 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 583 584 MlirNamedAttribute mlirNamedAttributeGet(MlirStringRef name, 585 MlirAttribute attr) { 586 return MlirNamedAttribute{name, attr}; 587 } 588 589 //===----------------------------------------------------------------------===// 590 // Identifier API. 591 //===----------------------------------------------------------------------===// 592 593 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { 594 return wrap(Identifier::get(unwrap(str), unwrap(context))); 595 } 596 597 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { 598 return unwrap(ident) == unwrap(other); 599 } 600 601 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { 602 return wrap(unwrap(ident).strref()); 603 } 604