xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision aad04534)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/Types.h"
21 #include "mlir/IR/Verifier.h"
22 #include "mlir/Interfaces/InferTypeOpInterface.h"
23 #include "mlir/Parser.h"
24 
25 #include "llvm/Support/Debug.h"
26 #include <cstddef>
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Context API.
32 //===----------------------------------------------------------------------===//
33 
34 MlirContext mlirContextCreate() {
35   auto *context = new MLIRContext;
36   return wrap(context);
37 }
38 
39 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
40   return unwrap(ctx1) == unwrap(ctx2);
41 }
42 
43 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
44 
45 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
46   unwrap(context)->allowUnregisteredDialects(allow);
47 }
48 
49 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
50   return unwrap(context)->allowsUnregisteredDialects();
51 }
52 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
53   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
54 }
55 
56 // TODO: expose a cheaper way than constructing + sorting a vector only to take
57 // its size.
58 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
59   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
60 }
61 
62 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
63                                         MlirStringRef name) {
64   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
65 }
66 
67 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
68   return unwrap(context)->isOperationRegistered(unwrap(name));
69 }
70 
71 void mlirContextEnableMultithreading(MlirContext context, bool enable) {
72   return unwrap(context)->enableMultithreading(enable);
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // Dialect API.
77 //===----------------------------------------------------------------------===//
78 
79 MlirContext mlirDialectGetContext(MlirDialect dialect) {
80   return wrap(unwrap(dialect)->getContext());
81 }
82 
83 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
84   return unwrap(dialect1) == unwrap(dialect2);
85 }
86 
87 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
88   return wrap(unwrap(dialect)->getNamespace());
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Printing flags API.
93 //===----------------------------------------------------------------------===//
94 
95 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
96   return wrap(new OpPrintingFlags());
97 }
98 
99 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
100   delete unwrap(flags);
101 }
102 
103 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
104                                                 intptr_t largeElementLimit) {
105   unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
106 }
107 
108 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
109                                         bool prettyForm) {
110   unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
111 }
112 
113 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
114   unwrap(flags)->printGenericOpForm();
115 }
116 
117 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
118   unwrap(flags)->useLocalScope();
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // Location API.
123 //===----------------------------------------------------------------------===//
124 
125 MlirLocation mlirLocationFileLineColGet(MlirContext context,
126                                         MlirStringRef filename, unsigned line,
127                                         unsigned col) {
128   return wrap(Location(
129       FileLineColLoc::get(unwrap(context), unwrap(filename), line, col)));
130 }
131 
132 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
133   return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
134 }
135 
136 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
137                                   MlirLocation const *locations,
138                                   MlirAttribute metadata) {
139   SmallVector<Location, 4> locs;
140   ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs);
141   return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
142 }
143 
144 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
145                                  MlirLocation childLoc) {
146   if (mlirLocationIsNull(childLoc))
147     return wrap(
148         Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name)))));
149   return wrap(Location(NameLoc::get(
150       StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
151 }
152 
153 MlirLocation mlirLocationUnknownGet(MlirContext context) {
154   return wrap(Location(UnknownLoc::get(unwrap(context))));
155 }
156 
157 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
158   return unwrap(l1) == unwrap(l2);
159 }
160 
161 MlirContext mlirLocationGetContext(MlirLocation location) {
162   return wrap(unwrap(location).getContext());
163 }
164 
165 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
166                        void *userData) {
167   detail::CallbackOstream stream(callback, userData);
168   unwrap(location).print(stream);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Module API.
173 //===----------------------------------------------------------------------===//
174 
175 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
176   return wrap(ModuleOp::create(unwrap(location)));
177 }
178 
179 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
180   OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context));
181   if (!owning)
182     return MlirModule{nullptr};
183   return MlirModule{owning.release().getOperation()};
184 }
185 
186 MlirContext mlirModuleGetContext(MlirModule module) {
187   return wrap(unwrap(module).getContext());
188 }
189 
190 MlirBlock mlirModuleGetBody(MlirModule module) {
191   return wrap(unwrap(module).getBody());
192 }
193 
194 void mlirModuleDestroy(MlirModule module) {
195   // Transfer ownership to an OwningModuleRef so that its destructor is called.
196   OwningModuleRef(unwrap(module));
197 }
198 
199 MlirOperation mlirModuleGetOperation(MlirModule module) {
200   return wrap(unwrap(module).getOperation());
201 }
202 
203 MlirModule mlirModuleFromOperation(MlirOperation op) {
204   return wrap(dyn_cast<ModuleOp>(unwrap(op)));
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // Operation state API.
209 //===----------------------------------------------------------------------===//
210 
211 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
212   MlirOperationState state;
213   state.name = name;
214   state.location = loc;
215   state.nResults = 0;
216   state.results = nullptr;
217   state.nOperands = 0;
218   state.operands = nullptr;
219   state.nRegions = 0;
220   state.regions = nullptr;
221   state.nSuccessors = 0;
222   state.successors = nullptr;
223   state.nAttributes = 0;
224   state.attributes = nullptr;
225   state.enableResultTypeInference = false;
226   return state;
227 }
228 
229 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
230   state->elemName =                                                            \
231       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
232   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
233   state->sizeName += n;
234 
235 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
236                                   MlirType const *results) {
237   APPEND_ELEMS(MlirType, nResults, results);
238 }
239 
240 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
241                                    MlirValue const *operands) {
242   APPEND_ELEMS(MlirValue, nOperands, operands);
243 }
244 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
245                                        MlirRegion const *regions) {
246   APPEND_ELEMS(MlirRegion, nRegions, regions);
247 }
248 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
249                                      MlirBlock const *successors) {
250   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
251 }
252 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
253                                      MlirNamedAttribute const *attributes) {
254   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
255 }
256 
257 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
258   state->enableResultTypeInference = true;
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // Operation API.
263 //===----------------------------------------------------------------------===//
264 
265 static LogicalResult inferOperationTypes(OperationState &state) {
266   MLIRContext *context = state.getContext();
267   Optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
268   if (!info) {
269     emitError(state.location)
270         << "type inference was requested for the operation " << state.name
271         << ", but the operation was not registered. Ensure that the dialect "
272            "containing the operation is linked into MLIR and registered with "
273            "the context";
274     return failure();
275   }
276 
277   // Fallback to inference via an op interface.
278   auto *inferInterface = info->getInterface<InferTypeOpInterface>();
279   if (!inferInterface) {
280     emitError(state.location)
281         << "type inference was requested for the operation " << state.name
282         << ", but the operation does not support type inference. Result "
283            "types must be specified explicitly.";
284     return failure();
285   }
286 
287   if (succeeded(inferInterface->inferReturnTypes(
288           context, state.location, state.operands,
289           state.attributes.getDictionary(context), state.regions, state.types)))
290     return success();
291 
292   // Diagnostic emitted by interface.
293   return failure();
294 }
295 
296 MlirOperation mlirOperationCreate(MlirOperationState *state) {
297   assert(state);
298   OperationState cppState(unwrap(state->location), unwrap(state->name));
299   SmallVector<Type, 4> resultStorage;
300   SmallVector<Value, 8> operandStorage;
301   SmallVector<Block *, 2> successorStorage;
302   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
303   cppState.addOperands(
304       unwrapList(state->nOperands, state->operands, operandStorage));
305   cppState.addSuccessors(
306       unwrapList(state->nSuccessors, state->successors, successorStorage));
307 
308   cppState.attributes.reserve(state->nAttributes);
309   for (intptr_t i = 0; i < state->nAttributes; ++i)
310     cppState.addAttribute(unwrap(state->attributes[i].name),
311                           unwrap(state->attributes[i].attribute));
312 
313   for (intptr_t i = 0; i < state->nRegions; ++i)
314     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
315 
316   free(state->results);
317   free(state->operands);
318   free(state->successors);
319   free(state->regions);
320   free(state->attributes);
321 
322   // Infer result types.
323   if (state->enableResultTypeInference) {
324     assert(cppState.types.empty() &&
325            "result type inference enabled and result types provided");
326     if (failed(inferOperationTypes(cppState)))
327       return {nullptr};
328   }
329 
330   MlirOperation result = wrap(Operation::create(cppState));
331   return result;
332 }
333 
334 MlirOperation mlirOperationClone(MlirOperation op) {
335   return wrap(unwrap(op)->clone());
336 }
337 
338 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
339 
340 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
341 
342 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
343   return unwrap(op) == unwrap(other);
344 }
345 
346 MlirContext mlirOperationGetContext(MlirOperation op) {
347   return wrap(unwrap(op)->getContext());
348 }
349 
350 MlirLocation mlirOperationGetLocation(MlirOperation op) {
351   return wrap(unwrap(op)->getLoc());
352 }
353 
354 MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
355   if (auto info = unwrap(op)->getRegisteredInfo())
356     return wrap(info->getTypeID());
357   return {nullptr};
358 }
359 
360 MlirIdentifier mlirOperationGetName(MlirOperation op) {
361   return wrap(unwrap(op)->getName().getIdentifier());
362 }
363 
364 MlirBlock mlirOperationGetBlock(MlirOperation op) {
365   return wrap(unwrap(op)->getBlock());
366 }
367 
368 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
369   return wrap(unwrap(op)->getParentOp());
370 }
371 
372 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
373   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
374 }
375 
376 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
377   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
378 }
379 
380 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
381   Operation *cppOp = unwrap(op);
382   if (cppOp->getNumRegions() == 0)
383     return wrap(static_cast<Region *>(nullptr));
384   return wrap(&cppOp->getRegion(0));
385 }
386 
387 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
388   Region *cppRegion = unwrap(region);
389   Operation *parent = cppRegion->getParentOp();
390   intptr_t next = cppRegion->getRegionNumber() + 1;
391   if (parent->getNumRegions() > next)
392     return wrap(&parent->getRegion(next));
393   return wrap(static_cast<Region *>(nullptr));
394 }
395 
396 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
397   return wrap(unwrap(op)->getNextNode());
398 }
399 
400 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
401   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
402 }
403 
404 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
405   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
406 }
407 
408 void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
409                              MlirValue newValue) {
410   unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
411 }
412 
413 intptr_t mlirOperationGetNumResults(MlirOperation op) {
414   return static_cast<intptr_t>(unwrap(op)->getNumResults());
415 }
416 
417 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
418   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
419 }
420 
421 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
422   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
423 }
424 
425 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
426   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
427 }
428 
429 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
430   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
431 }
432 
433 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
434   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
435   return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
436 }
437 
438 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
439                                               MlirStringRef name) {
440   return wrap(unwrap(op)->getAttr(unwrap(name)));
441 }
442 
443 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
444                                      MlirAttribute attr) {
445   unwrap(op)->setAttr(unwrap(name), unwrap(attr));
446 }
447 
448 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
449   return !!unwrap(op)->removeAttr(unwrap(name));
450 }
451 
452 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
453                         void *userData) {
454   detail::CallbackOstream stream(callback, userData);
455   unwrap(op)->print(stream);
456 }
457 
458 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
459                                  MlirStringCallback callback, void *userData) {
460   detail::CallbackOstream stream(callback, userData);
461   unwrap(op)->print(stream, *unwrap(flags));
462 }
463 
464 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
465 
466 bool mlirOperationVerify(MlirOperation op) {
467   return succeeded(verify(unwrap(op)));
468 }
469 
470 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
471   return unwrap(op)->moveAfter(unwrap(other));
472 }
473 
474 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
475   return unwrap(op)->moveBefore(unwrap(other));
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // Region API.
480 //===----------------------------------------------------------------------===//
481 
482 MlirRegion mlirRegionCreate() { return wrap(new Region); }
483 
484 bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
485   return unwrap(region) == unwrap(other);
486 }
487 
488 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
489   Region *cppRegion = unwrap(region);
490   if (cppRegion->empty())
491     return wrap(static_cast<Block *>(nullptr));
492   return wrap(&cppRegion->front());
493 }
494 
495 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
496   unwrap(region)->push_back(unwrap(block));
497 }
498 
499 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
500                                 MlirBlock block) {
501   auto &blockList = unwrap(region)->getBlocks();
502   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
503 }
504 
505 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
506                                      MlirBlock block) {
507   Region *cppRegion = unwrap(region);
508   if (mlirBlockIsNull(reference)) {
509     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
510     return;
511   }
512 
513   assert(unwrap(reference)->getParent() == unwrap(region) &&
514          "expected reference block to belong to the region");
515   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
516                                      unwrap(block));
517 }
518 
519 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
520                                       MlirBlock block) {
521   if (mlirBlockIsNull(reference))
522     return mlirRegionAppendOwnedBlock(region, block);
523 
524   assert(unwrap(reference)->getParent() == unwrap(region) &&
525          "expected reference block to belong to the region");
526   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
527                                      unwrap(block));
528 }
529 
530 void mlirRegionDestroy(MlirRegion region) {
531   delete static_cast<Region *>(region.ptr);
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // Block API.
536 //===----------------------------------------------------------------------===//
537 
538 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args,
539                           MlirLocation const *locs) {
540   Block *b = new Block;
541   for (intptr_t i = 0; i < nArgs; ++i)
542     b->addArgument(unwrap(args[i]), unwrap(locs[i]));
543   return wrap(b);
544 }
545 
546 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
547   return unwrap(block) == unwrap(other);
548 }
549 
550 MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
551   return wrap(unwrap(block)->getParentOp());
552 }
553 
554 MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
555   return wrap(unwrap(block)->getParent());
556 }
557 
558 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
559   return wrap(unwrap(block)->getNextNode());
560 }
561 
562 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
563   Block *cppBlock = unwrap(block);
564   if (cppBlock->empty())
565     return wrap(static_cast<Operation *>(nullptr));
566   return wrap(&cppBlock->front());
567 }
568 
569 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
570   Block *cppBlock = unwrap(block);
571   if (cppBlock->empty())
572     return wrap(static_cast<Operation *>(nullptr));
573   Operation &back = cppBlock->back();
574   if (!back.hasTrait<OpTrait::IsTerminator>())
575     return wrap(static_cast<Operation *>(nullptr));
576   return wrap(&back);
577 }
578 
579 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
580   unwrap(block)->push_back(unwrap(operation));
581 }
582 
583 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
584                                    MlirOperation operation) {
585   auto &opList = unwrap(block)->getOperations();
586   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
587 }
588 
589 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
590                                         MlirOperation reference,
591                                         MlirOperation operation) {
592   Block *cppBlock = unwrap(block);
593   if (mlirOperationIsNull(reference)) {
594     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
595     return;
596   }
597 
598   assert(unwrap(reference)->getBlock() == unwrap(block) &&
599          "expected reference operation to belong to the block");
600   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
601                                         unwrap(operation));
602 }
603 
604 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
605                                          MlirOperation reference,
606                                          MlirOperation operation) {
607   if (mlirOperationIsNull(reference))
608     return mlirBlockAppendOwnedOperation(block, operation);
609 
610   assert(unwrap(reference)->getBlock() == unwrap(block) &&
611          "expected reference operation to belong to the block");
612   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
613                                         unwrap(operation));
614 }
615 
616 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
617 
618 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
619   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
620 }
621 
622 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
623                                MlirLocation loc) {
624   return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc)));
625 }
626 
627 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
628   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
629 }
630 
631 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
632                     void *userData) {
633   detail::CallbackOstream stream(callback, userData);
634   unwrap(block)->print(stream);
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // Value API.
639 //===----------------------------------------------------------------------===//
640 
641 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
642   return unwrap(value1) == unwrap(value2);
643 }
644 
645 bool mlirValueIsABlockArgument(MlirValue value) {
646   return unwrap(value).isa<BlockArgument>();
647 }
648 
649 bool mlirValueIsAOpResult(MlirValue value) {
650   return unwrap(value).isa<OpResult>();
651 }
652 
653 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
654   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
655 }
656 
657 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
658   return static_cast<intptr_t>(
659       unwrap(value).cast<BlockArgument>().getArgNumber());
660 }
661 
662 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
663   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
664 }
665 
666 MlirOperation mlirOpResultGetOwner(MlirValue value) {
667   return wrap(unwrap(value).cast<OpResult>().getOwner());
668 }
669 
670 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
671   return static_cast<intptr_t>(
672       unwrap(value).cast<OpResult>().getResultNumber());
673 }
674 
675 MlirType mlirValueGetType(MlirValue value) {
676   return wrap(unwrap(value).getType());
677 }
678 
679 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
680 
681 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
682                     void *userData) {
683   detail::CallbackOstream stream(callback, userData);
684   unwrap(value).print(stream);
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // Type API.
689 //===----------------------------------------------------------------------===//
690 
691 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
692   return wrap(mlir::parseType(unwrap(type), unwrap(context)));
693 }
694 
695 MlirContext mlirTypeGetContext(MlirType type) {
696   return wrap(unwrap(type).getContext());
697 }
698 
699 MlirTypeID mlirTypeGetTypeID(MlirType type) {
700   return wrap(unwrap(type).getTypeID());
701 }
702 
703 bool mlirTypeEqual(MlirType t1, MlirType t2) {
704   return unwrap(t1) == unwrap(t2);
705 }
706 
707 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
708   detail::CallbackOstream stream(callback, userData);
709   unwrap(type).print(stream);
710 }
711 
712 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
713 
714 //===----------------------------------------------------------------------===//
715 // Attribute API.
716 //===----------------------------------------------------------------------===//
717 
718 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
719   return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
720 }
721 
722 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
723   return wrap(unwrap(attribute).getContext());
724 }
725 
726 MlirType mlirAttributeGetType(MlirAttribute attribute) {
727   return wrap(unwrap(attribute).getType());
728 }
729 
730 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
731   return wrap(unwrap(attr).getTypeID());
732 }
733 
734 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
735   return unwrap(a1) == unwrap(a2);
736 }
737 
738 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
739                         void *userData) {
740   detail::CallbackOstream stream(callback, userData);
741   unwrap(attr).print(stream);
742 }
743 
744 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
745 
746 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
747                                          MlirAttribute attr) {
748   return MlirNamedAttribute{name, attr};
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // Identifier API.
753 //===----------------------------------------------------------------------===//
754 
755 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
756   return wrap(StringAttr::get(unwrap(context), unwrap(str)));
757 }
758 
759 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
760   return wrap(unwrap(ident).getContext());
761 }
762 
763 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
764   return unwrap(ident) == unwrap(other);
765 }
766 
767 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
768   return wrap(unwrap(ident).strref());
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // TypeID API.
773 //===----------------------------------------------------------------------===//
774 
775 bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
776   return unwrap(typeID1) == unwrap(typeID2);
777 }
778 
779 size_t mlirTypeIDHashValue(MlirTypeID typeID) {
780   return hash_value(unwrap(typeID));
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Symbol and SymbolTable API.
785 //===----------------------------------------------------------------------===//
786 
787 MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
788   return wrap(SymbolTable::getSymbolAttrName());
789 }
790 
791 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
792   return wrap(SymbolTable::getVisibilityAttrName());
793 }
794 
795 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
796   if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
797     return wrap(static_cast<SymbolTable *>(nullptr));
798   return wrap(new SymbolTable(unwrap(operation)));
799 }
800 
801 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
802   delete unwrap(symbolTable);
803 }
804 
805 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
806                                     MlirStringRef name) {
807   return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
808 }
809 
810 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
811                                     MlirOperation operation) {
812   return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
813 }
814 
815 void mlirSymbolTableErase(MlirSymbolTable symbolTable,
816                           MlirOperation operation) {
817   unwrap(symbolTable)->erase(unwrap(operation));
818 }
819 
820 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
821                                                       MlirStringRef newSymbol,
822                                                       MlirOperation from) {
823   auto *cppFrom = unwrap(from);
824   auto *context = cppFrom->getContext();
825   auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol));
826   auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol));
827   return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
828                                                 unwrap(from)));
829 }
830 
831 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
832                                      void (*callback)(MlirOperation, bool,
833                                                       void *userData),
834                                      void *userData) {
835   SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible,
836                                 [&](Operation *foundOpCpp, bool isVisible) {
837                                   callback(wrap(foundOpCpp), isVisible,
838                                            userData);
839                                 });
840 }
841