xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision bb09ef95)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Utils.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Module.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/Types.h"
18 #include "mlir/Parser.h"
19 
20 using namespace mlir;
21 
22 /* ========================================================================== */
23 /* Context API.                                                               */
24 /* ========================================================================== */
25 
26 MlirContext mlirContextCreate() {
27   auto *context = new MLIRContext(/*loadAllDialects=*/false);
28   return wrap(context);
29 }
30 
31 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
32   return unwrap(ctx1) == unwrap(ctx2);
33 }
34 
35 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
36 
37 /* ========================================================================== */
38 /* Location API.                                                              */
39 /* ========================================================================== */
40 
41 MlirLocation mlirLocationFileLineColGet(MlirContext context,
42                                         const char *filename, unsigned line,
43                                         unsigned col) {
44   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
45 }
46 
47 MlirLocation mlirLocationUnknownGet(MlirContext context) {
48   return wrap(UnknownLoc::get(unwrap(context)));
49 }
50 
51 MlirContext mlirLocationGetContext(MlirLocation location) {
52   return wrap(unwrap(location).getContext());
53 }
54 
55 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
56                        void *userData) {
57   detail::CallbackOstream stream(callback, userData);
58   unwrap(location).print(stream);
59   stream.flush();
60 }
61 
62 /* ========================================================================== */
63 /* Module API.                                                                */
64 /* ========================================================================== */
65 
66 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
67   return wrap(ModuleOp::create(unwrap(location)));
68 }
69 
70 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
71   OwningModuleRef owning = parseSourceString(module, unwrap(context));
72   if (!owning)
73     return MlirModule{nullptr};
74   return MlirModule{owning.release().getOperation()};
75 }
76 
77 MlirContext mlirModuleGetContext(MlirModule module) {
78   return wrap(unwrap(module).getContext());
79 }
80 
81 void mlirModuleDestroy(MlirModule module) {
82   // Transfer ownership to an OwningModuleRef so that its destructor is called.
83   OwningModuleRef(unwrap(module));
84 }
85 
86 MlirOperation mlirModuleGetOperation(MlirModule module) {
87   return wrap(unwrap(module).getOperation());
88 }
89 
90 /* ========================================================================== */
91 /* Operation state API.                                                       */
92 /* ========================================================================== */
93 
94 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
95   MlirOperationState state;
96   state.name = name;
97   state.location = loc;
98   state.nResults = 0;
99   state.results = nullptr;
100   state.nOperands = 0;
101   state.operands = nullptr;
102   state.nRegions = 0;
103   state.regions = nullptr;
104   state.nSuccessors = 0;
105   state.successors = nullptr;
106   state.nAttributes = 0;
107   state.attributes = nullptr;
108   return state;
109 }
110 
111 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
112   state->elemName =                                                            \
113       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
114   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
115   state->sizeName += n;
116 
117 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
118                                   MlirType *results) {
119   APPEND_ELEMS(MlirType, nResults, results);
120 }
121 
122 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
123                                    MlirValue *operands) {
124   APPEND_ELEMS(MlirValue, nOperands, operands);
125 }
126 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
127                                        MlirRegion *regions) {
128   APPEND_ELEMS(MlirRegion, nRegions, regions);
129 }
130 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
131                                      MlirBlock *successors) {
132   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
133 }
134 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
135                                      MlirNamedAttribute *attributes) {
136   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
137 }
138 
139 /* ========================================================================== */
140 /* Operation API.                                                             */
141 /* ========================================================================== */
142 
143 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
144   assert(state);
145   OperationState cppState(unwrap(state->location), state->name);
146   SmallVector<Type, 4> resultStorage;
147   SmallVector<Value, 8> operandStorage;
148   SmallVector<Block *, 2> successorStorage;
149   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
150   cppState.addOperands(
151       unwrapList(state->nOperands, state->operands, operandStorage));
152   cppState.addSuccessors(
153       unwrapList(state->nSuccessors, state->successors, successorStorage));
154 
155   cppState.attributes.reserve(state->nAttributes);
156   for (intptr_t i = 0; i < state->nAttributes; ++i)
157     cppState.addAttribute(state->attributes[i].name,
158                           unwrap(state->attributes[i].attribute));
159 
160   for (intptr_t i = 0; i < state->nRegions; ++i)
161     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
162 
163   MlirOperation result = wrap(Operation::create(cppState));
164   free(state->results);
165   free(state->operands);
166   free(state->successors);
167   free(state->regions);
168   free(state->attributes);
169   return result;
170 }
171 
172 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
173 
174 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
175 
176 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
177   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
178 }
179 
180 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
181   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
182 }
183 
184 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
185   return wrap(unwrap(op)->getNextNode());
186 }
187 
188 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
189   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
190 }
191 
192 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
193   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
194 }
195 
196 intptr_t mlirOperationGetNumResults(MlirOperation op) {
197   return static_cast<intptr_t>(unwrap(op)->getNumResults());
198 }
199 
200 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
201   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
202 }
203 
204 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
205   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
206 }
207 
208 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
209   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
210 }
211 
212 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
213   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
214 }
215 
216 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
217   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
218   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
219 }
220 
221 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
222                                               const char *name) {
223   return wrap(unwrap(op)->getAttr(name));
224 }
225 
226 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
227                         void *userData) {
228   detail::CallbackOstream stream(callback, userData);
229   unwrap(op)->print(stream);
230   stream.flush();
231 }
232 
233 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
234 
235 /* ========================================================================== */
236 /* Region API.                                                                */
237 /* ========================================================================== */
238 
239 MlirRegion mlirRegionCreate() { return wrap(new Region); }
240 
241 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
242   Region *cppRegion = unwrap(region);
243   if (cppRegion->empty())
244     return wrap(static_cast<Block *>(nullptr));
245   return wrap(&cppRegion->front());
246 }
247 
248 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
249   unwrap(region)->push_back(unwrap(block));
250 }
251 
252 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
253                                 MlirBlock block) {
254   auto &blockList = unwrap(region)->getBlocks();
255   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
256 }
257 
258 void mlirRegionDestroy(MlirRegion region) {
259   delete static_cast<Region *>(region.ptr);
260 }
261 
262 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
263 
264 /* ========================================================================== */
265 /* Block API.                                                                 */
266 /* ========================================================================== */
267 
268 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
269   Block *b = new Block;
270   for (intptr_t i = 0; i < nArgs; ++i)
271     b->addArgument(unwrap(args[i]));
272   return wrap(b);
273 }
274 
275 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
276   return wrap(unwrap(block)->getNextNode());
277 }
278 
279 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
280   Block *cppBlock = unwrap(block);
281   if (cppBlock->empty())
282     return wrap(static_cast<Operation *>(nullptr));
283   return wrap(&cppBlock->front());
284 }
285 
286 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
287   unwrap(block)->push_back(unwrap(operation));
288 }
289 
290 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
291                                    MlirOperation operation) {
292   auto &opList = unwrap(block)->getOperations();
293   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
294 }
295 
296 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
297 
298 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
299 
300 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
301   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
302 }
303 
304 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
305   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
306 }
307 
308 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
309                     void *userData) {
310   detail::CallbackOstream stream(callback, userData);
311   unwrap(block)->print(stream);
312   stream.flush();
313 }
314 
315 /* ========================================================================== */
316 /* Value API.                                                                 */
317 /* ========================================================================== */
318 
319 MlirType mlirValueGetType(MlirValue value) {
320   return wrap(unwrap(value).getType());
321 }
322 
323 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
324                     void *userData) {
325   detail::CallbackOstream stream(callback, userData);
326   unwrap(value).print(stream);
327   stream.flush();
328 }
329 
330 /* ========================================================================== */
331 /* Type API.                                                                  */
332 /* ========================================================================== */
333 
334 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
335   return wrap(mlir::parseType(type, unwrap(context)));
336 }
337 
338 MlirContext mlirTypeGetContext(MlirType type) {
339   return wrap(unwrap(type).getContext());
340 }
341 
342 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
343 
344 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
345   detail::CallbackOstream stream(callback, userData);
346   unwrap(type).print(stream);
347   stream.flush();
348 }
349 
350 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
351 
352 /* ========================================================================== */
353 /* Attribute API.                                                             */
354 /* ========================================================================== */
355 
356 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
357   return wrap(mlir::parseAttribute(attr, unwrap(context)));
358 }
359 
360 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
361   return wrap(unwrap(attribute).getContext());
362 }
363 
364 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
365   return unwrap(a1) == unwrap(a2);
366 }
367 
368 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
369                         void *userData) {
370   detail::CallbackOstream stream(callback, userData);
371   unwrap(attr).print(stream);
372   stream.flush();
373 }
374 
375 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
376 
377 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
378   return MlirNamedAttribute{name, attr};
379 }
380