xref: /llvm-project-15.0.7/mlir/lib/CAPI/IR/IR.cpp (revision dd0fdf80)
1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/IR.h"
10 
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/Module.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "mlir/Parser.h"
16 
17 using namespace mlir;
18 
19 /* ========================================================================== */
20 /* Definitions of methods for non-owning structures used in C API.            */
21 /* ========================================================================== */
22 
23 #define DEFINE_C_API_PTR_METHODS(name, cpptype)                                \
24   static name wrap(cpptype *cpp) { return name{cpp}; }                         \
25   static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); }
26 
27 DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext)
28 DEFINE_C_API_PTR_METHODS(MlirOperation, Operation)
29 DEFINE_C_API_PTR_METHODS(MlirBlock, Block)
30 DEFINE_C_API_PTR_METHODS(MlirRegion, Region)
31 
32 #define DEFINE_C_API_METHODS(name, cpptype)                                    \
33   static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; }     \
34   static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); }
35 
36 DEFINE_C_API_METHODS(MlirAttribute, Attribute)
37 DEFINE_C_API_METHODS(MlirLocation, Location);
38 DEFINE_C_API_METHODS(MlirType, Type)
39 DEFINE_C_API_METHODS(MlirValue, Value)
40 DEFINE_C_API_METHODS(MlirModule, ModuleOp)
41 
42 template <typename CppTy, typename CTy>
43 static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
44                                   SmallVectorImpl<CppTy> &storage) {
45   static_assert(
46       std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
47       "incompatible C and C++ types");
48 
49   if (size == 0)
50     return llvm::None;
51 
52   assert(storage.empty() && "expected to populate storage");
53   storage.reserve(size);
54   for (intptr_t i = 0; i < size; ++i)
55     storage.push_back(unwrap(*(first + i)));
56   return storage;
57 }
58 
59 /* ========================================================================== */
60 /* Context API.                                                               */
61 /* ========================================================================== */
62 
63 MlirContext mlirContextCreate() {
64   auto *context = new MLIRContext;
65   return wrap(context);
66 }
67 
68 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
69 
70 /* ========================================================================== */
71 /* Location API.                                                              */
72 /* ========================================================================== */
73 
74 MlirLocation mlirLocationFileLineColGet(MlirContext context,
75                                         const char *filename, unsigned line,
76                                         unsigned col) {
77   return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
78 }
79 
80 MlirLocation mlirLocationUnknownGet(MlirContext context) {
81   return wrap(UnknownLoc::get(unwrap(context)));
82 }
83 
84 /* ========================================================================== */
85 /* Module API.                                                                */
86 /* ========================================================================== */
87 
88 MlirModule mlirModuleCreateEmpty(MlirLocation location) {
89   return wrap(ModuleOp::create(unwrap(location)));
90 }
91 
92 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
93   OwningModuleRef owning = parseSourceString(module, unwrap(context));
94   return MlirModule{owning.release().getOperation()};
95 }
96 
97 void mlirModuleDestroy(MlirModule module) {
98   // Transfer ownership to an OwningModuleRef so that its destructor is called.
99   OwningModuleRef(unwrap(module));
100 }
101 
102 MlirOperation mlirModuleGetOperation(MlirModule module) {
103   return wrap(unwrap(module).getOperation());
104 }
105 
106 /* ========================================================================== */
107 /* Operation state API.                                                       */
108 /* ========================================================================== */
109 
110 MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
111   MlirOperationState state;
112   state.name = name;
113   state.location = loc;
114   state.nResults = 0;
115   state.results = nullptr;
116   state.nOperands = 0;
117   state.operands = nullptr;
118   state.nRegions = 0;
119   state.regions = nullptr;
120   state.nSuccessors = 0;
121   state.successors = nullptr;
122   state.nAttributes = 0;
123   state.attributes = nullptr;
124   return state;
125 }
126 
127 #define APPEND_ELEMS(type, sizeName, elemName)                                 \
128   state->elemName =                                                            \
129       (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type));  \
130   memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type));       \
131   state->sizeName += n;
132 
133 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
134                                   MlirType *results) {
135   APPEND_ELEMS(MlirType, nResults, results);
136 }
137 
138 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
139                                    MlirValue *operands) {
140   APPEND_ELEMS(MlirValue, nOperands, operands);
141 }
142 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
143                                        MlirRegion *regions) {
144   APPEND_ELEMS(MlirRegion, nRegions, regions);
145 }
146 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
147                                      MlirBlock *successors) {
148   APPEND_ELEMS(MlirBlock, nSuccessors, successors);
149 }
150 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
151                                      MlirNamedAttribute *attributes) {
152   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
153 }
154 
155 /* ========================================================================== */
156 /* Operation API.                                                             */
157 /* ========================================================================== */
158 
159 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
160   assert(state);
161   OperationState cppState(unwrap(state->location), state->name);
162   SmallVector<Type, 4> resultStorage;
163   SmallVector<Value, 8> operandStorage;
164   SmallVector<Block *, 2> successorStorage;
165   cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
166   cppState.addOperands(
167       unwrapList(state->nOperands, state->operands, operandStorage));
168   cppState.addSuccessors(
169       unwrapList(state->nSuccessors, state->successors, successorStorage));
170 
171   cppState.attributes.reserve(state->nAttributes);
172   for (intptr_t i = 0; i < state->nAttributes; ++i)
173     cppState.addAttribute(state->attributes[i].name,
174                           unwrap(state->attributes[i].attribute));
175 
176   for (intptr_t i = 0; i < state->nRegions; ++i)
177     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
178 
179   MlirOperation result = wrap(Operation::create(cppState));
180   free(state->results);
181   free(state->operands);
182   free(state->successors);
183   free(state->regions);
184   free(state->attributes);
185   return result;
186 }
187 
188 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
189 
190 int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
191 
192 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
193   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
194 }
195 
196 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
197   return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
198 }
199 
200 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
201   return wrap(unwrap(op)->getNextNode());
202 }
203 
204 intptr_t mlirOperationGetNumOperands(MlirOperation op) {
205   return static_cast<intptr_t>(unwrap(op)->getNumOperands());
206 }
207 
208 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
209   return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
210 }
211 
212 intptr_t mlirOperationGetNumResults(MlirOperation op) {
213   return static_cast<intptr_t>(unwrap(op)->getNumResults());
214 }
215 
216 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
217   return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
218 }
219 
220 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
221   return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
222 }
223 
224 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
225   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
226 }
227 
228 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
229   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
230 }
231 
232 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
233   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
234   return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
235 }
236 
237 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
238                                               const char *name) {
239   return wrap(unwrap(op)->getAttr(name));
240 }
241 
242 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
243 
244 /* ========================================================================== */
245 /* Region API.                                                                */
246 /* ========================================================================== */
247 
248 MlirRegion mlirRegionCreate() { return wrap(new Region); }
249 
250 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
251   Region *cppRegion = unwrap(region);
252   if (cppRegion->empty())
253     return wrap(static_cast<Block *>(nullptr));
254   return wrap(&cppRegion->front());
255 }
256 
257 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
258   unwrap(region)->push_back(unwrap(block));
259 }
260 
261 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
262                                 MlirBlock block) {
263   auto &blockList = unwrap(region)->getBlocks();
264   blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
265 }
266 
267 void mlirRegionDestroy(MlirRegion region) {
268   delete static_cast<Region *>(region.ptr);
269 }
270 
271 int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
272 
273 /* ========================================================================== */
274 /* Block API.                                                                 */
275 /* ========================================================================== */
276 
277 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
278   Block *b = new Block;
279   for (intptr_t i = 0; i < nArgs; ++i)
280     b->addArgument(unwrap(args[i]));
281   return wrap(b);
282 }
283 
284 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
285   return wrap(unwrap(block)->getNextNode());
286 }
287 
288 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
289   Block *cppBlock = unwrap(block);
290   if (cppBlock->empty())
291     return wrap(static_cast<Operation *>(nullptr));
292   return wrap(&cppBlock->front());
293 }
294 
295 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
296   unwrap(block)->push_back(unwrap(operation));
297 }
298 
299 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
300                                    MlirOperation operation) {
301   auto &opList = unwrap(block)->getOperations();
302   opList.insert(std::next(opList.begin(), pos), unwrap(operation));
303 }
304 
305 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
306 
307 int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
308 
309 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
310   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
311 }
312 
313 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
314   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
315 }
316 
317 /* ========================================================================== */
318 /* Value API.                                                                 */
319 /* ========================================================================== */
320 
321 MlirType mlirValueGetType(MlirValue value) {
322   return wrap(unwrap(value).getType());
323 }
324 
325 /* ========================================================================== */
326 /* Type API.                                                                  */
327 /* ========================================================================== */
328 
329 MlirType mlirTypeParseGet(MlirContext context, const char *type) {
330   return wrap(mlir::parseType(type, unwrap(context)));
331 }
332 
333 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
334 
335 /* ========================================================================== */
336 /* Attribute API.                                                             */
337 /* ========================================================================== */
338 
339 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
340   return wrap(mlir::parseAttribute(attr, unwrap(context)));
341 }
342 
343 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
344 
345 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
346   return MlirNamedAttribute{name, attr};
347 }
348