xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision e1f613ce)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/Parser.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Context API.
26 //===----------------------------------------------------------------------===//
27 
28 MlirContext mlirContextCreate() {
29   auto *context = new MLIRContext;
30   return wrap(context);
31 }
32 
33 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
34   return unwrap(ctx1) == unwrap(ctx2);
35 }
36 
37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
38 
39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
40   unwrap(context)->allowUnregisteredDialects(allow);
41 }
42 
43 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
44   return unwrap(context)->allowsUnregisteredDialects();
45 }
46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
47   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
48 }
49 
50 // TODO: expose a cheaper way than constructing + sorting a vector only to take
51 // its size.
52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
53   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
54 }
55 
56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
57                                         MlirStringRef name) {
58   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // Dialect API.
63 //===----------------------------------------------------------------------===//
64 
65 MlirContext mlirDialectGetContext(MlirDialect dialect) {
66   return wrap(unwrap(dialect)->getContext());
67 }
68 
69 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
70   return unwrap(dialect1) == unwrap(dialect2);
71 }
72 
73 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
74   return wrap(unwrap(dialect)->getNamespace());
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Printing flags API.
79 //===----------------------------------------------------------------------===//
80 
81 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
82   return wrap(new OpPrintingFlags());
83 }
84 
85 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
86   delete unwrap(flags);
87 }
88 
89 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
90                                                 intptr_t largeElementLimit) {
91   unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
92 }
93 
94 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
95                                         bool prettyForm) {
96   unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
97 }
98 
99 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
100   unwrap(flags)->printGenericOpForm();
101 }
102 
103 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
104   unwrap(flags)->useLocalScope();
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // Location API.
109 //===----------------------------------------------------------------------===//
110 
111 MlirLocation mlirLocationFileLineColGet(MlirContext context,
112                                         MlirStringRef filename, unsigned line,
113                                         unsigned col) {
114   return wrap(
115       FileLineColLoc::get(unwrap(filename), line, col, unwrap(context)));
116 }
117 
118 MlirLocation mlirLocationUnknownGet(MlirContext context) {
119   return wrap(UnknownLoc::get(unwrap(context)));
120 }
121 
122 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
123   return unwrap(l1) == unwrap(l2);
124 }
125 
126 MlirContext mlirLocationGetContext(MlirLocation location) {
127   return wrap(unwrap(location).getContext());
128 }
129 
130 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
131                        void *userData) {
132   detail::CallbackOstream stream(callback, userData);
133   unwrap(location).print(stream);
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Module API.
138 //===----------------------------------------------------------------------===//
139 
140 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
141   return wrap(ModuleOp::create(unwrap(location)));
142 }
143 
144 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
145   OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context));
146   if (!owning)
147     return MlirModule{nullptr};
148   return MlirModule{owning.release().getOperation()};
149 }
150 
151 MlirContext mlirModuleGetContext(MlirModule module) {
152   return wrap(unwrap(module).getContext());
153 }
154 
155 MlirBlock mlirModuleGetBody(MlirModule module) {
156   return wrap(unwrap(module).getBody());
157 }
158 
159 void mlirModuleDestroy(MlirModule module) {
160   // Transfer ownership to an OwningModuleRef so that its destructor is called.
161   OwningModuleRef(unwrap(module));
162 }
163 
164 MlirOperation mlirModuleGetOperation(MlirModule module) {
165   return wrap(unwrap(module).getOperation());
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // Operation state API.
170 //===----------------------------------------------------------------------===//
171 
172 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
173   MlirOperationState state;
174   state.name = name;
175   state.location = loc;
176   state.nResults = 0;
177   state.results = nullptr;
178   state.nOperands = 0;
179   state.operands = nullptr;
180   state.nRegions = 0;
181   state.regions = nullptr;
182   state.nSuccessors = 0;
183   state.successors = nullptr;
184   state.nAttributes = 0;
185   state.attributes = nullptr;
186   return state;
187 }
188 
189 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
190   state->elemName =                                                            \
191       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
192   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
193   state->sizeName += n;
194 
195 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
196                                   MlirType const *results) {
197   APPEND_ELEMS(MlirType, nResults, results);
198 }
199 
200 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
201                                    MlirValue const *operands) {
202   APPEND_ELEMS(MlirValue, nOperands, operands);
203 }
204 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
205                                        MlirRegion const *regions) {
206   APPEND_ELEMS(MlirRegion, nRegions, regions);
207 }
208 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
209                                      MlirBlock const *successors) {
210   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
211 }
212 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
213                                      MlirNamedAttribute const *attributes) {
214   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // Operation API.
219 //===----------------------------------------------------------------------===//
220 
221 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
222   assert(state);
223   OperationState cppState(unwrap(state->location), unwrap(state->name));
224   SmallVector<Type, 4> resultStorage;
225   SmallVector<Value, 8> operandStorage;
226   SmallVector<Block *, 2> successorStorage;
227   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
228   cppState.addOperands(
229       unwrapList(state->nOperands, state->operands, operandStorage));
230   cppState.addSuccessors(
231       unwrapList(state->nSuccessors, state->successors, successorStorage));
232 
233   cppState.attributes.reserve(state->nAttributes);
234   for (intptr_t i = 0; i < state->nAttributes; ++i)
235     cppState.addAttribute(unwrap(state->attributes[i].name),
236                           unwrap(state->attributes[i].attribute));
237 
238   for (intptr_t i = 0; i < state->nRegions; ++i)
239     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
240 
241   MlirOperation result = wrap(Operation::create(cppState));
242   free(state->results);
243   free(state->operands);
244   free(state->successors);
245   free(state->regions);
246   free(state->attributes);
247   return result;
248 }
249 
250 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
251 
252 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
253   return unwrap(op) == unwrap(other);
254 }
255 
256 MlirIdentifier mlirOperationGetName(MlirOperation op) {
257   return wrap(unwrap(op)->getName().getIdentifier());
258 }
259 
260 MlirBlock mlirOperationGetBlock(MlirOperation op) {
261   return wrap(unwrap(op)->getBlock());
262 }
263 
264 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
265   return wrap(unwrap(op)->getParentOp());
266 }
267 
268 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
269   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
270 }
271 
272 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
273   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
274 }
275 
276 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
277   return wrap(unwrap(op)->getNextNode());
278 }
279 
280 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
281   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
282 }
283 
284 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
285   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
286 }
287 
288 intptr_t mlirOperationGetNumResults(MlirOperation op) {
289   return static_cast<intptr_t>(unwrap(op)->getNumResults());
290 }
291 
292 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
293   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
294 }
295 
296 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
297   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
298 }
299 
300 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
301   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
302 }
303 
304 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
305   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
306 }
307 
308 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
309   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
310   return MlirNamedAttribute{wrap(attr.first.strref()), wrap(attr.second)};
311 }
312 
313 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
314                                               MlirStringRef name) {
315   return wrap(unwrap(op)->getAttr(unwrap(name)));
316 }
317 
318 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
319                                      MlirAttribute attr) {
320   unwrap(op)->setAttr(unwrap(name), unwrap(attr));
321 }
322 
323 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
324   auto removeResult = unwrap(op)->removeAttr(unwrap(name));
325   return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
326 }
327 
328 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
329                         void *userData) {
330   detail::CallbackOstream stream(callback, userData);
331   unwrap(op)->print(stream);
332 }
333 
334 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
335                                  MlirStringCallback callback, void *userData) {
336   detail::CallbackOstream stream(callback, userData);
337   unwrap(op)->print(stream, *unwrap(flags));
338 }
339 
340 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
341 
342 //===----------------------------------------------------------------------===//
343 // Region API.
344 //===----------------------------------------------------------------------===//
345 
346 MlirRegion mlirRegionCreate() { return wrap(new Region); }
347 
348 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
349   Region *cppRegion = unwrap(region);
350   if (cppRegion->empty())
351     return wrap(static_cast<Block *>(nullptr));
352   return wrap(&cppRegion->front());
353 }
354 
355 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
356   unwrap(region)->push_back(unwrap(block));
357 }
358 
359 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
360                                 MlirBlock block) {
361   auto &blockList = unwrap(region)->getBlocks();
362   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
363 }
364 
365 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
366                                      MlirBlock block) {
367   Region *cppRegion = unwrap(region);
368   if (mlirBlockIsNull(reference)) {
369     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
370     return;
371   }
372 
373   assert(unwrap(reference)->getParent() == unwrap(region) &&
374          "expected reference block to belong to the region");
375   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
376                                      unwrap(block));
377 }
378 
379 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
380                                       MlirBlock block) {
381   if (mlirBlockIsNull(reference))
382     return mlirRegionAppendOwnedBlock(region, block);
383 
384   assert(unwrap(reference)->getParent() == unwrap(region) &&
385          "expected reference block to belong to the region");
386   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
387                                      unwrap(block));
388 }
389 
390 void mlirRegionDestroy(MlirRegion region) {
391   delete static_cast<Region *>(region.ptr);
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // Block API.
396 //===----------------------------------------------------------------------===//
397 
398 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) {
399   Block *b = new Block;
400   for (intptr_t i = 0; i < nArgs; ++i)
401     b->addArgument(unwrap(args[i]));
402   return wrap(b);
403 }
404 
405 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
406   return unwrap(block) == unwrap(other);
407 }
408 
409 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
410   return wrap(unwrap(block)->getNextNode());
411 }
412 
413 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
414   Block *cppBlock = unwrap(block);
415   if (cppBlock->empty())
416     return wrap(static_cast<Operation *>(nullptr));
417   return wrap(&cppBlock->front());
418 }
419 
420 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
421   Block *cppBlock = unwrap(block);
422   if (cppBlock->empty())
423     return wrap(static_cast<Operation *>(nullptr));
424   Operation &back = cppBlock->back();
425   if (!back.isKnownTerminator())
426     return wrap(static_cast<Operation *>(nullptr));
427   return wrap(&back);
428 }
429 
430 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
431   unwrap(block)->push_back(unwrap(operation));
432 }
433 
434 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
435                                    MlirOperation operation) {
436   auto &opList = unwrap(block)->getOperations();
437   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
438 }
439 
440 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
441                                         MlirOperation reference,
442                                         MlirOperation operation) {
443   Block *cppBlock = unwrap(block);
444   if (mlirOperationIsNull(reference)) {
445     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
446     return;
447   }
448 
449   assert(unwrap(reference)->getBlock() == unwrap(block) &&
450          "expected reference operation to belong to the block");
451   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
452                                         unwrap(operation));
453 }
454 
455 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
456                                          MlirOperation reference,
457                                          MlirOperation operation) {
458   if (mlirOperationIsNull(reference))
459     return mlirBlockAppendOwnedOperation(block, operation);
460 
461   assert(unwrap(reference)->getBlock() == unwrap(block) &&
462          "expected reference operation to belong to the block");
463   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
464                                         unwrap(operation));
465 }
466 
467 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
468 
469 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
470   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
471 }
472 
473 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
474   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
475 }
476 
477 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
478                     void *userData) {
479   detail::CallbackOstream stream(callback, userData);
480   unwrap(block)->print(stream);
481 }
482 
483 //===----------------------------------------------------------------------===//
484 // Value API.
485 //===----------------------------------------------------------------------===//
486 
487 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
488   return unwrap(value1) == unwrap(value2);
489 }
490 
491 bool mlirValueIsABlockArgument(MlirValue value) {
492   return unwrap(value).isa<BlockArgument>();
493 }
494 
495 bool mlirValueIsAOpResult(MlirValue value) {
496   return unwrap(value).isa<OpResult>();
497 }
498 
499 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
500   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
501 }
502 
503 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
504   return static_cast<intptr_t>(
505       unwrap(value).cast<BlockArgument>().getArgNumber());
506 }
507 
508 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
509   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
510 }
511 
512 MlirOperation mlirOpResultGetOwner(MlirValue value) {
513   return wrap(unwrap(value).cast<OpResult>().getOwner());
514 }
515 
516 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
517   return static_cast<intptr_t>(
518       unwrap(value).cast<OpResult>().getResultNumber());
519 }
520 
521 MlirType mlirValueGetType(MlirValue value) {
522   return wrap(unwrap(value).getType());
523 }
524 
525 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
526 
527 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
528                     void *userData) {
529   detail::CallbackOstream stream(callback, userData);
530   unwrap(value).print(stream);
531 }
532 
533 //===----------------------------------------------------------------------===//
534 // Type API.
535 //===----------------------------------------------------------------------===//
536 
537 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
538   return wrap(mlir::parseType(unwrap(type), unwrap(context)));
539 }
540 
541 MlirContext mlirTypeGetContext(MlirType type) {
542   return wrap(unwrap(type).getContext());
543 }
544 
545 bool mlirTypeEqual(MlirType t1, MlirType t2) {
546   return unwrap(t1) == unwrap(t2);
547 }
548 
549 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
550   detail::CallbackOstream stream(callback, userData);
551   unwrap(type).print(stream);
552 }
553 
554 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
555 
556 //===----------------------------------------------------------------------===//
557 // Attribute API.
558 //===----------------------------------------------------------------------===//
559 
560 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
561   return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
562 }
563 
564 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
565   return wrap(unwrap(attribute).getContext());
566 }
567 
568 MlirType mlirAttributeGetType(MlirAttribute attribute) {
569   return wrap(unwrap(attribute).getType());
570 }
571 
572 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
573   return unwrap(a1) == unwrap(a2);
574 }
575 
576 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
577                         void *userData) {
578   detail::CallbackOstream stream(callback, userData);
579   unwrap(attr).print(stream);
580 }
581 
582 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
583 
584 MlirNamedAttribute mlirNamedAttributeGet(MlirStringRef name,
585                                          MlirAttribute attr) {
586   return MlirNamedAttribute{name, attr};
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // Identifier API.
591 //===----------------------------------------------------------------------===//
592 
593 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
594   return wrap(Identifier::get(unwrap(str), unwrap(context)));
595 }
596 
597 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
598   return unwrap(ident) == unwrap(other);
599 }
600 
601 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
602   return wrap(unwrap(ident).strref());
603 }
604