xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision 5d2b8fa1)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/Types.h"
21 #include "mlir/IR/Verifier.h"
22 #include "mlir/Interfaces/InferTypeOpInterface.h"
23 #include "mlir/Parser.h"
24 
25 #include "llvm/Support/Debug.h"
26 #include <cstddef>
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Context API.
32 //===----------------------------------------------------------------------===//
33 
34 MlirContext mlirContextCreate() {
35   auto *context = new MLIRContext;
36   return wrap(context);
37 }
38 
39 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
40   return unwrap(ctx1) == unwrap(ctx2);
41 }
42 
43 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
44 
45 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
46   unwrap(context)->allowUnregisteredDialects(allow);
47 }
48 
49 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
50   return unwrap(context)->allowsUnregisteredDialects();
51 }
52 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
53   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
54 }
55 
56 // TODO: expose a cheaper way than constructing + sorting a vector only to take
57 // its size.
58 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
59   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
60 }
61 
62 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
63                                         MlirStringRef name) {
64   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
65 }
66 
67 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
68   return unwrap(context)->isOperationRegistered(unwrap(name));
69 }
70 
71 void mlirContextEnableMultithreading(MlirContext context, bool enable) {
72   return unwrap(context)->enableMultithreading(enable);
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // Dialect API.
77 //===----------------------------------------------------------------------===//
78 
79 MlirContext mlirDialectGetContext(MlirDialect dialect) {
80   return wrap(unwrap(dialect)->getContext());
81 }
82 
83 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
84   return unwrap(dialect1) == unwrap(dialect2);
85 }
86 
87 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
88   return wrap(unwrap(dialect)->getNamespace());
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Printing flags API.
93 //===----------------------------------------------------------------------===//
94 
95 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
96   return wrap(new OpPrintingFlags());
97 }
98 
99 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
100   delete unwrap(flags);
101 }
102 
103 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
104                                                 intptr_t largeElementLimit) {
105   unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
106 }
107 
108 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
109                                         bool prettyForm) {
110   unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
111 }
112 
113 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
114   unwrap(flags)->printGenericOpForm();
115 }
116 
117 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
118   unwrap(flags)->useLocalScope();
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // Location API.
123 //===----------------------------------------------------------------------===//
124 
125 MlirLocation mlirLocationFileLineColGet(MlirContext context,
126                                         MlirStringRef filename, unsigned line,
127                                         unsigned col) {
128   return wrap(Location(
129       FileLineColLoc::get(unwrap(context), unwrap(filename), line, col)));
130 }
131 
132 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
133   return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
134 }
135 
136 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
137                                   MlirLocation const *locations,
138                                   MlirAttribute metadata) {
139   SmallVector<Location, 4> locs;
140   ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs);
141   return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
142 }
143 
144 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
145                                  MlirLocation childLoc) {
146   if (mlirLocationIsNull(childLoc))
147     return wrap(
148         Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name)))));
149   return wrap(Location(NameLoc::get(
150       StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
151 }
152 
153 MlirLocation mlirLocationUnknownGet(MlirContext context) {
154   return wrap(Location(UnknownLoc::get(unwrap(context))));
155 }
156 
157 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
158   return unwrap(l1) == unwrap(l2);
159 }
160 
161 MlirContext mlirLocationGetContext(MlirLocation location) {
162   return wrap(unwrap(location).getContext());
163 }
164 
165 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
166                        void *userData) {
167   detail::CallbackOstream stream(callback, userData);
168   unwrap(location).print(stream);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Module API.
173 //===----------------------------------------------------------------------===//
174 
175 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
176   return wrap(ModuleOp::create(unwrap(location)));
177 }
178 
179 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
180   OwningOpRef<ModuleOp> owning =
181       parseSourceString(unwrap(module), unwrap(context));
182   if (!owning)
183     return MlirModule{nullptr};
184   return MlirModule{owning.release().getOperation()};
185 }
186 
187 MlirContext mlirModuleGetContext(MlirModule module) {
188   return wrap(unwrap(module).getContext());
189 }
190 
191 MlirBlock mlirModuleGetBody(MlirModule module) {
192   return wrap(unwrap(module).getBody());
193 }
194 
195 void mlirModuleDestroy(MlirModule module) {
196   // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is
197   // called.
198   OwningOpRef<ModuleOp>(unwrap(module));
199 }
200 
201 MlirOperation mlirModuleGetOperation(MlirModule module) {
202   return wrap(unwrap(module).getOperation());
203 }
204 
205 MlirModule mlirModuleFromOperation(MlirOperation op) {
206   return wrap(dyn_cast<ModuleOp>(unwrap(op)));
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // Operation state API.
211 //===----------------------------------------------------------------------===//
212 
213 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
214   MlirOperationState state;
215   state.name = name;
216   state.location = loc;
217   state.nResults = 0;
218   state.results = nullptr;
219   state.nOperands = 0;
220   state.operands = nullptr;
221   state.nRegions = 0;
222   state.regions = nullptr;
223   state.nSuccessors = 0;
224   state.successors = nullptr;
225   state.nAttributes = 0;
226   state.attributes = nullptr;
227   state.enableResultTypeInference = false;
228   return state;
229 }
230 
231 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
232   state->elemName =                                                            \
233       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
234   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
235   state->sizeName += n;
236 
237 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
238                                   MlirType const *results) {
239   APPEND_ELEMS(MlirType, nResults, results);
240 }
241 
242 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
243                                    MlirValue const *operands) {
244   APPEND_ELEMS(MlirValue, nOperands, operands);
245 }
246 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
247                                        MlirRegion const *regions) {
248   APPEND_ELEMS(MlirRegion, nRegions, regions);
249 }
250 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
251                                      MlirBlock const *successors) {
252   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
253 }
254 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
255                                      MlirNamedAttribute const *attributes) {
256   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
257 }
258 
259 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
260   state->enableResultTypeInference = true;
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // Operation API.
265 //===----------------------------------------------------------------------===//
266 
267 static LogicalResult inferOperationTypes(OperationState &state) {
268   MLIRContext *context = state.getContext();
269   Optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
270   if (!info) {
271     emitError(state.location)
272         << "type inference was requested for the operation " << state.name
273         << ", but the operation was not registered. Ensure that the dialect "
274            "containing the operation is linked into MLIR and registered with "
275            "the context";
276     return failure();
277   }
278 
279   // Fallback to inference via an op interface.
280   auto *inferInterface = info->getInterface<InferTypeOpInterface>();
281   if (!inferInterface) {
282     emitError(state.location)
283         << "type inference was requested for the operation " << state.name
284         << ", but the operation does not support type inference. Result "
285            "types must be specified explicitly.";
286     return failure();
287   }
288 
289   if (succeeded(inferInterface->inferReturnTypes(
290           context, state.location, state.operands,
291           state.attributes.getDictionary(context), state.regions, state.types)))
292     return success();
293 
294   // Diagnostic emitted by interface.
295   return failure();
296 }
297 
298 MlirOperation mlirOperationCreate(MlirOperationState *state) {
299   assert(state);
300   OperationState cppState(unwrap(state->location), unwrap(state->name));
301   SmallVector<Type, 4> resultStorage;
302   SmallVector<Value, 8> operandStorage;
303   SmallVector<Block *, 2> successorStorage;
304   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
305   cppState.addOperands(
306       unwrapList(state->nOperands, state->operands, operandStorage));
307   cppState.addSuccessors(
308       unwrapList(state->nSuccessors, state->successors, successorStorage));
309 
310   cppState.attributes.reserve(state->nAttributes);
311   for (intptr_t i = 0; i < state->nAttributes; ++i)
312     cppState.addAttribute(unwrap(state->attributes[i].name),
313                           unwrap(state->attributes[i].attribute));
314 
315   for (intptr_t i = 0; i < state->nRegions; ++i)
316     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
317 
318   free(state->results);
319   free(state->operands);
320   free(state->successors);
321   free(state->regions);
322   free(state->attributes);
323 
324   // Infer result types.
325   if (state->enableResultTypeInference) {
326     assert(cppState.types.empty() &&
327            "result type inference enabled and result types provided");
328     if (failed(inferOperationTypes(cppState)))
329       return {nullptr};
330   }
331 
332   MlirOperation result = wrap(Operation::create(cppState));
333   return result;
334 }
335 
336 MlirOperation mlirOperationClone(MlirOperation op) {
337   return wrap(unwrap(op)->clone());
338 }
339 
340 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
341 
342 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
343 
344 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
345   return unwrap(op) == unwrap(other);
346 }
347 
348 MlirContext mlirOperationGetContext(MlirOperation op) {
349   return wrap(unwrap(op)->getContext());
350 }
351 
352 MlirLocation mlirOperationGetLocation(MlirOperation op) {
353   return wrap(unwrap(op)->getLoc());
354 }
355 
356 MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
357   if (auto info = unwrap(op)->getRegisteredInfo())
358     return wrap(info->getTypeID());
359   return {nullptr};
360 }
361 
362 MlirIdentifier mlirOperationGetName(MlirOperation op) {
363   return wrap(unwrap(op)->getName().getIdentifier());
364 }
365 
366 MlirBlock mlirOperationGetBlock(MlirOperation op) {
367   return wrap(unwrap(op)->getBlock());
368 }
369 
370 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
371   return wrap(unwrap(op)->getParentOp());
372 }
373 
374 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
375   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
376 }
377 
378 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
379   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
380 }
381 
382 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
383   Operation *cppOp = unwrap(op);
384   if (cppOp->getNumRegions() == 0)
385     return wrap(static_cast<Region *>(nullptr));
386   return wrap(&cppOp->getRegion(0));
387 }
388 
389 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
390   Region *cppRegion = unwrap(region);
391   Operation *parent = cppRegion->getParentOp();
392   intptr_t next = cppRegion->getRegionNumber() + 1;
393   if (parent->getNumRegions() > next)
394     return wrap(&parent->getRegion(next));
395   return wrap(static_cast<Region *>(nullptr));
396 }
397 
398 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
399   return wrap(unwrap(op)->getNextNode());
400 }
401 
402 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
403   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
404 }
405 
406 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
407   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
408 }
409 
410 void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
411                              MlirValue newValue) {
412   unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
413 }
414 
415 intptr_t mlirOperationGetNumResults(MlirOperation op) {
416   return static_cast<intptr_t>(unwrap(op)->getNumResults());
417 }
418 
419 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
420   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
421 }
422 
423 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
424   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
425 }
426 
427 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
428   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
429 }
430 
431 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
432   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
433 }
434 
435 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
436   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
437   return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
438 }
439 
440 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
441                                               MlirStringRef name) {
442   return wrap(unwrap(op)->getAttr(unwrap(name)));
443 }
444 
445 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
446                                      MlirAttribute attr) {
447   unwrap(op)->setAttr(unwrap(name), unwrap(attr));
448 }
449 
450 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
451   return !!unwrap(op)->removeAttr(unwrap(name));
452 }
453 
454 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
455                         void *userData) {
456   detail::CallbackOstream stream(callback, userData);
457   unwrap(op)->print(stream);
458 }
459 
460 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
461                                  MlirStringCallback callback, void *userData) {
462   detail::CallbackOstream stream(callback, userData);
463   unwrap(op)->print(stream, *unwrap(flags));
464 }
465 
466 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
467 
468 bool mlirOperationVerify(MlirOperation op) {
469   return succeeded(verify(unwrap(op)));
470 }
471 
472 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
473   return unwrap(op)->moveAfter(unwrap(other));
474 }
475 
476 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
477   return unwrap(op)->moveBefore(unwrap(other));
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // Region API.
482 //===----------------------------------------------------------------------===//
483 
484 MlirRegion mlirRegionCreate() { return wrap(new Region); }
485 
486 bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
487   return unwrap(region) == unwrap(other);
488 }
489 
490 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
491   Region *cppRegion = unwrap(region);
492   if (cppRegion->empty())
493     return wrap(static_cast<Block *>(nullptr));
494   return wrap(&cppRegion->front());
495 }
496 
497 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
498   unwrap(region)->push_back(unwrap(block));
499 }
500 
501 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
502                                 MlirBlock block) {
503   auto &blockList = unwrap(region)->getBlocks();
504   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
505 }
506 
507 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
508                                      MlirBlock block) {
509   Region *cppRegion = unwrap(region);
510   if (mlirBlockIsNull(reference)) {
511     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
512     return;
513   }
514 
515   assert(unwrap(reference)->getParent() == unwrap(region) &&
516          "expected reference block to belong to the region");
517   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
518                                      unwrap(block));
519 }
520 
521 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
522                                       MlirBlock block) {
523   if (mlirBlockIsNull(reference))
524     return mlirRegionAppendOwnedBlock(region, block);
525 
526   assert(unwrap(reference)->getParent() == unwrap(region) &&
527          "expected reference block to belong to the region");
528   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
529                                      unwrap(block));
530 }
531 
532 void mlirRegionDestroy(MlirRegion region) {
533   delete static_cast<Region *>(region.ptr);
534 }
535 
536 //===----------------------------------------------------------------------===//
537 // Block API.
538 //===----------------------------------------------------------------------===//
539 
540 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args,
541                           MlirLocation const *locs) {
542   Block *b = new Block;
543   for (intptr_t i = 0; i < nArgs; ++i)
544     b->addArgument(unwrap(args[i]), unwrap(locs[i]));
545   return wrap(b);
546 }
547 
548 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
549   return unwrap(block) == unwrap(other);
550 }
551 
552 MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
553   return wrap(unwrap(block)->getParentOp());
554 }
555 
556 MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
557   return wrap(unwrap(block)->getParent());
558 }
559 
560 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
561   return wrap(unwrap(block)->getNextNode());
562 }
563 
564 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
565   Block *cppBlock = unwrap(block);
566   if (cppBlock->empty())
567     return wrap(static_cast<Operation *>(nullptr));
568   return wrap(&cppBlock->front());
569 }
570 
571 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
572   Block *cppBlock = unwrap(block);
573   if (cppBlock->empty())
574     return wrap(static_cast<Operation *>(nullptr));
575   Operation &back = cppBlock->back();
576   if (!back.hasTrait<OpTrait::IsTerminator>())
577     return wrap(static_cast<Operation *>(nullptr));
578   return wrap(&back);
579 }
580 
581 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
582   unwrap(block)->push_back(unwrap(operation));
583 }
584 
585 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
586                                    MlirOperation operation) {
587   auto &opList = unwrap(block)->getOperations();
588   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
589 }
590 
591 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
592                                         MlirOperation reference,
593                                         MlirOperation operation) {
594   Block *cppBlock = unwrap(block);
595   if (mlirOperationIsNull(reference)) {
596     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
597     return;
598   }
599 
600   assert(unwrap(reference)->getBlock() == unwrap(block) &&
601          "expected reference operation to belong to the block");
602   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
603                                         unwrap(operation));
604 }
605 
606 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
607                                          MlirOperation reference,
608                                          MlirOperation operation) {
609   if (mlirOperationIsNull(reference))
610     return mlirBlockAppendOwnedOperation(block, operation);
611 
612   assert(unwrap(reference)->getBlock() == unwrap(block) &&
613          "expected reference operation to belong to the block");
614   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
615                                         unwrap(operation));
616 }
617 
618 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
619 
620 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
621   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
622 }
623 
624 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
625                                MlirLocation loc) {
626   return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc)));
627 }
628 
629 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
630   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
631 }
632 
633 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
634                     void *userData) {
635   detail::CallbackOstream stream(callback, userData);
636   unwrap(block)->print(stream);
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // Value API.
641 //===----------------------------------------------------------------------===//
642 
643 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
644   return unwrap(value1) == unwrap(value2);
645 }
646 
647 bool mlirValueIsABlockArgument(MlirValue value) {
648   return unwrap(value).isa<BlockArgument>();
649 }
650 
651 bool mlirValueIsAOpResult(MlirValue value) {
652   return unwrap(value).isa<OpResult>();
653 }
654 
655 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
656   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
657 }
658 
659 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
660   return static_cast<intptr_t>(
661       unwrap(value).cast<BlockArgument>().getArgNumber());
662 }
663 
664 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
665   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
666 }
667 
668 MlirOperation mlirOpResultGetOwner(MlirValue value) {
669   return wrap(unwrap(value).cast<OpResult>().getOwner());
670 }
671 
672 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
673   return static_cast<intptr_t>(
674       unwrap(value).cast<OpResult>().getResultNumber());
675 }
676 
677 MlirType mlirValueGetType(MlirValue value) {
678   return wrap(unwrap(value).getType());
679 }
680 
681 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
682 
683 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
684                     void *userData) {
685   detail::CallbackOstream stream(callback, userData);
686   unwrap(value).print(stream);
687 }
688 
689 //===----------------------------------------------------------------------===//
690 // Type API.
691 //===----------------------------------------------------------------------===//
692 
693 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
694   return wrap(mlir::parseType(unwrap(type), unwrap(context)));
695 }
696 
697 MlirContext mlirTypeGetContext(MlirType type) {
698   return wrap(unwrap(type).getContext());
699 }
700 
701 MlirTypeID mlirTypeGetTypeID(MlirType type) {
702   return wrap(unwrap(type).getTypeID());
703 }
704 
705 bool mlirTypeEqual(MlirType t1, MlirType t2) {
706   return unwrap(t1) == unwrap(t2);
707 }
708 
709 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
710   detail::CallbackOstream stream(callback, userData);
711   unwrap(type).print(stream);
712 }
713 
714 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
715 
716 //===----------------------------------------------------------------------===//
717 // Attribute API.
718 //===----------------------------------------------------------------------===//
719 
720 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
721   return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
722 }
723 
724 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
725   return wrap(unwrap(attribute).getContext());
726 }
727 
728 MlirType mlirAttributeGetType(MlirAttribute attribute) {
729   return wrap(unwrap(attribute).getType());
730 }
731 
732 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
733   return wrap(unwrap(attr).getTypeID());
734 }
735 
736 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
737   return unwrap(a1) == unwrap(a2);
738 }
739 
740 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
741                         void *userData) {
742   detail::CallbackOstream stream(callback, userData);
743   unwrap(attr).print(stream);
744 }
745 
746 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
747 
748 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
749                                          MlirAttribute attr) {
750   return MlirNamedAttribute{name, attr};
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // Identifier API.
755 //===----------------------------------------------------------------------===//
756 
757 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
758   return wrap(StringAttr::get(unwrap(context), unwrap(str)));
759 }
760 
761 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
762   return wrap(unwrap(ident).getContext());
763 }
764 
765 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
766   return unwrap(ident) == unwrap(other);
767 }
768 
769 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
770   return wrap(unwrap(ident).strref());
771 }
772 
773 //===----------------------------------------------------------------------===//
774 // TypeID API.
775 //===----------------------------------------------------------------------===//
776 
777 bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
778   return unwrap(typeID1) == unwrap(typeID2);
779 }
780 
781 size_t mlirTypeIDHashValue(MlirTypeID typeID) {
782   return hash_value(unwrap(typeID));
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // Symbol and SymbolTable API.
787 //===----------------------------------------------------------------------===//
788 
789 MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
790   return wrap(SymbolTable::getSymbolAttrName());
791 }
792 
793 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
794   return wrap(SymbolTable::getVisibilityAttrName());
795 }
796 
797 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
798   if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
799     return wrap(static_cast<SymbolTable *>(nullptr));
800   return wrap(new SymbolTable(unwrap(operation)));
801 }
802 
803 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
804   delete unwrap(symbolTable);
805 }
806 
807 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
808                                     MlirStringRef name) {
809   return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
810 }
811 
812 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
813                                     MlirOperation operation) {
814   return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
815 }
816 
817 void mlirSymbolTableErase(MlirSymbolTable symbolTable,
818                           MlirOperation operation) {
819   unwrap(symbolTable)->erase(unwrap(operation));
820 }
821 
822 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
823                                                       MlirStringRef newSymbol,
824                                                       MlirOperation from) {
825   auto *cppFrom = unwrap(from);
826   auto *context = cppFrom->getContext();
827   auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol));
828   auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol));
829   return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
830                                                 unwrap(from)));
831 }
832 
833 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
834                                      void (*callback)(MlirOperation, bool,
835                                                       void *userData),
836                                      void *userData) {
837   SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible,
838                                 [&](Operation *foundOpCpp, bool isVisible) {
839                                   callback(wrap(foundOpCpp), isVisible,
840                                            userData);
841                                 });
842 }
843