xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision 396e7f45)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Utils.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Module.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/Types.h"
18 #include "mlir/Parser.h"
19 
20 using namespace mlir;
21 
22 /* ========================================================================== */
23 /* Context API.                                                               */
24 /* ========================================================================== */
25 
26 MlirContext mlirContextCreate() {
27   auto *context = new MLIRContext(/*loadAllDialects=*/false);
28   return wrap(context);
29 }
30 
31 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
32   return unwrap(ctx1) == unwrap(ctx2);
33 }
34 
35 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
36 
37 void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
38   unwrap(context)->allowUnregisteredDialects(allow);
39 }
40 
41 int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
42   return unwrap(context)->allowsUnregisteredDialects();
43 }
44 
45 /* ========================================================================== */
46 /* Location API.                                                              */
47 /* ========================================================================== */
48 
49 MlirLocation mlirLocationFileLineColGet(MlirContext context,
50                                         const char *filename, unsigned line,
51                                         unsigned col) {
52   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
53 }
54 
55 MlirLocation mlirLocationUnknownGet(MlirContext context) {
56   return wrap(UnknownLoc::get(unwrap(context)));
57 }
58 
59 MlirContext mlirLocationGetContext(MlirLocation location) {
60   return wrap(unwrap(location).getContext());
61 }
62 
63 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
64                        void *userData) {
65   detail::CallbackOstream stream(callback, userData);
66   unwrap(location).print(stream);
67   stream.flush();
68 }
69 
70 /* ========================================================================== */
71 /* Module API.                                                                */
72 /* ========================================================================== */
73 
74 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
75   return wrap(ModuleOp::create(unwrap(location)));
76 }
77 
78 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
79   OwningModuleRef owning = parseSourceString(module, unwrap(context));
80   if (!owning)
81     return MlirModule{nullptr};
82   return MlirModule{owning.release().getOperation()};
83 }
84 
85 MlirContext mlirModuleGetContext(MlirModule module) {
86   return wrap(unwrap(module).getContext());
87 }
88 
89 void mlirModuleDestroy(MlirModule module) {
90   // Transfer ownership to an OwningModuleRef so that its destructor is called.
91   OwningModuleRef(unwrap(module));
92 }
93 
94 MlirOperation mlirModuleGetOperation(MlirModule module) {
95   return wrap(unwrap(module).getOperation());
96 }
97 
98 /* ========================================================================== */
99 /* Operation state API.                                                       */
100 /* ========================================================================== */
101 
102 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
103   MlirOperationState state;
104   state.name = name;
105   state.location = loc;
106   state.nResults = 0;
107   state.results = nullptr;
108   state.nOperands = 0;
109   state.operands = nullptr;
110   state.nRegions = 0;
111   state.regions = nullptr;
112   state.nSuccessors = 0;
113   state.successors = nullptr;
114   state.nAttributes = 0;
115   state.attributes = nullptr;
116   return state;
117 }
118 
119 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
120   state->elemName =                                                            \
121       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
122   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
123   state->sizeName += n;
124 
125 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
126                                   MlirType *results) {
127   APPEND_ELEMS(MlirType, nResults, results);
128 }
129 
130 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
131                                    MlirValue *operands) {
132   APPEND_ELEMS(MlirValue, nOperands, operands);
133 }
134 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
135                                        MlirRegion *regions) {
136   APPEND_ELEMS(MlirRegion, nRegions, regions);
137 }
138 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
139                                      MlirBlock *successors) {
140   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
141 }
142 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
143                                      MlirNamedAttribute *attributes) {
144   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
145 }
146 
147 /* ========================================================================== */
148 /* Operation API.                                                             */
149 /* ========================================================================== */
150 
151 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
152   assert(state);
153   OperationState cppState(unwrap(state->location), state->name);
154   SmallVector<Type, 4> resultStorage;
155   SmallVector<Value, 8> operandStorage;
156   SmallVector<Block *, 2> successorStorage;
157   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
158   cppState.addOperands(
159       unwrapList(state->nOperands, state->operands, operandStorage));
160   cppState.addSuccessors(
161       unwrapList(state->nSuccessors, state->successors, successorStorage));
162 
163   cppState.attributes.reserve(state->nAttributes);
164   for (intptr_t i = 0; i < state->nAttributes; ++i)
165     cppState.addAttribute(state->attributes[i].name,
166                           unwrap(state->attributes[i].attribute));
167 
168   for (intptr_t i = 0; i < state->nRegions; ++i)
169     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
170 
171   MlirOperation result = wrap(Operation::create(cppState));
172   free(state->results);
173   free(state->operands);
174   free(state->successors);
175   free(state->regions);
176   free(state->attributes);
177   return result;
178 }
179 
180 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
181 
182 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
183 
184 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
185   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
186 }
187 
188 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
189   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
190 }
191 
192 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
193   return wrap(unwrap(op)->getNextNode());
194 }
195 
196 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
197   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
198 }
199 
200 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
201   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
202 }
203 
204 intptr_t mlirOperationGetNumResults(MlirOperation op) {
205   return static_cast<intptr_t>(unwrap(op)->getNumResults());
206 }
207 
208 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
209   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
210 }
211 
212 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
213   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
214 }
215 
216 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
217   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
218 }
219 
220 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
221   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
222 }
223 
224 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
225   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
226   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
227 }
228 
229 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
230                                               const char *name) {
231   return wrap(unwrap(op)->getAttr(name));
232 }
233 
234 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
235                         void *userData) {
236   detail::CallbackOstream stream(callback, userData);
237   unwrap(op)->print(stream);
238   stream.flush();
239 }
240 
241 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
242 
243 /* ========================================================================== */
244 /* Region API.                                                                */
245 /* ========================================================================== */
246 
247 MlirRegion mlirRegionCreate() { return wrap(new Region); }
248 
249 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
250   Region *cppRegion = unwrap(region);
251   if (cppRegion->empty())
252     return wrap(static_cast<Block *>(nullptr));
253   return wrap(&cppRegion->front());
254 }
255 
256 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
257   unwrap(region)->push_back(unwrap(block));
258 }
259 
260 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
261                                 MlirBlock block) {
262   auto &blockList = unwrap(region)->getBlocks();
263   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
264 }
265 
266 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
267                                      MlirBlock block) {
268   Region *cppRegion = unwrap(region);
269   if (mlirBlockIsNull(reference)) {
270     cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
271     return;
272   }
273 
274   assert(unwrap(reference)->getParent() == unwrap(region) &&
275          "expected reference block to belong to the region");
276   cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
277                                      unwrap(block));
278 }
279 
280 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
281                                       MlirBlock block) {
282   if (mlirBlockIsNull(reference))
283     return mlirRegionAppendOwnedBlock(region, block);
284 
285   assert(unwrap(reference)->getParent() == unwrap(region) &&
286          "expected reference block to belong to the region");
287   unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
288                                      unwrap(block));
289 }
290 
291 void mlirRegionDestroy(MlirRegion region) {
292   delete static_cast<Region *>(region.ptr);
293 }
294 
295 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
296 
297 /* ========================================================================== */
298 /* Block API.                                                                 */
299 /* ========================================================================== */
300 
301 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
302   Block *b = new Block;
303   for (intptr_t i = 0; i < nArgs; ++i)
304     b->addArgument(unwrap(args[i]));
305   return wrap(b);
306 }
307 
308 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
309   return wrap(unwrap(block)->getNextNode());
310 }
311 
312 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
313   Block *cppBlock = unwrap(block);
314   if (cppBlock->empty())
315     return wrap(static_cast<Operation *>(nullptr));
316   return wrap(&cppBlock->front());
317 }
318 
319 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
320   unwrap(block)->push_back(unwrap(operation));
321 }
322 
323 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
324                                    MlirOperation operation) {
325   auto &opList = unwrap(block)->getOperations();
326   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
327 }
328 
329 void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
330                                         MlirOperation reference,
331                                         MlirOperation operation) {
332   Block *cppBlock = unwrap(block);
333   if (mlirOperationIsNull(reference)) {
334     cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
335     return;
336   }
337 
338   assert(unwrap(reference)->getBlock() == unwrap(block) &&
339          "expected reference operation to belong to the block");
340   cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
341                                         unwrap(operation));
342 }
343 
344 void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
345                                          MlirOperation reference,
346                                          MlirOperation operation) {
347   if (mlirOperationIsNull(reference))
348     return mlirBlockAppendOwnedOperation(block, operation);
349 
350   assert(unwrap(reference)->getBlock() == unwrap(block) &&
351          "expected reference operation to belong to the block");
352   unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
353                                         unwrap(operation));
354 }
355 
356 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
357 
358 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
359 
360 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
361   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
362 }
363 
364 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
365   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
366 }
367 
368 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
369                     void *userData) {
370   detail::CallbackOstream stream(callback, userData);
371   unwrap(block)->print(stream);
372   stream.flush();
373 }
374 
375 /* ========================================================================== */
376 /* Value API.                                                                 */
377 /* ========================================================================== */
378 
379 MlirType mlirValueGetType(MlirValue value) {
380   return wrap(unwrap(value).getType());
381 }
382 
383 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
384                     void *userData) {
385   detail::CallbackOstream stream(callback, userData);
386   unwrap(value).print(stream);
387   stream.flush();
388 }
389 
390 /* ========================================================================== */
391 /* Type API.                                                                  */
392 /* ========================================================================== */
393 
394 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
395   return wrap(mlir::parseType(type, unwrap(context)));
396 }
397 
398 MlirContext mlirTypeGetContext(MlirType type) {
399   return wrap(unwrap(type).getContext());
400 }
401 
402 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
403 
404 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
405   detail::CallbackOstream stream(callback, userData);
406   unwrap(type).print(stream);
407   stream.flush();
408 }
409 
410 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
411 
412 /* ========================================================================== */
413 /* Attribute API.                                                             */
414 /* ========================================================================== */
415 
416 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
417   return wrap(mlir::parseAttribute(attr, unwrap(context)));
418 }
419 
420 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
421   return wrap(unwrap(attribute).getContext());
422 }
423 
424 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
425   return unwrap(a1) == unwrap(a2);
426 }
427 
428 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
429                         void *userData) {
430   detail::CallbackOstream stream(callback, userData);
431   unwrap(attr).print(stream);
432   stream.flush();
433 }
434 
435 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
436 
437 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
438   return MlirNamedAttribute{name, attr};
439 }
440