xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision 129d6e55)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Dialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/IR/Verifier.h"
21 #include "mlir/Parser.h"
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Context API.
27 //===----------------------------------------------------------------------===//
28 
29 MlirContext mlirContextCreate() {
30   auto *context = new MLIRContext;
31   return wrap(context);
32 }
33 
34 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
35   return unwrap(ctx1) == unwrap(ctx2);
36 }
37 
38 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
39 
40 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
41   unwrap(context)->allowUnregisteredDialects(allow);
42 }
43 
44 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
45   return unwrap(context)->allowsUnregisteredDialects();
46 }
47 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
48   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
49 }
50 
51 // TODO: expose a cheaper way than constructing + sorting a vector only to take
52 // its size.
53 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
54   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
55 }
56 
57 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
58                                         MlirStringRef name) {
59   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Dialect API.
64 //===----------------------------------------------------------------------===//
65 
66 MlirContext mlirDialectGetContext(MlirDialect dialect) {
67   return wrap(unwrap(dialect)->getContext());
68 }
69 
70 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
71   return unwrap(dialect1) == unwrap(dialect2);
72 }
73 
74 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
75   return wrap(unwrap(dialect)->getNamespace());
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // Printing flags API.
80 //===----------------------------------------------------------------------===//
81 
82 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
83   return wrap(new OpPrintingFlags());
84 }
85 
86 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
87   delete unwrap(flags);
88 }
89 
90 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
91                                                 intptr_t largeElementLimit) {
92   unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
93 }
94 
95 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
96                                         bool prettyForm) {
97   unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
98 }
99 
100 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
101   unwrap(flags)->printGenericOpForm();
102 }
103 
104 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
105   unwrap(flags)->useLocalScope();
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // Location API.
110 //===----------------------------------------------------------------------===//
111 
112 MlirLocation mlirLocationFileLineColGet(MlirContext context,
113                                         MlirStringRef filename, unsigned line,
114                                         unsigned col) {
115   return wrap(
116       FileLineColLoc::get(unwrap(filename), line, col, unwrap(context)));
117 }
118 
119 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
120   return wrap(CallSiteLoc::get(unwrap(callee), unwrap(caller)));
121 }
122 
123 MlirLocation mlirLocationUnknownGet(MlirContext context) {
124   return wrap(UnknownLoc::get(unwrap(context)));
125 }
126 
127 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
128   return unwrap(l1) == unwrap(l2);
129 }
130 
131 MlirContext mlirLocationGetContext(MlirLocation location) {
132   return wrap(unwrap(location).getContext());
133 }
134 
135 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
136                        void *userData) {
137   detail::CallbackOstream stream(callback, userData);
138   unwrap(location).print(stream);
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // Module API.
143 //===----------------------------------------------------------------------===//
144 
145 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
146   return wrap(ModuleOp::create(unwrap(location)));
147 }
148 
149 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
150   OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context));
151   if (!owning)
152     return MlirModule{nullptr};
153   return MlirModule{owning.release().getOperation()};
154 }
155 
156 MlirContext mlirModuleGetContext(MlirModule module) {
157   return wrap(unwrap(module).getContext());
158 }
159 
160 MlirBlock mlirModuleGetBody(MlirModule module) {
161   return wrap(unwrap(module).getBody());
162 }
163 
164 void mlirModuleDestroy(MlirModule module) {
165   // Transfer ownership to an OwningModuleRef so that its destructor is called.
166   OwningModuleRef(unwrap(module));
167 }
168 
169 MlirOperation mlirModuleGetOperation(MlirModule module) {
170   return wrap(unwrap(module).getOperation());
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // Operation state API.
175 //===----------------------------------------------------------------------===//
176 
177 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
178   MlirOperationState state;
179   state.name = name;
180   state.location = loc;
181   state.nResults = 0;
182   state.results = nullptr;
183   state.nOperands = 0;
184   state.operands = nullptr;
185   state.nRegions = 0;
186   state.regions = nullptr;
187   state.nSuccessors = 0;
188   state.successors = nullptr;
189   state.nAttributes = 0;
190   state.attributes = nullptr;
191   return state;
192 }
193 
194 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
195   state->elemName =                                                            \
196       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
197   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
198   state->sizeName += n;
199 
200 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
201                                   MlirType const *results) {
202   APPEND_ELEMS(MlirType, nResults, results);
203 }
204 
205 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
206                                    MlirValue const *operands) {
207   APPEND_ELEMS(MlirValue, nOperands, operands);
208 }
209 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
210                                        MlirRegion const *regions) {
211   APPEND_ELEMS(MlirRegion, nRegions, regions);
212 }
213 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
214                                      MlirBlock const *successors) {
215   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
216 }
217 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
218                                      MlirNamedAttribute const *attributes) {
219   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // Operation API.
224 //===----------------------------------------------------------------------===//
225 
226 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
227   assert(state);
228   OperationState cppState(unwrap(state->location), unwrap(state->name));
229   SmallVector<Type, 4> resultStorage;
230   SmallVector<Value, 8> operandStorage;
231   SmallVector<Block *, 2> successorStorage;
232   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
233   cppState.addOperands(
234       unwrapList(state->nOperands, state->operands, operandStorage));
235   cppState.addSuccessors(
236       unwrapList(state->nSuccessors, state->successors, successorStorage));
237 
238   cppState.attributes.reserve(state->nAttributes);
239   for (intptr_t i = 0; i < state->nAttributes; ++i)
240     cppState.addAttribute(unwrap(state->attributes[i].name),
241                           unwrap(state->attributes[i].attribute));
242 
243   for (intptr_t i = 0; i < state->nRegions; ++i)
244     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
245 
246   MlirOperation result = wrap(Operation::create(cppState));
247   free(state->results);
248   free(state->operands);
249   free(state->successors);
250   free(state->regions);
251   free(state->attributes);
252   return result;
253 }
254 
255 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
256 
257 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
258   return unwrap(op) == unwrap(other);
259 }
260 
261 MlirIdentifier mlirOperationGetName(MlirOperation op) {
262   return wrap(unwrap(op)->getName().getIdentifier());
263 }
264 
265 MlirBlock mlirOperationGetBlock(MlirOperation op) {
266   return wrap(unwrap(op)->getBlock());
267 }
268 
269 MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
270   return wrap(unwrap(op)->getParentOp());
271 }
272 
273 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
274   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
275 }
276 
277 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
278   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
279 }
280 
281 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
282   return wrap(unwrap(op)->getNextNode());
283 }
284 
285 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
286   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
287 }
288 
289 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
290   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
291 }
292 
293 intptr_t mlirOperationGetNumResults(MlirOperation op) {
294   return static_cast<intptr_t>(unwrap(op)->getNumResults());
295 }
296 
297 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
298   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
299 }
300 
301 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
302   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
303 }
304 
305 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
306   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
307 }
308 
309 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
310   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
311 }
312 
313 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
314   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
315   return MlirNamedAttribute{wrap(attr.first), wrap(attr.second)};
316 }
317 
318 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
319                                               MlirStringRef name) {
320   return wrap(unwrap(op)->getAttr(unwrap(name)));
321 }
322 
323 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
324                                      MlirAttribute attr) {
325   unwrap(op)->setAttr(unwrap(name), unwrap(attr));
326 }
327 
328 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
329   auto removeResult = unwrap(op)->removeAttr(unwrap(name));
330   return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
331 }
332 
333 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
334                         void *userData) {
335   detail::CallbackOstream stream(callback, userData);
336   unwrap(op)->print(stream);
337 }
338 
339 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
340                                  MlirStringCallback callback, void *userData) {
341   detail::CallbackOstream stream(callback, userData);
342   unwrap(op)->print(stream, *unwrap(flags));
343 }
344 
345 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
346 
347 bool mlirOperationVerify(MlirOperation op) {
348   return succeeded(verify(unwrap(op)));
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // Region API.
353 //===----------------------------------------------------------------------===//
354 
355 MlirRegion mlirRegionCreate() { return wrap(new Region); }
356 
357 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
358   Region *cppRegion = unwrap(region);
359   if (cppRegion->empty())
360     return wrap(static_cast<Block *>(nullptr));
361   return wrap(&cppRegion->front());
362 }
363 
364 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
365   unwrap(region)->push_back(unwrap(block));
366 }
367 
368 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
369                                 MlirBlock block) {
370   auto &blockList = unwrap(region)->getBlocks();
371   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
372 }
373 
374 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
375                                      MlirBlock block) {
376   Region *cppRegion = unwrap(region);
377   if (mlirBlockIsNull(reference)) {
378     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
379     return;
380   }
381 
382   assert(unwrap(reference)->getParent() == unwrap(region) &&
383          "expected reference block to belong to the region");
384   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
385                                      unwrap(block));
386 }
387 
388 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
389                                       MlirBlock block) {
390   if (mlirBlockIsNull(reference))
391     return mlirRegionAppendOwnedBlock(region, block);
392 
393   assert(unwrap(reference)->getParent() == unwrap(region) &&
394          "expected reference block to belong to the region");
395   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
396                                      unwrap(block));
397 }
398 
399 void mlirRegionDestroy(MlirRegion region) {
400   delete static_cast<Region *>(region.ptr);
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // Block API.
405 //===----------------------------------------------------------------------===//
406 
407 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) {
408   Block *b = new Block;
409   for (intptr_t i = 0; i < nArgs; ++i)
410     b->addArgument(unwrap(args[i]));
411   return wrap(b);
412 }
413 
414 bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
415   return unwrap(block) == unwrap(other);
416 }
417 
418 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
419   return wrap(unwrap(block)->getNextNode());
420 }
421 
422 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
423   Block *cppBlock = unwrap(block);
424   if (cppBlock->empty())
425     return wrap(static_cast<Operation *>(nullptr));
426   return wrap(&cppBlock->front());
427 }
428 
429 MlirOperation mlirBlockGetTerminator(MlirBlock block) {
430   Block *cppBlock = unwrap(block);
431   if (cppBlock->empty())
432     return wrap(static_cast<Operation *>(nullptr));
433   Operation &back = cppBlock->back();
434   if (!back.isKnownTerminator())
435     return wrap(static_cast<Operation *>(nullptr));
436   return wrap(&back);
437 }
438 
439 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
440   unwrap(block)->push_back(unwrap(operation));
441 }
442 
443 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
444                                    MlirOperation operation) {
445   auto &opList = unwrap(block)->getOperations();
446   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
447 }
448 
449 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
450                                         MlirOperation reference,
451                                         MlirOperation operation) {
452   Block *cppBlock = unwrap(block);
453   if (mlirOperationIsNull(reference)) {
454     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
455     return;
456   }
457 
458   assert(unwrap(reference)->getBlock() == unwrap(block) &&
459          "expected reference operation to belong to the block");
460   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
461                                         unwrap(operation));
462 }
463 
464 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
465                                          MlirOperation reference,
466                                          MlirOperation operation) {
467   if (mlirOperationIsNull(reference))
468     return mlirBlockAppendOwnedOperation(block, operation);
469 
470   assert(unwrap(reference)->getBlock() == unwrap(block) &&
471          "expected reference operation to belong to the block");
472   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
473                                         unwrap(operation));
474 }
475 
476 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
477 
478 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
479   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
480 }
481 
482 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
483   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
484 }
485 
486 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
487                     void *userData) {
488   detail::CallbackOstream stream(callback, userData);
489   unwrap(block)->print(stream);
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // Value API.
494 //===----------------------------------------------------------------------===//
495 
496 bool mlirValueEqual(MlirValue value1, MlirValue value2) {
497   return unwrap(value1) == unwrap(value2);
498 }
499 
500 bool mlirValueIsABlockArgument(MlirValue value) {
501   return unwrap(value).isa<BlockArgument>();
502 }
503 
504 bool mlirValueIsAOpResult(MlirValue value) {
505   return unwrap(value).isa<OpResult>();
506 }
507 
508 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
509   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
510 }
511 
512 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
513   return static_cast<intptr_t>(
514       unwrap(value).cast<BlockArgument>().getArgNumber());
515 }
516 
517 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
518   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
519 }
520 
521 MlirOperation mlirOpResultGetOwner(MlirValue value) {
522   return wrap(unwrap(value).cast<OpResult>().getOwner());
523 }
524 
525 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
526   return static_cast<intptr_t>(
527       unwrap(value).cast<OpResult>().getResultNumber());
528 }
529 
530 MlirType mlirValueGetType(MlirValue value) {
531   return wrap(unwrap(value).getType());
532 }
533 
534 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
535 
536 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
537                     void *userData) {
538   detail::CallbackOstream stream(callback, userData);
539   unwrap(value).print(stream);
540 }
541 
542 //===----------------------------------------------------------------------===//
543 // Type API.
544 //===----------------------------------------------------------------------===//
545 
546 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
547   return wrap(mlir::parseType(unwrap(type), unwrap(context)));
548 }
549 
550 MlirContext mlirTypeGetContext(MlirType type) {
551   return wrap(unwrap(type).getContext());
552 }
553 
554 bool mlirTypeEqual(MlirType t1, MlirType t2) {
555   return unwrap(t1) == unwrap(t2);
556 }
557 
558 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
559   detail::CallbackOstream stream(callback, userData);
560   unwrap(type).print(stream);
561 }
562 
563 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
564 
565 //===----------------------------------------------------------------------===//
566 // Attribute API.
567 //===----------------------------------------------------------------------===//
568 
569 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
570   return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
571 }
572 
573 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
574   return wrap(unwrap(attribute).getContext());
575 }
576 
577 MlirType mlirAttributeGetType(MlirAttribute attribute) {
578   return wrap(unwrap(attribute).getType());
579 }
580 
581 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
582   return unwrap(a1) == unwrap(a2);
583 }
584 
585 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
586                         void *userData) {
587   detail::CallbackOstream stream(callback, userData);
588   unwrap(attr).print(stream);
589 }
590 
591 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
592 
593 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
594                                          MlirAttribute attr) {
595   return MlirNamedAttribute{name, attr};
596 }
597 
598 //===----------------------------------------------------------------------===//
599 // Identifier API.
600 //===----------------------------------------------------------------------===//
601 
602 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
603   return wrap(Identifier::get(unwrap(str), unwrap(context)));
604 }
605 
606 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
607   return unwrap(ident) == unwrap(other);
608 }
609 
610 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
611   return wrap(unwrap(ident).strref());
612 }
613