1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/Module.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/Types.h" 17 #include "mlir/Parser.h" 18 #include "llvm/Support/raw_ostream.h" 19 20 using namespace mlir; 21 22 /* ========================================================================== */ 23 /* Printing helper. */ 24 /* ========================================================================== */ 25 26 namespace { 27 /// A simple raw ostream subclass that forwards write_impl calls to the 28 /// user-supplied callback together with opaque user-supplied data. 29 class CallbackOstream : public llvm::raw_ostream { 30 public: 31 CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback, 32 void *opaqueData) 33 : callback(callback), opaqueData(opaqueData), pos(0u) {} 34 35 void write_impl(const char *ptr, size_t size) override { 36 callback(ptr, size, opaqueData); 37 pos += size; 38 } 39 40 uint64_t current_pos() const override { return pos; } 41 42 private: 43 std::function<void(const char *, intptr_t, void *)> callback; 44 void *opaqueData; 45 uint64_t pos; 46 }; 47 } // end namespace 48 49 /* ========================================================================== */ 50 /* Context API. */ 51 /* ========================================================================== */ 52 53 MlirContext mlirContextCreate() { 54 auto *context = new MLIRContext(/*loadAllDialects=*/false); 55 return wrap(context); 56 } 57 58 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 59 60 /* ========================================================================== */ 61 /* Location API. */ 62 /* ========================================================================== */ 63 64 MlirLocation mlirLocationFileLineColGet(MlirContext context, 65 const char *filename, unsigned line, 66 unsigned col) { 67 return wrap(FileLineColLoc::get(filename, line, col, unwrap(context))); 68 } 69 70 MlirLocation mlirLocationUnknownGet(MlirContext context) { 71 return wrap(UnknownLoc::get(unwrap(context))); 72 } 73 74 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 75 void *userData) { 76 CallbackOstream stream(callback, userData); 77 unwrap(location).print(stream); 78 stream.flush(); 79 } 80 81 /* ========================================================================== */ 82 /* Module API. */ 83 /* ========================================================================== */ 84 85 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 86 return wrap(ModuleOp::create(unwrap(location))); 87 } 88 89 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { 90 OwningModuleRef owning = parseSourceString(module, unwrap(context)); 91 if (!owning) 92 return MlirModule{nullptr}; 93 return MlirModule{owning.release().getOperation()}; 94 } 95 96 void mlirModuleDestroy(MlirModule module) { 97 // Transfer ownership to an OwningModuleRef so that its destructor is called. 98 OwningModuleRef(unwrap(module)); 99 } 100 101 MlirOperation mlirModuleGetOperation(MlirModule module) { 102 return wrap(unwrap(module).getOperation()); 103 } 104 105 /* ========================================================================== */ 106 /* Operation state API. */ 107 /* ========================================================================== */ 108 109 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) { 110 MlirOperationState state; 111 state.name = name; 112 state.location = loc; 113 state.nResults = 0; 114 state.results = nullptr; 115 state.nOperands = 0; 116 state.operands = nullptr; 117 state.nRegions = 0; 118 state.regions = nullptr; 119 state.nSuccessors = 0; 120 state.successors = nullptr; 121 state.nAttributes = 0; 122 state.attributes = nullptr; 123 return state; 124 } 125 126 #define APPEND_ELEMS(type, sizeName, elemName) \ 127 state->elemName = \ 128 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 129 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 130 state->sizeName += n; 131 132 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 133 MlirType *results) { 134 APPEND_ELEMS(MlirType, nResults, results); 135 } 136 137 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 138 MlirValue *operands) { 139 APPEND_ELEMS(MlirValue, nOperands, operands); 140 } 141 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 142 MlirRegion *regions) { 143 APPEND_ELEMS(MlirRegion, nRegions, regions); 144 } 145 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 146 MlirBlock *successors) { 147 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 148 } 149 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 150 MlirNamedAttribute *attributes) { 151 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 152 } 153 154 /* ========================================================================== */ 155 /* Operation API. */ 156 /* ========================================================================== */ 157 158 MlirOperation mlirOperationCreate(const MlirOperationState *state) { 159 assert(state); 160 OperationState cppState(unwrap(state->location), state->name); 161 SmallVector<Type, 4> resultStorage; 162 SmallVector<Value, 8> operandStorage; 163 SmallVector<Block *, 2> successorStorage; 164 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 165 cppState.addOperands( 166 unwrapList(state->nOperands, state->operands, operandStorage)); 167 cppState.addSuccessors( 168 unwrapList(state->nSuccessors, state->successors, successorStorage)); 169 170 cppState.attributes.reserve(state->nAttributes); 171 for (intptr_t i = 0; i < state->nAttributes; ++i) 172 cppState.addAttribute(state->attributes[i].name, 173 unwrap(state->attributes[i].attribute)); 174 175 for (intptr_t i = 0; i < state->nRegions; ++i) 176 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 177 178 MlirOperation result = wrap(Operation::create(cppState)); 179 free(state->results); 180 free(state->operands); 181 free(state->successors); 182 free(state->regions); 183 free(state->attributes); 184 return result; 185 } 186 187 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 188 189 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } 190 191 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 192 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 193 } 194 195 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 196 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 197 } 198 199 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 200 return wrap(unwrap(op)->getNextNode()); 201 } 202 203 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 204 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 205 } 206 207 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 208 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 209 } 210 211 intptr_t mlirOperationGetNumResults(MlirOperation op) { 212 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 213 } 214 215 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 216 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 217 } 218 219 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 220 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 221 } 222 223 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 224 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 225 } 226 227 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 228 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 229 } 230 231 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 232 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 233 return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)}; 234 } 235 236 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 237 const char *name) { 238 return wrap(unwrap(op)->getAttr(name)); 239 } 240 241 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 242 void *userData) { 243 CallbackOstream stream(callback, userData); 244 unwrap(op)->print(stream); 245 stream.flush(); 246 } 247 248 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 249 250 /* ========================================================================== */ 251 /* Region API. */ 252 /* ========================================================================== */ 253 254 MlirRegion mlirRegionCreate() { return wrap(new Region); } 255 256 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 257 Region *cppRegion = unwrap(region); 258 if (cppRegion->empty()) 259 return wrap(static_cast<Block *>(nullptr)); 260 return wrap(&cppRegion->front()); 261 } 262 263 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 264 unwrap(region)->push_back(unwrap(block)); 265 } 266 267 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 268 MlirBlock block) { 269 auto &blockList = unwrap(region)->getBlocks(); 270 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 271 } 272 273 void mlirRegionDestroy(MlirRegion region) { 274 delete static_cast<Region *>(region.ptr); 275 } 276 277 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } 278 279 /* ========================================================================== */ 280 /* Block API. */ 281 /* ========================================================================== */ 282 283 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) { 284 Block *b = new Block; 285 for (intptr_t i = 0; i < nArgs; ++i) 286 b->addArgument(unwrap(args[i])); 287 return wrap(b); 288 } 289 290 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 291 return wrap(unwrap(block)->getNextNode()); 292 } 293 294 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 295 Block *cppBlock = unwrap(block); 296 if (cppBlock->empty()) 297 return wrap(static_cast<Operation *>(nullptr)); 298 return wrap(&cppBlock->front()); 299 } 300 301 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 302 unwrap(block)->push_back(unwrap(operation)); 303 } 304 305 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 306 MlirOperation operation) { 307 auto &opList = unwrap(block)->getOperations(); 308 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 309 } 310 311 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 312 313 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } 314 315 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 316 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 317 } 318 319 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 320 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 321 } 322 323 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 324 void *userData) { 325 CallbackOstream stream(callback, userData); 326 unwrap(block)->print(stream); 327 stream.flush(); 328 } 329 330 /* ========================================================================== */ 331 /* Value API. */ 332 /* ========================================================================== */ 333 334 MlirType mlirValueGetType(MlirValue value) { 335 return wrap(unwrap(value).getType()); 336 } 337 338 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 339 void *userData) { 340 CallbackOstream stream(callback, userData); 341 unwrap(value).print(stream); 342 stream.flush(); 343 } 344 345 /* ========================================================================== */ 346 /* Type API. */ 347 /* ========================================================================== */ 348 349 MlirType mlirTypeParseGet(MlirContext context, const char *type) { 350 return wrap(mlir::parseType(type, unwrap(context))); 351 } 352 353 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } 354 355 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 356 CallbackOstream stream(callback, userData); 357 unwrap(type).print(stream); 358 stream.flush(); 359 } 360 361 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 362 363 /* ========================================================================== */ 364 /* Attribute API. */ 365 /* ========================================================================== */ 366 367 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) { 368 return wrap(mlir::parseAttribute(attr, unwrap(context))); 369 } 370 371 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 372 return unwrap(a1) == unwrap(a2); 373 } 374 375 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 376 void *userData) { 377 CallbackOstream stream(callback, userData); 378 unwrap(attr).print(stream); 379 stream.flush(); 380 } 381 382 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 383 384 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { 385 return MlirNamedAttribute{name, attr}; 386 } 387