xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision fbfd831d)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/Module.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "mlir/Parser.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 using namespace mlir;
19 
20 /* ========================================================================== */
21 /* Definitions of methods for non-owning structures used in C API.            */
22 /* ========================================================================== */
23 
24 #define DEFINE_C_API_PTR_METHODS(name, cpptype)                                \
25   static name wrap(cpptype *cpp) { return name{cpp}; }                         \
26   static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); }
27 
28 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext)
29 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation)
30 DEFINE_C_API_PTR_METHODS(MlirBlock, Block)
31 DEFINE_C_API_PTR_METHODS(MlirRegion, Region)
32 
33 #define DEFINE_C_API_METHODS(name, cpptype)                                    \
34   static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; }     \
35   static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); }
36 
37 DEFINE_C_API_METHODS(MlirAttribute, Attribute)
38 DEFINE_C_API_METHODS(MlirLocation, Location);
39 DEFINE_C_API_METHODS(MlirType, Type)
40 DEFINE_C_API_METHODS(MlirValue, Value)
41 DEFINE_C_API_METHODS(MlirModule, ModuleOp)
42 
43 template <typename CppTy, typename CTy>
44 static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
45                                   SmallVectorImpl<CppTy> &storage) {
46   static_assert(
47       std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
48       "incompatible C and C++ types");
49 
50   if (size == 0)
51     return llvm::None;
52 
53   assert(storage.empty() && "expected to populate storage");
54   storage.reserve(size);
55   for (intptr_t i = 0; i < size; ++i)
56     storage.push_back(unwrap(*(first + i)));
57   return storage;
58 }
59 
60 /* ========================================================================== */
61 /* Printing helper.                                                           */
62 /* ========================================================================== */
63 
64 namespace {
65 /// A simple raw ostream subclass that forwards write_impl calls to the
66 /// user-supplied callback together with opaque user-supplied data.
67 class CallbackOstream : public llvm::raw_ostream {
68 public:
69   CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
70                   void *opaqueData)
71       : callback(callback), opaqueData(opaqueData), pos(0u) {}
72 
73   void write_impl(const char *ptr, size_t size) override {
74     callback(ptr, size, opaqueData);
75     pos += size;
76   }
77 
78   uint64_t current_pos() const override { return pos; }
79 
80 private:
81   std::function<void(const char *, intptr_t, void *)> callback;
82   void *opaqueData;
83   uint64_t pos;
84 };
85 } // end namespace
86 
87 /* ========================================================================== */
88 /* Context API.                                                               */
89 /* ========================================================================== */
90 
91 MlirContext mlirContextCreate() {
92   auto *context = new MLIRContext;
93   return wrap(context);
94 }
95 
96 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
97 
98 /* ========================================================================== */
99 /* Location API.                                                              */
100 /* ========================================================================== */
101 
102 MlirLocation mlirLocationFileLineColGet(MlirContext context,
103                                         const char *filename, unsigned line,
104                                         unsigned col) {
105   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
106 }
107 
108 MlirLocation mlirLocationUnknownGet(MlirContext context) {
109   return wrap(UnknownLoc::get(unwrap(context)));
110 }
111 
112 void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
113                        void *userData) {
114   CallbackOstream stream(callback, userData);
115   unwrap(location).print(stream);
116   stream.flush();
117 }
118 
119 /* ========================================================================== */
120 /* Module API.                                                                */
121 /* ========================================================================== */
122 
123 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
124   return wrap(ModuleOp::create(unwrap(location)));
125 }
126 
127 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
128   OwningModuleRef owning = parseSourceString(module, unwrap(context));
129   return MlirModule{owning.release().getOperation()};
130 }
131 
132 void mlirModuleDestroy(MlirModule module) {
133   // Transfer ownership to an OwningModuleRef so that its destructor is called.
134   OwningModuleRef(unwrap(module));
135 }
136 
137 MlirOperation mlirModuleGetOperation(MlirModule module) {
138   return wrap(unwrap(module).getOperation());
139 }
140 
141 /* ========================================================================== */
142 /* Operation state API.                                                       */
143 /* ========================================================================== */
144 
145 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
146   MlirOperationState state;
147   state.name = name;
148   state.location = loc;
149   state.nResults = 0;
150   state.results = nullptr;
151   state.nOperands = 0;
152   state.operands = nullptr;
153   state.nRegions = 0;
154   state.regions = nullptr;
155   state.nSuccessors = 0;
156   state.successors = nullptr;
157   state.nAttributes = 0;
158   state.attributes = nullptr;
159   return state;
160 }
161 
162 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
163   state->elemName =                                                            \
164       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
165   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
166   state->sizeName += n;
167 
168 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
169                                   MlirType *results) {
170   APPEND_ELEMS(MlirType, nResults, results);
171 }
172 
173 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
174                                    MlirValue *operands) {
175   APPEND_ELEMS(MlirValue, nOperands, operands);
176 }
177 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
178                                        MlirRegion *regions) {
179   APPEND_ELEMS(MlirRegion, nRegions, regions);
180 }
181 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
182                                      MlirBlock *successors) {
183   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
184 }
185 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
186                                      MlirNamedAttribute *attributes) {
187   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
188 }
189 
190 /* ========================================================================== */
191 /* Operation API.                                                             */
192 /* ========================================================================== */
193 
194 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
195   assert(state);
196   OperationState cppState(unwrap(state->location), state->name);
197   SmallVector<Type, 4> resultStorage;
198   SmallVector<Value, 8> operandStorage;
199   SmallVector<Block *, 2> successorStorage;
200   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
201   cppState.addOperands(
202       unwrapList(state->nOperands, state->operands, operandStorage));
203   cppState.addSuccessors(
204       unwrapList(state->nSuccessors, state->successors, successorStorage));
205 
206   cppState.attributes.reserve(state->nAttributes);
207   for (intptr_t i = 0; i < state->nAttributes; ++i)
208     cppState.addAttribute(state->attributes[i].name,
209                           unwrap(state->attributes[i].attribute));
210 
211   for (intptr_t i = 0; i < state->nRegions; ++i)
212     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
213 
214   MlirOperation result = wrap(Operation::create(cppState));
215   free(state->results);
216   free(state->operands);
217   free(state->successors);
218   free(state->regions);
219   free(state->attributes);
220   return result;
221 }
222 
223 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
224 
225 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
226 
227 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
228   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
229 }
230 
231 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
232   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
233 }
234 
235 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
236   return wrap(unwrap(op)->getNextNode());
237 }
238 
239 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
240   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
241 }
242 
243 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
244   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
245 }
246 
247 intptr_t mlirOperationGetNumResults(MlirOperation op) {
248   return static_cast<intptr_t>(unwrap(op)->getNumResults());
249 }
250 
251 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
252   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
253 }
254 
255 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
256   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
257 }
258 
259 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
260   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
261 }
262 
263 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
264   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
265 }
266 
267 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
268   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
269   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
270 }
271 
272 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
273                                               const char *name) {
274   return wrap(unwrap(op)->getAttr(name));
275 }
276 
277 void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
278                         void *userData) {
279   CallbackOstream stream(callback, userData);
280   unwrap(op)->print(stream);
281   stream.flush();
282 }
283 
284 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
285 
286 /* ========================================================================== */
287 /* Region API.                                                                */
288 /* ========================================================================== */
289 
290 MlirRegion mlirRegionCreate() { return wrap(new Region); }
291 
292 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
293   Region *cppRegion = unwrap(region);
294   if (cppRegion->empty())
295     return wrap(static_cast<Block *>(nullptr));
296   return wrap(&cppRegion->front());
297 }
298 
299 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
300   unwrap(region)->push_back(unwrap(block));
301 }
302 
303 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
304                                 MlirBlock block) {
305   auto &blockList = unwrap(region)->getBlocks();
306   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
307 }
308 
309 void mlirRegionDestroy(MlirRegion region) {
310   delete static_cast<Region *>(region.ptr);
311 }
312 
313 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
314 
315 /* ========================================================================== */
316 /* Block API.                                                                 */
317 /* ========================================================================== */
318 
319 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
320   Block *b = new Block;
321   for (intptr_t i = 0; i < nArgs; ++i)
322     b->addArgument(unwrap(args[i]));
323   return wrap(b);
324 }
325 
326 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
327   return wrap(unwrap(block)->getNextNode());
328 }
329 
330 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
331   Block *cppBlock = unwrap(block);
332   if (cppBlock->empty())
333     return wrap(static_cast<Operation *>(nullptr));
334   return wrap(&cppBlock->front());
335 }
336 
337 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
338   unwrap(block)->push_back(unwrap(operation));
339 }
340 
341 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
342                                    MlirOperation operation) {
343   auto &opList = unwrap(block)->getOperations();
344   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
345 }
346 
347 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
348 
349 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
350 
351 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
352   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
353 }
354 
355 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
356   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
357 }
358 
359 void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
360                     void *userData) {
361   CallbackOstream stream(callback, userData);
362   unwrap(block)->print(stream);
363   stream.flush();
364 }
365 
366 /* ========================================================================== */
367 /* Value API.                                                                 */
368 /* ========================================================================== */
369 
370 MlirType mlirValueGetType(MlirValue value) {
371   return wrap(unwrap(value).getType());
372 }
373 
374 void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
375                     void *userData) {
376   CallbackOstream stream(callback, userData);
377   unwrap(value).print(stream);
378   stream.flush();
379 }
380 
381 /* ========================================================================== */
382 /* Type API.                                                                  */
383 /* ========================================================================== */
384 
385 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
386   return wrap(mlir::parseType(type, unwrap(context)));
387 }
388 
389 void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
390   CallbackOstream stream(callback, userData);
391   unwrap(type).print(stream);
392   stream.flush();
393 }
394 
395 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
396 
397 /* ========================================================================== */
398 /* Attribute API.                                                             */
399 /* ========================================================================== */
400 
401 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
402   return wrap(mlir::parseAttribute(attr, unwrap(context)));
403 }
404 
405 void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
406                         void *userData) {
407   CallbackOstream stream(callback, userData);
408   unwrap(attr).print(stream);
409   stream.flush();
410 }
411 
412 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
413 
414 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
415   return MlirNamedAttribute{name, attr};
416 }
417