1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/CAPI/Utils.h" 13 #include "mlir/IR/Attributes.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Module.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/Types.h" 18 #include "mlir/Parser.h" 19 20 using namespace mlir; 21 22 /* ========================================================================== */ 23 /* Context API. */ 24 /* ========================================================================== */ 25 26 MlirContext mlirContextCreate() { 27 auto *context = new MLIRContext(/*loadAllDialects=*/false); 28 return wrap(context); 29 } 30 31 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 32 return unwrap(ctx1) == unwrap(ctx2); 33 } 34 35 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 36 37 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) { 38 unwrap(context)->allowUnregisteredDialects(allow); 39 } 40 41 int mlirContextGetAllowUnregisteredDialects(MlirContext context) { 42 return unwrap(context)->allowsUnregisteredDialects(); 43 } 44 45 /* ========================================================================== */ 46 /* Location API. */ 47 /* ========================================================================== */ 48 49 MlirLocation mlirLocationFileLineColGet(MlirContext context, 50 const char *filename, unsigned line, 51 unsigned col) { 52 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 53 } 54 55 MlirLocation mlirLocationUnknownGet(MlirContext context) { 56 return wrap(UnknownLoc::get(unwrap(context))); 57 } 58 59 MlirContext mlirLocationGetContext(MlirLocation location) { 60 return wrap(unwrap(location).getContext()); 61 } 62 63 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 64 void *userData) { 65 detail::CallbackOstream stream(callback, userData); 66 unwrap(location).print(stream); 67 stream.flush(); 68 } 69 70 /* ========================================================================== */ 71 /* Module API. */ 72 /* ========================================================================== */ 73 74 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 75 return wrap(ModuleOp::create(unwrap(location))); 76 } 77 78 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 79 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 80 if (!owning) 81 return MlirModule{nullptr}; 82 return MlirModule{owning.release().getOperation()}; 83 } 84 85 MlirContext mlirModuleGetContext(MlirModule module) { 86 return wrap(unwrap(module).getContext()); 87 } 88 89 void mlirModuleDestroy(MlirModule module) { 90 // Transfer ownership to an OwningModuleRef so that its destructor is called. 91 OwningModuleRef(unwrap(module)); 92 } 93 94 MlirOperation mlirModuleGetOperation(MlirModule module) { 95 return wrap(unwrap(module).getOperation()); 96 } 97 98 /* ========================================================================== */ 99 /* Operation state API. */ 100 /* ========================================================================== */ 101 102 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 103 MlirOperationState state; 104 state.name = name; 105 state.location = loc; 106 state.nResults = 0; 107 state.results = nullptr; 108 state.nOperands = 0; 109 state.operands = nullptr; 110 state.nRegions = 0; 111 state.regions = nullptr; 112 state.nSuccessors = 0; 113 state.successors = nullptr; 114 state.nAttributes = 0; 115 state.attributes = nullptr; 116 return state; 117 } 118 119 #define APPEND_ELEMS(type, sizeName, elemName) \ 120 state->elemName = \ 121 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 122 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 123 state->sizeName += n; 124 125 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 126 MlirType *results) { 127 APPEND_ELEMS(MlirType, nResults, results); 128 } 129 130 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 131 MlirValue *operands) { 132 APPEND_ELEMS(MlirValue, nOperands, operands); 133 } 134 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 135 MlirRegion *regions) { 136 APPEND_ELEMS(MlirRegion, nRegions, regions); 137 } 138 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 139 MlirBlock *successors) { 140 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 141 } 142 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 143 MlirNamedAttribute *attributes) { 144 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 145 } 146 147 /* ========================================================================== */ 148 /* Operation API. */ 149 /* ========================================================================== */ 150 151 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 152 assert(state); 153 OperationState cppState(unwrap(state->location), state->name); 154 SmallVector<Type, 4> resultStorage; 155 SmallVector<Value, 8> operandStorage; 156 SmallVector<Block *, 2> successorStorage; 157 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 158 cppState.addOperands( 159 unwrapList(state->nOperands, state->operands, operandStorage)); 160 cppState.addSuccessors( 161 unwrapList(state->nSuccessors, state->successors, successorStorage)); 162 163 cppState.attributes.reserve(state->nAttributes); 164 for (intptr_t i = 0; i < state->nAttributes; ++i) 165 cppState.addAttribute(state->attributes[i].name, 166 unwrap(state->attributes[i].attribute)); 167 168 for (intptr_t i = 0; i < state->nRegions; ++i) 169 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 170 171 MlirOperation result = wrap(Operation::create(cppState)); 172 free(state->results); 173 free(state->operands); 174 free(state->successors); 175 free(state->regions); 176 free(state->attributes); 177 return result; 178 } 179 180 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 181 182 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } 183 184 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 185 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 186 } 187 188 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 189 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 190 } 191 192 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 193 return wrap(unwrap(op)->getNextNode()); 194 } 195 196 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 197 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 198 } 199 200 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 201 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 202 } 203 204 intptr_t mlirOperationGetNumResults(MlirOperation op) { 205 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 206 } 207 208 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 209 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 210 } 211 212 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 213 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 214 } 215 216 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 217 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 218 } 219 220 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 221 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 222 } 223 224 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 225 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 226 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 227 } 228 229 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 230 const char *name) { 231 return wrap(unwrap(op)->getAttr(name)); 232 } 233 234 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 235 void *userData) { 236 detail::CallbackOstream stream(callback, userData); 237 unwrap(op)->print(stream); 238 stream.flush(); 239 } 240 241 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 242 243 /* ========================================================================== */ 244 /* Region API. */ 245 /* ========================================================================== */ 246 247 MlirRegion mlirRegionCreate() { return wrap(new Region); } 248 249 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 250 Region *cppRegion = unwrap(region); 251 if (cppRegion->empty()) 252 return wrap(static_cast<Block *>(nullptr)); 253 return wrap(&cppRegion->front()); 254 } 255 256 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 257 unwrap(region)->push_back(unwrap(block)); 258 } 259 260 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 261 MlirBlock block) { 262 auto &blockList = unwrap(region)->getBlocks(); 263 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 264 } 265 266 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 267 MlirBlock block) { 268 Region *cppRegion = unwrap(region); 269 if (mlirBlockIsNull(reference)) { 270 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 271 return; 272 } 273 274 assert(unwrap(reference)->getParent() == unwrap(region) && 275 "expected reference block to belong to the region"); 276 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 277 unwrap(block)); 278 } 279 280 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 281 MlirBlock block) { 282 if (mlirBlockIsNull(reference)) 283 return mlirRegionAppendOwnedBlock(region, block); 284 285 assert(unwrap(reference)->getParent() == unwrap(region) && 286 "expected reference block to belong to the region"); 287 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 288 unwrap(block)); 289 } 290 291 void mlirRegionDestroy(MlirRegion region) { 292 delete static_cast<Region *>(region.ptr); 293 } 294 295 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } 296 297 /* ========================================================================== */ 298 /* Block API. */ 299 /* ========================================================================== */ 300 301 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 302 Block *b = new Block; 303 for (intptr_t i = 0; i < nArgs; ++i) 304 b->addArgument(unwrap(args[i])); 305 return wrap(b); 306 } 307 308 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 309 return wrap(unwrap(block)->getNextNode()); 310 } 311 312 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 313 Block *cppBlock = unwrap(block); 314 if (cppBlock->empty()) 315 return wrap(static_cast<Operation *>(nullptr)); 316 return wrap(&cppBlock->front()); 317 } 318 319 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 320 unwrap(block)->push_back(unwrap(operation)); 321 } 322 323 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 324 MlirOperation operation) { 325 auto &opList = unwrap(block)->getOperations(); 326 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 327 } 328 329 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 330 MlirOperation reference, 331 MlirOperation operation) { 332 Block *cppBlock = unwrap(block); 333 if (mlirOperationIsNull(reference)) { 334 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 335 return; 336 } 337 338 assert(unwrap(reference)->getBlock() == unwrap(block) && 339 "expected reference operation to belong to the block"); 340 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 341 unwrap(operation)); 342 } 343 344 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 345 MlirOperation reference, 346 MlirOperation operation) { 347 if (mlirOperationIsNull(reference)) 348 return mlirBlockAppendOwnedOperation(block, operation); 349 350 assert(unwrap(reference)->getBlock() == unwrap(block) && 351 "expected reference operation to belong to the block"); 352 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 353 unwrap(operation)); 354 } 355 356 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 357 358 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } 359 360 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 361 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 362 } 363 364 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 365 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 366 } 367 368 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 369 void *userData) { 370 detail::CallbackOstream stream(callback, userData); 371 unwrap(block)->print(stream); 372 stream.flush(); 373 } 374 375 /* ========================================================================== */ 376 /* Value API. */ 377 /* ========================================================================== */ 378 379 MlirType mlirValueGetType(MlirValue value) { 380 return wrap(unwrap(value).getType()); 381 } 382 383 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 384 void *userData) { 385 detail::CallbackOstream stream(callback, userData); 386 unwrap(value).print(stream); 387 stream.flush(); 388 } 389 390 /* ========================================================================== */ 391 /* Type API. */ 392 /* ========================================================================== */ 393 394 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 395 return wrap(mlir::parseType(type, unwrap(context))); 396 } 397 398 MlirContext mlirTypeGetContext(MlirType type) { 399 return wrap(unwrap(type).getContext()); 400 } 401 402 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } 403 404 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 405 detail::CallbackOstream stream(callback, userData); 406 unwrap(type).print(stream); 407 stream.flush(); 408 } 409 410 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 411 412 /* ========================================================================== */ 413 /* Attribute API. */ 414 /* ========================================================================== */ 415 416 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 417 return wrap(mlir::parseAttribute(attr, unwrap(context))); 418 } 419 420 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 421 return wrap(unwrap(attribute).getContext()); 422 } 423 424 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 425 return unwrap(a1) == unwrap(a2); 426 } 427 428 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 429 void *userData) { 430 detail::CallbackOstream stream(callback, userData); 431 unwrap(attr).print(stream); 432 stream.flush(); 433 } 434 435 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 436 437 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 438 return MlirNamedAttribute{name, attr}; 439 } 440