1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Module.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/Types.h" 17 #include "mlir/Parser.h" 18 #include "llvm/Support/raw_ostream.h" 19 20 using namespace mlir; 21 22 /* ========================================================================== */ 23 /* Printing helper. */ 24 /* ========================================================================== */ 25 26 namespace { 27 /// A simple raw ostream subclass that forwards write_impl calls to the 28 /// user-supplied callback together with opaque user-supplied data. 29 class CallbackOstream : public llvm::raw_ostream { 30 public: 31 CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback, 32 void *opaqueData) 33 : callback(callback), opaqueData(opaqueData), pos(0u) {} 34 35 void write_impl(const char *ptr, size_t size) override { 36 callback(ptr, size, opaqueData); 37 pos += size; 38 } 39 40 uint64_t current_pos() const override { return pos; } 41 42 private: 43 std::function<void(const char *, intptr_t, void *)> callback; 44 void *opaqueData; 45 uint64_t pos; 46 }; 47 } // end namespace 48 49 /* ========================================================================== */ 50 /* Context API. */ 51 /* ========================================================================== */ 52 53 MlirContext mlirContextCreate() { 54 auto *context = new MLIRContext(/*loadAllDialects=*/false); 55 return wrap(context); 56 } 57 58 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 59 return unwrap(ctx1) == unwrap(ctx2); 60 } 61 62 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 63 64 /* ========================================================================== */ 65 /* Location API. */ 66 /* ========================================================================== */ 67 68 MlirLocation mlirLocationFileLineColGet(MlirContext context, 69 const char *filename, unsigned line, 70 unsigned col) { 71 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 72 } 73 74 MlirLocation mlirLocationUnknownGet(MlirContext context) { 75 return wrap(UnknownLoc::get(unwrap(context))); 76 } 77 78 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 79 void *userData) { 80 CallbackOstream stream(callback, userData); 81 unwrap(location).print(stream); 82 stream.flush(); 83 } 84 85 /* ========================================================================== */ 86 /* Module API. */ 87 /* ========================================================================== */ 88 89 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 90 return wrap(ModuleOp::create(unwrap(location))); 91 } 92 93 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 94 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 95 if (!owning) 96 return MlirModule{nullptr}; 97 return MlirModule{owning.release().getOperation()}; 98 } 99 100 void mlirModuleDestroy(MlirModule module) { 101 // Transfer ownership to an OwningModuleRef so that its destructor is called. 102 OwningModuleRef(unwrap(module)); 103 } 104 105 MlirOperation mlirModuleGetOperation(MlirModule module) { 106 return wrap(unwrap(module).getOperation()); 107 } 108 109 /* ========================================================================== */ 110 /* Operation state API. */ 111 /* ========================================================================== */ 112 113 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 114 MlirOperationState state; 115 state.name = name; 116 state.location = loc; 117 state.nResults = 0; 118 state.results = nullptr; 119 state.nOperands = 0; 120 state.operands = nullptr; 121 state.nRegions = 0; 122 state.regions = nullptr; 123 state.nSuccessors = 0; 124 state.successors = nullptr; 125 state.nAttributes = 0; 126 state.attributes = nullptr; 127 return state; 128 } 129 130 #define APPEND_ELEMS(type, sizeName, elemName) \ 131 state->elemName = \ 132 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 133 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 134 state->sizeName += n; 135 136 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 137 MlirType *results) { 138 APPEND_ELEMS(MlirType, nResults, results); 139 } 140 141 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 142 MlirValue *operands) { 143 APPEND_ELEMS(MlirValue, nOperands, operands); 144 } 145 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 146 MlirRegion *regions) { 147 APPEND_ELEMS(MlirRegion, nRegions, regions); 148 } 149 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 150 MlirBlock *successors) { 151 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 152 } 153 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 154 MlirNamedAttribute *attributes) { 155 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 156 } 157 158 /* ========================================================================== */ 159 /* Operation API. */ 160 /* ========================================================================== */ 161 162 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 163 assert(state); 164 OperationState cppState(unwrap(state->location), state->name); 165 SmallVector<Type, 4> resultStorage; 166 SmallVector<Value, 8> operandStorage; 167 SmallVector<Block *, 2> successorStorage; 168 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 169 cppState.addOperands( 170 unwrapList(state->nOperands, state->operands, operandStorage)); 171 cppState.addSuccessors( 172 unwrapList(state->nSuccessors, state->successors, successorStorage)); 173 174 cppState.attributes.reserve(state->nAttributes); 175 for (intptr_t i = 0; i < state->nAttributes; ++i) 176 cppState.addAttribute(state->attributes[i].name, 177 unwrap(state->attributes[i].attribute)); 178 179 for (intptr_t i = 0; i < state->nRegions; ++i) 180 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 181 182 MlirOperation result = wrap(Operation::create(cppState)); 183 free(state->results); 184 free(state->operands); 185 free(state->successors); 186 free(state->regions); 187 free(state->attributes); 188 return result; 189 } 190 191 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 192 193 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } 194 195 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 196 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 197 } 198 199 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 200 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 201 } 202 203 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 204 return wrap(unwrap(op)->getNextNode()); 205 } 206 207 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 208 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 209 } 210 211 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 212 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 213 } 214 215 intptr_t mlirOperationGetNumResults(MlirOperation op) { 216 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 217 } 218 219 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 220 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 221 } 222 223 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 224 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 225 } 226 227 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 228 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 229 } 230 231 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 232 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 233 } 234 235 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 236 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 237 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 238 } 239 240 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 241 const char *name) { 242 return wrap(unwrap(op)->getAttr(name)); 243 } 244 245 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 246 void *userData) { 247 CallbackOstream stream(callback, userData); 248 unwrap(op)->print(stream); 249 stream.flush(); 250 } 251 252 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 253 254 /* ========================================================================== */ 255 /* Region API. */ 256 /* ========================================================================== */ 257 258 MlirRegion mlirRegionCreate() { return wrap(new Region); } 259 260 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 261 Region *cppRegion = unwrap(region); 262 if (cppRegion->empty()) 263 return wrap(static_cast<Block *>(nullptr)); 264 return wrap(&cppRegion->front()); 265 } 266 267 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 268 unwrap(region)->push_back(unwrap(block)); 269 } 270 271 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 272 MlirBlock block) { 273 auto &blockList = unwrap(region)->getBlocks(); 274 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 275 } 276 277 void mlirRegionDestroy(MlirRegion region) { 278 delete static_cast<Region *>(region.ptr); 279 } 280 281 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } 282 283 /* ========================================================================== */ 284 /* Block API. */ 285 /* ========================================================================== */ 286 287 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 288 Block *b = new Block; 289 for (intptr_t i = 0; i < nArgs; ++i) 290 b->addArgument(unwrap(args[i])); 291 return wrap(b); 292 } 293 294 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 295 return wrap(unwrap(block)->getNextNode()); 296 } 297 298 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 299 Block *cppBlock = unwrap(block); 300 if (cppBlock->empty()) 301 return wrap(static_cast<Operation *>(nullptr)); 302 return wrap(&cppBlock->front()); 303 } 304 305 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 306 unwrap(block)->push_back(unwrap(operation)); 307 } 308 309 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 310 MlirOperation operation) { 311 auto &opList = unwrap(block)->getOperations(); 312 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 313 } 314 315 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 316 317 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } 318 319 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 320 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 321 } 322 323 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 324 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 325 } 326 327 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 328 void *userData) { 329 CallbackOstream stream(callback, userData); 330 unwrap(block)->print(stream); 331 stream.flush(); 332 } 333 334 /* ========================================================================== */ 335 /* Value API. */ 336 /* ========================================================================== */ 337 338 MlirType mlirValueGetType(MlirValue value) { 339 return wrap(unwrap(value).getType()); 340 } 341 342 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 343 void *userData) { 344 CallbackOstream stream(callback, userData); 345 unwrap(value).print(stream); 346 stream.flush(); 347 } 348 349 /* ========================================================================== */ 350 /* Type API. */ 351 /* ========================================================================== */ 352 353 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 354 return wrap(mlir::parseType(type, unwrap(context))); 355 } 356 357 MlirContext mlirTypeGetContext(MlirType type) { 358 return wrap(unwrap(type).getContext()); 359 } 360 361 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } 362 363 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 364 CallbackOstream stream(callback, userData); 365 unwrap(type).print(stream); 366 stream.flush(); 367 } 368 369 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 370 371 /* ========================================================================== */ 372 /* Attribute API. */ 373 /* ========================================================================== */ 374 375 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 376 return wrap(mlir::parseAttribute(attr, unwrap(context))); 377 } 378 379 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 380 return unwrap(a1) == unwrap(a2); 381 } 382 383 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 384 void *userData) { 385 CallbackOstream stream(callback, userData); 386 unwrap(attr).print(stream); 387 stream.flush(); 388 } 389 390 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 391 392 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 393 return MlirNamedAttribute{name, attr}; 394 } 395