1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 11 #include "mlir/IR/Attributes.h" 12 #include "mlir/IR/Module.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/IR/Types.h" 15 #include "mlir/Parser.h" 16 #include "llvm/Support/raw_ostream.h" 17 18 using namespace mlir; 19 20 /* ========================================================================== */ 21 /* Definitions of methods for non-owning structures used in C API. */ 22 /* ========================================================================== */ 23 24 #define DEFINE_C_API_PTR_METHODS(name, cpptype) \ 25 static name wrap(cpptype *cpp) { return name{cpp}; } \ 26 static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); } 27 28 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext) 29 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation) 30 DEFINE_C_API_PTR_METHODS(MlirBlock, Block) 31 DEFINE_C_API_PTR_METHODS(MlirRegion, Region) 32 33 #define DEFINE_C_API_METHODS(name, cpptype) \ 34 static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; } \ 35 static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); } 36 37 DEFINE_C_API_METHODS(MlirAttribute, Attribute) 38 DEFINE_C_API_METHODS(MlirLocation, Location); 39 DEFINE_C_API_METHODS(MlirType, Type) 40 DEFINE_C_API_METHODS(MlirValue, Value) 41 DEFINE_C_API_METHODS(MlirModule, ModuleOp) 42 43 template <typename CppTy, typename CTy> 44 static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first, 45 SmallVectorImpl<CppTy> &storage) { 46 static_assert( 47 std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value, 48 "incompatible C and C++ types"); 49 50 if (size == 0) 51 return llvm::None; 52 53 assert(storage.empty() && "expected to populate storage"); 54 storage.reserve(size); 55 for (intptr_t i = 0; i < size; ++i) 56 storage.push_back(unwrap(*(first + i))); 57 return storage; 58 } 59 60 /* ========================================================================== */ 61 /* Printing helper. */ 62 /* ========================================================================== */ 63 64 namespace { 65 /// A simple raw ostream subclass that forwards write_impl calls to the 66 /// user-supplied callback together with opaque user-supplied data. 67 class CallbackOstream : public llvm::raw_ostream { 68 public: 69 CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback, 70 void *opaqueData) 71 : callback(callback), opaqueData(opaqueData), pos(0u) {} 72 73 void write_impl(const char *ptr, size_t size) override { 74 callback(ptr, size, opaqueData); 75 pos += size; 76 } 77 78 uint64_t current_pos() const override { return pos; } 79 80 private: 81 std::function<void(const char *, intptr_t, void *)> callback; 82 void *opaqueData; 83 uint64_t pos; 84 }; 85 } // end namespace 86 87 /* ========================================================================== */ 88 /* Context API. */ 89 /* ========================================================================== */ 90 91 MlirContext mlirContextCreate() { 92 auto *context = new MLIRContext; 93 return wrap(context); 94 } 95 96 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 97 98 /* ========================================================================== */ 99 /* Location API. */ 100 /* ========================================================================== */ 101 102 MlirLocation mlirLocationFileLineColGet(MlirContext context, 103 const char *filename, unsigned line, 104 unsigned col) { 105 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 106 } 107 108 MlirLocation mlirLocationUnknownGet(MlirContext context) { 109 return wrap(UnknownLoc::get(unwrap(context))); 110 } 111 112 void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, 113 void *userData) { 114 CallbackOstream stream(callback, userData); 115 unwrap(location).print(stream); 116 stream.flush(); 117 } 118 119 /* ========================================================================== */ 120 /* Module API. */ 121 /* ========================================================================== */ 122 123 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 124 return wrap(ModuleOp::create(unwrap(location))); 125 } 126 127 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 128 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 129 if (!owning) 130 return MlirModule{nullptr}; 131 return MlirModule{owning.release().getOperation()}; 132 } 133 134 void mlirModuleDestroy(MlirModule module) { 135 // Transfer ownership to an OwningModuleRef so that its destructor is called. 136 OwningModuleRef(unwrap(module)); 137 } 138 139 MlirOperation mlirModuleGetOperation(MlirModule module) { 140 return wrap(unwrap(module).getOperation()); 141 } 142 143 /* ========================================================================== */ 144 /* Operation state API. */ 145 /* ========================================================================== */ 146 147 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 148 MlirOperationState state; 149 state.name = name; 150 state.location = loc; 151 state.nResults = 0; 152 state.results = nullptr; 153 state.nOperands = 0; 154 state.operands = nullptr; 155 state.nRegions = 0; 156 state.regions = nullptr; 157 state.nSuccessors = 0; 158 state.successors = nullptr; 159 state.nAttributes = 0; 160 state.attributes = nullptr; 161 return state; 162 } 163 164 #define APPEND_ELEMS(type, sizeName, elemName) \ 165 state->elemName = \ 166 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 167 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 168 state->sizeName += n; 169 170 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 171 MlirType *results) { 172 APPEND_ELEMS(MlirType, nResults, results); 173 } 174 175 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 176 MlirValue *operands) { 177 APPEND_ELEMS(MlirValue, nOperands, operands); 178 } 179 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 180 MlirRegion *regions) { 181 APPEND_ELEMS(MlirRegion, nRegions, regions); 182 } 183 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 184 MlirBlock *successors) { 185 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 186 } 187 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 188 MlirNamedAttribute *attributes) { 189 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 190 } 191 192 /* ========================================================================== */ 193 /* Operation API. */ 194 /* ========================================================================== */ 195 196 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 197 assert(state); 198 OperationState cppState(unwrap(state->location), state->name); 199 SmallVector<Type, 4> resultStorage; 200 SmallVector<Value, 8> operandStorage; 201 SmallVector<Block *, 2> successorStorage; 202 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 203 cppState.addOperands( 204 unwrapList(state->nOperands, state->operands, operandStorage)); 205 cppState.addSuccessors( 206 unwrapList(state->nSuccessors, state->successors, successorStorage)); 207 208 cppState.attributes.reserve(state->nAttributes); 209 for (intptr_t i = 0; i < state->nAttributes; ++i) 210 cppState.addAttribute(state->attributes[i].name, 211 unwrap(state->attributes[i].attribute)); 212 213 for (intptr_t i = 0; i < state->nRegions; ++i) 214 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 215 216 MlirOperation result = wrap(Operation::create(cppState)); 217 free(state->results); 218 free(state->operands); 219 free(state->successors); 220 free(state->regions); 221 free(state->attributes); 222 return result; 223 } 224 225 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 226 227 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } 228 229 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 230 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 231 } 232 233 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 234 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 235 } 236 237 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 238 return wrap(unwrap(op)->getNextNode()); 239 } 240 241 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 242 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 243 } 244 245 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 246 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 247 } 248 249 intptr_t mlirOperationGetNumResults(MlirOperation op) { 250 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 251 } 252 253 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 254 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 255 } 256 257 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 258 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 259 } 260 261 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 262 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 263 } 264 265 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 266 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 267 } 268 269 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 270 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 271 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 272 } 273 274 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 275 const char *name) { 276 return wrap(unwrap(op)->getAttr(name)); 277 } 278 279 void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, 280 void *userData) { 281 CallbackOstream stream(callback, userData); 282 unwrap(op)->print(stream); 283 stream.flush(); 284 } 285 286 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 287 288 /* ========================================================================== */ 289 /* Region API. */ 290 /* ========================================================================== */ 291 292 MlirRegion mlirRegionCreate() { return wrap(new Region); } 293 294 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 295 Region *cppRegion = unwrap(region); 296 if (cppRegion->empty()) 297 return wrap(static_cast<Block *>(nullptr)); 298 return wrap(&cppRegion->front()); 299 } 300 301 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 302 unwrap(region)->push_back(unwrap(block)); 303 } 304 305 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 306 MlirBlock block) { 307 auto &blockList = unwrap(region)->getBlocks(); 308 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 309 } 310 311 void mlirRegionDestroy(MlirRegion region) { 312 delete static_cast<Region *>(region.ptr); 313 } 314 315 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } 316 317 /* ========================================================================== */ 318 /* Block API. */ 319 /* ========================================================================== */ 320 321 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 322 Block *b = new Block; 323 for (intptr_t i = 0; i < nArgs; ++i) 324 b->addArgument(unwrap(args[i])); 325 return wrap(b); 326 } 327 328 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 329 return wrap(unwrap(block)->getNextNode()); 330 } 331 332 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 333 Block *cppBlock = unwrap(block); 334 if (cppBlock->empty()) 335 return wrap(static_cast<Operation *>(nullptr)); 336 return wrap(&cppBlock->front()); 337 } 338 339 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 340 unwrap(block)->push_back(unwrap(operation)); 341 } 342 343 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 344 MlirOperation operation) { 345 auto &opList = unwrap(block)->getOperations(); 346 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 347 } 348 349 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 350 351 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } 352 353 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 354 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 355 } 356 357 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 358 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 359 } 360 361 void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, 362 void *userData) { 363 CallbackOstream stream(callback, userData); 364 unwrap(block)->print(stream); 365 stream.flush(); 366 } 367 368 /* ========================================================================== */ 369 /* Value API. */ 370 /* ========================================================================== */ 371 372 MlirType mlirValueGetType(MlirValue value) { 373 return wrap(unwrap(value).getType()); 374 } 375 376 void mlirValuePrint(MlirValue value, MlirPrintCallback callback, 377 void *userData) { 378 CallbackOstream stream(callback, userData); 379 unwrap(value).print(stream); 380 stream.flush(); 381 } 382 383 /* ========================================================================== */ 384 /* Type API. */ 385 /* ========================================================================== */ 386 387 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 388 return wrap(mlir::parseType(type, unwrap(context))); 389 } 390 391 void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) { 392 CallbackOstream stream(callback, userData); 393 unwrap(type).print(stream); 394 stream.flush(); 395 } 396 397 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 398 399 /* ========================================================================== */ 400 /* Attribute API. */ 401 /* ========================================================================== */ 402 403 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 404 return wrap(mlir::parseAttribute(attr, unwrap(context))); 405 } 406 407 void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, 408 void *userData) { 409 CallbackOstream stream(callback, userData); 410 unwrap(attr).print(stream); 411 stream.flush(); 412 } 413 414 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 415 416 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 417 return MlirNamedAttribute{name, attr}; 418 } 419