1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Module.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/Parser.h" 21 22 using namespace mlir; 23 24 /* ========================================================================== */ 25 /* Context API. */ 26 /* ========================================================================== */ 27 28 MlirContext mlirContextCreate() { 29 auto *context = new MLIRContext(/*loadAllDialects=*/false); 30 return wrap(context); 31 } 32 33 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 34 return unwrap(ctx1) == unwrap(ctx2); 35 } 36 37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 38 39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) { 40 unwrap(context)->allowUnregisteredDialects(allow); 41 } 42 43 int mlirContextGetAllowUnregisteredDialects(MlirContext context) { 44 return unwrap(context)->allowsUnregisteredDialects(); 45 } 46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 47 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 48 } 49 50 // TODO: expose a cheaper way than constructing + sorting a vector only to take 51 // its size. 52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 53 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 54 } 55 56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 57 MlirStringRef name) { 58 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 59 } 60 61 /* ========================================================================== */ 62 /* Dialect API. */ 63 /* ========================================================================== */ 64 65 MlirContext mlirDialectGetContext(MlirDialect dialect) { 66 return wrap(unwrap(dialect)->getContext()); 67 } 68 69 int mlirDialectIsNull(MlirDialect dialect) { 70 return unwrap(dialect) == nullptr; 71 } 72 73 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 74 return unwrap(dialect1) == unwrap(dialect2); 75 } 76 77 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 78 return wrap(unwrap(dialect)->getNamespace()); 79 } 80 81 /* ========================================================================== */ 82 /* Location API. */ 83 /* ========================================================================== */ 84 85 MlirLocation mlirLocationFileLineColGet(MlirContext context, 86 const char *filename, unsigned line, 87 unsigned col) { 88 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 89 } 90 91 MlirLocation mlirLocationUnknownGet(MlirContext context) { 92 return wrap(UnknownLoc::get(unwrap(context))); 93 } 94 95 MlirContext mlirLocationGetContext(MlirLocation location) { 96 return wrap(unwrap(location).getContext()); 97 } 98 99 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 100 void *userData) { 101 detail::CallbackOstream stream(callback, userData); 102 unwrap(location).print(stream); 103 stream.flush(); 104 } 105 106 /* ========================================================================== */ 107 /* Module API. */ 108 /* ========================================================================== */ 109 110 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 111 return wrap(ModuleOp::create(unwrap(location))); 112 } 113 114 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 115 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 116 if (!owning) 117 return MlirModule{nullptr}; 118 return MlirModule{owning.release().getOperation()}; 119 } 120 121 MlirContext mlirModuleGetContext(MlirModule module) { 122 return wrap(unwrap(module).getContext()); 123 } 124 125 void mlirModuleDestroy(MlirModule module) { 126 // Transfer ownership to an OwningModuleRef so that its destructor is called. 127 OwningModuleRef(unwrap(module)); 128 } 129 130 MlirOperation mlirModuleGetOperation(MlirModule module) { 131 return wrap(unwrap(module).getOperation()); 132 } 133 134 /* ========================================================================== */ 135 /* Operation state API. */ 136 /* ========================================================================== */ 137 138 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 139 MlirOperationState state; 140 state.name = name; 141 state.location = loc; 142 state.nResults = 0; 143 state.results = nullptr; 144 state.nOperands = 0; 145 state.operands = nullptr; 146 state.nRegions = 0; 147 state.regions = nullptr; 148 state.nSuccessors = 0; 149 state.successors = nullptr; 150 state.nAttributes = 0; 151 state.attributes = nullptr; 152 return state; 153 } 154 155 #define APPEND_ELEMS(type, sizeName, elemName) \ 156 state->elemName = \ 157 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 158 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 159 state->sizeName += n; 160 161 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 162 MlirType *results) { 163 APPEND_ELEMS(MlirType, nResults, results); 164 } 165 166 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 167 MlirValue *operands) { 168 APPEND_ELEMS(MlirValue, nOperands, operands); 169 } 170 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 171 MlirRegion *regions) { 172 APPEND_ELEMS(MlirRegion, nRegions, regions); 173 } 174 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 175 MlirBlock *successors) { 176 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 177 } 178 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 179 MlirNamedAttribute *attributes) { 180 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 181 } 182 183 /* ========================================================================== */ 184 /* Operation API. */ 185 /* ========================================================================== */ 186 187 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 188 assert(state); 189 OperationState cppState(unwrap(state->location), state->name); 190 SmallVector<Type, 4> resultStorage; 191 SmallVector<Value, 8> operandStorage; 192 SmallVector<Block *, 2> successorStorage; 193 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 194 cppState.addOperands( 195 unwrapList(state->nOperands, state->operands, operandStorage)); 196 cppState.addSuccessors( 197 unwrapList(state->nSuccessors, state->successors, successorStorage)); 198 199 cppState.attributes.reserve(state->nAttributes); 200 for (intptr_t i = 0; i < state->nAttributes; ++i) 201 cppState.addAttribute(state->attributes[i].name, 202 unwrap(state->attributes[i].attribute)); 203 204 for (intptr_t i = 0; i < state->nRegions; ++i) 205 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 206 207 MlirOperation result = wrap(Operation::create(cppState)); 208 free(state->results); 209 free(state->operands); 210 free(state->successors); 211 free(state->regions); 212 free(state->attributes); 213 return result; 214 } 215 216 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 217 218 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } 219 220 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 221 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 222 } 223 224 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 225 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 226 } 227 228 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 229 return wrap(unwrap(op)->getNextNode()); 230 } 231 232 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 233 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 234 } 235 236 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 237 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 238 } 239 240 intptr_t mlirOperationGetNumResults(MlirOperation op) { 241 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 242 } 243 244 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 245 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 246 } 247 248 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 249 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 250 } 251 252 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 253 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 254 } 255 256 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 257 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 258 } 259 260 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 261 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 262 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 263 } 264 265 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 266 const char *name) { 267 return wrap(unwrap(op)->getAttr(name)); 268 } 269 270 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 271 void *userData) { 272 detail::CallbackOstream stream(callback, userData); 273 unwrap(op)->print(stream); 274 stream.flush(); 275 } 276 277 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 278 279 /* ========================================================================== */ 280 /* Region API. */ 281 /* ========================================================================== */ 282 283 MlirRegion mlirRegionCreate() { return wrap(new Region); } 284 285 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 286 Region *cppRegion = unwrap(region); 287 if (cppRegion->empty()) 288 return wrap(static_cast<Block *>(nullptr)); 289 return wrap(&cppRegion->front()); 290 } 291 292 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 293 unwrap(region)->push_back(unwrap(block)); 294 } 295 296 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 297 MlirBlock block) { 298 auto &blockList = unwrap(region)->getBlocks(); 299 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 300 } 301 302 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 303 MlirBlock block) { 304 Region *cppRegion = unwrap(region); 305 if (mlirBlockIsNull(reference)) { 306 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 307 return; 308 } 309 310 assert(unwrap(reference)->getParent() == unwrap(region) && 311 "expected reference block to belong to the region"); 312 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 313 unwrap(block)); 314 } 315 316 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 317 MlirBlock block) { 318 if (mlirBlockIsNull(reference)) 319 return mlirRegionAppendOwnedBlock(region, block); 320 321 assert(unwrap(reference)->getParent() == unwrap(region) && 322 "expected reference block to belong to the region"); 323 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 324 unwrap(block)); 325 } 326 327 void mlirRegionDestroy(MlirRegion region) { 328 delete static_cast<Region *>(region.ptr); 329 } 330 331 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } 332 333 /* ========================================================================== */ 334 /* Block API. */ 335 /* ========================================================================== */ 336 337 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 338 Block *b = new Block; 339 for (intptr_t i = 0; i < nArgs; ++i) 340 b->addArgument(unwrap(args[i])); 341 return wrap(b); 342 } 343 344 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 345 return wrap(unwrap(block)->getNextNode()); 346 } 347 348 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 349 Block *cppBlock = unwrap(block); 350 if (cppBlock->empty()) 351 return wrap(static_cast<Operation *>(nullptr)); 352 return wrap(&cppBlock->front()); 353 } 354 355 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 356 unwrap(block)->push_back(unwrap(operation)); 357 } 358 359 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 360 MlirOperation operation) { 361 auto &opList = unwrap(block)->getOperations(); 362 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 363 } 364 365 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 366 MlirOperation reference, 367 MlirOperation operation) { 368 Block *cppBlock = unwrap(block); 369 if (mlirOperationIsNull(reference)) { 370 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 371 return; 372 } 373 374 assert(unwrap(reference)->getBlock() == unwrap(block) && 375 "expected reference operation to belong to the block"); 376 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 377 unwrap(operation)); 378 } 379 380 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 381 MlirOperation reference, 382 MlirOperation operation) { 383 if (mlirOperationIsNull(reference)) 384 return mlirBlockAppendOwnedOperation(block, operation); 385 386 assert(unwrap(reference)->getBlock() == unwrap(block) && 387 "expected reference operation to belong to the block"); 388 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 389 unwrap(operation)); 390 } 391 392 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 393 394 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } 395 396 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 397 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 398 } 399 400 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 401 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 402 } 403 404 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 405 void *userData) { 406 detail::CallbackOstream stream(callback, userData); 407 unwrap(block)->print(stream); 408 stream.flush(); 409 } 410 411 /* ========================================================================== */ 412 /* Value API. */ 413 /* ========================================================================== */ 414 415 MlirType mlirValueGetType(MlirValue value) { 416 return wrap(unwrap(value).getType()); 417 } 418 419 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 420 void *userData) { 421 detail::CallbackOstream stream(callback, userData); 422 unwrap(value).print(stream); 423 stream.flush(); 424 } 425 426 /* ========================================================================== */ 427 /* Type API. */ 428 /* ========================================================================== */ 429 430 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 431 return wrap(mlir::parseType(type, unwrap(context))); 432 } 433 434 MlirContext mlirTypeGetContext(MlirType type) { 435 return wrap(unwrap(type).getContext()); 436 } 437 438 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } 439 440 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 441 detail::CallbackOstream stream(callback, userData); 442 unwrap(type).print(stream); 443 stream.flush(); 444 } 445 446 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 447 448 /* ========================================================================== */ 449 /* Attribute API. */ 450 /* ========================================================================== */ 451 452 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 453 return wrap(mlir::parseAttribute(attr, unwrap(context))); 454 } 455 456 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 457 return wrap(unwrap(attribute).getContext()); 458 } 459 460 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 461 return unwrap(a1) == unwrap(a2); 462 } 463 464 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 465 void *userData) { 466 detail::CallbackOstream stream(callback, userData); 467 unwrap(attr).print(stream); 468 stream.flush(); 469 } 470 471 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 472 473 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 474 return MlirNamedAttribute{name, attr}; 475 } 476