1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 11 #include "mlir/IR/Attributes.h" 12 #include "mlir/IR/Module.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/IR/Types.h" 15 #include "mlir/Parser.h" 16 #include "llvm/Support/raw_ostream.h" 17 18 using namespace mlir; 19 20 /* ========================================================================== */ 21 /* Definitions of methods for non-owning structures used in C API. */ 22 /* ========================================================================== */ 23 24 #define DEFINE_C_API_PTR_METHODS(name, cpptype) \ 25 static name wrap(cpptype *cpp) { return name{cpp}; } \ 26 static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); } 27 28 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext) 29 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation) 30 DEFINE_C_API_PTR_METHODS(MlirBlock, Block) 31 DEFINE_C_API_PTR_METHODS(MlirRegion, Region) 32 33 #define DEFINE_C_API_METHODS(name, cpptype) \ 34 static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; } \ 35 static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); } 36 37 DEFINE_C_API_METHODS(MlirAttribute, Attribute) 38 DEFINE_C_API_METHODS(MlirLocation, Location); 39 DEFINE_C_API_METHODS(MlirType, Type) 40 DEFINE_C_API_METHODS(MlirValue, Value) 41 DEFINE_C_API_METHODS(MlirModule, ModuleOp) 42 43 template <typename CppTy, typename CTy> 44 static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first, 45 SmallVectorImpl<CppTy> &storage) { 46 static_assert( 47 std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value, 48 "incompatible C and C++ types"); 49 50 if (size == 0) 51 return llvm::None; 52 53 assert(storage.empty() && "expected to populate storage"); 54 storage.reserve(size); 55 for (intptr_t i = 0; i < size; ++i) 56 storage.push_back(unwrap(*(first + i))); 57 return storage; 58 } 59 60 /* ========================================================================== */ 61 /* Printing helper. */ 62 /* ========================================================================== */ 63 64 namespace { 65 /// A simple raw ostream subclass that forwards write_impl calls to the 66 /// user-supplied callback together with opaque user-supplied data. 67 class CallbackOstream : public llvm::raw_ostream { 68 public: 69 CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback, 70 void *opaqueData) 71 : callback(callback), opaqueData(opaqueData), pos(0u) {} 72 73 void write_impl(const char *ptr, size_t size) override { 74 callback(ptr, size, opaqueData); 75 pos += size; 76 } 77 78 uint64_t current_pos() const override { return pos; } 79 80 private: 81 std::function<void(const char *, intptr_t, void *)> callback; 82 void *opaqueData; 83 uint64_t pos; 84 }; 85 } // end namespace 86 87 /* ========================================================================== */ 88 /* Context API. */ 89 /* ========================================================================== */ 90 91 MlirContext mlirContextCreate() { 92 auto *context = new MLIRContext; 93 return wrap(context); 94 } 95 96 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 97 98 /* ========================================================================== */ 99 /* Location API. */ 100 /* ========================================================================== */ 101 102 MlirLocation mlirLocationFileLineColGet(MlirContext context, 103 const char *filename, unsigned line, 104 unsigned col) { 105 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 106 } 107 108 MlirLocation mlirLocationUnknownGet(MlirContext context) { 109 return wrap(UnknownLoc::get(unwrap(context))); 110 } 111 112 void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, 113 void *userData) { 114 CallbackOstream stream(callback, userData); 115 unwrap(location).print(stream); 116 stream.flush(); 117 } 118 119 /* ========================================================================== */ 120 /* Module API. */ 121 /* ========================================================================== */ 122 123 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 124 return wrap(ModuleOp::create(unwrap(location))); 125 } 126 127 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 128 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 129 return MlirModule{owning.release().getOperation()}; 130 } 131 132 void mlirModuleDestroy(MlirModule module) { 133 // Transfer ownership to an OwningModuleRef so that its destructor is called. 134 OwningModuleRef(unwrap(module)); 135 } 136 137 MlirOperation mlirModuleGetOperation(MlirModule module) { 138 return wrap(unwrap(module).getOperation()); 139 } 140 141 /* ========================================================================== */ 142 /* Operation state API. */ 143 /* ========================================================================== */ 144 145 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 146 MlirOperationState state; 147 state.name = name; 148 state.location = loc; 149 state.nResults = 0; 150 state.results = nullptr; 151 state.nOperands = 0; 152 state.operands = nullptr; 153 state.nRegions = 0; 154 state.regions = nullptr; 155 state.nSuccessors = 0; 156 state.successors = nullptr; 157 state.nAttributes = 0; 158 state.attributes = nullptr; 159 return state; 160 } 161 162 #define APPEND_ELEMS(type, sizeName, elemName) \ 163 state->elemName = \ 164 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 165 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 166 state->sizeName += n; 167 168 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 169 MlirType *results) { 170 APPEND_ELEMS(MlirType, nResults, results); 171 } 172 173 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 174 MlirValue *operands) { 175 APPEND_ELEMS(MlirValue, nOperands, operands); 176 } 177 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 178 MlirRegion *regions) { 179 APPEND_ELEMS(MlirRegion, nRegions, regions); 180 } 181 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 182 MlirBlock *successors) { 183 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 184 } 185 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 186 MlirNamedAttribute *attributes) { 187 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 188 } 189 190 /* ========================================================================== */ 191 /* Operation API. */ 192 /* ========================================================================== */ 193 194 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 195 assert(state); 196 OperationState cppState(unwrap(state->location), state->name); 197 SmallVector<Type, 4> resultStorage; 198 SmallVector<Value, 8> operandStorage; 199 SmallVector<Block *, 2> successorStorage; 200 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 201 cppState.addOperands( 202 unwrapList(state->nOperands, state->operands, operandStorage)); 203 cppState.addSuccessors( 204 unwrapList(state->nSuccessors, state->successors, successorStorage)); 205 206 cppState.attributes.reserve(state->nAttributes); 207 for (intptr_t i = 0; i < state->nAttributes; ++i) 208 cppState.addAttribute(state->attributes[i].name, 209 unwrap(state->attributes[i].attribute)); 210 211 for (intptr_t i = 0; i < state->nRegions; ++i) 212 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 213 214 MlirOperation result = wrap(Operation::create(cppState)); 215 free(state->results); 216 free(state->operands); 217 free(state->successors); 218 free(state->regions); 219 free(state->attributes); 220 return result; 221 } 222 223 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 224 225 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } 226 227 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 228 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 229 } 230 231 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 232 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 233 } 234 235 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 236 return wrap(unwrap(op)->getNextNode()); 237 } 238 239 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 240 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 241 } 242 243 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 244 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 245 } 246 247 intptr_t mlirOperationGetNumResults(MlirOperation op) { 248 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 249 } 250 251 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 252 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 253 } 254 255 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 256 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 257 } 258 259 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 260 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 261 } 262 263 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 264 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 265 } 266 267 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 268 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 269 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 270 } 271 272 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 273 const char *name) { 274 return wrap(unwrap(op)->getAttr(name)); 275 } 276 277 void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, 278 void *userData) { 279 CallbackOstream stream(callback, userData); 280 unwrap(op)->print(stream); 281 stream.flush(); 282 } 283 284 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 285 286 /* ========================================================================== */ 287 /* Region API. */ 288 /* ========================================================================== */ 289 290 MlirRegion mlirRegionCreate() { return wrap(new Region); } 291 292 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 293 Region *cppRegion = unwrap(region); 294 if (cppRegion->empty()) 295 return wrap(static_cast<Block *>(nullptr)); 296 return wrap(&cppRegion->front()); 297 } 298 299 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 300 unwrap(region)->push_back(unwrap(block)); 301 } 302 303 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 304 MlirBlock block) { 305 auto &blockList = unwrap(region)->getBlocks(); 306 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 307 } 308 309 void mlirRegionDestroy(MlirRegion region) { 310 delete static_cast<Region *>(region.ptr); 311 } 312 313 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } 314 315 /* ========================================================================== */ 316 /* Block API. */ 317 /* ========================================================================== */ 318 319 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 320 Block *b = new Block; 321 for (intptr_t i = 0; i < nArgs; ++i) 322 b->addArgument(unwrap(args[i])); 323 return wrap(b); 324 } 325 326 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 327 return wrap(unwrap(block)->getNextNode()); 328 } 329 330 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 331 Block *cppBlock = unwrap(block); 332 if (cppBlock->empty()) 333 return wrap(static_cast<Operation *>(nullptr)); 334 return wrap(&cppBlock->front()); 335 } 336 337 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 338 unwrap(block)->push_back(unwrap(operation)); 339 } 340 341 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 342 MlirOperation operation) { 343 auto &opList = unwrap(block)->getOperations(); 344 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 345 } 346 347 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 348 349 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } 350 351 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 352 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 353 } 354 355 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 356 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 357 } 358 359 void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, 360 void *userData) { 361 CallbackOstream stream(callback, userData); 362 unwrap(block)->print(stream); 363 stream.flush(); 364 } 365 366 /* ========================================================================== */ 367 /* Value API. */ 368 /* ========================================================================== */ 369 370 MlirType mlirValueGetType(MlirValue value) { 371 return wrap(unwrap(value).getType()); 372 } 373 374 void mlirValuePrint(MlirValue value, MlirPrintCallback callback, 375 void *userData) { 376 CallbackOstream stream(callback, userData); 377 unwrap(value).print(stream); 378 stream.flush(); 379 } 380 381 /* ========================================================================== */ 382 /* Type API. */ 383 /* ========================================================================== */ 384 385 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 386 return wrap(mlir::parseType(type, unwrap(context))); 387 } 388 389 void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) { 390 CallbackOstream stream(callback, userData); 391 unwrap(type).print(stream); 392 stream.flush(); 393 } 394 395 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 396 397 /* ========================================================================== */ 398 /* Attribute API. */ 399 /* ========================================================================== */ 400 401 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 402 return wrap(mlir::parseAttribute(attr, unwrap(context))); 403 } 404 405 void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, 406 void *userData) { 407 CallbackOstream stream(callback, userData); 408 unwrap(attr).print(stream); 409 stream.flush(); 410 } 411 412 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 413 414 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 415 return MlirNamedAttribute{name, attr}; 416 } 417