xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision 05df9cc7)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/Module.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "mlir/Parser.h"
16 
17 using namespace mlir;
18 
19 /* ========================================================================== */
20 /* Definitions of methods for non-owning structures used in C API.            */
21 /* ========================================================================== */
22 
23 #define DEFINE_C_API_PTR_METHODS(name, cpptype)                                \
24   static name wrap(cpptype *cpp) { return name{cpp}; }                         \
25   static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); }
26 
27 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext)
28 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation)
29 DEFINE_C_API_PTR_METHODS(MlirBlock, Block)
30 DEFINE_C_API_PTR_METHODS(MlirRegion, Region)
31 
32 #define DEFINE_C_API_METHODS(name, cpptype)                                    \
33   static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; }     \
34   static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); }
35 
36 DEFINE_C_API_METHODS(MlirAttribute, Attribute)
37 DEFINE_C_API_METHODS(MlirLocation, Location);
38 DEFINE_C_API_METHODS(MlirType, Type)
39 DEFINE_C_API_METHODS(MlirValue, Value)
40 DEFINE_C_API_METHODS(MlirModule, ModuleOp)
41 
42 template <typename CppTy, typename CTy>
43 static ArrayRef<CppTy> unwrapList(unsigned size, CTy *first,
44                                   SmallVectorImpl<CppTy> &storage) {
45   static_assert(
46       std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
47       "incompatible C and C++ types");
48 
49   if (size == 0)
50     return llvm::None;
51 
52   assert(storage.empty() && "expected to populate storage");
53   storage.reserve(size);
54   for (unsigned i = 0; i < size; ++i)
55     storage.push_back(unwrap(*(first + i)));
56   return storage;
57 }
58 
59 /* ========================================================================== */
60 /* Context API.                                                               */
61 /* ========================================================================== */
62 
63 MlirContext mlirContextCreate() {
64   auto *context = new MLIRContext;
65   return wrap(context);
66 }
67 
68 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
69 
70 /* ========================================================================== */
71 /* Location API.                                                              */
72 /* ========================================================================== */
73 
74 MlirLocation mlirLocationFileLineColGet(MlirContext context,
75                                         const char *filename, unsigned line,
76                                         unsigned col) {
77   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
78 }
79 
80 MlirLocation mlirLocationUnknownGet(MlirContext context) {
81   return wrap(UnknownLoc::get(unwrap(context)));
82 }
83 
84 /* ========================================================================== */
85 /* Module API.                                                                */
86 /* ========================================================================== */
87 
88 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
89   return wrap(ModuleOp::create(unwrap(location)));
90 }
91 
92 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
93   OwningModuleRef owning = parseSourceString(module, unwrap(context));
94   return MlirModule{owning.release().getOperation()};
95 }
96 
97 void mlirModuleDestroy(MlirModule module) {
98   // Transfer ownership to an OwningModuleRef so that its destructor is called.
99   OwningModuleRef(unwrap(module));
100 }
101 
102 MlirOperation mlirModuleGetOperation(MlirModule module) {
103   return wrap(unwrap(module).getOperation());
104 }
105 
106 /* ========================================================================== */
107 /* Operation state API.                                                       */
108 /* ========================================================================== */
109 
110 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
111   MlirOperationState state;
112   state.name = name;
113   state.location = loc;
114   state.nResults = 0;
115   state.results = nullptr;
116   state.nOperands = 0;
117   state.operands = nullptr;
118   state.nRegions = 0;
119   state.regions = nullptr;
120   state.nSuccessors = 0;
121   state.successors = nullptr;
122   state.nAttributes = 0;
123   state.attributes = nullptr;
124   return state;
125 }
126 
127 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
128   state->elemName =                                                            \
129       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
130   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
131   state->sizeName += n;
132 
133 void mlirOperationStateAddResults(MlirOperationState *state, unsigned n,
134                                   MlirType *results) {
135   APPEND_ELEMS(MlirType, nResults, results);
136 }
137 
138 void mlirOperationStateAddOperands(MlirOperationState *state, unsigned n,
139                                    MlirValue *operands) {
140   APPEND_ELEMS(MlirValue, nOperands, operands);
141 }
142 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, unsigned n,
143                                        MlirRegion *regions) {
144   APPEND_ELEMS(MlirRegion, nRegions, regions);
145 }
146 void mlirOperationStateAddSuccessors(MlirOperationState *state, unsigned n,
147                                      MlirBlock *successors) {
148   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
149 }
150 void mlirOperationStateAddAttributes(MlirOperationState *state, unsigned n,
151                                      MlirNamedAttribute *attributes) {
152   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
153 }
154 
155 /* ========================================================================== */
156 /* Operation API.                                                             */
157 /* ========================================================================== */
158 
159 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
160   assert(state);
161   OperationState cppState(unwrap(state->location), state->name);
162   SmallVector<Type, 4> resultStorage;
163   SmallVector<Value, 8> operandStorage;
164   SmallVector<Block *, 2> successorStorage;
165   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
166   cppState.addOperands(
167       unwrapList(state->nOperands, state->operands, operandStorage));
168   cppState.addSuccessors(
169       unwrapList(state->nSuccessors, state->successors, successorStorage));
170 
171   cppState.attributes.reserve(state->nAttributes);
172   for (unsigned i = 0; i < state->nAttributes; ++i)
173     cppState.addAttribute(state->attributes[i].name,
174                           unwrap(state->attributes[i].attribute));
175 
176   for (unsigned i = 0; i < state->nRegions; ++i)
177     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
178 
179   return wrap(Operation::create(cppState));
180 }
181 
182 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
183 
184 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
185 
186 unsigned mlirOperationGetNumRegions(MlirOperation op) {
187   return unwrap(op)->getNumRegions();
188 }
189 
190 MlirRegion mlirOperationGetRegion(MlirOperation op, unsigned pos) {
191   return wrap(&unwrap(op)->getRegion(pos));
192 }
193 
194 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
195   return wrap(unwrap(op)->getNextNode());
196 }
197 
198 unsigned mlirOperationGetNumOperands(MlirOperation op) {
199   return unwrap(op)->getNumOperands();
200 }
201 
202 MlirValue mlirOperationGetOperand(MlirOperation op, unsigned pos) {
203   return wrap(unwrap(op)->getOperand(pos));
204 }
205 
206 unsigned mlirOperationGetNumResults(MlirOperation op) {
207   return unwrap(op)->getNumResults();
208 }
209 
210 MlirValue mlirOperationGetResult(MlirOperation op, unsigned pos) {
211   return wrap(unwrap(op)->getResult(pos));
212 }
213 
214 unsigned mlirOperationGetNumSuccessors(MlirOperation op) {
215   return unwrap(op)->getNumSuccessors();
216 }
217 
218 MlirBlock mlirOperationGetSuccessor(MlirOperation op, unsigned pos) {
219   return wrap(unwrap(op)->getSuccessor(pos));
220 }
221 
222 unsigned mlirOperationGetNumAttributes(MlirOperation op) {
223   return unwrap(op)->getAttrs().size();
224 }
225 
226 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, unsigned pos) {
227   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
228   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
229 }
230 
231 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
232                                               const char *name) {
233   return wrap(unwrap(op)->getAttr(name));
234 }
235 
236 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
237 
238 /* ========================================================================== */
239 /* Region API.                                                                */
240 /* ========================================================================== */
241 
242 MlirRegion mlirRegionCreate() { return wrap(new Region); }
243 
244 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
245   Region *cppRegion = unwrap(region);
246   if (cppRegion->empty())
247     return wrap(static_cast<Block *>(nullptr));
248   return wrap(&cppRegion->front());
249 }
250 
251 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
252   unwrap(region)->push_back(unwrap(block));
253 }
254 
255 void mlirRegionInsertOwnedBlock(MlirRegion region, unsigned pos,
256                                 MlirBlock block) {
257   auto &blockList = unwrap(region)->getBlocks();
258   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
259 }
260 
261 void mlirRegionDestroy(MlirRegion region) {
262   delete static_cast<Region *>(region.ptr);
263 }
264 
265 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
266 
267 /* ========================================================================== */
268 /* Block API.                                                                 */
269 /* ========================================================================== */
270 
271 MlirBlock mlirBlockCreate(unsigned nArgs, MlirType *args) {
272   Block *b = new Block;
273   for (unsigned i = 0; i < nArgs; ++i)
274     b->addArgument(unwrap(args[i]));
275   return wrap(b);
276 }
277 
278 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
279   return wrap(unwrap(block)->getNextNode());
280 }
281 
282 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
283   Block *cppBlock = unwrap(block);
284   if (cppBlock->empty())
285     return wrap(static_cast<Operation *>(nullptr));
286   return wrap(&cppBlock->front());
287 }
288 
289 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
290   unwrap(block)->push_back(unwrap(operation));
291 }
292 
293 void mlirBlockInsertOwnedOperation(MlirBlock block, unsigned pos,
294                                    MlirOperation operation) {
295   auto &opList = unwrap(block)->getOperations();
296   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
297 }
298 
299 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
300 
301 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
302 
303 unsigned mlirBlockGetNumArguments(MlirBlock block) {
304   return unwrap(block)->getNumArguments();
305 }
306 
307 MlirValue mlirBlockGetArgument(MlirBlock block, unsigned pos) {
308   return wrap(unwrap(block)->getArgument(pos));
309 }
310 
311 /* ========================================================================== */
312 /* Value API.                                                                 */
313 /* ========================================================================== */
314 
315 MlirType mlirValueGetType(MlirValue value) {
316   return wrap(unwrap(value).getType());
317 }
318 
319 /* ========================================================================== */
320 /* Type API.                                                                  */
321 /* ========================================================================== */
322 
323 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
324   return wrap(mlir::parseType(type, unwrap(context)));
325 }
326 
327 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
328 
329 /* ========================================================================== */
330 /* Attribute API.                                                             */
331 /* ========================================================================== */
332 
333 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
334   return wrap(mlir::parseAttribute(attr, unwrap(context)));
335 }
336 
337 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
338 
339 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
340   return MlirNamedAttribute{name, attr};
341 }
342