xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision fa5fa63f)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Module.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/Parser.h"
21 
22 using namespace mlir;
23 
24 /* ========================================================================== */
25 /* Context API.                                                               */
26 /* ========================================================================== */
27 
28 MlirContext mlirContextCreate() {
29   auto *context = new MLIRContext(/*loadAllDialects=*/false);
30   return wrap(context);
31 }
32 
33 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
34   return unwrap(ctx1) == unwrap(ctx2);
35 }
36 
37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
38 
39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
40   unwrap(context)->allowUnregisteredDialects(allow);
41 }
42 
43 int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
44   return unwrap(context)->allowsUnregisteredDialects();
45 }
46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
47   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
48 }
49 
50 // TODO: expose a cheaper way than constructing + sorting a vector only to take
51 // its size.
52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
53   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
54 }
55 
56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
57                                         MlirStringRef name) {
58   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
59 }
60 
61 /* ========================================================================== */
62 /* Dialect API.                                                               */
63 /* ========================================================================== */
64 
65 MlirContext mlirDialectGetContext(MlirDialect dialect) {
66   return wrap(unwrap(dialect)->getContext());
67 }
68 
69 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
70   return unwrap(dialect1) == unwrap(dialect2);
71 }
72 
73 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
74   return wrap(unwrap(dialect)->getNamespace());
75 }
76 
77 /* ========================================================================== */
78 /* Location API.                                                              */
79 /* ========================================================================== */
80 
81 MlirLocation mlirLocationFileLineColGet(MlirContext context,
82                                         const char *filename, unsigned line,
83                                         unsigned col) {
84   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
85 }
86 
87 MlirLocation mlirLocationUnknownGet(MlirContext context) {
88   return wrap(UnknownLoc::get(unwrap(context)));
89 }
90 
91 MlirContext mlirLocationGetContext(MlirLocation location) {
92   return wrap(unwrap(location).getContext());
93 }
94 
95 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
96                        void *userData) {
97   detail::CallbackOstream stream(callback, userData);
98   unwrap(location).print(stream);
99   stream.flush();
100 }
101 
102 /* ========================================================================== */
103 /* Module API.                                                                */
104 /* ========================================================================== */
105 
106 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
107   return wrap(ModuleOp::create(unwrap(location)));
108 }
109 
110 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
111   OwningModuleRef owning = parseSourceString(module, unwrap(context));
112   if (!owning)
113     return MlirModule{nullptr};
114   return MlirModule{owning.release().getOperation()};
115 }
116 
117 MlirContext mlirModuleGetContext(MlirModule module) {
118   return wrap(unwrap(module).getContext());
119 }
120 
121 void mlirModuleDestroy(MlirModule module) {
122   // Transfer ownership to an OwningModuleRef so that its destructor is called.
123   OwningModuleRef(unwrap(module));
124 }
125 
126 MlirOperation mlirModuleGetOperation(MlirModule module) {
127   return wrap(unwrap(module).getOperation());
128 }
129 
130 /* ========================================================================== */
131 /* Operation state API.                                                       */
132 /* ========================================================================== */
133 
134 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
135   MlirOperationState state;
136   state.name = name;
137   state.location = loc;
138   state.nResults = 0;
139   state.results = nullptr;
140   state.nOperands = 0;
141   state.operands = nullptr;
142   state.nRegions = 0;
143   state.regions = nullptr;
144   state.nSuccessors = 0;
145   state.successors = nullptr;
146   state.nAttributes = 0;
147   state.attributes = nullptr;
148   return state;
149 }
150 
151 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
152   state->elemName =                                                            \
153       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
154   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
155   state->sizeName += n;
156 
157 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
158                                   MlirType *results) {
159   APPEND_ELEMS(MlirType, nResults, results);
160 }
161 
162 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
163                                    MlirValue *operands) {
164   APPEND_ELEMS(MlirValue, nOperands, operands);
165 }
166 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
167                                        MlirRegion *regions) {
168   APPEND_ELEMS(MlirRegion, nRegions, regions);
169 }
170 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
171                                      MlirBlock *successors) {
172   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
173 }
174 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
175                                      MlirNamedAttribute *attributes) {
176   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
177 }
178 
179 /* ========================================================================== */
180 /* Operation API.                                                             */
181 /* ========================================================================== */
182 
183 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
184   assert(state);
185   OperationState cppState(unwrap(state->location), state->name);
186   SmallVector<Type, 4> resultStorage;
187   SmallVector<Value, 8> operandStorage;
188   SmallVector<Block *, 2> successorStorage;
189   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
190   cppState.addOperands(
191       unwrapList(state->nOperands, state->operands, operandStorage));
192   cppState.addSuccessors(
193       unwrapList(state->nSuccessors, state->successors, successorStorage));
194 
195   cppState.attributes.reserve(state->nAttributes);
196   for (intptr_t i = 0; i < state->nAttributes; ++i)
197     cppState.addAttribute(state->attributes[i].name,
198                           unwrap(state->attributes[i].attribute));
199 
200   for (intptr_t i = 0; i < state->nRegions; ++i)
201     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
202 
203   MlirOperation result = wrap(Operation::create(cppState));
204   free(state->results);
205   free(state->operands);
206   free(state->successors);
207   free(state->regions);
208   free(state->attributes);
209   return result;
210 }
211 
212 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
213 
214 int mlirOperationEqual(MlirOperation op, MlirOperation other) {
215   return unwrap(op) == unwrap(other);
216 }
217 
218 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
219   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
220 }
221 
222 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
223   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
224 }
225 
226 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
227   return wrap(unwrap(op)->getNextNode());
228 }
229 
230 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
231   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
232 }
233 
234 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
235   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
236 }
237 
238 intptr_t mlirOperationGetNumResults(MlirOperation op) {
239   return static_cast<intptr_t>(unwrap(op)->getNumResults());
240 }
241 
242 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
243   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
244 }
245 
246 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
247   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
248 }
249 
250 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
251   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
252 }
253 
254 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
255   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
256 }
257 
258 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
259   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
260   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
261 }
262 
263 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
264                                               const char *name) {
265   return wrap(unwrap(op)->getAttr(name));
266 }
267 
268 void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
269                                      MlirAttribute attr) {
270   unwrap(op)->setAttr(name, unwrap(attr));
271 }
272 
273 int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) {
274   auto removeResult = unwrap(op)->removeAttr(name);
275   return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
276 }
277 
278 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
279                         void *userData) {
280   detail::CallbackOstream stream(callback, userData);
281   unwrap(op)->print(stream);
282   stream.flush();
283 }
284 
285 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
286 
287 /* ========================================================================== */
288 /* Region API.                                                                */
289 /* ========================================================================== */
290 
291 MlirRegion mlirRegionCreate() { return wrap(new Region); }
292 
293 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
294   Region *cppRegion = unwrap(region);
295   if (cppRegion->empty())
296     return wrap(static_cast<Block *>(nullptr));
297   return wrap(&cppRegion->front());
298 }
299 
300 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
301   unwrap(region)->push_back(unwrap(block));
302 }
303 
304 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
305                                 MlirBlock block) {
306   auto &blockList = unwrap(region)->getBlocks();
307   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
308 }
309 
310 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
311                                      MlirBlock block) {
312   Region *cppRegion = unwrap(region);
313   if (mlirBlockIsNull(reference)) {
314     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
315     return;
316   }
317 
318   assert(unwrap(reference)->getParent() == unwrap(region) &&
319          "expected reference block to belong to the region");
320   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
321                                      unwrap(block));
322 }
323 
324 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
325                                       MlirBlock block) {
326   if (mlirBlockIsNull(reference))
327     return mlirRegionAppendOwnedBlock(region, block);
328 
329   assert(unwrap(reference)->getParent() == unwrap(region) &&
330          "expected reference block to belong to the region");
331   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
332                                      unwrap(block));
333 }
334 
335 void mlirRegionDestroy(MlirRegion region) {
336   delete static_cast<Region *>(region.ptr);
337 }
338 
339 /* ========================================================================== */
340 /* Block API.                                                                 */
341 /* ========================================================================== */
342 
343 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
344   Block *b = new Block;
345   for (intptr_t i = 0; i < nArgs; ++i)
346     b->addArgument(unwrap(args[i]));
347   return wrap(b);
348 }
349 
350 int mlirBlockEqual(MlirBlock block, MlirBlock other) {
351   return unwrap(block) == unwrap(other);
352 }
353 
354 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
355   return wrap(unwrap(block)->getNextNode());
356 }
357 
358 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
359   Block *cppBlock = unwrap(block);
360   if (cppBlock->empty())
361     return wrap(static_cast<Operation *>(nullptr));
362   return wrap(&cppBlock->front());
363 }
364 
365 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
366   unwrap(block)->push_back(unwrap(operation));
367 }
368 
369 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
370                                    MlirOperation operation) {
371   auto &opList = unwrap(block)->getOperations();
372   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
373 }
374 
375 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
376                                         MlirOperation reference,
377                                         MlirOperation operation) {
378   Block *cppBlock = unwrap(block);
379   if (mlirOperationIsNull(reference)) {
380     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
381     return;
382   }
383 
384   assert(unwrap(reference)->getBlock() == unwrap(block) &&
385          "expected reference operation to belong to the block");
386   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
387                                         unwrap(operation));
388 }
389 
390 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
391                                          MlirOperation reference,
392                                          MlirOperation operation) {
393   if (mlirOperationIsNull(reference))
394     return mlirBlockAppendOwnedOperation(block, operation);
395 
396   assert(unwrap(reference)->getBlock() == unwrap(block) &&
397          "expected reference operation to belong to the block");
398   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
399                                         unwrap(operation));
400 }
401 
402 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
403 
404 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
405   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
406 }
407 
408 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
409   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
410 }
411 
412 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
413                     void *userData) {
414   detail::CallbackOstream stream(callback, userData);
415   unwrap(block)->print(stream);
416   stream.flush();
417 }
418 
419 /* ========================================================================== */
420 /* Value API.                                                                 */
421 /* ========================================================================== */
422 
423 int mlirValueIsABlockArgument(MlirValue value) {
424   return unwrap(value).isa<BlockArgument>();
425 }
426 
427 int mlirValueIsAOpResult(MlirValue value) {
428   return unwrap(value).isa<OpResult>();
429 }
430 
431 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
432   return wrap(unwrap(value).cast<BlockArgument>().getOwner());
433 }
434 
435 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
436   return static_cast<intptr_t>(
437       unwrap(value).cast<BlockArgument>().getArgNumber());
438 }
439 
440 void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
441   unwrap(value).cast<BlockArgument>().setType(unwrap(type));
442 }
443 
444 MlirOperation mlirOpResultGetOwner(MlirValue value) {
445   return wrap(unwrap(value).cast<OpResult>().getOwner());
446 }
447 
448 intptr_t mlirOpResultGetResultNumber(MlirValue value) {
449   return static_cast<intptr_t>(
450       unwrap(value).cast<OpResult>().getResultNumber());
451 }
452 
453 MlirType mlirValueGetType(MlirValue value) {
454   return wrap(unwrap(value).getType());
455 }
456 
457 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
458 
459 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
460                     void *userData) {
461   detail::CallbackOstream stream(callback, userData);
462   unwrap(value).print(stream);
463   stream.flush();
464 }
465 
466 /* ========================================================================== */
467 /* Type API.                                                                  */
468 /* ========================================================================== */
469 
470 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
471   return wrap(mlir::parseType(type, unwrap(context)));
472 }
473 
474 MlirContext mlirTypeGetContext(MlirType type) {
475   return wrap(unwrap(type).getContext());
476 }
477 
478 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
479 
480 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
481   detail::CallbackOstream stream(callback, userData);
482   unwrap(type).print(stream);
483   stream.flush();
484 }
485 
486 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
487 
488 /* ========================================================================== */
489 /* Attribute API.                                                             */
490 /* ========================================================================== */
491 
492 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
493   return wrap(mlir::parseAttribute(attr, unwrap(context)));
494 }
495 
496 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
497   return wrap(unwrap(attribute).getContext());
498 }
499 
500 MlirType mlirAttributeGetType(MlirAttribute attribute) {
501   return wrap(unwrap(attribute).getType());
502 }
503 
504 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
505   return unwrap(a1) == unwrap(a2);
506 }
507 
508 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
509                         void *userData) {
510   detail::CallbackOstream stream(callback, userData);
511   unwrap(attr).print(stream);
512   stream.flush();
513 }
514 
515 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
516 
517 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
518   return MlirNamedAttribute{name, attr};
519 }
520