1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11
12 #include "mlir/AsmParser/AsmParser.h"
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/CAPI/Utils.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/Types.h"
22 #include "mlir/IR/Verifier.h"
23 #include "mlir/Interfaces/InferTypeOpInterface.h"
24 #include "mlir/Parser/Parser.h"
25
26 #include "llvm/Support/Debug.h"
27 #include <cstddef>
28
29 using namespace mlir;
30
31 //===----------------------------------------------------------------------===//
32 // Context API.
33 //===----------------------------------------------------------------------===//
34
mlirContextCreate()35 MlirContext mlirContextCreate() {
36 auto *context = new MLIRContext;
37 return wrap(context);
38 }
39
mlirContextEqual(MlirContext ctx1,MlirContext ctx2)40 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
41 return unwrap(ctx1) == unwrap(ctx2);
42 }
43
mlirContextDestroy(MlirContext context)44 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
45
mlirContextSetAllowUnregisteredDialects(MlirContext context,bool allow)46 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
47 unwrap(context)->allowUnregisteredDialects(allow);
48 }
49
mlirContextGetAllowUnregisteredDialects(MlirContext context)50 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
51 return unwrap(context)->allowsUnregisteredDialects();
52 }
mlirContextGetNumRegisteredDialects(MlirContext context)53 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
54 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
55 }
56
mlirContextAppendDialectRegistry(MlirContext ctx,MlirDialectRegistry registry)57 void mlirContextAppendDialectRegistry(MlirContext ctx,
58 MlirDialectRegistry registry) {
59 unwrap(ctx)->appendDialectRegistry(*unwrap(registry));
60 }
61
62 // TODO: expose a cheaper way than constructing + sorting a vector only to take
63 // its size.
mlirContextGetNumLoadedDialects(MlirContext context)64 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
65 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
66 }
67
mlirContextGetOrLoadDialect(MlirContext context,MlirStringRef name)68 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
69 MlirStringRef name) {
70 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
71 }
72
mlirContextIsRegisteredOperation(MlirContext context,MlirStringRef name)73 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
74 return unwrap(context)->isOperationRegistered(unwrap(name));
75 }
76
mlirContextEnableMultithreading(MlirContext context,bool enable)77 void mlirContextEnableMultithreading(MlirContext context, bool enable) {
78 return unwrap(context)->enableMultithreading(enable);
79 }
80
mlirContextLoadAllAvailableDialects(MlirContext context)81 void mlirContextLoadAllAvailableDialects(MlirContext context) {
82 unwrap(context)->loadAllAvailableDialects();
83 }
84
85 //===----------------------------------------------------------------------===//
86 // Dialect API.
87 //===----------------------------------------------------------------------===//
88
mlirDialectGetContext(MlirDialect dialect)89 MlirContext mlirDialectGetContext(MlirDialect dialect) {
90 return wrap(unwrap(dialect)->getContext());
91 }
92
mlirDialectEqual(MlirDialect dialect1,MlirDialect dialect2)93 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
94 return unwrap(dialect1) == unwrap(dialect2);
95 }
96
mlirDialectGetNamespace(MlirDialect dialect)97 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
98 return wrap(unwrap(dialect)->getNamespace());
99 }
100
101 //===----------------------------------------------------------------------===//
102 // DialectRegistry API.
103 //===----------------------------------------------------------------------===//
104
mlirDialectRegistryCreate()105 MlirDialectRegistry mlirDialectRegistryCreate() {
106 return wrap(new DialectRegistry());
107 }
108
mlirDialectRegistryDestroy(MlirDialectRegistry registry)109 void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
110 delete unwrap(registry);
111 }
112
113 //===----------------------------------------------------------------------===//
114 // Printing flags API.
115 //===----------------------------------------------------------------------===//
116
mlirOpPrintingFlagsCreate()117 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
118 return wrap(new OpPrintingFlags());
119 }
120
mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags)121 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
122 delete unwrap(flags);
123 }
124
mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,intptr_t largeElementLimit)125 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
126 intptr_t largeElementLimit) {
127 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
128 }
129
mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,bool prettyForm)130 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
131 bool prettyForm) {
132 unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
133 }
134
mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags)135 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
136 unwrap(flags)->printGenericOpForm();
137 }
138
mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags)139 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
140 unwrap(flags)->useLocalScope();
141 }
142
143 //===----------------------------------------------------------------------===//
144 // Location API.
145 //===----------------------------------------------------------------------===//
146
mlirLocationFileLineColGet(MlirContext context,MlirStringRef filename,unsigned line,unsigned col)147 MlirLocation mlirLocationFileLineColGet(MlirContext context,
148 MlirStringRef filename, unsigned line,
149 unsigned col) {
150 return wrap(Location(
151 FileLineColLoc::get(unwrap(context), unwrap(filename), line, col)));
152 }
153
mlirLocationCallSiteGet(MlirLocation callee,MlirLocation caller)154 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
155 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
156 }
157
mlirLocationFusedGet(MlirContext ctx,intptr_t nLocations,MlirLocation const * locations,MlirAttribute metadata)158 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
159 MlirLocation const *locations,
160 MlirAttribute metadata) {
161 SmallVector<Location, 4> locs;
162 ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs);
163 return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
164 }
165
mlirLocationNameGet(MlirContext context,MlirStringRef name,MlirLocation childLoc)166 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
167 MlirLocation childLoc) {
168 if (mlirLocationIsNull(childLoc))
169 return wrap(
170 Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name)))));
171 return wrap(Location(NameLoc::get(
172 StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
173 }
174
mlirLocationUnknownGet(MlirContext context)175 MlirLocation mlirLocationUnknownGet(MlirContext context) {
176 return wrap(Location(UnknownLoc::get(unwrap(context))));
177 }
178
mlirLocationEqual(MlirLocation l1,MlirLocation l2)179 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
180 return unwrap(l1) == unwrap(l2);
181 }
182
mlirLocationGetContext(MlirLocation location)183 MlirContext mlirLocationGetContext(MlirLocation location) {
184 return wrap(unwrap(location).getContext());
185 }
186
mlirLocationPrint(MlirLocation location,MlirStringCallback callback,void * userData)187 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
188 void *userData) {
189 detail::CallbackOstream stream(callback, userData);
190 unwrap(location).print(stream);
191 }
192
193 //===----------------------------------------------------------------------===//
194 // Module API.
195 //===----------------------------------------------------------------------===//
196
mlirModuleCreateEmpty(MlirLocation location)197 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
198 return wrap(ModuleOp::create(unwrap(location)));
199 }
200
mlirModuleCreateParse(MlirContext context,MlirStringRef module)201 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
202 OwningOpRef<ModuleOp> owning =
203 parseSourceString<ModuleOp>(unwrap(module), unwrap(context));
204 if (!owning)
205 return MlirModule{nullptr};
206 return MlirModule{owning.release().getOperation()};
207 }
208
mlirModuleGetContext(MlirModule module)209 MlirContext mlirModuleGetContext(MlirModule module) {
210 return wrap(unwrap(module).getContext());
211 }
212
mlirModuleGetBody(MlirModule module)213 MlirBlock mlirModuleGetBody(MlirModule module) {
214 return wrap(unwrap(module).getBody());
215 }
216
mlirModuleDestroy(MlirModule module)217 void mlirModuleDestroy(MlirModule module) {
218 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is
219 // called.
220 OwningOpRef<ModuleOp>(unwrap(module));
221 }
222
mlirModuleGetOperation(MlirModule module)223 MlirOperation mlirModuleGetOperation(MlirModule module) {
224 return wrap(unwrap(module).getOperation());
225 }
226
mlirModuleFromOperation(MlirOperation op)227 MlirModule mlirModuleFromOperation(MlirOperation op) {
228 return wrap(dyn_cast<ModuleOp>(unwrap(op)));
229 }
230
231 //===----------------------------------------------------------------------===//
232 // Operation state API.
233 //===----------------------------------------------------------------------===//
234
mlirOperationStateGet(MlirStringRef name,MlirLocation loc)235 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
236 MlirOperationState state;
237 state.name = name;
238 state.location = loc;
239 state.nResults = 0;
240 state.results = nullptr;
241 state.nOperands = 0;
242 state.operands = nullptr;
243 state.nRegions = 0;
244 state.regions = nullptr;
245 state.nSuccessors = 0;
246 state.successors = nullptr;
247 state.nAttributes = 0;
248 state.attributes = nullptr;
249 state.enableResultTypeInference = false;
250 return state;
251 }
252
253 #define APPEND_ELEMS(type, sizeName, elemName) \
254 state->elemName = \
255 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \
256 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \
257 state->sizeName += n;
258
mlirOperationStateAddResults(MlirOperationState * state,intptr_t n,MlirType const * results)259 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
260 MlirType const *results) {
261 APPEND_ELEMS(MlirType, nResults, results);
262 }
263
mlirOperationStateAddOperands(MlirOperationState * state,intptr_t n,MlirValue const * operands)264 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
265 MlirValue const *operands) {
266 APPEND_ELEMS(MlirValue, nOperands, operands);
267 }
mlirOperationStateAddOwnedRegions(MlirOperationState * state,intptr_t n,MlirRegion const * regions)268 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
269 MlirRegion const *regions) {
270 APPEND_ELEMS(MlirRegion, nRegions, regions);
271 }
mlirOperationStateAddSuccessors(MlirOperationState * state,intptr_t n,MlirBlock const * successors)272 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
273 MlirBlock const *successors) {
274 APPEND_ELEMS(MlirBlock, nSuccessors, successors);
275 }
mlirOperationStateAddAttributes(MlirOperationState * state,intptr_t n,MlirNamedAttribute const * attributes)276 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
277 MlirNamedAttribute const *attributes) {
278 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
279 }
280
mlirOperationStateEnableResultTypeInference(MlirOperationState * state)281 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
282 state->enableResultTypeInference = true;
283 }
284
285 //===----------------------------------------------------------------------===//
286 // Operation API.
287 //===----------------------------------------------------------------------===//
288
inferOperationTypes(OperationState & state)289 static LogicalResult inferOperationTypes(OperationState &state) {
290 MLIRContext *context = state.getContext();
291 Optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
292 if (!info) {
293 emitError(state.location)
294 << "type inference was requested for the operation " << state.name
295 << ", but the operation was not registered. Ensure that the dialect "
296 "containing the operation is linked into MLIR and registered with "
297 "the context";
298 return failure();
299 }
300
301 // Fallback to inference via an op interface.
302 auto *inferInterface = info->getInterface<InferTypeOpInterface>();
303 if (!inferInterface) {
304 emitError(state.location)
305 << "type inference was requested for the operation " << state.name
306 << ", but the operation does not support type inference. Result "
307 "types must be specified explicitly.";
308 return failure();
309 }
310
311 if (succeeded(inferInterface->inferReturnTypes(
312 context, state.location, state.operands,
313 state.attributes.getDictionary(context), state.regions, state.types)))
314 return success();
315
316 // Diagnostic emitted by interface.
317 return failure();
318 }
319
mlirOperationCreate(MlirOperationState * state)320 MlirOperation mlirOperationCreate(MlirOperationState *state) {
321 assert(state);
322 OperationState cppState(unwrap(state->location), unwrap(state->name));
323 SmallVector<Type, 4> resultStorage;
324 SmallVector<Value, 8> operandStorage;
325 SmallVector<Block *, 2> successorStorage;
326 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
327 cppState.addOperands(
328 unwrapList(state->nOperands, state->operands, operandStorage));
329 cppState.addSuccessors(
330 unwrapList(state->nSuccessors, state->successors, successorStorage));
331
332 cppState.attributes.reserve(state->nAttributes);
333 for (intptr_t i = 0; i < state->nAttributes; ++i)
334 cppState.addAttribute(unwrap(state->attributes[i].name),
335 unwrap(state->attributes[i].attribute));
336
337 for (intptr_t i = 0; i < state->nRegions; ++i)
338 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
339
340 free(state->results);
341 free(state->operands);
342 free(state->successors);
343 free(state->regions);
344 free(state->attributes);
345
346 // Infer result types.
347 if (state->enableResultTypeInference) {
348 assert(cppState.types.empty() &&
349 "result type inference enabled and result types provided");
350 if (failed(inferOperationTypes(cppState)))
351 return {nullptr};
352 }
353
354 MlirOperation result = wrap(Operation::create(cppState));
355 return result;
356 }
357
mlirOperationClone(MlirOperation op)358 MlirOperation mlirOperationClone(MlirOperation op) {
359 return wrap(unwrap(op)->clone());
360 }
361
mlirOperationDestroy(MlirOperation op)362 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
363
mlirOperationRemoveFromParent(MlirOperation op)364 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
365
mlirOperationEqual(MlirOperation op,MlirOperation other)366 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
367 return unwrap(op) == unwrap(other);
368 }
369
mlirOperationGetContext(MlirOperation op)370 MlirContext mlirOperationGetContext(MlirOperation op) {
371 return wrap(unwrap(op)->getContext());
372 }
373
mlirOperationGetLocation(MlirOperation op)374 MlirLocation mlirOperationGetLocation(MlirOperation op) {
375 return wrap(unwrap(op)->getLoc());
376 }
377
mlirOperationGetTypeID(MlirOperation op)378 MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
379 if (auto info = unwrap(op)->getRegisteredInfo())
380 return wrap(info->getTypeID());
381 return {nullptr};
382 }
383
mlirOperationGetName(MlirOperation op)384 MlirIdentifier mlirOperationGetName(MlirOperation op) {
385 return wrap(unwrap(op)->getName().getIdentifier());
386 }
387
mlirOperationGetBlock(MlirOperation op)388 MlirBlock mlirOperationGetBlock(MlirOperation op) {
389 return wrap(unwrap(op)->getBlock());
390 }
391
mlirOperationGetParentOperation(MlirOperation op)392 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
393 return wrap(unwrap(op)->getParentOp());
394 }
395
mlirOperationGetNumRegions(MlirOperation op)396 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
397 return static_cast<intptr_t>(unwrap(op)->getNumRegions());
398 }
399
mlirOperationGetRegion(MlirOperation op,intptr_t pos)400 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
401 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
402 }
403
mlirOperationGetFirstRegion(MlirOperation op)404 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
405 Operation *cppOp = unwrap(op);
406 if (cppOp->getNumRegions() == 0)
407 return wrap(static_cast<Region *>(nullptr));
408 return wrap(&cppOp->getRegion(0));
409 }
410
mlirRegionGetNextInOperation(MlirRegion region)411 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
412 Region *cppRegion = unwrap(region);
413 Operation *parent = cppRegion->getParentOp();
414 intptr_t next = cppRegion->getRegionNumber() + 1;
415 if (parent->getNumRegions() > next)
416 return wrap(&parent->getRegion(next));
417 return wrap(static_cast<Region *>(nullptr));
418 }
419
mlirOperationGetNextInBlock(MlirOperation op)420 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
421 return wrap(unwrap(op)->getNextNode());
422 }
423
mlirOperationGetNumOperands(MlirOperation op)424 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
425 return static_cast<intptr_t>(unwrap(op)->getNumOperands());
426 }
427
mlirOperationGetOperand(MlirOperation op,intptr_t pos)428 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
429 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
430 }
431
mlirOperationSetOperand(MlirOperation op,intptr_t pos,MlirValue newValue)432 void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
433 MlirValue newValue) {
434 unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
435 }
436
mlirOperationGetNumResults(MlirOperation op)437 intptr_t mlirOperationGetNumResults(MlirOperation op) {
438 return static_cast<intptr_t>(unwrap(op)->getNumResults());
439 }
440
mlirOperationGetResult(MlirOperation op,intptr_t pos)441 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
442 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
443 }
444
mlirOperationGetNumSuccessors(MlirOperation op)445 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
446 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
447 }
448
mlirOperationGetSuccessor(MlirOperation op,intptr_t pos)449 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
450 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
451 }
452
mlirOperationGetNumAttributes(MlirOperation op)453 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
454 return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
455 }
456
mlirOperationGetAttribute(MlirOperation op,intptr_t pos)457 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
458 NamedAttribute attr = unwrap(op)->getAttrs()[pos];
459 return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
460 }
461
mlirOperationGetAttributeByName(MlirOperation op,MlirStringRef name)462 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
463 MlirStringRef name) {
464 return wrap(unwrap(op)->getAttr(unwrap(name)));
465 }
466
mlirOperationSetAttributeByName(MlirOperation op,MlirStringRef name,MlirAttribute attr)467 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
468 MlirAttribute attr) {
469 unwrap(op)->setAttr(unwrap(name), unwrap(attr));
470 }
471
mlirOperationRemoveAttributeByName(MlirOperation op,MlirStringRef name)472 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
473 return !!unwrap(op)->removeAttr(unwrap(name));
474 }
475
mlirOperationPrint(MlirOperation op,MlirStringCallback callback,void * userData)476 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
477 void *userData) {
478 detail::CallbackOstream stream(callback, userData);
479 unwrap(op)->print(stream);
480 }
481
mlirOperationPrintWithFlags(MlirOperation op,MlirOpPrintingFlags flags,MlirStringCallback callback,void * userData)482 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
483 MlirStringCallback callback, void *userData) {
484 detail::CallbackOstream stream(callback, userData);
485 unwrap(op)->print(stream, *unwrap(flags));
486 }
487
mlirOperationDump(MlirOperation op)488 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
489
mlirOperationVerify(MlirOperation op)490 bool mlirOperationVerify(MlirOperation op) {
491 return succeeded(verify(unwrap(op)));
492 }
493
mlirOperationMoveAfter(MlirOperation op,MlirOperation other)494 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
495 return unwrap(op)->moveAfter(unwrap(other));
496 }
497
mlirOperationMoveBefore(MlirOperation op,MlirOperation other)498 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
499 return unwrap(op)->moveBefore(unwrap(other));
500 }
501
502 //===----------------------------------------------------------------------===//
503 // Region API.
504 //===----------------------------------------------------------------------===//
505
mlirRegionCreate()506 MlirRegion mlirRegionCreate() { return wrap(new Region); }
507
mlirRegionEqual(MlirRegion region,MlirRegion other)508 bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
509 return unwrap(region) == unwrap(other);
510 }
511
mlirRegionGetFirstBlock(MlirRegion region)512 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
513 Region *cppRegion = unwrap(region);
514 if (cppRegion->empty())
515 return wrap(static_cast<Block *>(nullptr));
516 return wrap(&cppRegion->front());
517 }
518
mlirRegionAppendOwnedBlock(MlirRegion region,MlirBlock block)519 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
520 unwrap(region)->push_back(unwrap(block));
521 }
522
mlirRegionInsertOwnedBlock(MlirRegion region,intptr_t pos,MlirBlock block)523 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
524 MlirBlock block) {
525 auto &blockList = unwrap(region)->getBlocks();
526 blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
527 }
528
mlirRegionInsertOwnedBlockAfter(MlirRegion region,MlirBlock reference,MlirBlock block)529 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
530 MlirBlock block) {
531 Region *cppRegion = unwrap(region);
532 if (mlirBlockIsNull(reference)) {
533 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
534 return;
535 }
536
537 assert(unwrap(reference)->getParent() == unwrap(region) &&
538 "expected reference block to belong to the region");
539 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
540 unwrap(block));
541 }
542
mlirRegionInsertOwnedBlockBefore(MlirRegion region,MlirBlock reference,MlirBlock block)543 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
544 MlirBlock block) {
545 if (mlirBlockIsNull(reference))
546 return mlirRegionAppendOwnedBlock(region, block);
547
548 assert(unwrap(reference)->getParent() == unwrap(region) &&
549 "expected reference block to belong to the region");
550 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
551 unwrap(block));
552 }
553
mlirRegionDestroy(MlirRegion region)554 void mlirRegionDestroy(MlirRegion region) {
555 delete static_cast<Region *>(region.ptr);
556 }
557
558 //===----------------------------------------------------------------------===//
559 // Block API.
560 //===----------------------------------------------------------------------===//
561
mlirBlockCreate(intptr_t nArgs,MlirType const * args,MlirLocation const * locs)562 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args,
563 MlirLocation const *locs) {
564 Block *b = new Block;
565 for (intptr_t i = 0; i < nArgs; ++i)
566 b->addArgument(unwrap(args[i]), unwrap(locs[i]));
567 return wrap(b);
568 }
569
mlirBlockEqual(MlirBlock block,MlirBlock other)570 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
571 return unwrap(block) == unwrap(other);
572 }
573
mlirBlockGetParentOperation(MlirBlock block)574 MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
575 return wrap(unwrap(block)->getParentOp());
576 }
577
mlirBlockGetParentRegion(MlirBlock block)578 MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
579 return wrap(unwrap(block)->getParent());
580 }
581
mlirBlockGetNextInRegion(MlirBlock block)582 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
583 return wrap(unwrap(block)->getNextNode());
584 }
585
mlirBlockGetFirstOperation(MlirBlock block)586 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
587 Block *cppBlock = unwrap(block);
588 if (cppBlock->empty())
589 return wrap(static_cast<Operation *>(nullptr));
590 return wrap(&cppBlock->front());
591 }
592
mlirBlockGetTerminator(MlirBlock block)593 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
594 Block *cppBlock = unwrap(block);
595 if (cppBlock->empty())
596 return wrap(static_cast<Operation *>(nullptr));
597 Operation &back = cppBlock->back();
598 if (!back.hasTrait<OpTrait::IsTerminator>())
599 return wrap(static_cast<Operation *>(nullptr));
600 return wrap(&back);
601 }
602
mlirBlockAppendOwnedOperation(MlirBlock block,MlirOperation operation)603 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
604 unwrap(block)->push_back(unwrap(operation));
605 }
606
mlirBlockInsertOwnedOperation(MlirBlock block,intptr_t pos,MlirOperation operation)607 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
608 MlirOperation operation) {
609 auto &opList = unwrap(block)->getOperations();
610 opList.insert(std::next(opList.begin(), pos), unwrap(operation));
611 }
612
mlirBlockInsertOwnedOperationAfter(MlirBlock block,MlirOperation reference,MlirOperation operation)613 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
614 MlirOperation reference,
615 MlirOperation operation) {
616 Block *cppBlock = unwrap(block);
617 if (mlirOperationIsNull(reference)) {
618 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
619 return;
620 }
621
622 assert(unwrap(reference)->getBlock() == unwrap(block) &&
623 "expected reference operation to belong to the block");
624 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
625 unwrap(operation));
626 }
627
mlirBlockInsertOwnedOperationBefore(MlirBlock block,MlirOperation reference,MlirOperation operation)628 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
629 MlirOperation reference,
630 MlirOperation operation) {
631 if (mlirOperationIsNull(reference))
632 return mlirBlockAppendOwnedOperation(block, operation);
633
634 assert(unwrap(reference)->getBlock() == unwrap(block) &&
635 "expected reference operation to belong to the block");
636 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
637 unwrap(operation));
638 }
639
mlirBlockDestroy(MlirBlock block)640 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
641
mlirBlockDetach(MlirBlock block)642 void mlirBlockDetach(MlirBlock block) {
643 Block *b = unwrap(block);
644 b->getParent()->getBlocks().remove(b);
645 }
646
mlirBlockGetNumArguments(MlirBlock block)647 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
648 return static_cast<intptr_t>(unwrap(block)->getNumArguments());
649 }
650
mlirBlockAddArgument(MlirBlock block,MlirType type,MlirLocation loc)651 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
652 MlirLocation loc) {
653 return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc)));
654 }
655
mlirBlockGetArgument(MlirBlock block,intptr_t pos)656 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
657 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
658 }
659
mlirBlockPrint(MlirBlock block,MlirStringCallback callback,void * userData)660 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
661 void *userData) {
662 detail::CallbackOstream stream(callback, userData);
663 unwrap(block)->print(stream);
664 }
665
666 //===----------------------------------------------------------------------===//
667 // Value API.
668 //===----------------------------------------------------------------------===//
669
mlirValueEqual(MlirValue value1,MlirValue value2)670 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
671 return unwrap(value1) == unwrap(value2);
672 }
673
mlirValueIsABlockArgument(MlirValue value)674 bool mlirValueIsABlockArgument(MlirValue value) {
675 return unwrap(value).isa<BlockArgument>();
676 }
677
mlirValueIsAOpResult(MlirValue value)678 bool mlirValueIsAOpResult(MlirValue value) {
679 return unwrap(value).isa<OpResult>();
680 }
681
mlirBlockArgumentGetOwner(MlirValue value)682 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
683 return wrap(unwrap(value).cast<BlockArgument>().getOwner());
684 }
685
mlirBlockArgumentGetArgNumber(MlirValue value)686 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
687 return static_cast<intptr_t>(
688 unwrap(value).cast<BlockArgument>().getArgNumber());
689 }
690
mlirBlockArgumentSetType(MlirValue value,MlirType type)691 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
692 unwrap(value).cast<BlockArgument>().setType(unwrap(type));
693 }
694
mlirOpResultGetOwner(MlirValue value)695 MlirOperation mlirOpResultGetOwner(MlirValue value) {
696 return wrap(unwrap(value).cast<OpResult>().getOwner());
697 }
698
mlirOpResultGetResultNumber(MlirValue value)699 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
700 return static_cast<intptr_t>(
701 unwrap(value).cast<OpResult>().getResultNumber());
702 }
703
mlirValueGetType(MlirValue value)704 MlirType mlirValueGetType(MlirValue value) {
705 return wrap(unwrap(value).getType());
706 }
707
mlirValueDump(MlirValue value)708 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
709
mlirValuePrint(MlirValue value,MlirStringCallback callback,void * userData)710 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
711 void *userData) {
712 detail::CallbackOstream stream(callback, userData);
713 unwrap(value).print(stream);
714 }
715
716 //===----------------------------------------------------------------------===//
717 // Type API.
718 //===----------------------------------------------------------------------===//
719
mlirTypeParseGet(MlirContext context,MlirStringRef type)720 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
721 return wrap(mlir::parseType(unwrap(type), unwrap(context)));
722 }
723
mlirTypeGetContext(MlirType type)724 MlirContext mlirTypeGetContext(MlirType type) {
725 return wrap(unwrap(type).getContext());
726 }
727
mlirTypeGetTypeID(MlirType type)728 MlirTypeID mlirTypeGetTypeID(MlirType type) {
729 return wrap(unwrap(type).getTypeID());
730 }
731
mlirTypeEqual(MlirType t1,MlirType t2)732 bool mlirTypeEqual(MlirType t1, MlirType t2) {
733 return unwrap(t1) == unwrap(t2);
734 }
735
mlirTypePrint(MlirType type,MlirStringCallback callback,void * userData)736 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
737 detail::CallbackOstream stream(callback, userData);
738 unwrap(type).print(stream);
739 }
740
mlirTypeDump(MlirType type)741 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
742
743 //===----------------------------------------------------------------------===//
744 // Attribute API.
745 //===----------------------------------------------------------------------===//
746
mlirAttributeParseGet(MlirContext context,MlirStringRef attr)747 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
748 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
749 }
750
mlirAttributeGetContext(MlirAttribute attribute)751 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
752 return wrap(unwrap(attribute).getContext());
753 }
754
mlirAttributeGetType(MlirAttribute attribute)755 MlirType mlirAttributeGetType(MlirAttribute attribute) {
756 return wrap(unwrap(attribute).getType());
757 }
758
mlirAttributeGetTypeID(MlirAttribute attr)759 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
760 return wrap(unwrap(attr).getTypeID());
761 }
762
mlirAttributeEqual(MlirAttribute a1,MlirAttribute a2)763 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
764 return unwrap(a1) == unwrap(a2);
765 }
766
mlirAttributePrint(MlirAttribute attr,MlirStringCallback callback,void * userData)767 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
768 void *userData) {
769 detail::CallbackOstream stream(callback, userData);
770 unwrap(attr).print(stream);
771 }
772
mlirAttributeDump(MlirAttribute attr)773 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
774
mlirNamedAttributeGet(MlirIdentifier name,MlirAttribute attr)775 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
776 MlirAttribute attr) {
777 return MlirNamedAttribute{name, attr};
778 }
779
780 //===----------------------------------------------------------------------===//
781 // Identifier API.
782 //===----------------------------------------------------------------------===//
783
mlirIdentifierGet(MlirContext context,MlirStringRef str)784 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
785 return wrap(StringAttr::get(unwrap(context), unwrap(str)));
786 }
787
mlirIdentifierGetContext(MlirIdentifier ident)788 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
789 return wrap(unwrap(ident).getContext());
790 }
791
mlirIdentifierEqual(MlirIdentifier ident,MlirIdentifier other)792 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
793 return unwrap(ident) == unwrap(other);
794 }
795
mlirIdentifierStr(MlirIdentifier ident)796 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
797 return wrap(unwrap(ident).strref());
798 }
799
800 //===----------------------------------------------------------------------===//
801 // Symbol and SymbolTable API.
802 //===----------------------------------------------------------------------===//
803
mlirSymbolTableGetSymbolAttributeName()804 MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
805 return wrap(SymbolTable::getSymbolAttrName());
806 }
807
mlirSymbolTableGetVisibilityAttributeName()808 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
809 return wrap(SymbolTable::getVisibilityAttrName());
810 }
811
mlirSymbolTableCreate(MlirOperation operation)812 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
813 if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
814 return wrap(static_cast<SymbolTable *>(nullptr));
815 return wrap(new SymbolTable(unwrap(operation)));
816 }
817
mlirSymbolTableDestroy(MlirSymbolTable symbolTable)818 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
819 delete unwrap(symbolTable);
820 }
821
mlirSymbolTableLookup(MlirSymbolTable symbolTable,MlirStringRef name)822 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
823 MlirStringRef name) {
824 return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
825 }
826
mlirSymbolTableInsert(MlirSymbolTable symbolTable,MlirOperation operation)827 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
828 MlirOperation operation) {
829 return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
830 }
831
mlirSymbolTableErase(MlirSymbolTable symbolTable,MlirOperation operation)832 void mlirSymbolTableErase(MlirSymbolTable symbolTable,
833 MlirOperation operation) {
834 unwrap(symbolTable)->erase(unwrap(operation));
835 }
836
mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,MlirStringRef newSymbol,MlirOperation from)837 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
838 MlirStringRef newSymbol,
839 MlirOperation from) {
840 auto *cppFrom = unwrap(from);
841 auto *context = cppFrom->getContext();
842 auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol));
843 auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol));
844 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
845 unwrap(from)));
846 }
847
mlirSymbolTableWalkSymbolTables(MlirOperation from,bool allSymUsesVisible,void (* callback)(MlirOperation,bool,void * userData),void * userData)848 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
849 void (*callback)(MlirOperation, bool,
850 void *userData),
851 void *userData) {
852 SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible,
853 [&](Operation *foundOpCpp, bool isVisible) {
854 callback(wrap(foundOpCpp), isVisible,
855 userData);
856 });
857 }
858