1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/CAPI/Utils.h" 13 #include "mlir/IR/Attributes.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Module.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/Types.h" 18 #include "mlir/Parser.h" 19 20 using namespace mlir; 21 22 /* ========================================================================== */ 23 /* Context API. */ 24 /* ========================================================================== */ 25 26 MlirContext mlirContextCreate() { 27 auto *context = new MLIRContext(/*loadAllDialects=*/false); 28 return wrap(context); 29 } 30 31 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 32 return unwrap(ctx1) == unwrap(ctx2); 33 } 34 35 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 36 37 /* ========================================================================== */ 38 /* Location API. */ 39 /* ========================================================================== */ 40 41 MlirLocation mlirLocationFileLineColGet(MlirContext context, 42 const char *filename, unsigned line, 43 unsigned col) { 44 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 45 } 46 47 MlirLocation mlirLocationUnknownGet(MlirContext context) { 48 return wrap(UnknownLoc::get(unwrap(context))); 49 } 50 51 MlirContext mlirLocationGetContext(MlirLocation location) { 52 return wrap(unwrap(location).getContext()); 53 } 54 55 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 56 void *userData) { 57 detail::CallbackOstream stream(callback, userData); 58 unwrap(location).print(stream); 59 stream.flush(); 60 } 61 62 /* ========================================================================== */ 63 /* Module API. */ 64 /* ========================================================================== */ 65 66 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 67 return wrap(ModuleOp::create(unwrap(location))); 68 } 69 70 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 71 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 72 if (!owning) 73 return MlirModule{nullptr}; 74 return MlirModule{owning.release().getOperation()}; 75 } 76 77 MlirContext mlirModuleGetContext(MlirModule module) { 78 return wrap(unwrap(module).getContext()); 79 } 80 81 void mlirModuleDestroy(MlirModule module) { 82 // Transfer ownership to an OwningModuleRef so that its destructor is called. 83 OwningModuleRef(unwrap(module)); 84 } 85 86 MlirOperation mlirModuleGetOperation(MlirModule module) { 87 return wrap(unwrap(module).getOperation()); 88 } 89 90 /* ========================================================================== */ 91 /* Operation state API. */ 92 /* ========================================================================== */ 93 94 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 95 MlirOperationState state; 96 state.name = name; 97 state.location = loc; 98 state.nResults = 0; 99 state.results = nullptr; 100 state.nOperands = 0; 101 state.operands = nullptr; 102 state.nRegions = 0; 103 state.regions = nullptr; 104 state.nSuccessors = 0; 105 state.successors = nullptr; 106 state.nAttributes = 0; 107 state.attributes = nullptr; 108 return state; 109 } 110 111 #define APPEND_ELEMS(type, sizeName, elemName) \ 112 state->elemName = \ 113 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 114 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 115 state->sizeName += n; 116 117 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 118 MlirType *results) { 119 APPEND_ELEMS(MlirType, nResults, results); 120 } 121 122 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 123 MlirValue *operands) { 124 APPEND_ELEMS(MlirValue, nOperands, operands); 125 } 126 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 127 MlirRegion *regions) { 128 APPEND_ELEMS(MlirRegion, nRegions, regions); 129 } 130 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 131 MlirBlock *successors) { 132 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 133 } 134 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 135 MlirNamedAttribute *attributes) { 136 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 137 } 138 139 /* ========================================================================== */ 140 /* Operation API. */ 141 /* ========================================================================== */ 142 143 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 144 assert(state); 145 OperationState cppState(unwrap(state->location), state->name); 146 SmallVector<Type, 4> resultStorage; 147 SmallVector<Value, 8> operandStorage; 148 SmallVector<Block *, 2> successorStorage; 149 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 150 cppState.addOperands( 151 unwrapList(state->nOperands, state->operands, operandStorage)); 152 cppState.addSuccessors( 153 unwrapList(state->nSuccessors, state->successors, successorStorage)); 154 155 cppState.attributes.reserve(state->nAttributes); 156 for (intptr_t i = 0; i < state->nAttributes; ++i) 157 cppState.addAttribute(state->attributes[i].name, 158 unwrap(state->attributes[i].attribute)); 159 160 for (intptr_t i = 0; i < state->nRegions; ++i) 161 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 162 163 MlirOperation result = wrap(Operation::create(cppState)); 164 free(state->results); 165 free(state->operands); 166 free(state->successors); 167 free(state->regions); 168 free(state->attributes); 169 return result; 170 } 171 172 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 173 174 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } 175 176 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 177 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 178 } 179 180 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 181 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 182 } 183 184 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 185 return wrap(unwrap(op)->getNextNode()); 186 } 187 188 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 189 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 190 } 191 192 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 193 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 194 } 195 196 intptr_t mlirOperationGetNumResults(MlirOperation op) { 197 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 198 } 199 200 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 201 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 202 } 203 204 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 205 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 206 } 207 208 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 209 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 210 } 211 212 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 213 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 214 } 215 216 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 217 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 218 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 219 } 220 221 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 222 const char *name) { 223 return wrap(unwrap(op)->getAttr(name)); 224 } 225 226 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 227 void *userData) { 228 detail::CallbackOstream stream(callback, userData); 229 unwrap(op)->print(stream); 230 stream.flush(); 231 } 232 233 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 234 235 /* ========================================================================== */ 236 /* Region API. */ 237 /* ========================================================================== */ 238 239 MlirRegion mlirRegionCreate() { return wrap(new Region); } 240 241 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 242 Region *cppRegion = unwrap(region); 243 if (cppRegion->empty()) 244 return wrap(static_cast<Block *>(nullptr)); 245 return wrap(&cppRegion->front()); 246 } 247 248 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 249 unwrap(region)->push_back(unwrap(block)); 250 } 251 252 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 253 MlirBlock block) { 254 auto &blockList = unwrap(region)->getBlocks(); 255 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 256 } 257 258 void mlirRegionDestroy(MlirRegion region) { 259 delete static_cast<Region *>(region.ptr); 260 } 261 262 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } 263 264 /* ========================================================================== */ 265 /* Block API. */ 266 /* ========================================================================== */ 267 268 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 269 Block *b = new Block; 270 for (intptr_t i = 0; i < nArgs; ++i) 271 b->addArgument(unwrap(args[i])); 272 return wrap(b); 273 } 274 275 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 276 return wrap(unwrap(block)->getNextNode()); 277 } 278 279 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 280 Block *cppBlock = unwrap(block); 281 if (cppBlock->empty()) 282 return wrap(static_cast<Operation *>(nullptr)); 283 return wrap(&cppBlock->front()); 284 } 285 286 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 287 unwrap(block)->push_back(unwrap(operation)); 288 } 289 290 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 291 MlirOperation operation) { 292 auto &opList = unwrap(block)->getOperations(); 293 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 294 } 295 296 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 297 298 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } 299 300 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 301 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 302 } 303 304 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 305 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 306 } 307 308 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 309 void *userData) { 310 detail::CallbackOstream stream(callback, userData); 311 unwrap(block)->print(stream); 312 stream.flush(); 313 } 314 315 /* ========================================================================== */ 316 /* Value API. */ 317 /* ========================================================================== */ 318 319 MlirType mlirValueGetType(MlirValue value) { 320 return wrap(unwrap(value).getType()); 321 } 322 323 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 324 void *userData) { 325 detail::CallbackOstream stream(callback, userData); 326 unwrap(value).print(stream); 327 stream.flush(); 328 } 329 330 /* ========================================================================== */ 331 /* Type API. */ 332 /* ========================================================================== */ 333 334 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 335 return wrap(mlir::parseType(type, unwrap(context))); 336 } 337 338 MlirContext mlirTypeGetContext(MlirType type) { 339 return wrap(unwrap(type).getContext()); 340 } 341 342 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } 343 344 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 345 detail::CallbackOstream stream(callback, userData); 346 unwrap(type).print(stream); 347 stream.flush(); 348 } 349 350 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 351 352 /* ========================================================================== */ 353 /* Attribute API. */ 354 /* ========================================================================== */ 355 356 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 357 return wrap(mlir::parseAttribute(attr, unwrap(context))); 358 } 359 360 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 361 return wrap(unwrap(attribute).getContext()); 362 } 363 364 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 365 return unwrap(a1) == unwrap(a2); 366 } 367 368 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 369 void *userData) { 370 detail::CallbackOstream stream(callback, userData); 371 unwrap(attr).print(stream); 372 stream.flush(); 373 } 374 375 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 376 377 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 378 return MlirNamedAttribute{name, attr}; 379 } 380