1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Module.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/Parser.h" 21 22 using namespace mlir; 23 24 /* ========================================================================== */ 25 /* Context API. */ 26 /* ========================================================================== */ 27 28 MlirContext mlirContextCreate() { 29 auto *context = new MLIRContext(/*loadAllDialects=*/false); 30 return wrap(context); 31 } 32 33 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 34 return unwrap(ctx1) == unwrap(ctx2); 35 } 36 37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 38 39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) { 40 unwrap(context)->allowUnregisteredDialects(allow); 41 } 42 43 int mlirContextGetAllowUnregisteredDialects(MlirContext context) { 44 return unwrap(context)->allowsUnregisteredDialects(); 45 } 46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 47 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 48 } 49 50 // TODO: expose a cheaper way than constructing + sorting a vector only to take 51 // its size. 52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 53 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 54 } 55 56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 57 MlirStringRef name) { 58 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 59 } 60 61 /* ========================================================================== */ 62 /* Dialect API. */ 63 /* ========================================================================== */ 64 65 MlirContext mlirDialectGetContext(MlirDialect dialect) { 66 return wrap(unwrap(dialect)->getContext()); 67 } 68 69 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 70 return unwrap(dialect1) == unwrap(dialect2); 71 } 72 73 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 74 return wrap(unwrap(dialect)->getNamespace()); 75 } 76 77 /* ========================================================================== */ 78 /* Location API. */ 79 /* ========================================================================== */ 80 81 MlirLocation mlirLocationFileLineColGet(MlirContext context, 82 const char *filename, unsigned line, 83 unsigned col) { 84 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 85 } 86 87 MlirLocation mlirLocationUnknownGet(MlirContext context) { 88 return wrap(UnknownLoc::get(unwrap(context))); 89 } 90 91 MlirContext mlirLocationGetContext(MlirLocation location) { 92 return wrap(unwrap(location).getContext()); 93 } 94 95 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 96 void *userData) { 97 detail::CallbackOstream stream(callback, userData); 98 unwrap(location).print(stream); 99 stream.flush(); 100 } 101 102 /* ========================================================================== */ 103 /* Module API. */ 104 /* ========================================================================== */ 105 106 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 107 return wrap(ModuleOp::create(unwrap(location))); 108 } 109 110 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 111 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 112 if (!owning) 113 return MlirModule{nullptr}; 114 return MlirModule{owning.release().getOperation()}; 115 } 116 117 MlirContext mlirModuleGetContext(MlirModule module) { 118 return wrap(unwrap(module).getContext()); 119 } 120 121 void mlirModuleDestroy(MlirModule module) { 122 // Transfer ownership to an OwningModuleRef so that its destructor is called. 123 OwningModuleRef(unwrap(module)); 124 } 125 126 MlirOperation mlirModuleGetOperation(MlirModule module) { 127 return wrap(unwrap(module).getOperation()); 128 } 129 130 /* ========================================================================== */ 131 /* Operation state API. */ 132 /* ========================================================================== */ 133 134 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 135 MlirOperationState state; 136 state.name = name; 137 state.location = loc; 138 state.nResults = 0; 139 state.results = nullptr; 140 state.nOperands = 0; 141 state.operands = nullptr; 142 state.nRegions = 0; 143 state.regions = nullptr; 144 state.nSuccessors = 0; 145 state.successors = nullptr; 146 state.nAttributes = 0; 147 state.attributes = nullptr; 148 return state; 149 } 150 151 #define APPEND_ELEMS(type, sizeName, elemName) \ 152 state->elemName = \ 153 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 154 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 155 state->sizeName += n; 156 157 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 158 MlirType *results) { 159 APPEND_ELEMS(MlirType, nResults, results); 160 } 161 162 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 163 MlirValue *operands) { 164 APPEND_ELEMS(MlirValue, nOperands, operands); 165 } 166 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 167 MlirRegion *regions) { 168 APPEND_ELEMS(MlirRegion, nRegions, regions); 169 } 170 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 171 MlirBlock *successors) { 172 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 173 } 174 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 175 MlirNamedAttribute *attributes) { 176 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 177 } 178 179 /* ========================================================================== */ 180 /* Operation API. */ 181 /* ========================================================================== */ 182 183 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 184 assert(state); 185 OperationState cppState(unwrap(state->location), state->name); 186 SmallVector<Type, 4> resultStorage; 187 SmallVector<Value, 8> operandStorage; 188 SmallVector<Block *, 2> successorStorage; 189 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 190 cppState.addOperands( 191 unwrapList(state->nOperands, state->operands, operandStorage)); 192 cppState.addSuccessors( 193 unwrapList(state->nSuccessors, state->successors, successorStorage)); 194 195 cppState.attributes.reserve(state->nAttributes); 196 for (intptr_t i = 0; i < state->nAttributes; ++i) 197 cppState.addAttribute(state->attributes[i].name, 198 unwrap(state->attributes[i].attribute)); 199 200 for (intptr_t i = 0; i < state->nRegions; ++i) 201 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 202 203 MlirOperation result = wrap(Operation::create(cppState)); 204 free(state->results); 205 free(state->operands); 206 free(state->successors); 207 free(state->regions); 208 free(state->attributes); 209 return result; 210 } 211 212 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 213 214 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 215 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 216 } 217 218 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 219 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 220 } 221 222 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 223 return wrap(unwrap(op)->getNextNode()); 224 } 225 226 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 227 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 228 } 229 230 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 231 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 232 } 233 234 intptr_t mlirOperationGetNumResults(MlirOperation op) { 235 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 236 } 237 238 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 239 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 240 } 241 242 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 243 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 244 } 245 246 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 247 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 248 } 249 250 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 251 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 252 } 253 254 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 255 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 256 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 257 } 258 259 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 260 const char *name) { 261 return wrap(unwrap(op)->getAttr(name)); 262 } 263 264 void mlirOperationSetAttributeByName(MlirOperation op, const char *name, 265 MlirAttribute attr) { 266 unwrap(op)->setAttr(name, unwrap(attr)); 267 } 268 269 int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) { 270 auto removeResult = unwrap(op)->removeAttr(name); 271 return removeResult == MutableDictionaryAttr::RemoveResult::Removed; 272 } 273 274 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 275 void *userData) { 276 detail::CallbackOstream stream(callback, userData); 277 unwrap(op)->print(stream); 278 stream.flush(); 279 } 280 281 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 282 283 /* ========================================================================== */ 284 /* Region API. */ 285 /* ========================================================================== */ 286 287 MlirRegion mlirRegionCreate() { return wrap(new Region); } 288 289 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 290 Region *cppRegion = unwrap(region); 291 if (cppRegion->empty()) 292 return wrap(static_cast<Block *>(nullptr)); 293 return wrap(&cppRegion->front()); 294 } 295 296 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 297 unwrap(region)->push_back(unwrap(block)); 298 } 299 300 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 301 MlirBlock block) { 302 auto &blockList = unwrap(region)->getBlocks(); 303 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 304 } 305 306 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 307 MlirBlock block) { 308 Region *cppRegion = unwrap(region); 309 if (mlirBlockIsNull(reference)) { 310 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 311 return; 312 } 313 314 assert(unwrap(reference)->getParent() == unwrap(region) && 315 "expected reference block to belong to the region"); 316 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 317 unwrap(block)); 318 } 319 320 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 321 MlirBlock block) { 322 if (mlirBlockIsNull(reference)) 323 return mlirRegionAppendOwnedBlock(region, block); 324 325 assert(unwrap(reference)->getParent() == unwrap(region) && 326 "expected reference block to belong to the region"); 327 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 328 unwrap(block)); 329 } 330 331 void mlirRegionDestroy(MlirRegion region) { 332 delete static_cast<Region *>(region.ptr); 333 } 334 335 /* ========================================================================== */ 336 /* Block API. */ 337 /* ========================================================================== */ 338 339 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 340 Block *b = new Block; 341 for (intptr_t i = 0; i < nArgs; ++i) 342 b->addArgument(unwrap(args[i])); 343 return wrap(b); 344 } 345 346 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 347 return wrap(unwrap(block)->getNextNode()); 348 } 349 350 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 351 Block *cppBlock = unwrap(block); 352 if (cppBlock->empty()) 353 return wrap(static_cast<Operation *>(nullptr)); 354 return wrap(&cppBlock->front()); 355 } 356 357 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 358 unwrap(block)->push_back(unwrap(operation)); 359 } 360 361 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 362 MlirOperation operation) { 363 auto &opList = unwrap(block)->getOperations(); 364 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 365 } 366 367 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 368 MlirOperation reference, 369 MlirOperation operation) { 370 Block *cppBlock = unwrap(block); 371 if (mlirOperationIsNull(reference)) { 372 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 373 return; 374 } 375 376 assert(unwrap(reference)->getBlock() == unwrap(block) && 377 "expected reference operation to belong to the block"); 378 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 379 unwrap(operation)); 380 } 381 382 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 383 MlirOperation reference, 384 MlirOperation operation) { 385 if (mlirOperationIsNull(reference)) 386 return mlirBlockAppendOwnedOperation(block, operation); 387 388 assert(unwrap(reference)->getBlock() == unwrap(block) && 389 "expected reference operation to belong to the block"); 390 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 391 unwrap(operation)); 392 } 393 394 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 395 396 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 397 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 398 } 399 400 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 401 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 402 } 403 404 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 405 void *userData) { 406 detail::CallbackOstream stream(callback, userData); 407 unwrap(block)->print(stream); 408 stream.flush(); 409 } 410 411 /* ========================================================================== */ 412 /* Value API. */ 413 /* ========================================================================== */ 414 415 MlirType mlirValueGetType(MlirValue value) { 416 return wrap(unwrap(value).getType()); 417 } 418 419 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 420 void *userData) { 421 detail::CallbackOstream stream(callback, userData); 422 unwrap(value).print(stream); 423 stream.flush(); 424 } 425 426 /* ========================================================================== */ 427 /* Type API. */ 428 /* ========================================================================== */ 429 430 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 431 return wrap(mlir::parseType(type, unwrap(context))); 432 } 433 434 MlirContext mlirTypeGetContext(MlirType type) { 435 return wrap(unwrap(type).getContext()); 436 } 437 438 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } 439 440 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 441 detail::CallbackOstream stream(callback, userData); 442 unwrap(type).print(stream); 443 stream.flush(); 444 } 445 446 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 447 448 /* ========================================================================== */ 449 /* Attribute API. */ 450 /* ========================================================================== */ 451 452 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 453 return wrap(mlir::parseAttribute(attr, unwrap(context))); 454 } 455 456 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 457 return wrap(unwrap(attribute).getContext()); 458 } 459 460 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 461 return unwrap(a1) == unwrap(a2); 462 } 463 464 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 465 void *userData) { 466 detail::CallbackOstream stream(callback, userData); 467 unwrap(attr).print(stream); 468 stream.flush(); 469 } 470 471 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 472 473 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 474 return MlirNamedAttribute{name, attr}; 475 } 476