xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision dd1d5488)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/Parser.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 
22 /* ========================================================================== */
23 /* Printing helper.                                                           */
24 /* ========================================================================== */
25 
26 namespace {
27 /// A simple raw ostream subclass that forwards write_impl calls to the
28 /// user-supplied callback together with opaque user-supplied data.
29 class CallbackOstream : public llvm::raw_ostream {
30 public:
31   CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
32                   void *opaqueData)
33       : callback(callback), opaqueData(opaqueData), pos(0u) {}
34 
35   void write_impl(const char *ptr, size_t size) override {
36     callback(ptr, size, opaqueData);
37     pos += size;
38   }
39 
40   uint64_t current_pos() const override { return pos; }
41 
42 private:
43   std::function<void(const char *, intptr_t, void *)> callback;
44   void *opaqueData;
45   uint64_t pos;
46 };
47 } // end namespace
48 
49 /* ========================================================================== */
50 /* Context API.                                                               */
51 /* ========================================================================== */
52 
53 MlirContext mlirContextCreate() {
54   auto *context = new MLIRContext(/*loadAllDialects=*/false);
55   return wrap(context);
56 }
57 
58 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
59   return unwrap(ctx1) == unwrap(ctx2);
60 }
61 
62 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
63 
64 /* ========================================================================== */
65 /* Location API.                                                              */
66 /* ========================================================================== */
67 
68 MlirLocation mlirLocationFileLineColGet(MlirContext context,
69                                         const char *filename, unsigned line,
70                                         unsigned col) {
71   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
72 }
73 
74 MlirLocation mlirLocationUnknownGet(MlirContext context) {
75   return wrap(UnknownLoc::get(unwrap(context)));
76 }
77 
78 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
79                        void *userData) {
80   CallbackOstream stream(callback, userData);
81   unwrap(location).print(stream);
82   stream.flush();
83 }
84 
85 /* ========================================================================== */
86 /* Module API.                                                                */
87 /* ========================================================================== */
88 
89 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
90   return wrap(ModuleOp::create(unwrap(location)));
91 }
92 
93 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
94   OwningModuleRef owning = parseSourceString(module, unwrap(context));
95   if (!owning)
96     return MlirModule{nullptr};
97   return MlirModule{owning.release().getOperation()};
98 }
99 
100 void mlirModuleDestroy(MlirModule module) {
101   // Transfer ownership to an OwningModuleRef so that its destructor is called.
102   OwningModuleRef(unwrap(module));
103 }
104 
105 MlirOperation mlirModuleGetOperation(MlirModule module) {
106   return wrap(unwrap(module).getOperation());
107 }
108 
109 /* ========================================================================== */
110 /* Operation state API.                                                       */
111 /* ========================================================================== */
112 
113 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
114   MlirOperationState state;
115   state.name = name;
116   state.location = loc;
117   state.nResults = 0;
118   state.results = nullptr;
119   state.nOperands = 0;
120   state.operands = nullptr;
121   state.nRegions = 0;
122   state.regions = nullptr;
123   state.nSuccessors = 0;
124   state.successors = nullptr;
125   state.nAttributes = 0;
126   state.attributes = nullptr;
127   return state;
128 }
129 
130 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
131   state->elemName =                                                            \
132       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
133   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
134   state->sizeName += n;
135 
136 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
137                                   MlirType *results) {
138   APPEND_ELEMS(MlirType, nResults, results);
139 }
140 
141 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
142                                    MlirValue *operands) {
143   APPEND_ELEMS(MlirValue, nOperands, operands);
144 }
145 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
146                                        MlirRegion *regions) {
147   APPEND_ELEMS(MlirRegion, nRegions, regions);
148 }
149 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
150                                      MlirBlock *successors) {
151   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
152 }
153 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
154                                      MlirNamedAttribute *attributes) {
155   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
156 }
157 
158 /* ========================================================================== */
159 /* Operation API.                                                             */
160 /* ========================================================================== */
161 
162 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
163   assert(state);
164   OperationState cppState(unwrap(state->location), state->name);
165   SmallVector<Type, 4> resultStorage;
166   SmallVector<Value, 8> operandStorage;
167   SmallVector<Block *, 2> successorStorage;
168   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
169   cppState.addOperands(
170       unwrapList(state->nOperands, state->operands, operandStorage));
171   cppState.addSuccessors(
172       unwrapList(state->nSuccessors, state->successors, successorStorage));
173 
174   cppState.attributes.reserve(state->nAttributes);
175   for (intptr_t i = 0; i < state->nAttributes; ++i)
176     cppState.addAttribute(state->attributes[i].name,
177                           unwrap(state->attributes[i].attribute));
178 
179   for (intptr_t i = 0; i < state->nRegions; ++i)
180     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
181 
182   MlirOperation result = wrap(Operation::create(cppState));
183   free(state->results);
184   free(state->operands);
185   free(state->successors);
186   free(state->regions);
187   free(state->attributes);
188   return result;
189 }
190 
191 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
192 
193 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
194 
195 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
196   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
197 }
198 
199 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
200   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
201 }
202 
203 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
204   return wrap(unwrap(op)->getNextNode());
205 }
206 
207 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
208   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
209 }
210 
211 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
212   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
213 }
214 
215 intptr_t mlirOperationGetNumResults(MlirOperation op) {
216   return static_cast<intptr_t>(unwrap(op)->getNumResults());
217 }
218 
219 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
220   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
221 }
222 
223 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
224   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
225 }
226 
227 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
228   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
229 }
230 
231 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
232   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
233 }
234 
235 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
236   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
237   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
238 }
239 
240 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
241                                               const char *name) {
242   return wrap(unwrap(op)->getAttr(name));
243 }
244 
245 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
246                         void *userData) {
247   CallbackOstream stream(callback, userData);
248   unwrap(op)->print(stream);
249   stream.flush();
250 }
251 
252 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
253 
254 /* ========================================================================== */
255 /* Region API.                                                                */
256 /* ========================================================================== */
257 
258 MlirRegion mlirRegionCreate() { return wrap(new Region); }
259 
260 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
261   Region *cppRegion = unwrap(region);
262   if (cppRegion->empty())
263     return wrap(static_cast<Block *>(nullptr));
264   return wrap(&cppRegion->front());
265 }
266 
267 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
268   unwrap(region)->push_back(unwrap(block));
269 }
270 
271 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
272                                 MlirBlock block) {
273   auto &blockList = unwrap(region)->getBlocks();
274   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
275 }
276 
277 void mlirRegionDestroy(MlirRegion region) {
278   delete static_cast<Region *>(region.ptr);
279 }
280 
281 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
282 
283 /* ========================================================================== */
284 /* Block API.                                                                 */
285 /* ========================================================================== */
286 
287 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
288   Block *b = new Block;
289   for (intptr_t i = 0; i < nArgs; ++i)
290     b->addArgument(unwrap(args[i]));
291   return wrap(b);
292 }
293 
294 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
295   return wrap(unwrap(block)->getNextNode());
296 }
297 
298 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
299   Block *cppBlock = unwrap(block);
300   if (cppBlock->empty())
301     return wrap(static_cast<Operation *>(nullptr));
302   return wrap(&cppBlock->front());
303 }
304 
305 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
306   unwrap(block)->push_back(unwrap(operation));
307 }
308 
309 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
310                                    MlirOperation operation) {
311   auto &opList = unwrap(block)->getOperations();
312   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
313 }
314 
315 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
316 
317 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
318 
319 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
320   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
321 }
322 
323 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
324   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
325 }
326 
327 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
328                     void *userData) {
329   CallbackOstream stream(callback, userData);
330   unwrap(block)->print(stream);
331   stream.flush();
332 }
333 
334 /* ========================================================================== */
335 /* Value API.                                                                 */
336 /* ========================================================================== */
337 
338 MlirType mlirValueGetType(MlirValue value) {
339   return wrap(unwrap(value).getType());
340 }
341 
342 void mlirValuePrint(MlirValue value, MlirStringCallback callback,
343                     void *userData) {
344   CallbackOstream stream(callback, userData);
345   unwrap(value).print(stream);
346   stream.flush();
347 }
348 
349 /* ========================================================================== */
350 /* Type API.                                                                  */
351 /* ========================================================================== */
352 
353 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
354   return wrap(mlir::parseType(type, unwrap(context)));
355 }
356 
357 MlirContext mlirTypeGetContext(MlirType type) {
358   return wrap(unwrap(type).getContext());
359 }
360 
361 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
362 
363 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
364   CallbackOstream stream(callback, userData);
365   unwrap(type).print(stream);
366   stream.flush();
367 }
368 
369 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
370 
371 /* ========================================================================== */
372 /* Attribute API.                                                             */
373 /* ========================================================================== */
374 
375 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
376   return wrap(mlir::parseAttribute(attr, unwrap(context)));
377 }
378 
379 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
380   return unwrap(a1) == unwrap(a2);
381 }
382 
383 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
384                         void *userData) {
385   CallbackOstream stream(callback, userData);
386   unwrap(attr).print(stream);
387   stream.flush();
388 }
389 
390 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
391 
392 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
393   return MlirNamedAttribute{name, attr};
394 }
395