xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision 37d4d3bb)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Module.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/Parser.h"
21 
22 using namespace mlir;
23 
24 /* ========================================================================== */
25 /* Context API.                                                               */
26 /* ========================================================================== */
27 
28 MlirContext mlirContextCreate() {
29   auto *context = new MLIRContext;
30   return wrap(context);
31 }
32 
33 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
34   return unwrap(ctx1) == unwrap(ctx2);
35 }
36 
37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
38 
39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
40   unwrap(context)->allowUnregisteredDialects(allow);
41 }
42 
43 int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
44   return unwrap(context)->allowsUnregisteredDialects();
45 }
46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
47   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
48 }
49 
50 // TODO: expose a cheaper way than constructing + sorting a vector only to take
51 // its size.
52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
53   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
54 }
55 
56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
57                                         MlirStringRef name) {
58   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
59 }
60 
61 /* ========================================================================== */
62 /* Dialect API.                                                               */
63 /* ========================================================================== */
64 
65 MlirContext mlirDialectGetContext(MlirDialect dialect) {
66   return wrap(unwrap(dialect)->getContext());
67 }
68 
69 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
70   return unwrap(dialect1) == unwrap(dialect2);
71 }
72 
73 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
74   return wrap(unwrap(dialect)->getNamespace());
75 }
76 
77 /* ========================================================================== */
78 /* Printing flags API.                                                        */
79 /* ========================================================================== */
80 
81 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
82   return wrap(new OpPrintingFlags());
83 }
84 
85 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
86   delete unwrap(flags);
87 }
88 
89 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
90                                                 intptr_t largeElementLimit) {
91   unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
92 }
93 
94 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
95                                         int prettyForm) {
96   unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
97 }
98 
99 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
100   unwrap(flags)->printGenericOpForm();
101 }
102 
103 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
104   unwrap(flags)->useLocalScope();
105 }
106 
107 /* ========================================================================== */
108 /* Location API.                                                              */
109 /* ========================================================================== */
110 
111 MlirLocation mlirLocationFileLineColGet(MlirContext context,
112                                         const char *filename, unsigned line,
113                                         unsigned col) {
114   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
115 }
116 
117 MlirLocation mlirLocationUnknownGet(MlirContext context) {
118   return wrap(UnknownLoc::get(unwrap(context)));
119 }
120 
121 MlirContext mlirLocationGetContext(MlirLocation location) {
122   return wrap(unwrap(location).getContext());
123 }
124 
125 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
126                        void *userData) {
127   detail::CallbackOstream stream(callback, userData);
128   unwrap(location).print(stream);
129   stream.flush();
130 }
131 
132 /* ========================================================================== */
133 /* Module API.                                                                */
134 /* ========================================================================== */
135 
136 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
137   return wrap(ModuleOp::create(unwrap(location)));
138 }
139 
140 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
141   OwningModuleRef owning = parseSourceString(module, unwrap(context));
142   if (!owning)
143     return MlirModule{nullptr};
144   return MlirModule{owning.release().getOperation()};
145 }
146 
147 MlirContext mlirModuleGetContext(MlirModule module) {
148   return wrap(unwrap(module).getContext());
149 }
150 
151 void mlirModuleDestroy(MlirModule module) {
152   // Transfer ownership to an OwningModuleRef so that its destructor is called.
153   OwningModuleRef(unwrap(module));
154 }
155 
156 MlirOperation mlirModuleGetOperation(MlirModule module) {
157   return wrap(unwrap(module).getOperation());
158 }
159 
160 /* ========================================================================== */
161 /* Operation state API.                                                       */
162 /* ========================================================================== */
163 
164 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
165   MlirOperationState state;
166   state.name = name;
167   state.location = loc;
168   state.nResults = 0;
169   state.results = nullptr;
170   state.nOperands = 0;
171   state.operands = nullptr;
172   state.nRegions = 0;
173   state.regions = nullptr;
174   state.nSuccessors = 0;
175   state.successors = nullptr;
176   state.nAttributes = 0;
177   state.attributes = nullptr;
178   return state;
179 }
180 
181 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
182   state->elemName =                                                            \
183       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
184   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
185   state->sizeName += n;
186 
187 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
188                                   MlirType *results) {
189   APPEND_ELEMS(MlirType, nResults, results);
190 }
191 
192 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
193                                    MlirValue *operands) {
194   APPEND_ELEMS(MlirValue, nOperands, operands);
195 }
196 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
197                                        MlirRegion *regions) {
198   APPEND_ELEMS(MlirRegion, nRegions, regions);
199 }
200 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
201                                      MlirBlock *successors) {
202   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
203 }
204 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
205                                      MlirNamedAttribute *attributes) {
206   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
207 }
208 
209 /* ========================================================================== */
210 /* Operation API.                                                             */
211 /* ========================================================================== */
212 
213 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
214   assert(state);
215   OperationState cppState(unwrap(state->location), state->name);
216   SmallVector<Type, 4> resultStorage;
217   SmallVector<Value, 8> operandStorage;
218   SmallVector<Block *, 2> successorStorage;
219   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
220   cppState.addOperands(
221       unwrapList(state->nOperands, state->operands, operandStorage));
222   cppState.addSuccessors(
223       unwrapList(state->nSuccessors, state->successors, successorStorage));
224 
225   cppState.attributes.reserve(state->nAttributes);
226   for (intptr_t i = 0; i < state->nAttributes; ++i)
227     cppState.addAttribute(state->attributes[i].name,
228                           unwrap(state->attributes[i].attribute));
229 
230   for (intptr_t i = 0; i < state->nRegions; ++i)
231     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
232 
233   MlirOperation result = wrap(Operation::create(cppState));
234   free(state->results);
235   free(state->operands);
236   free(state->successors);
237   free(state->regions);
238   free(state->attributes);
239   return result;
240 }
241 
242 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
243 
244 int mlirOperationEqual(MlirOperation op, MlirOperation other) {
245   return unwrap(op) == unwrap(other);
246 }
247 
248 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
249   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
250 }
251 
252 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
253   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
254 }
255 
256 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
257   return wrap(unwrap(op)->getNextNode());
258 }
259 
260 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
261   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
262 }
263 
264 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
265   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
266 }
267 
268 intptr_t mlirOperationGetNumResults(MlirOperation op) {
269   return static_cast<intptr_t>(unwrap(op)->getNumResults());
270 }
271 
272 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
273   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
274 }
275 
276 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
277   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
278 }
279 
280 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
281   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
282 }
283 
284 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
285   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
286 }
287 
288 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
289   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
290   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
291 }
292 
293 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
294                                               const char *name) {
295   return wrap(unwrap(op)->getAttr(name));
296 }
297 
298 void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
299                                      MlirAttribute attr) {
300   unwrap(op)->setAttr(name, unwrap(attr));
301 }
302 
303 int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) {
304   auto removeResult = unwrap(op)->removeAttr(name);
305   return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
306 }
307 
308 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
309                         void *userData) {
310   detail::CallbackOstream stream(callback, userData);
311   unwrap(op)->print(stream);
312   stream.flush();
313 }
314 
315 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
316                                  MlirStringCallback callback, void *userData) {
317   detail::CallbackOstream stream(callback, userData);
318   unwrap(op)->print(stream, *unwrap(flags));
319   stream.flush();
320 }
321 
322 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
323 
324 /* ========================================================================== */
325 /* Region API.                                                                */
326 /* ========================================================================== */
327 
328 MlirRegion mlirRegionCreate() { return wrap(new Region); }
329 
330 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
331   Region *cppRegion = unwrap(region);
332   if (cppRegion->empty())
333     return wrap(static_cast<Block *>(nullptr));
334   return wrap(&cppRegion->front());
335 }
336 
337 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
338   unwrap(region)->push_back(unwrap(block));
339 }
340 
341 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
342                                 MlirBlock block) {
343   auto &blockList = unwrap(region)->getBlocks();
344   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
345 }
346 
347 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
348                                      MlirBlock block) {
349   Region *cppRegion = unwrap(region);
350   if (mlirBlockIsNull(reference)) {
351     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
352     return;
353   }
354 
355   assert(unwrap(reference)->getParent() == unwrap(region) &&
356          "expected reference block to belong to the region");
357   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
358                                      unwrap(block));
359 }
360 
361 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
362                                       MlirBlock block) {
363   if (mlirBlockIsNull(reference))
364     return mlirRegionAppendOwnedBlock(region, block);
365 
366   assert(unwrap(reference)->getParent() == unwrap(region) &&
367          "expected reference block to belong to the region");
368   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
369                                      unwrap(block));
370 }
371 
372 void mlirRegionDestroy(MlirRegion region) {
373   delete static_cast<Region *>(region.ptr);
374 }
375 
376 /* ========================================================================== */
377 /* Block API.                                                                 */
378 /* ========================================================================== */
379 
380 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
381   Block *b = new Block;
382   for (intptr_t i = 0; i < nArgs; ++i)
383     b->addArgument(unwrap(args[i]));
384   return wrap(b);
385 }
386 
387 int mlirBlockEqual(MlirBlock block, MlirBlock other) {
388   return unwrap(block) == unwrap(other);
389 }
390 
391 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
392   return wrap(unwrap(block)->getNextNode());
393 }
394 
395 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
396   Block *cppBlock = unwrap(block);
397   if (cppBlock->empty())
398     return wrap(static_cast<Operation *>(nullptr));
399   return wrap(&cppBlock->front());
400 }
401 
402 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
403   unwrap(block)->push_back(unwrap(operation));
404 }
405 
406 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
407                                    MlirOperation operation) {
408   auto &opList = unwrap(block)->getOperations();
409   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
410 }
411 
412 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
413                                         MlirOperation reference,
414                                         MlirOperation operation) {
415   Block *cppBlock = unwrap(block);
416   if (mlirOperationIsNull(reference)) {
417     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
418     return;
419   }
420 
421   assert(unwrap(reference)->getBlock() == unwrap(block) &&
422          "expected reference operation to belong to the block");
423   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
424                                         unwrap(operation));
425 }
426 
427 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
428                                          MlirOperation reference,
429                                          MlirOperation operation) {
430   if (mlirOperationIsNull(reference))
431     return mlirBlockAppendOwnedOperation(block, operation);
432 
433   assert(unwrap(reference)->getBlock() == unwrap(block) &&
434          "expected reference operation to belong to the block");
435   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
436                                         unwrap(operation));
437 }
438 
439 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
440 
441 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
442   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
443 }
444 
445 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
446   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
447 }
448 
449 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
450                     void *userData) {
451   detail::CallbackOstream stream(callback, userData);
452   unwrap(block)->print(stream);
453   stream.flush();
454 }
455 
456 /* ========================================================================== */
457 /* Value API.                                                                 */
458 /* ========================================================================== */
459 
460 int mlirValueIsABlockArgument(MlirValue value) {
461   return unwrap(value).isa<BlockArgument>();
462 }
463 
464 int mlirValueIsAOpResult(MlirValue value) {
465   return unwrap(value).isa<OpResult>();
466 }
467 
468 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
469   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
470 }
471 
472 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
473   return static_cast<intptr_t>(
474       unwrap(value).cast<BlockArgument>().getArgNumber());
475 }
476 
477 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
478   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
479 }
480 
481 MlirOperation mlirOpResultGetOwner(MlirValue value) {
482   return wrap(unwrap(value).cast<OpResult>().getOwner());
483 }
484 
485 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
486   return static_cast<intptr_t>(
487       unwrap(value).cast<OpResult>().getResultNumber());
488 }
489 
490 MlirType mlirValueGetType(MlirValue value) {
491   return wrap(unwrap(value).getType());
492 }
493 
494 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
495 
496 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
497                     void *userData) {
498   detail::CallbackOstream stream(callback, userData);
499   unwrap(value).print(stream);
500   stream.flush();
501 }
502 
503 /* ========================================================================== */
504 /* Type API.                                                                  */
505 /* ========================================================================== */
506 
507 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
508   return wrap(mlir::parseType(type, unwrap(context)));
509 }
510 
511 MlirContext mlirTypeGetContext(MlirType type) {
512   return wrap(unwrap(type).getContext());
513 }
514 
515 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
516 
517 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
518   detail::CallbackOstream stream(callback, userData);
519   unwrap(type).print(stream);
520   stream.flush();
521 }
522 
523 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
524 
525 /* ========================================================================== */
526 /* Attribute API.                                                             */
527 /* ========================================================================== */
528 
529 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
530   return wrap(mlir::parseAttribute(attr, unwrap(context)));
531 }
532 
533 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
534   return wrap(unwrap(attribute).getContext());
535 }
536 
537 MlirType mlirAttributeGetType(MlirAttribute attribute) {
538   return wrap(unwrap(attribute).getType());
539 }
540 
541 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
542   return unwrap(a1) == unwrap(a2);
543 }
544 
545 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
546                         void *userData) {
547   detail::CallbackOstream stream(callback, userData);
548   unwrap(attr).print(stream);
549   stream.flush();
550 }
551 
552 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
553 
554 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
555   return MlirNamedAttribute{name, attr};
556 }
557