xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision 24d3210e)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/Module.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "mlir/Parser.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 using namespace mlir;
19 
20 /* ========================================================================== */
21 /* Definitions of methods for non-owning structures used in C API.            */
22 /* ========================================================================== */
23 
24 #define DEFINE_C_API_PTR_METHODS(name, cpptype)                                \
25   static name wrap(cpptype *cpp) { return name{cpp}; }                         \
26   static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); }
27 
28 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext)
29 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation)
30 DEFINE_C_API_PTR_METHODS(MlirBlock, Block)
31 DEFINE_C_API_PTR_METHODS(MlirRegion, Region)
32 
33 #define DEFINE_C_API_METHODS(name, cpptype)                                    \
34   static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; }     \
35   static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); }
36 
37 DEFINE_C_API_METHODS(MlirAttribute, Attribute)
38 DEFINE_C_API_METHODS(MlirLocation, Location);
39 DEFINE_C_API_METHODS(MlirType, Type)
40 DEFINE_C_API_METHODS(MlirValue, Value)
41 DEFINE_C_API_METHODS(MlirModule, ModuleOp)
42 
43 template <typename CppTy, typename CTy>
44 static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
45                                   SmallVectorImpl<CppTy> &storage) {
46   static_assert(
47       std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
48       "incompatible C and C++ types");
49 
50   if (size == 0)
51     return llvm::None;
52 
53   assert(storage.empty() && "expected to populate storage");
54   storage.reserve(size);
55   for (intptr_t i = 0; i < size; ++i)
56     storage.push_back(unwrap(*(first + i)));
57   return storage;
58 }
59 
60 /* ========================================================================== */
61 /* Printing helper.                                                           */
62 /* ========================================================================== */
63 
64 namespace {
65 /// A simple raw ostream subclass that forwards write_impl calls to the
66 /// user-supplied callback together with opaque user-supplied data.
67 class CallbackOstream : public llvm::raw_ostream {
68 public:
69   CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
70                   void *opaqueData)
71       : callback(callback), opaqueData(opaqueData), pos(0u) {}
72 
73   void write_impl(const char *ptr, size_t size) override {
74     callback(ptr, size, opaqueData);
75     pos += size;
76   }
77 
78   uint64_t current_pos() const override { return pos; }
79 
80 private:
81   std::function<void(const char *, intptr_t, void *)> callback;
82   void *opaqueData;
83   uint64_t pos;
84 };
85 } // end namespace
86 
87 /* ========================================================================== */
88 /* Context API.                                                               */
89 /* ========================================================================== */
90 
91 MlirContext mlirContextCreate() {
92   auto *context = new MLIRContext;
93   return wrap(context);
94 }
95 
96 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
97 
98 /* ========================================================================== */
99 /* Location API.                                                              */
100 /* ========================================================================== */
101 
102 MlirLocation mlirLocationFileLineColGet(MlirContext context,
103                                         const char *filename, unsigned line,
104                                         unsigned col) {
105   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
106 }
107 
108 MlirLocation mlirLocationUnknownGet(MlirContext context) {
109   return wrap(UnknownLoc::get(unwrap(context)));
110 }
111 
112 void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
113                        void *userData) {
114   CallbackOstream stream(callback, userData);
115   unwrap(location).print(stream);
116   stream.flush();
117 }
118 
119 /* ========================================================================== */
120 /* Module API.                                                                */
121 /* ========================================================================== */
122 
123 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
124   return wrap(ModuleOp::create(unwrap(location)));
125 }
126 
127 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
128   OwningModuleRef owning = parseSourceString(module, unwrap(context));
129   if (!owning)
130     return MlirModule{nullptr};
131   return MlirModule{owning.release().getOperation()};
132 }
133 
134 void mlirModuleDestroy(MlirModule module) {
135   // Transfer ownership to an OwningModuleRef so that its destructor is called.
136   OwningModuleRef(unwrap(module));
137 }
138 
139 MlirOperation mlirModuleGetOperation(MlirModule module) {
140   return wrap(unwrap(module).getOperation());
141 }
142 
143 /* ========================================================================== */
144 /* Operation state API.                                                       */
145 /* ========================================================================== */
146 
147 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
148   MlirOperationState state;
149   state.name = name;
150   state.location = loc;
151   state.nResults = 0;
152   state.results = nullptr;
153   state.nOperands = 0;
154   state.operands = nullptr;
155   state.nRegions = 0;
156   state.regions = nullptr;
157   state.nSuccessors = 0;
158   state.successors = nullptr;
159   state.nAttributes = 0;
160   state.attributes = nullptr;
161   return state;
162 }
163 
164 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
165   state->elemName =                                                            \
166       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
167   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
168   state->sizeName += n;
169 
170 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
171                                   MlirType *results) {
172   APPEND_ELEMS(MlirType, nResults, results);
173 }
174 
175 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
176                                    MlirValue *operands) {
177   APPEND_ELEMS(MlirValue, nOperands, operands);
178 }
179 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
180                                        MlirRegion *regions) {
181   APPEND_ELEMS(MlirRegion, nRegions, regions);
182 }
183 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
184                                      MlirBlock *successors) {
185   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
186 }
187 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
188                                      MlirNamedAttribute *attributes) {
189   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
190 }
191 
192 /* ========================================================================== */
193 /* Operation API.                                                             */
194 /* ========================================================================== */
195 
196 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
197   assert(state);
198   OperationState cppState(unwrap(state->location), state->name);
199   SmallVector<Type, 4> resultStorage;
200   SmallVector<Value, 8> operandStorage;
201   SmallVector<Block *, 2> successorStorage;
202   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
203   cppState.addOperands(
204       unwrapList(state->nOperands, state->operands, operandStorage));
205   cppState.addSuccessors(
206       unwrapList(state->nSuccessors, state->successors, successorStorage));
207 
208   cppState.attributes.reserve(state->nAttributes);
209   for (intptr_t i = 0; i < state->nAttributes; ++i)
210     cppState.addAttribute(state->attributes[i].name,
211                           unwrap(state->attributes[i].attribute));
212 
213   for (intptr_t i = 0; i < state->nRegions; ++i)
214     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
215 
216   MlirOperation result = wrap(Operation::create(cppState));
217   free(state->results);
218   free(state->operands);
219   free(state->successors);
220   free(state->regions);
221   free(state->attributes);
222   return result;
223 }
224 
225 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
226 
227 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
228 
229 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
230   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
231 }
232 
233 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
234   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
235 }
236 
237 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
238   return wrap(unwrap(op)->getNextNode());
239 }
240 
241 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
242   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
243 }
244 
245 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
246   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
247 }
248 
249 intptr_t mlirOperationGetNumResults(MlirOperation op) {
250   return static_cast<intptr_t>(unwrap(op)->getNumResults());
251 }
252 
253 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
254   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
255 }
256 
257 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
258   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
259 }
260 
261 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
262   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
263 }
264 
265 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
266   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
267 }
268 
269 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
270   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
271   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
272 }
273 
274 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
275                                               const char *name) {
276   return wrap(unwrap(op)->getAttr(name));
277 }
278 
279 void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
280                         void *userData) {
281   CallbackOstream stream(callback, userData);
282   unwrap(op)->print(stream);
283   stream.flush();
284 }
285 
286 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
287 
288 /* ========================================================================== */
289 /* Region API.                                                                */
290 /* ========================================================================== */
291 
292 MlirRegion mlirRegionCreate() { return wrap(new Region); }
293 
294 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
295   Region *cppRegion = unwrap(region);
296   if (cppRegion->empty())
297     return wrap(static_cast<Block *>(nullptr));
298   return wrap(&cppRegion->front());
299 }
300 
301 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
302   unwrap(region)->push_back(unwrap(block));
303 }
304 
305 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
306                                 MlirBlock block) {
307   auto &blockList = unwrap(region)->getBlocks();
308   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
309 }
310 
311 void mlirRegionDestroy(MlirRegion region) {
312   delete static_cast<Region *>(region.ptr);
313 }
314 
315 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
316 
317 /* ========================================================================== */
318 /* Block API.                                                                 */
319 /* ========================================================================== */
320 
321 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
322   Block *b = new Block;
323   for (intptr_t i = 0; i < nArgs; ++i)
324     b->addArgument(unwrap(args[i]));
325   return wrap(b);
326 }
327 
328 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
329   return wrap(unwrap(block)->getNextNode());
330 }
331 
332 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
333   Block *cppBlock = unwrap(block);
334   if (cppBlock->empty())
335     return wrap(static_cast<Operation *>(nullptr));
336   return wrap(&cppBlock->front());
337 }
338 
339 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
340   unwrap(block)->push_back(unwrap(operation));
341 }
342 
343 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
344                                    MlirOperation operation) {
345   auto &opList = unwrap(block)->getOperations();
346   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
347 }
348 
349 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
350 
351 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
352 
353 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
354   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
355 }
356 
357 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
358   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
359 }
360 
361 void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
362                     void *userData) {
363   CallbackOstream stream(callback, userData);
364   unwrap(block)->print(stream);
365   stream.flush();
366 }
367 
368 /* ========================================================================== */
369 /* Value API.                                                                 */
370 /* ========================================================================== */
371 
372 MlirType mlirValueGetType(MlirValue value) {
373   return wrap(unwrap(value).getType());
374 }
375 
376 void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
377                     void *userData) {
378   CallbackOstream stream(callback, userData);
379   unwrap(value).print(stream);
380   stream.flush();
381 }
382 
383 /* ========================================================================== */
384 /* Type API.                                                                  */
385 /* ========================================================================== */
386 
387 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
388   return wrap(mlir::parseType(type, unwrap(context)));
389 }
390 
391 void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
392   CallbackOstream stream(callback, userData);
393   unwrap(type).print(stream);
394   stream.flush();
395 }
396 
397 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
398 
399 /* ========================================================================== */
400 /* Attribute API.                                                             */
401 /* ========================================================================== */
402 
403 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
404   return wrap(mlir::parseAttribute(attr, unwrap(context)));
405 }
406 
407 void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
408                         void *userData) {
409   CallbackOstream stream(callback, userData);
410   unwrap(attr).print(stream);
411   stream.flush();
412 }
413 
414 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
415 
416 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
417   return MlirNamedAttribute{name, attr};
418 }
419