xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision 5ee910fe)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/AsmParser/AsmParser.h"
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/CAPI/Utils.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/Types.h"
22 #include "mlir/IR/Verifier.h"
23 #include "mlir/Interfaces/InferTypeOpInterface.h"
24 #include "mlir/Parser/Parser.h"
25 
26 #include "llvm/Support/Debug.h"
27 #include <cstddef>
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Context API.
33 //===----------------------------------------------------------------------===//
34 
35 MlirContext mlirContextCreate() {
36   auto *context = new MLIRContext;
37   return wrap(context);
38 }
39 
40 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
41   return unwrap(ctx1) == unwrap(ctx2);
42 }
43 
44 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
45 
46 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
47   unwrap(context)->allowUnregisteredDialects(allow);
48 }
49 
50 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
51   return unwrap(context)->allowsUnregisteredDialects();
52 }
53 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
54   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
55 }
56 
57 void mlirContextAppendDialectRegistry(MlirContext ctx,
58                                       MlirDialectRegistry registry) {
59   unwrap(ctx)->appendDialectRegistry(*unwrap(registry));
60 }
61 
62 // TODO: expose a cheaper way than constructing + sorting a vector only to take
63 // its size.
64 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
65   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
66 }
67 
68 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
69                                         MlirStringRef name) {
70   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
71 }
72 
73 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
74   return unwrap(context)->isOperationRegistered(unwrap(name));
75 }
76 
77 void mlirContextEnableMultithreading(MlirContext context, bool enable) {
78   return unwrap(context)->enableMultithreading(enable);
79 }
80 
81 void mlirContextLoadAllAvailableDialects(MlirContext context) {
82   unwrap(context)->loadAllAvailableDialects();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Dialect API.
87 //===----------------------------------------------------------------------===//
88 
89 MlirContext mlirDialectGetContext(MlirDialect dialect) {
90   return wrap(unwrap(dialect)->getContext());
91 }
92 
93 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
94   return unwrap(dialect1) == unwrap(dialect2);
95 }
96 
97 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
98   return wrap(unwrap(dialect)->getNamespace());
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // DialectRegistry API.
103 //===----------------------------------------------------------------------===//
104 
105 MlirDialectRegistry mlirDialectRegistryCreate() {
106   return wrap(new DialectRegistry());
107 }
108 
109 void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
110   delete unwrap(registry);
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // Printing flags API.
115 //===----------------------------------------------------------------------===//
116 
117 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
118   return wrap(new OpPrintingFlags());
119 }
120 
121 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
122   delete unwrap(flags);
123 }
124 
125 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
126                                                 intptr_t largeElementLimit) {
127   unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
128 }
129 
130 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
131                                         bool prettyForm) {
132   unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
133 }
134 
135 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
136   unwrap(flags)->printGenericOpForm();
137 }
138 
139 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
140   unwrap(flags)->useLocalScope();
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // Location API.
145 //===----------------------------------------------------------------------===//
146 
147 MlirLocation mlirLocationFileLineColGet(MlirContext context,
148                                         MlirStringRef filename, unsigned line,
149                                         unsigned col) {
150   return wrap(Location(
151       FileLineColLoc::get(unwrap(context), unwrap(filename), line, col)));
152 }
153 
154 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
155   return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
156 }
157 
158 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
159                                   MlirLocation const *locations,
160                                   MlirAttribute metadata) {
161   SmallVector<Location, 4> locs;
162   ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs);
163   return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
164 }
165 
166 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
167                                  MlirLocation childLoc) {
168   if (mlirLocationIsNull(childLoc))
169     return wrap(
170         Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name)))));
171   return wrap(Location(NameLoc::get(
172       StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
173 }
174 
175 MlirLocation mlirLocationUnknownGet(MlirContext context) {
176   return wrap(Location(UnknownLoc::get(unwrap(context))));
177 }
178 
179 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
180   return unwrap(l1) == unwrap(l2);
181 }
182 
183 MlirContext mlirLocationGetContext(MlirLocation location) {
184   return wrap(unwrap(location).getContext());
185 }
186 
187 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
188                        void *userData) {
189   detail::CallbackOstream stream(callback, userData);
190   unwrap(location).print(stream);
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // Module API.
195 //===----------------------------------------------------------------------===//
196 
197 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
198   return wrap(ModuleOp::create(unwrap(location)));
199 }
200 
201 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
202   OwningOpRef<ModuleOp> owning =
203       parseSourceString<ModuleOp>(unwrap(module), unwrap(context));
204   if (!owning)
205     return MlirModule{nullptr};
206   return MlirModule{owning.release().getOperation()};
207 }
208 
209 MlirContext mlirModuleGetContext(MlirModule module) {
210   return wrap(unwrap(module).getContext());
211 }
212 
213 MlirBlock mlirModuleGetBody(MlirModule module) {
214   return wrap(unwrap(module).getBody());
215 }
216 
217 void mlirModuleDestroy(MlirModule module) {
218   // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is
219   // called.
220   OwningOpRef<ModuleOp>(unwrap(module));
221 }
222 
223 MlirOperation mlirModuleGetOperation(MlirModule module) {
224   return wrap(unwrap(module).getOperation());
225 }
226 
227 MlirModule mlirModuleFromOperation(MlirOperation op) {
228   return wrap(dyn_cast<ModuleOp>(unwrap(op)));
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // Operation state API.
233 //===----------------------------------------------------------------------===//
234 
235 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
236   MlirOperationState state;
237   state.name = name;
238   state.location = loc;
239   state.nResults = 0;
240   state.results = nullptr;
241   state.nOperands = 0;
242   state.operands = nullptr;
243   state.nRegions = 0;
244   state.regions = nullptr;
245   state.nSuccessors = 0;
246   state.successors = nullptr;
247   state.nAttributes = 0;
248   state.attributes = nullptr;
249   state.enableResultTypeInference = false;
250   return state;
251 }
252 
253 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
254   state->elemName =                                                            \
255       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
256   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
257   state->sizeName += n;
258 
259 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
260                                   MlirType const *results) {
261   APPEND_ELEMS(MlirType, nResults, results);
262 }
263 
264 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
265                                    MlirValue const *operands) {
266   APPEND_ELEMS(MlirValue, nOperands, operands);
267 }
268 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
269                                        MlirRegion const *regions) {
270   APPEND_ELEMS(MlirRegion, nRegions, regions);
271 }
272 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
273                                      MlirBlock const *successors) {
274   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
275 }
276 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
277                                      MlirNamedAttribute const *attributes) {
278   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
279 }
280 
281 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
282   state->enableResultTypeInference = true;
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // Operation API.
287 //===----------------------------------------------------------------------===//
288 
289 static LogicalResult inferOperationTypes(OperationState &state) {
290   MLIRContext *context = state.getContext();
291   Optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
292   if (!info) {
293     emitError(state.location)
294         << "type inference was requested for the operation " << state.name
295         << ", but the operation was not registered. Ensure that the dialect "
296            "containing the operation is linked into MLIR and registered with "
297            "the context";
298     return failure();
299   }
300 
301   // Fallback to inference via an op interface.
302   auto *inferInterface = info->getInterface<InferTypeOpInterface>();
303   if (!inferInterface) {
304     emitError(state.location)
305         << "type inference was requested for the operation " << state.name
306         << ", but the operation does not support type inference. Result "
307            "types must be specified explicitly.";
308     return failure();
309   }
310 
311   if (succeeded(inferInterface->inferReturnTypes(
312           context, state.location, state.operands,
313           state.attributes.getDictionary(context), state.regions, state.types)))
314     return success();
315 
316   // Diagnostic emitted by interface.
317   return failure();
318 }
319 
320 MlirOperation mlirOperationCreate(MlirOperationState *state) {
321   assert(state);
322   OperationState cppState(unwrap(state->location), unwrap(state->name));
323   SmallVector<Type, 4> resultStorage;
324   SmallVector<Value, 8> operandStorage;
325   SmallVector<Block *, 2> successorStorage;
326   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
327   cppState.addOperands(
328       unwrapList(state->nOperands, state->operands, operandStorage));
329   cppState.addSuccessors(
330       unwrapList(state->nSuccessors, state->successors, successorStorage));
331 
332   cppState.attributes.reserve(state->nAttributes);
333   for (intptr_t i = 0; i < state->nAttributes; ++i)
334     cppState.addAttribute(unwrap(state->attributes[i].name),
335                           unwrap(state->attributes[i].attribute));
336 
337   for (intptr_t i = 0; i < state->nRegions; ++i)
338     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
339 
340   free(state->results);
341   free(state->operands);
342   free(state->successors);
343   free(state->regions);
344   free(state->attributes);
345 
346   // Infer result types.
347   if (state->enableResultTypeInference) {
348     assert(cppState.types.empty() &&
349            "result type inference enabled and result types provided");
350     if (failed(inferOperationTypes(cppState)))
351       return {nullptr};
352   }
353 
354   MlirOperation result = wrap(Operation::create(cppState));
355   return result;
356 }
357 
358 MlirOperation mlirOperationClone(MlirOperation op) {
359   return wrap(unwrap(op)->clone());
360 }
361 
362 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
363 
364 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
365 
366 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
367   return unwrap(op) == unwrap(other);
368 }
369 
370 MlirContext mlirOperationGetContext(MlirOperation op) {
371   return wrap(unwrap(op)->getContext());
372 }
373 
374 MlirLocation mlirOperationGetLocation(MlirOperation op) {
375   return wrap(unwrap(op)->getLoc());
376 }
377 
378 MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
379   if (auto info = unwrap(op)->getRegisteredInfo())
380     return wrap(info->getTypeID());
381   return {nullptr};
382 }
383 
384 MlirIdentifier mlirOperationGetName(MlirOperation op) {
385   return wrap(unwrap(op)->getName().getIdentifier());
386 }
387 
388 MlirBlock mlirOperationGetBlock(MlirOperation op) {
389   return wrap(unwrap(op)->getBlock());
390 }
391 
392 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
393   return wrap(unwrap(op)->getParentOp());
394 }
395 
396 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
397   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
398 }
399 
400 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
401   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
402 }
403 
404 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
405   Operation *cppOp = unwrap(op);
406   if (cppOp->getNumRegions() == 0)
407     return wrap(static_cast<Region *>(nullptr));
408   return wrap(&cppOp->getRegion(0));
409 }
410 
411 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
412   Region *cppRegion = unwrap(region);
413   Operation *parent = cppRegion->getParentOp();
414   intptr_t next = cppRegion->getRegionNumber() + 1;
415   if (parent->getNumRegions() > next)
416     return wrap(&parent->getRegion(next));
417   return wrap(static_cast<Region *>(nullptr));
418 }
419 
420 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
421   return wrap(unwrap(op)->getNextNode());
422 }
423 
424 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
425   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
426 }
427 
428 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
429   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
430 }
431 
432 void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
433                              MlirValue newValue) {
434   unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
435 }
436 
437 intptr_t mlirOperationGetNumResults(MlirOperation op) {
438   return static_cast<intptr_t>(unwrap(op)->getNumResults());
439 }
440 
441 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
442   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
443 }
444 
445 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
446   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
447 }
448 
449 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
450   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
451 }
452 
453 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
454   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
455 }
456 
457 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
458   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
459   return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
460 }
461 
462 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
463                                               MlirStringRef name) {
464   return wrap(unwrap(op)->getAttr(unwrap(name)));
465 }
466 
467 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
468                                      MlirAttribute attr) {
469   unwrap(op)->setAttr(unwrap(name), unwrap(attr));
470 }
471 
472 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
473   return !!unwrap(op)->removeAttr(unwrap(name));
474 }
475 
476 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
477                         void *userData) {
478   detail::CallbackOstream stream(callback, userData);
479   unwrap(op)->print(stream);
480 }
481 
482 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
483                                  MlirStringCallback callback, void *userData) {
484   detail::CallbackOstream stream(callback, userData);
485   unwrap(op)->print(stream, *unwrap(flags));
486 }
487 
488 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
489 
490 bool mlirOperationVerify(MlirOperation op) {
491   return succeeded(verify(unwrap(op)));
492 }
493 
494 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
495   return unwrap(op)->moveAfter(unwrap(other));
496 }
497 
498 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
499   return unwrap(op)->moveBefore(unwrap(other));
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // Region API.
504 //===----------------------------------------------------------------------===//
505 
506 MlirRegion mlirRegionCreate() { return wrap(new Region); }
507 
508 bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
509   return unwrap(region) == unwrap(other);
510 }
511 
512 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
513   Region *cppRegion = unwrap(region);
514   if (cppRegion->empty())
515     return wrap(static_cast<Block *>(nullptr));
516   return wrap(&cppRegion->front());
517 }
518 
519 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
520   unwrap(region)->push_back(unwrap(block));
521 }
522 
523 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
524                                 MlirBlock block) {
525   auto &blockList = unwrap(region)->getBlocks();
526   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
527 }
528 
529 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
530                                      MlirBlock block) {
531   Region *cppRegion = unwrap(region);
532   if (mlirBlockIsNull(reference)) {
533     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
534     return;
535   }
536 
537   assert(unwrap(reference)->getParent() == unwrap(region) &&
538          "expected reference block to belong to the region");
539   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
540                                      unwrap(block));
541 }
542 
543 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
544                                       MlirBlock block) {
545   if (mlirBlockIsNull(reference))
546     return mlirRegionAppendOwnedBlock(region, block);
547 
548   assert(unwrap(reference)->getParent() == unwrap(region) &&
549          "expected reference block to belong to the region");
550   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
551                                      unwrap(block));
552 }
553 
554 void mlirRegionDestroy(MlirRegion region) {
555   delete static_cast<Region *>(region.ptr);
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // Block API.
560 //===----------------------------------------------------------------------===//
561 
562 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args,
563                           MlirLocation const *locs) {
564   Block *b = new Block;
565   for (intptr_t i = 0; i < nArgs; ++i)
566     b->addArgument(unwrap(args[i]), unwrap(locs[i]));
567   return wrap(b);
568 }
569 
570 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
571   return unwrap(block) == unwrap(other);
572 }
573 
574 MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
575   return wrap(unwrap(block)->getParentOp());
576 }
577 
578 MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
579   return wrap(unwrap(block)->getParent());
580 }
581 
582 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
583   return wrap(unwrap(block)->getNextNode());
584 }
585 
586 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
587   Block *cppBlock = unwrap(block);
588   if (cppBlock->empty())
589     return wrap(static_cast<Operation *>(nullptr));
590   return wrap(&cppBlock->front());
591 }
592 
593 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
594   Block *cppBlock = unwrap(block);
595   if (cppBlock->empty())
596     return wrap(static_cast<Operation *>(nullptr));
597   Operation &back = cppBlock->back();
598   if (!back.hasTrait<OpTrait::IsTerminator>())
599     return wrap(static_cast<Operation *>(nullptr));
600   return wrap(&back);
601 }
602 
603 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
604   unwrap(block)->push_back(unwrap(operation));
605 }
606 
607 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
608                                    MlirOperation operation) {
609   auto &opList = unwrap(block)->getOperations();
610   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
611 }
612 
613 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
614                                         MlirOperation reference,
615                                         MlirOperation operation) {
616   Block *cppBlock = unwrap(block);
617   if (mlirOperationIsNull(reference)) {
618     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
619     return;
620   }
621 
622   assert(unwrap(reference)->getBlock() == unwrap(block) &&
623          "expected reference operation to belong to the block");
624   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
625                                         unwrap(operation));
626 }
627 
628 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
629                                          MlirOperation reference,
630                                          MlirOperation operation) {
631   if (mlirOperationIsNull(reference))
632     return mlirBlockAppendOwnedOperation(block, operation);
633 
634   assert(unwrap(reference)->getBlock() == unwrap(block) &&
635          "expected reference operation to belong to the block");
636   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
637                                         unwrap(operation));
638 }
639 
640 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
641 
642 void mlirBlockDetach(MlirBlock block) {
643   Block *b = unwrap(block);
644   b->getParent()->getBlocks().remove(b);
645 }
646 
647 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
648   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
649 }
650 
651 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
652                                MlirLocation loc) {
653   return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc)));
654 }
655 
656 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
657   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
658 }
659 
660 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
661                     void *userData) {
662   detail::CallbackOstream stream(callback, userData);
663   unwrap(block)->print(stream);
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // Value API.
668 //===----------------------------------------------------------------------===//
669 
670 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
671   return unwrap(value1) == unwrap(value2);
672 }
673 
674 bool mlirValueIsABlockArgument(MlirValue value) {
675   return unwrap(value).isa<BlockArgument>();
676 }
677 
678 bool mlirValueIsAOpResult(MlirValue value) {
679   return unwrap(value).isa<OpResult>();
680 }
681 
682 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
683   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
684 }
685 
686 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
687   return static_cast<intptr_t>(
688       unwrap(value).cast<BlockArgument>().getArgNumber());
689 }
690 
691 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
692   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
693 }
694 
695 MlirOperation mlirOpResultGetOwner(MlirValue value) {
696   return wrap(unwrap(value).cast<OpResult>().getOwner());
697 }
698 
699 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
700   return static_cast<intptr_t>(
701       unwrap(value).cast<OpResult>().getResultNumber());
702 }
703 
704 MlirType mlirValueGetType(MlirValue value) {
705   return wrap(unwrap(value).getType());
706 }
707 
708 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
709 
710 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
711                     void *userData) {
712   detail::CallbackOstream stream(callback, userData);
713   unwrap(value).print(stream);
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // Type API.
718 //===----------------------------------------------------------------------===//
719 
720 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
721   return wrap(mlir::parseType(unwrap(type), unwrap(context)));
722 }
723 
724 MlirContext mlirTypeGetContext(MlirType type) {
725   return wrap(unwrap(type).getContext());
726 }
727 
728 MlirTypeID mlirTypeGetTypeID(MlirType type) {
729   return wrap(unwrap(type).getTypeID());
730 }
731 
732 bool mlirTypeEqual(MlirType t1, MlirType t2) {
733   return unwrap(t1) == unwrap(t2);
734 }
735 
736 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
737   detail::CallbackOstream stream(callback, userData);
738   unwrap(type).print(stream);
739 }
740 
741 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
742 
743 //===----------------------------------------------------------------------===//
744 // Attribute API.
745 //===----------------------------------------------------------------------===//
746 
747 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
748   return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
749 }
750 
751 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
752   return wrap(unwrap(attribute).getContext());
753 }
754 
755 MlirType mlirAttributeGetType(MlirAttribute attribute) {
756   return wrap(unwrap(attribute).getType());
757 }
758 
759 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
760   return wrap(unwrap(attr).getTypeID());
761 }
762 
763 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
764   return unwrap(a1) == unwrap(a2);
765 }
766 
767 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
768                         void *userData) {
769   detail::CallbackOstream stream(callback, userData);
770   unwrap(attr).print(stream);
771 }
772 
773 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
774 
775 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
776                                          MlirAttribute attr) {
777   return MlirNamedAttribute{name, attr};
778 }
779 
780 //===----------------------------------------------------------------------===//
781 // Identifier API.
782 //===----------------------------------------------------------------------===//
783 
784 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
785   return wrap(StringAttr::get(unwrap(context), unwrap(str)));
786 }
787 
788 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
789   return wrap(unwrap(ident).getContext());
790 }
791 
792 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
793   return unwrap(ident) == unwrap(other);
794 }
795 
796 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
797   return wrap(unwrap(ident).strref());
798 }
799 
800 //===----------------------------------------------------------------------===//
801 // Symbol and SymbolTable API.
802 //===----------------------------------------------------------------------===//
803 
804 MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
805   return wrap(SymbolTable::getSymbolAttrName());
806 }
807 
808 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
809   return wrap(SymbolTable::getVisibilityAttrName());
810 }
811 
812 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
813   if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
814     return wrap(static_cast<SymbolTable *>(nullptr));
815   return wrap(new SymbolTable(unwrap(operation)));
816 }
817 
818 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
819   delete unwrap(symbolTable);
820 }
821 
822 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
823                                     MlirStringRef name) {
824   return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
825 }
826 
827 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
828                                     MlirOperation operation) {
829   return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
830 }
831 
832 void mlirSymbolTableErase(MlirSymbolTable symbolTable,
833                           MlirOperation operation) {
834   unwrap(symbolTable)->erase(unwrap(operation));
835 }
836 
837 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
838                                                       MlirStringRef newSymbol,
839                                                       MlirOperation from) {
840   auto *cppFrom = unwrap(from);
841   auto *context = cppFrom->getContext();
842   auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol));
843   auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol));
844   return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
845                                                 unwrap(from)));
846 }
847 
848 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
849                                      void (*callback)(MlirOperation, bool,
850                                                       void *userData),
851                                      void *userData) {
852   SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible,
853                                 [&](Operation *foundOpCpp, bool isVisible) {
854                                   callback(wrap(foundOpCpp), isVisible,
855                                            userData);
856                                 });
857 }
858