xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision 8fcfe286)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/Module.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/Types.h"
16 #include "mlir/Parser.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 using namespace mlir;
20 
21 /* ========================================================================== */
22 /* Definitions of methods for non-owning structures used in C API.            */
23 /* ========================================================================== */
24 
25 #define DEFINE_C_API_PTR_METHODS(name, cpptype)                                \
26   static name wrap(cpptype *cpp) { return name{cpp}; }                         \
27   static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); }
28 
29 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext)
30 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation)
31 DEFINE_C_API_PTR_METHODS(MlirBlock, Block)
32 DEFINE_C_API_PTR_METHODS(MlirRegion, Region)
33 
34 #define DEFINE_C_API_METHODS(name, cpptype)                                    \
35   static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; }     \
36   static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); }
37 
38 DEFINE_C_API_METHODS(MlirAttribute, Attribute)
39 DEFINE_C_API_METHODS(MlirLocation, Location);
40 DEFINE_C_API_METHODS(MlirType, Type)
41 DEFINE_C_API_METHODS(MlirValue, Value)
42 DEFINE_C_API_METHODS(MlirModule, ModuleOp)
43 
44 template <typename CppTy, typename CTy>
45 static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
46                                   SmallVectorImpl<CppTy> &storage) {
47   static_assert(
48       std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
49       "incompatible C and C++ types");
50 
51   if (size == 0)
52     return llvm::None;
53 
54   assert(storage.empty() && "expected to populate storage");
55   storage.reserve(size);
56   for (intptr_t i = 0; i < size; ++i)
57     storage.push_back(unwrap(*(first + i)));
58   return storage;
59 }
60 
61 /* ========================================================================== */
62 /* Printing helper.                                                           */
63 /* ========================================================================== */
64 
65 namespace {
66 /// A simple raw ostream subclass that forwards write_impl calls to the
67 /// user-supplied callback together with opaque user-supplied data.
68 class CallbackOstream : public llvm::raw_ostream {
69 public:
70   CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
71                   void *opaqueData)
72       : callback(callback), opaqueData(opaqueData), pos(0u) {}
73 
74   void write_impl(const char *ptr, size_t size) override {
75     callback(ptr, size, opaqueData);
76     pos += size;
77   }
78 
79   uint64_t current_pos() const override { return pos; }
80 
81 private:
82   std::function<void(const char *, intptr_t, void *)> callback;
83   void *opaqueData;
84   uint64_t pos;
85 };
86 } // end namespace
87 
88 /* ========================================================================== */
89 /* Context API.                                                               */
90 /* ========================================================================== */
91 
92 MlirContext mlirContextCreate() {
93   auto *context = new MLIRContext(false);
94   return wrap(context);
95 }
96 
97 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
98 
99 void mlirContextLoadAllDialects(MlirContext context) {
100   unwrap(context)->loadAllGloballyRegisteredDialects();
101 }
102 
103 /* ========================================================================== */
104 /* Location API.                                                              */
105 /* ========================================================================== */
106 
107 MlirLocation mlirLocationFileLineColGet(MlirContext context,
108                                         const char *filename, unsigned line,
109                                         unsigned col) {
110   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
111 }
112 
113 MlirLocation mlirLocationUnknownGet(MlirContext context) {
114   return wrap(UnknownLoc::get(unwrap(context)));
115 }
116 
117 void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
118                        void *userData) {
119   CallbackOstream stream(callback, userData);
120   unwrap(location).print(stream);
121   stream.flush();
122 }
123 
124 /* ========================================================================== */
125 /* Module API.                                                                */
126 /* ========================================================================== */
127 
128 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
129   return wrap(ModuleOp::create(unwrap(location)));
130 }
131 
132 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
133   OwningModuleRef owning = parseSourceString(module, unwrap(context));
134   return MlirModule{owning.release().getOperation()};
135 }
136 
137 void mlirModuleDestroy(MlirModule module) {
138   // Transfer ownership to an OwningModuleRef so that its destructor is called.
139   OwningModuleRef(unwrap(module));
140 }
141 
142 MlirOperation mlirModuleGetOperation(MlirModule module) {
143   return wrap(unwrap(module).getOperation());
144 }
145 
146 /* ========================================================================== */
147 /* Operation state API.                                                       */
148 /* ========================================================================== */
149 
150 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
151   MlirOperationState state;
152   state.name = name;
153   state.location = loc;
154   state.nResults = 0;
155   state.results = nullptr;
156   state.nOperands = 0;
157   state.operands = nullptr;
158   state.nRegions = 0;
159   state.regions = nullptr;
160   state.nSuccessors = 0;
161   state.successors = nullptr;
162   state.nAttributes = 0;
163   state.attributes = nullptr;
164   return state;
165 }
166 
167 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
168   state->elemName =                                                            \
169       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
170   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
171   state->sizeName += n;
172 
173 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
174                                   MlirType *results) {
175   APPEND_ELEMS(MlirType, nResults, results);
176 }
177 
178 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
179                                    MlirValue *operands) {
180   APPEND_ELEMS(MlirValue, nOperands, operands);
181 }
182 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
183                                        MlirRegion *regions) {
184   APPEND_ELEMS(MlirRegion, nRegions, regions);
185 }
186 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
187                                      MlirBlock *successors) {
188   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
189 }
190 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
191                                      MlirNamedAttribute *attributes) {
192   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
193 }
194 
195 /* ========================================================================== */
196 /* Operation API.                                                             */
197 /* ========================================================================== */
198 
199 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
200   assert(state);
201   OperationState cppState(unwrap(state->location), state->name);
202   SmallVector<Type, 4> resultStorage;
203   SmallVector<Value, 8> operandStorage;
204   SmallVector<Block *, 2> successorStorage;
205   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
206   cppState.addOperands(
207       unwrapList(state->nOperands, state->operands, operandStorage));
208   cppState.addSuccessors(
209       unwrapList(state->nSuccessors, state->successors, successorStorage));
210 
211   cppState.attributes.reserve(state->nAttributes);
212   for (intptr_t i = 0; i < state->nAttributes; ++i)
213     cppState.addAttribute(state->attributes[i].name,
214                           unwrap(state->attributes[i].attribute));
215 
216   for (intptr_t i = 0; i < state->nRegions; ++i)
217     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
218 
219   MlirOperation result = wrap(Operation::create(cppState));
220   free(state->results);
221   free(state->operands);
222   free(state->successors);
223   free(state->regions);
224   free(state->attributes);
225   return result;
226 }
227 
228 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
229 
230 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
231 
232 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
233   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
234 }
235 
236 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
237   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
238 }
239 
240 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
241   return wrap(unwrap(op)->getNextNode());
242 }
243 
244 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
245   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
246 }
247 
248 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
249   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
250 }
251 
252 intptr_t mlirOperationGetNumResults(MlirOperation op) {
253   return static_cast<intptr_t>(unwrap(op)->getNumResults());
254 }
255 
256 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
257   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
258 }
259 
260 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
261   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
262 }
263 
264 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
265   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
266 }
267 
268 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
269   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
270 }
271 
272 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
273   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
274   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
275 }
276 
277 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
278                                               const char *name) {
279   return wrap(unwrap(op)->getAttr(name));
280 }
281 
282 void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
283                         void *userData) {
284   CallbackOstream stream(callback, userData);
285   unwrap(op)->print(stream);
286   stream.flush();
287 }
288 
289 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
290 
291 /* ========================================================================== */
292 /* Region API.                                                                */
293 /* ========================================================================== */
294 
295 MlirRegion mlirRegionCreate() { return wrap(new Region); }
296 
297 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
298   Region *cppRegion = unwrap(region);
299   if (cppRegion->empty())
300     return wrap(static_cast<Block *>(nullptr));
301   return wrap(&cppRegion->front());
302 }
303 
304 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
305   unwrap(region)->push_back(unwrap(block));
306 }
307 
308 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
309                                 MlirBlock block) {
310   auto &blockList = unwrap(region)->getBlocks();
311   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
312 }
313 
314 void mlirRegionDestroy(MlirRegion region) {
315   delete static_cast<Region *>(region.ptr);
316 }
317 
318 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
319 
320 /* ========================================================================== */
321 /* Block API.                                                                 */
322 /* ========================================================================== */
323 
324 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
325   Block *b = new Block;
326   for (intptr_t i = 0; i < nArgs; ++i)
327     b->addArgument(unwrap(args[i]));
328   return wrap(b);
329 }
330 
331 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
332   return wrap(unwrap(block)->getNextNode());
333 }
334 
335 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
336   Block *cppBlock = unwrap(block);
337   if (cppBlock->empty())
338     return wrap(static_cast<Operation *>(nullptr));
339   return wrap(&cppBlock->front());
340 }
341 
342 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
343   unwrap(block)->push_back(unwrap(operation));
344 }
345 
346 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
347                                    MlirOperation operation) {
348   auto &opList = unwrap(block)->getOperations();
349   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
350 }
351 
352 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
353 
354 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
355 
356 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
357   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
358 }
359 
360 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
361   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
362 }
363 
364 void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
365                     void *userData) {
366   CallbackOstream stream(callback, userData);
367   unwrap(block)->print(stream);
368   stream.flush();
369 }
370 
371 /* ========================================================================== */
372 /* Value API.                                                                 */
373 /* ========================================================================== */
374 
375 MlirType mlirValueGetType(MlirValue value) {
376   return wrap(unwrap(value).getType());
377 }
378 
379 void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
380                     void *userData) {
381   CallbackOstream stream(callback, userData);
382   unwrap(value).print(stream);
383   stream.flush();
384 }
385 
386 /* ========================================================================== */
387 /* Type API.                                                                  */
388 /* ========================================================================== */
389 
390 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
391   return wrap(mlir::parseType(type, unwrap(context)));
392 }
393 
394 void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
395   CallbackOstream stream(callback, userData);
396   unwrap(type).print(stream);
397   stream.flush();
398 }
399 
400 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
401 
402 /* ========================================================================== */
403 /* Attribute API.                                                             */
404 /* ========================================================================== */
405 
406 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
407   return wrap(mlir::parseAttribute(attr, unwrap(context)));
408 }
409 
410 void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
411                         void *userData) {
412   CallbackOstream stream(callback, userData);
413   unwrap(attr).print(stream);
414   stream.flush();
415 }
416 
417 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
418 
419 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
420   return MlirNamedAttribute{name, attr};
421 }
422