xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision fdc6aea3)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/Module.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/Types.h"
16 #include "mlir/InitAllDialects.h"
17 #include "mlir/Parser.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 
22 /* ========================================================================== */
23 /* Definitions of methods for non-owning structures used in C API.            */
24 /* ========================================================================== */
25 
26 #define DEFINE_C_API_PTR_METHODS(name, cpptype)                                \
27   static name wrap(cpptype *cpp) { return name{cpp}; }                         \
28   static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); }
29 
30 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext)
31 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation)
32 DEFINE_C_API_PTR_METHODS(MlirBlock, Block)
33 DEFINE_C_API_PTR_METHODS(MlirRegion, Region)
34 
35 #define DEFINE_C_API_METHODS(name, cpptype)                                    \
36   static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; }     \
37   static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); }
38 
39 DEFINE_C_API_METHODS(MlirAttribute, Attribute)
40 DEFINE_C_API_METHODS(MlirLocation, Location);
41 DEFINE_C_API_METHODS(MlirType, Type)
42 DEFINE_C_API_METHODS(MlirValue, Value)
43 DEFINE_C_API_METHODS(MlirModule, ModuleOp)
44 
45 template <typename CppTy, typename CTy>
46 static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
47                                   SmallVectorImpl<CppTy> &storage) {
48   static_assert(
49       std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
50       "incompatible C and C++ types");
51 
52   if (size == 0)
53     return llvm::None;
54 
55   assert(storage.empty() && "expected to populate storage");
56   storage.reserve(size);
57   for (intptr_t i = 0; i < size; ++i)
58     storage.push_back(unwrap(*(first + i)));
59   return storage;
60 }
61 
62 /* ========================================================================== */
63 /* Printing helper.                                                           */
64 /* ========================================================================== */
65 
66 namespace {
67 /// A simple raw ostream subclass that forwards write_impl calls to the
68 /// user-supplied callback together with opaque user-supplied data.
69 class CallbackOstream : public llvm::raw_ostream {
70 public:
71   CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
72                   void *opaqueData)
73       : callback(callback), opaqueData(opaqueData), pos(0u) {}
74 
75   void write_impl(const char *ptr, size_t size) override {
76     callback(ptr, size, opaqueData);
77     pos += size;
78   }
79 
80   uint64_t current_pos() const override { return pos; }
81 
82 private:
83   std::function<void(const char *, intptr_t, void *)> callback;
84   void *opaqueData;
85   uint64_t pos;
86 };
87 } // end namespace
88 
89 /* ========================================================================== */
90 /* Context API.                                                               */
91 /* ========================================================================== */
92 
93 MlirContext mlirContextCreate() {
94   auto *context = new MLIRContext(false);
95   return wrap(context);
96 }
97 
98 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
99 
100 void mlirContextLoadAllDialects(MlirContext context) {
101   registerAllDialects(unwrap(context));
102   getGlobalDialectRegistry().loadAll(unwrap(context));
103 }
104 
105 /* ========================================================================== */
106 /* Location API.                                                              */
107 /* ========================================================================== */
108 
109 MlirLocation mlirLocationFileLineColGet(MlirContext context,
110                                         const char *filename, unsigned line,
111                                         unsigned col) {
112   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
113 }
114 
115 MlirLocation mlirLocationUnknownGet(MlirContext context) {
116   return wrap(UnknownLoc::get(unwrap(context)));
117 }
118 
119 void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
120                        void *userData) {
121   CallbackOstream stream(callback, userData);
122   unwrap(location).print(stream);
123   stream.flush();
124 }
125 
126 /* ========================================================================== */
127 /* Module API.                                                                */
128 /* ========================================================================== */
129 
130 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
131   return wrap(ModuleOp::create(unwrap(location)));
132 }
133 
134 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
135   OwningModuleRef owning = parseSourceString(module, unwrap(context));
136   return MlirModule{owning.release().getOperation()};
137 }
138 
139 void mlirModuleDestroy(MlirModule module) {
140   // Transfer ownership to an OwningModuleRef so that its destructor is called.
141   OwningModuleRef(unwrap(module));
142 }
143 
144 MlirOperation mlirModuleGetOperation(MlirModule module) {
145   return wrap(unwrap(module).getOperation());
146 }
147 
148 /* ========================================================================== */
149 /* Operation state API.                                                       */
150 /* ========================================================================== */
151 
152 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
153   MlirOperationState state;
154   state.name = name;
155   state.location = loc;
156   state.nResults = 0;
157   state.results = nullptr;
158   state.nOperands = 0;
159   state.operands = nullptr;
160   state.nRegions = 0;
161   state.regions = nullptr;
162   state.nSuccessors = 0;
163   state.successors = nullptr;
164   state.nAttributes = 0;
165   state.attributes = nullptr;
166   return state;
167 }
168 
169 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
170   state->elemName =                                                            \
171       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
172   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
173   state->sizeName += n;
174 
175 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
176                                   MlirType *results) {
177   APPEND_ELEMS(MlirType, nResults, results);
178 }
179 
180 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
181                                    MlirValue *operands) {
182   APPEND_ELEMS(MlirValue, nOperands, operands);
183 }
184 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
185                                        MlirRegion *regions) {
186   APPEND_ELEMS(MlirRegion, nRegions, regions);
187 }
188 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
189                                      MlirBlock *successors) {
190   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
191 }
192 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
193                                      MlirNamedAttribute *attributes) {
194   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
195 }
196 
197 /* ========================================================================== */
198 /* Operation API.                                                             */
199 /* ========================================================================== */
200 
201 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
202   assert(state);
203   OperationState cppState(unwrap(state->location), state->name);
204   SmallVector<Type, 4> resultStorage;
205   SmallVector<Value, 8> operandStorage;
206   SmallVector<Block *, 2> successorStorage;
207   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
208   cppState.addOperands(
209       unwrapList(state->nOperands, state->operands, operandStorage));
210   cppState.addSuccessors(
211       unwrapList(state->nSuccessors, state->successors, successorStorage));
212 
213   cppState.attributes.reserve(state->nAttributes);
214   for (intptr_t i = 0; i < state->nAttributes; ++i)
215     cppState.addAttribute(state->attributes[i].name,
216                           unwrap(state->attributes[i].attribute));
217 
218   for (intptr_t i = 0; i < state->nRegions; ++i)
219     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
220 
221   MlirOperation result = wrap(Operation::create(cppState));
222   free(state->results);
223   free(state->operands);
224   free(state->successors);
225   free(state->regions);
226   free(state->attributes);
227   return result;
228 }
229 
230 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
231 
232 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
233 
234 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
235   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
236 }
237 
238 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
239   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
240 }
241 
242 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
243   return wrap(unwrap(op)->getNextNode());
244 }
245 
246 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
247   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
248 }
249 
250 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
251   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
252 }
253 
254 intptr_t mlirOperationGetNumResults(MlirOperation op) {
255   return static_cast<intptr_t>(unwrap(op)->getNumResults());
256 }
257 
258 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
259   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
260 }
261 
262 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
263   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
264 }
265 
266 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
267   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
268 }
269 
270 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
271   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
272 }
273 
274 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
275   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
276   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
277 }
278 
279 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
280                                               const char *name) {
281   return wrap(unwrap(op)->getAttr(name));
282 }
283 
284 void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
285                         void *userData) {
286   CallbackOstream stream(callback, userData);
287   unwrap(op)->print(stream);
288   stream.flush();
289 }
290 
291 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
292 
293 /* ========================================================================== */
294 /* Region API.                                                                */
295 /* ========================================================================== */
296 
297 MlirRegion mlirRegionCreate() { return wrap(new Region); }
298 
299 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
300   Region *cppRegion = unwrap(region);
301   if (cppRegion->empty())
302     return wrap(static_cast<Block *>(nullptr));
303   return wrap(&cppRegion->front());
304 }
305 
306 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
307   unwrap(region)->push_back(unwrap(block));
308 }
309 
310 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
311                                 MlirBlock block) {
312   auto &blockList = unwrap(region)->getBlocks();
313   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
314 }
315 
316 void mlirRegionDestroy(MlirRegion region) {
317   delete static_cast<Region *>(region.ptr);
318 }
319 
320 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
321 
322 /* ========================================================================== */
323 /* Block API.                                                                 */
324 /* ========================================================================== */
325 
326 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
327   Block *b = new Block;
328   for (intptr_t i = 0; i < nArgs; ++i)
329     b->addArgument(unwrap(args[i]));
330   return wrap(b);
331 }
332 
333 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
334   return wrap(unwrap(block)->getNextNode());
335 }
336 
337 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
338   Block *cppBlock = unwrap(block);
339   if (cppBlock->empty())
340     return wrap(static_cast<Operation *>(nullptr));
341   return wrap(&cppBlock->front());
342 }
343 
344 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
345   unwrap(block)->push_back(unwrap(operation));
346 }
347 
348 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
349                                    MlirOperation operation) {
350   auto &opList = unwrap(block)->getOperations();
351   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
352 }
353 
354 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
355 
356 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
357 
358 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
359   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
360 }
361 
362 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
363   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
364 }
365 
366 void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
367                     void *userData) {
368   CallbackOstream stream(callback, userData);
369   unwrap(block)->print(stream);
370   stream.flush();
371 }
372 
373 /* ========================================================================== */
374 /* Value API.                                                                 */
375 /* ========================================================================== */
376 
377 MlirType mlirValueGetType(MlirValue value) {
378   return wrap(unwrap(value).getType());
379 }
380 
381 void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
382                     void *userData) {
383   CallbackOstream stream(callback, userData);
384   unwrap(value).print(stream);
385   stream.flush();
386 }
387 
388 /* ========================================================================== */
389 /* Type API.                                                                  */
390 /* ========================================================================== */
391 
392 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
393   return wrap(mlir::parseType(type, unwrap(context)));
394 }
395 
396 void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
397   CallbackOstream stream(callback, userData);
398   unwrap(type).print(stream);
399   stream.flush();
400 }
401 
402 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
403 
404 /* ========================================================================== */
405 /* Attribute API.                                                             */
406 /* ========================================================================== */
407 
408 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
409   return wrap(mlir::parseAttribute(attr, unwrap(context)));
410 }
411 
412 void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
413                         void *userData) {
414   CallbackOstream stream(callback, userData);
415   unwrap(attr).print(stream);
416   stream.flush();
417 }
418 
419 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
420 
421 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
422   return MlirNamedAttribute{name, attr};
423 }
424