xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision 15ea45f1)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 #include "mlir-c/Support.h"
11 
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Module.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/Parser.h"
21 
22 using namespace mlir;
23 
24 /* ========================================================================== */
25 /* Context API.                                                               */
26 /* ========================================================================== */
27 
28 MlirContext mlirContextCreate() {
29   auto *context = new MLIRContext(/*loadAllDialects=*/false);
30   return wrap(context);
31 }
32 
33 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
34   return unwrap(ctx1) == unwrap(ctx2);
35 }
36 
37 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
38 
39 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
40   unwrap(context)->allowUnregisteredDialects(allow);
41 }
42 
43 int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
44   return unwrap(context)->allowsUnregisteredDialects();
45 }
46 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
47   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
48 }
49 
50 // TODO: expose a cheaper way than constructing + sorting a vector only to take
51 // its size.
52 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
53   return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
54 }
55 
56 MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
57                                         MlirStringRef name) {
58   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
59 }
60 
61 /* ========================================================================== */
62 /* Dialect API.                                                               */
63 /* ========================================================================== */
64 
65 MlirContext mlirDialectGetContext(MlirDialect dialect) {
66   return wrap(unwrap(dialect)->getContext());
67 }
68 
69 int mlirDialectIsNull(MlirDialect dialect) {
70   return unwrap(dialect) == nullptr;
71 }
72 
73 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
74   return unwrap(dialect1) == unwrap(dialect2);
75 }
76 
77 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
78   return wrap(unwrap(dialect)->getNamespace());
79 }
80 
81 /* ========================================================================== */
82 /* Location API.                                                              */
83 /* ========================================================================== */
84 
85 MlirLocation mlirLocationFileLineColGet(MlirContext context,
86                                         const char *filename, unsigned line,
87                                         unsigned col) {
88   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
89 }
90 
91 MlirLocation mlirLocationUnknownGet(MlirContext context) {
92   return wrap(UnknownLoc::get(unwrap(context)));
93 }
94 
95 MlirContext mlirLocationGetContext(MlirLocation location) {
96   return wrap(unwrap(location).getContext());
97 }
98 
99 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
100                        void *userData) {
101   detail::CallbackOstream stream(callback, userData);
102   unwrap(location).print(stream);
103   stream.flush();
104 }
105 
106 /* ========================================================================== */
107 /* Module API.                                                                */
108 /* ========================================================================== */
109 
110 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
111   return wrap(ModuleOp::create(unwrap(location)));
112 }
113 
114 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
115   OwningModuleRef owning = parseSourceString(module, unwrap(context));
116   if (!owning)
117     return MlirModule{nullptr};
118   return MlirModule{owning.release().getOperation()};
119 }
120 
121 MlirContext mlirModuleGetContext(MlirModule module) {
122   return wrap(unwrap(module).getContext());
123 }
124 
125 void mlirModuleDestroy(MlirModule module) {
126   // Transfer ownership to an OwningModuleRef so that its destructor is called.
127   OwningModuleRef(unwrap(module));
128 }
129 
130 MlirOperation mlirModuleGetOperation(MlirModule module) {
131   return wrap(unwrap(module).getOperation());
132 }
133 
134 /* ========================================================================== */
135 /* Operation state API.                                                       */
136 /* ========================================================================== */
137 
138 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
139   MlirOperationState state;
140   state.name = name;
141   state.location = loc;
142   state.nResults = 0;
143   state.results = nullptr;
144   state.nOperands = 0;
145   state.operands = nullptr;
146   state.nRegions = 0;
147   state.regions = nullptr;
148   state.nSuccessors = 0;
149   state.successors = nullptr;
150   state.nAttributes = 0;
151   state.attributes = nullptr;
152   return state;
153 }
154 
155 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
156   state->elemName =                                                            \
157       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
158   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
159   state->sizeName += n;
160 
161 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
162                                   MlirType *results) {
163   APPEND_ELEMS(MlirType, nResults, results);
164 }
165 
166 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
167                                    MlirValue *operands) {
168   APPEND_ELEMS(MlirValue, nOperands, operands);
169 }
170 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
171                                        MlirRegion *regions) {
172   APPEND_ELEMS(MlirRegion, nRegions, regions);
173 }
174 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
175                                      MlirBlock *successors) {
176   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
177 }
178 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
179                                      MlirNamedAttribute *attributes) {
180   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
181 }
182 
183 /* ========================================================================== */
184 /* Operation API.                                                             */
185 /* ========================================================================== */
186 
187 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
188   assert(state);
189   OperationState cppState(unwrap(state->location), state->name);
190   SmallVector<Type, 4> resultStorage;
191   SmallVector<Value, 8> operandStorage;
192   SmallVector<Block *, 2> successorStorage;
193   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
194   cppState.addOperands(
195       unwrapList(state->nOperands, state->operands, operandStorage));
196   cppState.addSuccessors(
197       unwrapList(state->nSuccessors, state->successors, successorStorage));
198 
199   cppState.attributes.reserve(state->nAttributes);
200   for (intptr_t i = 0; i < state->nAttributes; ++i)
201     cppState.addAttribute(state->attributes[i].name,
202                           unwrap(state->attributes[i].attribute));
203 
204   for (intptr_t i = 0; i < state->nRegions; ++i)
205     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
206 
207   MlirOperation result = wrap(Operation::create(cppState));
208   free(state->results);
209   free(state->operands);
210   free(state->successors);
211   free(state->regions);
212   free(state->attributes);
213   return result;
214 }
215 
216 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
217 
218 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
219 
220 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
221   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
222 }
223 
224 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
225   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
226 }
227 
228 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
229   return wrap(unwrap(op)->getNextNode());
230 }
231 
232 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
233   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
234 }
235 
236 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
237   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
238 }
239 
240 intptr_t mlirOperationGetNumResults(MlirOperation op) {
241   return static_cast<intptr_t>(unwrap(op)->getNumResults());
242 }
243 
244 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
245   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
246 }
247 
248 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
249   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
250 }
251 
252 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
253   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
254 }
255 
256 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
257   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
258 }
259 
260 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
261   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
262   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
263 }
264 
265 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
266                                               const char *name) {
267   return wrap(unwrap(op)->getAttr(name));
268 }
269 
270 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
271                         void *userData) {
272   detail::CallbackOstream stream(callback, userData);
273   unwrap(op)->print(stream);
274   stream.flush();
275 }
276 
277 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
278 
279 /* ========================================================================== */
280 /* Region API.                                                                */
281 /* ========================================================================== */
282 
283 MlirRegion mlirRegionCreate() { return wrap(new Region); }
284 
285 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
286   Region *cppRegion = unwrap(region);
287   if (cppRegion->empty())
288     return wrap(static_cast<Block *>(nullptr));
289   return wrap(&cppRegion->front());
290 }
291 
292 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
293   unwrap(region)->push_back(unwrap(block));
294 }
295 
296 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
297                                 MlirBlock block) {
298   auto &blockList = unwrap(region)->getBlocks();
299   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
300 }
301 
302 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
303                                      MlirBlock block) {
304   Region *cppRegion = unwrap(region);
305   if (mlirBlockIsNull(reference)) {
306     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
307     return;
308   }
309 
310   assert(unwrap(reference)->getParent() == unwrap(region) &&
311          "expected reference block to belong to the region");
312   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
313                                      unwrap(block));
314 }
315 
316 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
317                                       MlirBlock block) {
318   if (mlirBlockIsNull(reference))
319     return mlirRegionAppendOwnedBlock(region, block);
320 
321   assert(unwrap(reference)->getParent() == unwrap(region) &&
322          "expected reference block to belong to the region");
323   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
324                                      unwrap(block));
325 }
326 
327 void mlirRegionDestroy(MlirRegion region) {
328   delete static_cast<Region *>(region.ptr);
329 }
330 
331 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
332 
333 /* ========================================================================== */
334 /* Block API.                                                                 */
335 /* ========================================================================== */
336 
337 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
338   Block *b = new Block;
339   for (intptr_t i = 0; i < nArgs; ++i)
340     b->addArgument(unwrap(args[i]));
341   return wrap(b);
342 }
343 
344 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
345   return wrap(unwrap(block)->getNextNode());
346 }
347 
348 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
349   Block *cppBlock = unwrap(block);
350   if (cppBlock->empty())
351     return wrap(static_cast<Operation *>(nullptr));
352   return wrap(&cppBlock->front());
353 }
354 
355 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
356   unwrap(block)->push_back(unwrap(operation));
357 }
358 
359 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
360                                    MlirOperation operation) {
361   auto &opList = unwrap(block)->getOperations();
362   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
363 }
364 
365 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
366                                         MlirOperation reference,
367                                         MlirOperation operation) {
368   Block *cppBlock = unwrap(block);
369   if (mlirOperationIsNull(reference)) {
370     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
371     return;
372   }
373 
374   assert(unwrap(reference)->getBlock() == unwrap(block) &&
375          "expected reference operation to belong to the block");
376   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
377                                         unwrap(operation));
378 }
379 
380 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
381                                          MlirOperation reference,
382                                          MlirOperation operation) {
383   if (mlirOperationIsNull(reference))
384     return mlirBlockAppendOwnedOperation(block, operation);
385 
386   assert(unwrap(reference)->getBlock() == unwrap(block) &&
387          "expected reference operation to belong to the block");
388   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
389                                         unwrap(operation));
390 }
391 
392 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
393 
394 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
395 
396 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
397   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
398 }
399 
400 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
401   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
402 }
403 
404 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
405                     void *userData) {
406   detail::CallbackOstream stream(callback, userData);
407   unwrap(block)->print(stream);
408   stream.flush();
409 }
410 
411 /* ========================================================================== */
412 /* Value API.                                                                 */
413 /* ========================================================================== */
414 
415 MlirType mlirValueGetType(MlirValue value) {
416   return wrap(unwrap(value).getType());
417 }
418 
419 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
420                     void *userData) {
421   detail::CallbackOstream stream(callback, userData);
422   unwrap(value).print(stream);
423   stream.flush();
424 }
425 
426 /* ========================================================================== */
427 /* Type API.                                                                  */
428 /* ========================================================================== */
429 
430 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
431   return wrap(mlir::parseType(type, unwrap(context)));
432 }
433 
434 MlirContext mlirTypeGetContext(MlirType type) {
435   return wrap(unwrap(type).getContext());
436 }
437 
438 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
439 
440 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
441   detail::CallbackOstream stream(callback, userData);
442   unwrap(type).print(stream);
443   stream.flush();
444 }
445 
446 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
447 
448 /* ========================================================================== */
449 /* Attribute API.                                                             */
450 /* ========================================================================== */
451 
452 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
453   return wrap(mlir::parseAttribute(attr, unwrap(context)));
454 }
455 
456 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
457   return wrap(unwrap(attribute).getContext());
458 }
459 
460 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
461   return unwrap(a1) == unwrap(a2);
462 }
463 
464 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
465                         void *userData) {
466   detail::CallbackOstream stream(callback, userData);
467   unwrap(attr).print(stream);
468   stream.flush();
469 }
470 
471 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
472 
473 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
474   return MlirNamedAttribute{name, attr};
475 }
476