1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Module.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/Parser.h" 21 22 using namespace mlir; 23 24 /* ========================================================================== */ 25 /* Context API. */ 26 /* ========================================================================== */ 27 28 MlirContext mlirContextCreate() { 29 auto *context = new MLIRContext(/*loadAllDialects=*/false); 30 return wrap(context); 31 } 32 33 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 34 return unwrap(ctx1) == unwrap(ctx2); 35 } 36 37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 38 39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) { 40 unwrap(context)->allowUnregisteredDialects(allow); 41 } 42 43 int mlirContextGetAllowUnregisteredDialects(MlirContext context) { 44 return unwrap(context)->allowsUnregisteredDialects(); 45 } 46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 47 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 48 } 49 50 // TODO: expose a cheaper way than constructing + sorting a vector only to take 51 // its size. 52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 53 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 54 } 55 56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 57 MlirStringRef name) { 58 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 59 } 60 61 /* ========================================================================== */ 62 /* Dialect API. */ 63 /* ========================================================================== */ 64 65 MlirContext mlirDialectGetContext(MlirDialect dialect) { 66 return wrap(unwrap(dialect)->getContext()); 67 } 68 69 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 70 return unwrap(dialect1) == unwrap(dialect2); 71 } 72 73 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 74 return wrap(unwrap(dialect)->getNamespace()); 75 } 76 77 /* ========================================================================== */ 78 /* Location API. */ 79 /* ========================================================================== */ 80 81 MlirLocation mlirLocationFileLineColGet(MlirContext context, 82 const char *filename, unsigned line, 83 unsigned col) { 84 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 85 } 86 87 MlirLocation mlirLocationUnknownGet(MlirContext context) { 88 return wrap(UnknownLoc::get(unwrap(context))); 89 } 90 91 MlirContext mlirLocationGetContext(MlirLocation location) { 92 return wrap(unwrap(location).getContext()); 93 } 94 95 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 96 void *userData) { 97 detail::CallbackOstream stream(callback, userData); 98 unwrap(location).print(stream); 99 stream.flush(); 100 } 101 102 /* ========================================================================== */ 103 /* Module API. */ 104 /* ========================================================================== */ 105 106 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 107 return wrap(ModuleOp::create(unwrap(location))); 108 } 109 110 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 111 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 112 if (!owning) 113 return MlirModule{nullptr}; 114 return MlirModule{owning.release().getOperation()}; 115 } 116 117 MlirContext mlirModuleGetContext(MlirModule module) { 118 return wrap(unwrap(module).getContext()); 119 } 120 121 void mlirModuleDestroy(MlirModule module) { 122 // Transfer ownership to an OwningModuleRef so that its destructor is called. 123 OwningModuleRef(unwrap(module)); 124 } 125 126 MlirOperation mlirModuleGetOperation(MlirModule module) { 127 return wrap(unwrap(module).getOperation()); 128 } 129 130 /* ========================================================================== */ 131 /* Operation state API. */ 132 /* ========================================================================== */ 133 134 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 135 MlirOperationState state; 136 state.name = name; 137 state.location = loc; 138 state.nResults = 0; 139 state.results = nullptr; 140 state.nOperands = 0; 141 state.operands = nullptr; 142 state.nRegions = 0; 143 state.regions = nullptr; 144 state.nSuccessors = 0; 145 state.successors = nullptr; 146 state.nAttributes = 0; 147 state.attributes = nullptr; 148 return state; 149 } 150 151 #define APPEND_ELEMS(type, sizeName, elemName) \ 152 state->elemName = \ 153 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 154 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 155 state->sizeName += n; 156 157 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 158 MlirType *results) { 159 APPEND_ELEMS(MlirType, nResults, results); 160 } 161 162 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 163 MlirValue *operands) { 164 APPEND_ELEMS(MlirValue, nOperands, operands); 165 } 166 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 167 MlirRegion *regions) { 168 APPEND_ELEMS(MlirRegion, nRegions, regions); 169 } 170 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 171 MlirBlock *successors) { 172 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 173 } 174 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 175 MlirNamedAttribute *attributes) { 176 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 177 } 178 179 /* ========================================================================== */ 180 /* Operation API. */ 181 /* ========================================================================== */ 182 183 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 184 assert(state); 185 OperationState cppState(unwrap(state->location), state->name); 186 SmallVector<Type, 4> resultStorage; 187 SmallVector<Value, 8> operandStorage; 188 SmallVector<Block *, 2> successorStorage; 189 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 190 cppState.addOperands( 191 unwrapList(state->nOperands, state->operands, operandStorage)); 192 cppState.addSuccessors( 193 unwrapList(state->nSuccessors, state->successors, successorStorage)); 194 195 cppState.attributes.reserve(state->nAttributes); 196 for (intptr_t i = 0; i < state->nAttributes; ++i) 197 cppState.addAttribute(state->attributes[i].name, 198 unwrap(state->attributes[i].attribute)); 199 200 for (intptr_t i = 0; i < state->nRegions; ++i) 201 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 202 203 MlirOperation result = wrap(Operation::create(cppState)); 204 free(state->results); 205 free(state->operands); 206 free(state->successors); 207 free(state->regions); 208 free(state->attributes); 209 return result; 210 } 211 212 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 213 214 int mlirOperationEqual(MlirOperation op, MlirOperation other) { 215 return unwrap(op) == unwrap(other); 216 } 217 218 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 219 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 220 } 221 222 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 223 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 224 } 225 226 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 227 return wrap(unwrap(op)->getNextNode()); 228 } 229 230 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 231 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 232 } 233 234 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 235 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 236 } 237 238 intptr_t mlirOperationGetNumResults(MlirOperation op) { 239 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 240 } 241 242 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 243 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 244 } 245 246 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 247 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 248 } 249 250 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 251 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 252 } 253 254 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 255 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 256 } 257 258 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 259 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 260 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 261 } 262 263 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 264 const char *name) { 265 return wrap(unwrap(op)->getAttr(name)); 266 } 267 268 void mlirOperationSetAttributeByName(MlirOperation op, const char *name, 269 MlirAttribute attr) { 270 unwrap(op)->setAttr(name, unwrap(attr)); 271 } 272 273 int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) { 274 auto removeResult = unwrap(op)->removeAttr(name); 275 return removeResult == MutableDictionaryAttr::RemoveResult::Removed; 276 } 277 278 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 279 void *userData) { 280 detail::CallbackOstream stream(callback, userData); 281 unwrap(op)->print(stream); 282 stream.flush(); 283 } 284 285 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 286 287 /* ========================================================================== */ 288 /* Region API. */ 289 /* ========================================================================== */ 290 291 MlirRegion mlirRegionCreate() { return wrap(new Region); } 292 293 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 294 Region *cppRegion = unwrap(region); 295 if (cppRegion->empty()) 296 return wrap(static_cast<Block *>(nullptr)); 297 return wrap(&cppRegion->front()); 298 } 299 300 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 301 unwrap(region)->push_back(unwrap(block)); 302 } 303 304 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 305 MlirBlock block) { 306 auto &blockList = unwrap(region)->getBlocks(); 307 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 308 } 309 310 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 311 MlirBlock block) { 312 Region *cppRegion = unwrap(region); 313 if (mlirBlockIsNull(reference)) { 314 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 315 return; 316 } 317 318 assert(unwrap(reference)->getParent() == unwrap(region) && 319 "expected reference block to belong to the region"); 320 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 321 unwrap(block)); 322 } 323 324 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 325 MlirBlock block) { 326 if (mlirBlockIsNull(reference)) 327 return mlirRegionAppendOwnedBlock(region, block); 328 329 assert(unwrap(reference)->getParent() == unwrap(region) && 330 "expected reference block to belong to the region"); 331 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 332 unwrap(block)); 333 } 334 335 void mlirRegionDestroy(MlirRegion region) { 336 delete static_cast<Region *>(region.ptr); 337 } 338 339 /* ========================================================================== */ 340 /* Block API. */ 341 /* ========================================================================== */ 342 343 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 344 Block *b = new Block; 345 for (intptr_t i = 0; i < nArgs; ++i) 346 b->addArgument(unwrap(args[i])); 347 return wrap(b); 348 } 349 350 int mlirBlockEqual(MlirBlock block, MlirBlock other) { 351 return unwrap(block) == unwrap(other); 352 } 353 354 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 355 return wrap(unwrap(block)->getNextNode()); 356 } 357 358 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 359 Block *cppBlock = unwrap(block); 360 if (cppBlock->empty()) 361 return wrap(static_cast<Operation *>(nullptr)); 362 return wrap(&cppBlock->front()); 363 } 364 365 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 366 unwrap(block)->push_back(unwrap(operation)); 367 } 368 369 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 370 MlirOperation operation) { 371 auto &opList = unwrap(block)->getOperations(); 372 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 373 } 374 375 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 376 MlirOperation reference, 377 MlirOperation operation) { 378 Block *cppBlock = unwrap(block); 379 if (mlirOperationIsNull(reference)) { 380 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 381 return; 382 } 383 384 assert(unwrap(reference)->getBlock() == unwrap(block) && 385 "expected reference operation to belong to the block"); 386 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 387 unwrap(operation)); 388 } 389 390 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 391 MlirOperation reference, 392 MlirOperation operation) { 393 if (mlirOperationIsNull(reference)) 394 return mlirBlockAppendOwnedOperation(block, operation); 395 396 assert(unwrap(reference)->getBlock() == unwrap(block) && 397 "expected reference operation to belong to the block"); 398 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 399 unwrap(operation)); 400 } 401 402 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 403 404 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 405 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 406 } 407 408 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 409 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 410 } 411 412 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 413 void *userData) { 414 detail::CallbackOstream stream(callback, userData); 415 unwrap(block)->print(stream); 416 stream.flush(); 417 } 418 419 /* ========================================================================== */ 420 /* Value API. */ 421 /* ========================================================================== */ 422 423 int mlirValueIsABlockArgument(MlirValue value) { 424 return unwrap(value).isa<BlockArgument>(); 425 } 426 427 int mlirValueIsAOpResult(MlirValue value) { 428 return unwrap(value).isa<OpResult>(); 429 } 430 431 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 432 return wrap(unwrap(value).cast<BlockArgument>().getOwner()); 433 } 434 435 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 436 return static_cast<intptr_t>( 437 unwrap(value).cast<BlockArgument>().getArgNumber()); 438 } 439 440 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 441 unwrap(value).cast<BlockArgument>().setType(unwrap(type)); 442 } 443 444 MlirOperation mlirOpResultGetOwner(MlirValue value) { 445 return wrap(unwrap(value).cast<OpResult>().getOwner()); 446 } 447 448 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 449 return static_cast<intptr_t>( 450 unwrap(value).cast<OpResult>().getResultNumber()); 451 } 452 453 MlirType mlirValueGetType(MlirValue value) { 454 return wrap(unwrap(value).getType()); 455 } 456 457 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 458 459 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 460 void *userData) { 461 detail::CallbackOstream stream(callback, userData); 462 unwrap(value).print(stream); 463 stream.flush(); 464 } 465 466 /* ========================================================================== */ 467 /* Type API. */ 468 /* ========================================================================== */ 469 470 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 471 return wrap(mlir::parseType(type, unwrap(context))); 472 } 473 474 MlirContext mlirTypeGetContext(MlirType type) { 475 return wrap(unwrap(type).getContext()); 476 } 477 478 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } 479 480 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 481 detail::CallbackOstream stream(callback, userData); 482 unwrap(type).print(stream); 483 stream.flush(); 484 } 485 486 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 487 488 /* ========================================================================== */ 489 /* Attribute API. */ 490 /* ========================================================================== */ 491 492 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 493 return wrap(mlir::parseAttribute(attr, unwrap(context))); 494 } 495 496 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 497 return wrap(unwrap(attribute).getContext()); 498 } 499 500 MlirType mlirAttributeGetType(MlirAttribute attribute) { 501 return wrap(unwrap(attribute).getType()); 502 } 503 504 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 505 return unwrap(a1) == unwrap(a2); 506 } 507 508 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 509 void *userData) { 510 detail::CallbackOstream stream(callback, userData); 511 unwrap(attr).print(stream); 512 stream.flush(); 513 } 514 515 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 516 517 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 518 return MlirNamedAttribute{name, attr}; 519 } 520