1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 11 #include "mlir/IR/Attributes.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/IR/Module.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/IR/Types.h" 16 #include "mlir/Parser.h" 17 #include "llvm/Support/raw_ostream.h" 18 19 using namespace mlir; 20 21 /* ========================================================================== */ 22 /* Definitions of methods for non-owning structures used in C API. */ 23 /* ========================================================================== */ 24 25 #define DEFINE_C_API_PTR_METHODS(name, cpptype) \ 26 static name wrap(cpptype *cpp) { return name{cpp}; } \ 27 static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); } 28 29 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext) 30 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation) 31 DEFINE_C_API_PTR_METHODS(MlirBlock, Block) 32 DEFINE_C_API_PTR_METHODS(MlirRegion, Region) 33 34 #define DEFINE_C_API_METHODS(name, cpptype) \ 35 static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; } \ 36 static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); } 37 38 DEFINE_C_API_METHODS(MlirAttribute, Attribute) 39 DEFINE_C_API_METHODS(MlirLocation, Location); 40 DEFINE_C_API_METHODS(MlirType, Type) 41 DEFINE_C_API_METHODS(MlirValue, Value) 42 DEFINE_C_API_METHODS(MlirModule, ModuleOp) 43 44 template <typename CppTy, typename CTy> 45 static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first, 46 SmallVectorImpl<CppTy> &storage) { 47 static_assert( 48 std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value, 49 "incompatible C and C++ types"); 50 51 if (size == 0) 52 return llvm::None; 53 54 assert(storage.empty() && "expected to populate storage"); 55 storage.reserve(size); 56 for (intptr_t i = 0; i < size; ++i) 57 storage.push_back(unwrap(*(first + i))); 58 return storage; 59 } 60 61 /* ========================================================================== */ 62 /* Printing helper. */ 63 /* ========================================================================== */ 64 65 namespace { 66 /// A simple raw ostream subclass that forwards write_impl calls to the 67 /// user-supplied callback together with opaque user-supplied data. 68 class CallbackOstream : public llvm::raw_ostream { 69 public: 70 CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback, 71 void *opaqueData) 72 : callback(callback), opaqueData(opaqueData), pos(0u) {} 73 74 void write_impl(const char *ptr, size_t size) override { 75 callback(ptr, size, opaqueData); 76 pos += size; 77 } 78 79 uint64_t current_pos() const override { return pos; } 80 81 private: 82 std::function<void(const char *, intptr_t, void *)> callback; 83 void *opaqueData; 84 uint64_t pos; 85 }; 86 } // end namespace 87 88 /* ========================================================================== */ 89 /* Context API. */ 90 /* ========================================================================== */ 91 92 MlirContext mlirContextCreate() { 93 auto *context = new MLIRContext(false); 94 return wrap(context); 95 } 96 97 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 98 99 void mlirContextLoadAllDialects(MlirContext context) { 100 unwrap(context)->loadAllGloballyRegisteredDialects(); 101 } 102 103 /* ========================================================================== */ 104 /* Location API. */ 105 /* ========================================================================== */ 106 107 MlirLocation mlirLocationFileLineColGet(MlirContext context, 108 const char *filename, unsigned line, 109 unsigned col) { 110 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 111 } 112 113 MlirLocation mlirLocationUnknownGet(MlirContext context) { 114 return wrap(UnknownLoc::get(unwrap(context))); 115 } 116 117 void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, 118 void *userData) { 119 CallbackOstream stream(callback, userData); 120 unwrap(location).print(stream); 121 stream.flush(); 122 } 123 124 /* ========================================================================== */ 125 /* Module API. */ 126 /* ========================================================================== */ 127 128 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 129 return wrap(ModuleOp::create(unwrap(location))); 130 } 131 132 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 133 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 134 return MlirModule{owning.release().getOperation()}; 135 } 136 137 void mlirModuleDestroy(MlirModule module) { 138 // Transfer ownership to an OwningModuleRef so that its destructor is called. 139 OwningModuleRef(unwrap(module)); 140 } 141 142 MlirOperation mlirModuleGetOperation(MlirModule module) { 143 return wrap(unwrap(module).getOperation()); 144 } 145 146 /* ========================================================================== */ 147 /* Operation state API. */ 148 /* ========================================================================== */ 149 150 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 151 MlirOperationState state; 152 state.name = name; 153 state.location = loc; 154 state.nResults = 0; 155 state.results = nullptr; 156 state.nOperands = 0; 157 state.operands = nullptr; 158 state.nRegions = 0; 159 state.regions = nullptr; 160 state.nSuccessors = 0; 161 state.successors = nullptr; 162 state.nAttributes = 0; 163 state.attributes = nullptr; 164 return state; 165 } 166 167 #define APPEND_ELEMS(type, sizeName, elemName) \ 168 state->elemName = \ 169 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 170 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 171 state->sizeName += n; 172 173 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 174 MlirType *results) { 175 APPEND_ELEMS(MlirType, nResults, results); 176 } 177 178 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 179 MlirValue *operands) { 180 APPEND_ELEMS(MlirValue, nOperands, operands); 181 } 182 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 183 MlirRegion *regions) { 184 APPEND_ELEMS(MlirRegion, nRegions, regions); 185 } 186 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 187 MlirBlock *successors) { 188 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 189 } 190 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 191 MlirNamedAttribute *attributes) { 192 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 193 } 194 195 /* ========================================================================== */ 196 /* Operation API. */ 197 /* ========================================================================== */ 198 199 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 200 assert(state); 201 OperationState cppState(unwrap(state->location), state->name); 202 SmallVector<Type, 4> resultStorage; 203 SmallVector<Value, 8> operandStorage; 204 SmallVector<Block *, 2> successorStorage; 205 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 206 cppState.addOperands( 207 unwrapList(state->nOperands, state->operands, operandStorage)); 208 cppState.addSuccessors( 209 unwrapList(state->nSuccessors, state->successors, successorStorage)); 210 211 cppState.attributes.reserve(state->nAttributes); 212 for (intptr_t i = 0; i < state->nAttributes; ++i) 213 cppState.addAttribute(state->attributes[i].name, 214 unwrap(state->attributes[i].attribute)); 215 216 for (intptr_t i = 0; i < state->nRegions; ++i) 217 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 218 219 MlirOperation result = wrap(Operation::create(cppState)); 220 free(state->results); 221 free(state->operands); 222 free(state->successors); 223 free(state->regions); 224 free(state->attributes); 225 return result; 226 } 227 228 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 229 230 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } 231 232 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 233 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 234 } 235 236 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 237 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 238 } 239 240 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 241 return wrap(unwrap(op)->getNextNode()); 242 } 243 244 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 245 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 246 } 247 248 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 249 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 250 } 251 252 intptr_t mlirOperationGetNumResults(MlirOperation op) { 253 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 254 } 255 256 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 257 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 258 } 259 260 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 261 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 262 } 263 264 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 265 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 266 } 267 268 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 269 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 270 } 271 272 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 273 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 274 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 275 } 276 277 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 278 const char *name) { 279 return wrap(unwrap(op)->getAttr(name)); 280 } 281 282 void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, 283 void *userData) { 284 CallbackOstream stream(callback, userData); 285 unwrap(op)->print(stream); 286 stream.flush(); 287 } 288 289 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 290 291 /* ========================================================================== */ 292 /* Region API. */ 293 /* ========================================================================== */ 294 295 MlirRegion mlirRegionCreate() { return wrap(new Region); } 296 297 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 298 Region *cppRegion = unwrap(region); 299 if (cppRegion->empty()) 300 return wrap(static_cast<Block *>(nullptr)); 301 return wrap(&cppRegion->front()); 302 } 303 304 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 305 unwrap(region)->push_back(unwrap(block)); 306 } 307 308 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 309 MlirBlock block) { 310 auto &blockList = unwrap(region)->getBlocks(); 311 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 312 } 313 314 void mlirRegionDestroy(MlirRegion region) { 315 delete static_cast<Region *>(region.ptr); 316 } 317 318 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } 319 320 /* ========================================================================== */ 321 /* Block API. */ 322 /* ========================================================================== */ 323 324 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 325 Block *b = new Block; 326 for (intptr_t i = 0; i < nArgs; ++i) 327 b->addArgument(unwrap(args[i])); 328 return wrap(b); 329 } 330 331 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 332 return wrap(unwrap(block)->getNextNode()); 333 } 334 335 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 336 Block *cppBlock = unwrap(block); 337 if (cppBlock->empty()) 338 return wrap(static_cast<Operation *>(nullptr)); 339 return wrap(&cppBlock->front()); 340 } 341 342 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 343 unwrap(block)->push_back(unwrap(operation)); 344 } 345 346 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 347 MlirOperation operation) { 348 auto &opList = unwrap(block)->getOperations(); 349 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 350 } 351 352 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 353 354 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } 355 356 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 357 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 358 } 359 360 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 361 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 362 } 363 364 void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, 365 void *userData) { 366 CallbackOstream stream(callback, userData); 367 unwrap(block)->print(stream); 368 stream.flush(); 369 } 370 371 /* ========================================================================== */ 372 /* Value API. */ 373 /* ========================================================================== */ 374 375 MlirType mlirValueGetType(MlirValue value) { 376 return wrap(unwrap(value).getType()); 377 } 378 379 void mlirValuePrint(MlirValue value, MlirPrintCallback callback, 380 void *userData) { 381 CallbackOstream stream(callback, userData); 382 unwrap(value).print(stream); 383 stream.flush(); 384 } 385 386 /* ========================================================================== */ 387 /* Type API. */ 388 /* ========================================================================== */ 389 390 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 391 return wrap(mlir::parseType(type, unwrap(context))); 392 } 393 394 void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) { 395 CallbackOstream stream(callback, userData); 396 unwrap(type).print(stream); 397 stream.flush(); 398 } 399 400 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 401 402 /* ========================================================================== */ 403 /* Attribute API. */ 404 /* ========================================================================== */ 405 406 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 407 return wrap(mlir::parseAttribute(attr, unwrap(context))); 408 } 409 410 void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, 411 void *userData) { 412 CallbackOstream stream(callback, userData); 413 unwrap(attr).print(stream); 414 stream.flush(); 415 } 416 417 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 418 419 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 420 return MlirNamedAttribute{name, attr}; 421 } 422