xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision b8cc449b)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/Parser.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 
22 /* ========================================================================== */
23 /* Printing helper.                                                           */
24 /* ========================================================================== */
25 
26 namespace {
27 /// A simple raw ostream subclass that forwards write_impl calls to the
28 /// user-supplied callback together with opaque user-supplied data.
29 class CallbackOstream : public llvm::raw_ostream {
30 public:
31   CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
32                   void *opaqueData)
33       : callback(callback), opaqueData(opaqueData), pos(0u) {}
34 
35   void write_impl(const char *ptr, size_t size) override {
36     callback(ptr, size, opaqueData);
37     pos += size;
38   }
39 
40   uint64_t current_pos() const override { return pos; }
41 
42 private:
43   std::function<void(const char *, intptr_t, void *)> callback;
44   void *opaqueData;
45   uint64_t pos;
46 };
47 } // end namespace
48 
49 /* ========================================================================== */
50 /* Context API.                                                               */
51 /* ========================================================================== */
52 
53 MlirContext mlirContextCreate() {
54   auto *context = new MLIRContext(/*loadAllDialects=*/false);
55   return wrap(context);
56 }
57 
58 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
59 
60 /* ========================================================================== */
61 /* Location API.                                                              */
62 /* ========================================================================== */
63 
64 MlirLocation mlirLocationFileLineColGet(MlirContext context,
65                                         const char *filename, unsigned line,
66                                         unsigned col) {
67   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
68 }
69 
70 MlirLocation mlirLocationUnknownGet(MlirContext context) {
71   return wrap(UnknownLoc::get(unwrap(context)));
72 }
73 
74 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
75                        void *userData) {
76   CallbackOstream stream(callback, userData);
77   unwrap(location).print(stream);
78   stream.flush();
79 }
80 
81 /* ========================================================================== */
82 /* Module API.                                                                */
83 /* ========================================================================== */
84 
85 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
86   return wrap(ModuleOp::create(unwrap(location)));
87 }
88 
89 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
90   OwningModuleRef owning = parseSourceString(module, unwrap(context));
91   if (!owning)
92     return MlirModule{nullptr};
93   return MlirModule{owning.release().getOperation()};
94 }
95 
96 void mlirModuleDestroy(MlirModule module) {
97   // Transfer ownership to an OwningModuleRef so that its destructor is called.
98   OwningModuleRef(unwrap(module));
99 }
100 
101 MlirOperation mlirModuleGetOperation(MlirModule module) {
102   return wrap(unwrap(module).getOperation());
103 }
104 
105 /* ========================================================================== */
106 /* Operation state API.                                                       */
107 /* ========================================================================== */
108 
109 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
110   MlirOperationState state;
111   state.name = name;
112   state.location = loc;
113   state.nResults = 0;
114   state.results = nullptr;
115   state.nOperands = 0;
116   state.operands = nullptr;
117   state.nRegions = 0;
118   state.regions = nullptr;
119   state.nSuccessors = 0;
120   state.successors = nullptr;
121   state.nAttributes = 0;
122   state.attributes = nullptr;
123   return state;
124 }
125 
126 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
127   state->elemName =                                                            \
128       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
129   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
130   state->sizeName += n;
131 
132 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
133                                   MlirType *results) {
134   APPEND_ELEMS(MlirType, nResults, results);
135 }
136 
137 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
138                                    MlirValue *operands) {
139   APPEND_ELEMS(MlirValue, nOperands, operands);
140 }
141 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
142                                        MlirRegion *regions) {
143   APPEND_ELEMS(MlirRegion, nRegions, regions);
144 }
145 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
146                                      MlirBlock *successors) {
147   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
148 }
149 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
150                                      MlirNamedAttribute *attributes) {
151   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
152 }
153 
154 /* ========================================================================== */
155 /* Operation API.                                                             */
156 /* ========================================================================== */
157 
158 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
159   assert(state);
160   OperationState cppState(unwrap(state->location), state->name);
161   SmallVector<Type, 4> resultStorage;
162   SmallVector<Value, 8> operandStorage;
163   SmallVector<Block *, 2> successorStorage;
164   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
165   cppState.addOperands(
166       unwrapList(state->nOperands, state->operands, operandStorage));
167   cppState.addSuccessors(
168       unwrapList(state->nSuccessors, state->successors, successorStorage));
169 
170   cppState.attributes.reserve(state->nAttributes);
171   for (intptr_t i = 0; i < state->nAttributes; ++i)
172     cppState.addAttribute(state->attributes[i].name,
173                           unwrap(state->attributes[i].attribute));
174 
175   for (intptr_t i = 0; i < state->nRegions; ++i)
176     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
177 
178   MlirOperation result = wrap(Operation::create(cppState));
179   free(state->results);
180   free(state->operands);
181   free(state->successors);
182   free(state->regions);
183   free(state->attributes);
184   return result;
185 }
186 
187 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
188 
189 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
190 
191 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
192   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
193 }
194 
195 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
196   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
197 }
198 
199 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
200   return wrap(unwrap(op)->getNextNode());
201 }
202 
203 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
204   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
205 }
206 
207 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
208   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
209 }
210 
211 intptr_t mlirOperationGetNumResults(MlirOperation op) {
212   return static_cast<intptr_t>(unwrap(op)->getNumResults());
213 }
214 
215 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
216   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
217 }
218 
219 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
220   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
221 }
222 
223 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
224   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
225 }
226 
227 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
228   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
229 }
230 
231 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
232   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
233   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
234 }
235 
236 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
237                                               const char *name) {
238   return wrap(unwrap(op)->getAttr(name));
239 }
240 
241 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
242                         void *userData) {
243   CallbackOstream stream(callback, userData);
244   unwrap(op)->print(stream);
245   stream.flush();
246 }
247 
248 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
249 
250 /* ========================================================================== */
251 /* Region API.                                                                */
252 /* ========================================================================== */
253 
254 MlirRegion mlirRegionCreate() { return wrap(new Region); }
255 
256 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
257   Region *cppRegion = unwrap(region);
258   if (cppRegion->empty())
259     return wrap(static_cast<Block *>(nullptr));
260   return wrap(&cppRegion->front());
261 }
262 
263 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
264   unwrap(region)->push_back(unwrap(block));
265 }
266 
267 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
268                                 MlirBlock block) {
269   auto &blockList = unwrap(region)->getBlocks();
270   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
271 }
272 
273 void mlirRegionDestroy(MlirRegion region) {
274   delete static_cast<Region *>(region.ptr);
275 }
276 
277 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
278 
279 /* ========================================================================== */
280 /* Block API.                                                                 */
281 /* ========================================================================== */
282 
283 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
284   Block *b = new Block;
285   for (intptr_t i = 0; i < nArgs; ++i)
286     b->addArgument(unwrap(args[i]));
287   return wrap(b);
288 }
289 
290 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
291   return wrap(unwrap(block)->getNextNode());
292 }
293 
294 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
295   Block *cppBlock = unwrap(block);
296   if (cppBlock->empty())
297     return wrap(static_cast<Operation *>(nullptr));
298   return wrap(&cppBlock->front());
299 }
300 
301 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
302   unwrap(block)->push_back(unwrap(operation));
303 }
304 
305 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
306                                    MlirOperation operation) {
307   auto &opList = unwrap(block)->getOperations();
308   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
309 }
310 
311 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
312 
313 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
314 
315 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
316   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
317 }
318 
319 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
320   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
321 }
322 
323 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
324                     void *userData) {
325   CallbackOstream stream(callback, userData);
326   unwrap(block)->print(stream);
327   stream.flush();
328 }
329 
330 /* ========================================================================== */
331 /* Value API.                                                                 */
332 /* ========================================================================== */
333 
334 MlirType mlirValueGetType(MlirValue value) {
335   return wrap(unwrap(value).getType());
336 }
337 
338 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
339                     void *userData) {
340   CallbackOstream stream(callback, userData);
341   unwrap(value).print(stream);
342   stream.flush();
343 }
344 
345 /* ========================================================================== */
346 /* Type API.                                                                  */
347 /* ========================================================================== */
348 
349 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
350   return wrap(mlir::parseType(type, unwrap(context)));
351 }
352 
353 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
354 
355 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
356   CallbackOstream stream(callback, userData);
357   unwrap(type).print(stream);
358   stream.flush();
359 }
360 
361 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
362 
363 /* ========================================================================== */
364 /* Attribute API.                                                             */
365 /* ========================================================================== */
366 
367 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
368   return wrap(mlir::parseAttribute(attr, unwrap(context)));
369 }
370 
371 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
372   return unwrap(a1) == unwrap(a2);
373 }
374 
375 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
376                         void *userData) {
377   CallbackOstream stream(callback, userData);
378   unwrap(attr).print(stream);
379   stream.flush();
380 }
381 
382 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
383 
384 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
385   return MlirNamedAttribute{name, attr};
386 }
387