xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision beebffa9)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/Types.h"
21 #include "mlir/IR/Verifier.h"
22 #include "mlir/Interfaces/InferTypeOpInterface.h"
23 #include "mlir/Parser/Parser.h"
24 
25 #include "llvm/Support/Debug.h"
26 #include <cstddef>
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Context API.
32 //===----------------------------------------------------------------------===//
33 
34 MlirContext mlirContextCreate() {
35   auto *context = new MLIRContext;
36   return wrap(context);
37 }
38 
39 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
40   return unwrap(ctx1) == unwrap(ctx2);
41 }
42 
43 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
44 
45 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
46   unwrap(context)->allowUnregisteredDialects(allow);
47 }
48 
49 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
50   return unwrap(context)->allowsUnregisteredDialects();
51 }
52 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
53   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
54 }
55 
56 void mlirContextAppendDialectRegistry(MlirContext ctx,
57                                       MlirDialectRegistry registry) {
58   unwrap(ctx)->appendDialectRegistry(*unwrap(registry));
59 }
60 
61 // TODO: expose a cheaper way than constructing + sorting a vector only to take
62 // its size.
63 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
64   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
65 }
66 
67 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
68                                         MlirStringRef name) {
69   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
70 }
71 
72 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
73   return unwrap(context)->isOperationRegistered(unwrap(name));
74 }
75 
76 void mlirContextEnableMultithreading(MlirContext context, bool enable) {
77   return unwrap(context)->enableMultithreading(enable);
78 }
79 
80 void mlirContextLoadAllAvailableDialects(MlirContext context) {
81   unwrap(context)->loadAllAvailableDialects();
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // Dialect API.
86 //===----------------------------------------------------------------------===//
87 
88 MlirContext mlirDialectGetContext(MlirDialect dialect) {
89   return wrap(unwrap(dialect)->getContext());
90 }
91 
92 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
93   return unwrap(dialect1) == unwrap(dialect2);
94 }
95 
96 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
97   return wrap(unwrap(dialect)->getNamespace());
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // DialectRegistry API.
102 //===----------------------------------------------------------------------===//
103 
104 MlirDialectRegistry mlirDialectRegistryCreate() {
105   return wrap(new DialectRegistry());
106 }
107 
108 void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
109   delete unwrap(registry);
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Printing flags API.
114 //===----------------------------------------------------------------------===//
115 
116 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
117   return wrap(new OpPrintingFlags());
118 }
119 
120 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
121   delete unwrap(flags);
122 }
123 
124 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
125                                                 intptr_t largeElementLimit) {
126   unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
127 }
128 
129 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
130                                         bool prettyForm) {
131   unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
132 }
133 
134 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
135   unwrap(flags)->printGenericOpForm();
136 }
137 
138 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
139   unwrap(flags)->useLocalScope();
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // Location API.
144 //===----------------------------------------------------------------------===//
145 
146 MlirLocation mlirLocationFileLineColGet(MlirContext context,
147                                         MlirStringRef filename, unsigned line,
148                                         unsigned col) {
149   return wrap(Location(
150       FileLineColLoc::get(unwrap(context), unwrap(filename), line, col)));
151 }
152 
153 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
154   return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
155 }
156 
157 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
158                                   MlirLocation const *locations,
159                                   MlirAttribute metadata) {
160   SmallVector<Location, 4> locs;
161   ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs);
162   return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
163 }
164 
165 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
166                                  MlirLocation childLoc) {
167   if (mlirLocationIsNull(childLoc))
168     return wrap(
169         Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name)))));
170   return wrap(Location(NameLoc::get(
171       StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
172 }
173 
174 MlirLocation mlirLocationUnknownGet(MlirContext context) {
175   return wrap(Location(UnknownLoc::get(unwrap(context))));
176 }
177 
178 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
179   return unwrap(l1) == unwrap(l2);
180 }
181 
182 MlirContext mlirLocationGetContext(MlirLocation location) {
183   return wrap(unwrap(location).getContext());
184 }
185 
186 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
187                        void *userData) {
188   detail::CallbackOstream stream(callback, userData);
189   unwrap(location).print(stream);
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Module API.
194 //===----------------------------------------------------------------------===//
195 
196 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
197   return wrap(ModuleOp::create(unwrap(location)));
198 }
199 
200 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
201   OwningOpRef<ModuleOp> owning =
202       parseSourceString<ModuleOp>(unwrap(module), unwrap(context));
203   if (!owning)
204     return MlirModule{nullptr};
205   return MlirModule{owning.release().getOperation()};
206 }
207 
208 MlirContext mlirModuleGetContext(MlirModule module) {
209   return wrap(unwrap(module).getContext());
210 }
211 
212 MlirBlock mlirModuleGetBody(MlirModule module) {
213   return wrap(unwrap(module).getBody());
214 }
215 
216 void mlirModuleDestroy(MlirModule module) {
217   // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is
218   // called.
219   OwningOpRef<ModuleOp>(unwrap(module));
220 }
221 
222 MlirOperation mlirModuleGetOperation(MlirModule module) {
223   return wrap(unwrap(module).getOperation());
224 }
225 
226 MlirModule mlirModuleFromOperation(MlirOperation op) {
227   return wrap(dyn_cast<ModuleOp>(unwrap(op)));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Operation state API.
232 //===----------------------------------------------------------------------===//
233 
234 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
235   MlirOperationState state;
236   state.name = name;
237   state.location = loc;
238   state.nResults = 0;
239   state.results = nullptr;
240   state.nOperands = 0;
241   state.operands = nullptr;
242   state.nRegions = 0;
243   state.regions = nullptr;
244   state.nSuccessors = 0;
245   state.successors = nullptr;
246   state.nAttributes = 0;
247   state.attributes = nullptr;
248   state.enableResultTypeInference = false;
249   return state;
250 }
251 
252 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
253   state->elemName =                                                            \
254       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
255   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
256   state->sizeName += n;
257 
258 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
259                                   MlirType const *results) {
260   APPEND_ELEMS(MlirType, nResults, results);
261 }
262 
263 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
264                                    MlirValue const *operands) {
265   APPEND_ELEMS(MlirValue, nOperands, operands);
266 }
267 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
268                                        MlirRegion const *regions) {
269   APPEND_ELEMS(MlirRegion, nRegions, regions);
270 }
271 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
272                                      MlirBlock const *successors) {
273   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
274 }
275 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
276                                      MlirNamedAttribute const *attributes) {
277   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
278 }
279 
280 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
281   state->enableResultTypeInference = true;
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // Operation API.
286 //===----------------------------------------------------------------------===//
287 
288 static LogicalResult inferOperationTypes(OperationState &state) {
289   MLIRContext *context = state.getContext();
290   Optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
291   if (!info) {
292     emitError(state.location)
293         << "type inference was requested for the operation " << state.name
294         << ", but the operation was not registered. Ensure that the dialect "
295            "containing the operation is linked into MLIR and registered with "
296            "the context";
297     return failure();
298   }
299 
300   // Fallback to inference via an op interface.
301   auto *inferInterface = info->getInterface<InferTypeOpInterface>();
302   if (!inferInterface) {
303     emitError(state.location)
304         << "type inference was requested for the operation " << state.name
305         << ", but the operation does not support type inference. Result "
306            "types must be specified explicitly.";
307     return failure();
308   }
309 
310   if (succeeded(inferInterface->inferReturnTypes(
311           context, state.location, state.operands,
312           state.attributes.getDictionary(context), state.regions, state.types)))
313     return success();
314 
315   // Diagnostic emitted by interface.
316   return failure();
317 }
318 
319 MlirOperation mlirOperationCreate(MlirOperationState *state) {
320   assert(state);
321   OperationState cppState(unwrap(state->location), unwrap(state->name));
322   SmallVector<Type, 4> resultStorage;
323   SmallVector<Value, 8> operandStorage;
324   SmallVector<Block *, 2> successorStorage;
325   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
326   cppState.addOperands(
327       unwrapList(state->nOperands, state->operands, operandStorage));
328   cppState.addSuccessors(
329       unwrapList(state->nSuccessors, state->successors, successorStorage));
330 
331   cppState.attributes.reserve(state->nAttributes);
332   for (intptr_t i = 0; i < state->nAttributes; ++i)
333     cppState.addAttribute(unwrap(state->attributes[i].name),
334                           unwrap(state->attributes[i].attribute));
335 
336   for (intptr_t i = 0; i < state->nRegions; ++i)
337     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
338 
339   free(state->results);
340   free(state->operands);
341   free(state->successors);
342   free(state->regions);
343   free(state->attributes);
344 
345   // Infer result types.
346   if (state->enableResultTypeInference) {
347     assert(cppState.types.empty() &&
348            "result type inference enabled and result types provided");
349     if (failed(inferOperationTypes(cppState)))
350       return {nullptr};
351   }
352 
353   MlirOperation result = wrap(Operation::create(cppState));
354   return result;
355 }
356 
357 MlirOperation mlirOperationClone(MlirOperation op) {
358   return wrap(unwrap(op)->clone());
359 }
360 
361 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
362 
363 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
364 
365 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
366   return unwrap(op) == unwrap(other);
367 }
368 
369 MlirContext mlirOperationGetContext(MlirOperation op) {
370   return wrap(unwrap(op)->getContext());
371 }
372 
373 MlirLocation mlirOperationGetLocation(MlirOperation op) {
374   return wrap(unwrap(op)->getLoc());
375 }
376 
377 MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
378   if (auto info = unwrap(op)->getRegisteredInfo())
379     return wrap(info->getTypeID());
380   return {nullptr};
381 }
382 
383 MlirIdentifier mlirOperationGetName(MlirOperation op) {
384   return wrap(unwrap(op)->getName().getIdentifier());
385 }
386 
387 MlirBlock mlirOperationGetBlock(MlirOperation op) {
388   return wrap(unwrap(op)->getBlock());
389 }
390 
391 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
392   return wrap(unwrap(op)->getParentOp());
393 }
394 
395 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
396   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
397 }
398 
399 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
400   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
401 }
402 
403 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
404   Operation *cppOp = unwrap(op);
405   if (cppOp->getNumRegions() == 0)
406     return wrap(static_cast<Region *>(nullptr));
407   return wrap(&cppOp->getRegion(0));
408 }
409 
410 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
411   Region *cppRegion = unwrap(region);
412   Operation *parent = cppRegion->getParentOp();
413   intptr_t next = cppRegion->getRegionNumber() + 1;
414   if (parent->getNumRegions() > next)
415     return wrap(&parent->getRegion(next));
416   return wrap(static_cast<Region *>(nullptr));
417 }
418 
419 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
420   return wrap(unwrap(op)->getNextNode());
421 }
422 
423 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
424   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
425 }
426 
427 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
428   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
429 }
430 
431 void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
432                              MlirValue newValue) {
433   unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
434 }
435 
436 intptr_t mlirOperationGetNumResults(MlirOperation op) {
437   return static_cast<intptr_t>(unwrap(op)->getNumResults());
438 }
439 
440 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
441   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
442 }
443 
444 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
445   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
446 }
447 
448 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
449   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
450 }
451 
452 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
453   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
454 }
455 
456 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
457   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
458   return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
459 }
460 
461 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
462                                               MlirStringRef name) {
463   return wrap(unwrap(op)->getAttr(unwrap(name)));
464 }
465 
466 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
467                                      MlirAttribute attr) {
468   unwrap(op)->setAttr(unwrap(name), unwrap(attr));
469 }
470 
471 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
472   return !!unwrap(op)->removeAttr(unwrap(name));
473 }
474 
475 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
476                         void *userData) {
477   detail::CallbackOstream stream(callback, userData);
478   unwrap(op)->print(stream);
479 }
480 
481 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
482                                  MlirStringCallback callback, void *userData) {
483   detail::CallbackOstream stream(callback, userData);
484   unwrap(op)->print(stream, *unwrap(flags));
485 }
486 
487 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
488 
489 bool mlirOperationVerify(MlirOperation op) {
490   return succeeded(verify(unwrap(op)));
491 }
492 
493 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
494   return unwrap(op)->moveAfter(unwrap(other));
495 }
496 
497 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
498   return unwrap(op)->moveBefore(unwrap(other));
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // Region API.
503 //===----------------------------------------------------------------------===//
504 
505 MlirRegion mlirRegionCreate() { return wrap(new Region); }
506 
507 bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
508   return unwrap(region) == unwrap(other);
509 }
510 
511 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
512   Region *cppRegion = unwrap(region);
513   if (cppRegion->empty())
514     return wrap(static_cast<Block *>(nullptr));
515   return wrap(&cppRegion->front());
516 }
517 
518 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
519   unwrap(region)->push_back(unwrap(block));
520 }
521 
522 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
523                                 MlirBlock block) {
524   auto &blockList = unwrap(region)->getBlocks();
525   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
526 }
527 
528 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
529                                      MlirBlock block) {
530   Region *cppRegion = unwrap(region);
531   if (mlirBlockIsNull(reference)) {
532     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
533     return;
534   }
535 
536   assert(unwrap(reference)->getParent() == unwrap(region) &&
537          "expected reference block to belong to the region");
538   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
539                                      unwrap(block));
540 }
541 
542 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
543                                       MlirBlock block) {
544   if (mlirBlockIsNull(reference))
545     return mlirRegionAppendOwnedBlock(region, block);
546 
547   assert(unwrap(reference)->getParent() == unwrap(region) &&
548          "expected reference block to belong to the region");
549   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
550                                      unwrap(block));
551 }
552 
553 void mlirRegionDestroy(MlirRegion region) {
554   delete static_cast<Region *>(region.ptr);
555 }
556 
557 //===----------------------------------------------------------------------===//
558 // Block API.
559 //===----------------------------------------------------------------------===//
560 
561 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args,
562                           MlirLocation const *locs) {
563   Block *b = new Block;
564   for (intptr_t i = 0; i < nArgs; ++i)
565     b->addArgument(unwrap(args[i]), unwrap(locs[i]));
566   return wrap(b);
567 }
568 
569 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
570   return unwrap(block) == unwrap(other);
571 }
572 
573 MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
574   return wrap(unwrap(block)->getParentOp());
575 }
576 
577 MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
578   return wrap(unwrap(block)->getParent());
579 }
580 
581 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
582   return wrap(unwrap(block)->getNextNode());
583 }
584 
585 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
586   Block *cppBlock = unwrap(block);
587   if (cppBlock->empty())
588     return wrap(static_cast<Operation *>(nullptr));
589   return wrap(&cppBlock->front());
590 }
591 
592 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
593   Block *cppBlock = unwrap(block);
594   if (cppBlock->empty())
595     return wrap(static_cast<Operation *>(nullptr));
596   Operation &back = cppBlock->back();
597   if (!back.hasTrait<OpTrait::IsTerminator>())
598     return wrap(static_cast<Operation *>(nullptr));
599   return wrap(&back);
600 }
601 
602 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
603   unwrap(block)->push_back(unwrap(operation));
604 }
605 
606 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
607                                    MlirOperation operation) {
608   auto &opList = unwrap(block)->getOperations();
609   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
610 }
611 
612 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
613                                         MlirOperation reference,
614                                         MlirOperation operation) {
615   Block *cppBlock = unwrap(block);
616   if (mlirOperationIsNull(reference)) {
617     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
618     return;
619   }
620 
621   assert(unwrap(reference)->getBlock() == unwrap(block) &&
622          "expected reference operation to belong to the block");
623   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
624                                         unwrap(operation));
625 }
626 
627 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
628                                          MlirOperation reference,
629                                          MlirOperation operation) {
630   if (mlirOperationIsNull(reference))
631     return mlirBlockAppendOwnedOperation(block, operation);
632 
633   assert(unwrap(reference)->getBlock() == unwrap(block) &&
634          "expected reference operation to belong to the block");
635   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
636                                         unwrap(operation));
637 }
638 
639 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
640 
641 void mlirBlockDetach(MlirBlock block) {
642   Block *b = unwrap(block);
643   b->getParent()->getBlocks().remove(b);
644 }
645 
646 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
647   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
648 }
649 
650 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type,
651                                MlirLocation loc) {
652   return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc)));
653 }
654 
655 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
656   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
657 }
658 
659 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
660                     void *userData) {
661   detail::CallbackOstream stream(callback, userData);
662   unwrap(block)->print(stream);
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // Value API.
667 //===----------------------------------------------------------------------===//
668 
669 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
670   return unwrap(value1) == unwrap(value2);
671 }
672 
673 bool mlirValueIsABlockArgument(MlirValue value) {
674   return unwrap(value).isa<BlockArgument>();
675 }
676 
677 bool mlirValueIsAOpResult(MlirValue value) {
678   return unwrap(value).isa<OpResult>();
679 }
680 
681 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
682   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
683 }
684 
685 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
686   return static_cast<intptr_t>(
687       unwrap(value).cast<BlockArgument>().getArgNumber());
688 }
689 
690 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
691   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
692 }
693 
694 MlirOperation mlirOpResultGetOwner(MlirValue value) {
695   return wrap(unwrap(value).cast<OpResult>().getOwner());
696 }
697 
698 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
699   return static_cast<intptr_t>(
700       unwrap(value).cast<OpResult>().getResultNumber());
701 }
702 
703 MlirType mlirValueGetType(MlirValue value) {
704   return wrap(unwrap(value).getType());
705 }
706 
707 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
708 
709 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
710                     void *userData) {
711   detail::CallbackOstream stream(callback, userData);
712   unwrap(value).print(stream);
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Type API.
717 //===----------------------------------------------------------------------===//
718 
719 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
720   return wrap(mlir::parseType(unwrap(type), unwrap(context)));
721 }
722 
723 MlirContext mlirTypeGetContext(MlirType type) {
724   return wrap(unwrap(type).getContext());
725 }
726 
727 MlirTypeID mlirTypeGetTypeID(MlirType type) {
728   return wrap(unwrap(type).getTypeID());
729 }
730 
731 bool mlirTypeEqual(MlirType t1, MlirType t2) {
732   return unwrap(t1) == unwrap(t2);
733 }
734 
735 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
736   detail::CallbackOstream stream(callback, userData);
737   unwrap(type).print(stream);
738 }
739 
740 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
741 
742 //===----------------------------------------------------------------------===//
743 // Attribute API.
744 //===----------------------------------------------------------------------===//
745 
746 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
747   return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
748 }
749 
750 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
751   return wrap(unwrap(attribute).getContext());
752 }
753 
754 MlirType mlirAttributeGetType(MlirAttribute attribute) {
755   return wrap(unwrap(attribute).getType());
756 }
757 
758 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
759   return wrap(unwrap(attr).getTypeID());
760 }
761 
762 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
763   return unwrap(a1) == unwrap(a2);
764 }
765 
766 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
767                         void *userData) {
768   detail::CallbackOstream stream(callback, userData);
769   unwrap(attr).print(stream);
770 }
771 
772 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
773 
774 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
775                                          MlirAttribute attr) {
776   return MlirNamedAttribute{name, attr};
777 }
778 
779 //===----------------------------------------------------------------------===//
780 // Identifier API.
781 //===----------------------------------------------------------------------===//
782 
783 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
784   return wrap(StringAttr::get(unwrap(context), unwrap(str)));
785 }
786 
787 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
788   return wrap(unwrap(ident).getContext());
789 }
790 
791 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
792   return unwrap(ident) == unwrap(other);
793 }
794 
795 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
796   return wrap(unwrap(ident).strref());
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // Symbol and SymbolTable API.
801 //===----------------------------------------------------------------------===//
802 
803 MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
804   return wrap(SymbolTable::getSymbolAttrName());
805 }
806 
807 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
808   return wrap(SymbolTable::getVisibilityAttrName());
809 }
810 
811 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
812   if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
813     return wrap(static_cast<SymbolTable *>(nullptr));
814   return wrap(new SymbolTable(unwrap(operation)));
815 }
816 
817 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
818   delete unwrap(symbolTable);
819 }
820 
821 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
822                                     MlirStringRef name) {
823   return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
824 }
825 
826 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
827                                     MlirOperation operation) {
828   return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
829 }
830 
831 void mlirSymbolTableErase(MlirSymbolTable symbolTable,
832                           MlirOperation operation) {
833   unwrap(symbolTable)->erase(unwrap(operation));
834 }
835 
836 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
837                                                       MlirStringRef newSymbol,
838                                                       MlirOperation from) {
839   auto *cppFrom = unwrap(from);
840   auto *context = cppFrom->getContext();
841   auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol));
842   auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol));
843   return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
844                                                 unwrap(from)));
845 }
846 
847 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
848                                      void (*callback)(MlirOperation, bool,
849                                                       void *userData),
850                                      void *userData) {
851   SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible,
852                                 [&](Operation *foundOpCpp, bool isVisible) {
853                                   callback(wrap(foundOpCpp), isVisible,
854                                            userData);
855                                 });
856 }
857