1 //===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/BuiltinTypes.h"
10 #include "mlir-c/AffineMap.h"
11 #include "mlir-c/IR.h"
12 #include "mlir/CAPI/AffineMap.h"
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Types.h"
18 
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // Integer types.
23 //===----------------------------------------------------------------------===//
24 
mlirTypeIsAInteger(MlirType type)25 bool mlirTypeIsAInteger(MlirType type) {
26   return unwrap(type).isa<IntegerType>();
27 }
28 
mlirIntegerTypeGet(MlirContext ctx,unsigned bitwidth)29 MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
30   return wrap(IntegerType::get(unwrap(ctx), bitwidth));
31 }
32 
mlirIntegerTypeSignedGet(MlirContext ctx,unsigned bitwidth)33 MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
34   return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
35 }
36 
mlirIntegerTypeUnsignedGet(MlirContext ctx,unsigned bitwidth)37 MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
38   return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
39 }
40 
mlirIntegerTypeGetWidth(MlirType type)41 unsigned mlirIntegerTypeGetWidth(MlirType type) {
42   return unwrap(type).cast<IntegerType>().getWidth();
43 }
44 
mlirIntegerTypeIsSignless(MlirType type)45 bool mlirIntegerTypeIsSignless(MlirType type) {
46   return unwrap(type).cast<IntegerType>().isSignless();
47 }
48 
mlirIntegerTypeIsSigned(MlirType type)49 bool mlirIntegerTypeIsSigned(MlirType type) {
50   return unwrap(type).cast<IntegerType>().isSigned();
51 }
52 
mlirIntegerTypeIsUnsigned(MlirType type)53 bool mlirIntegerTypeIsUnsigned(MlirType type) {
54   return unwrap(type).cast<IntegerType>().isUnsigned();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // Index type.
59 //===----------------------------------------------------------------------===//
60 
mlirTypeIsAIndex(MlirType type)61 bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
62 
mlirIndexTypeGet(MlirContext ctx)63 MlirType mlirIndexTypeGet(MlirContext ctx) {
64   return wrap(IndexType::get(unwrap(ctx)));
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // Floating-point types.
69 //===----------------------------------------------------------------------===//
70 
mlirTypeIsABF16(MlirType type)71 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
72 
mlirBF16TypeGet(MlirContext ctx)73 MlirType mlirBF16TypeGet(MlirContext ctx) {
74   return wrap(FloatType::getBF16(unwrap(ctx)));
75 }
76 
mlirTypeIsAF16(MlirType type)77 bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
78 
mlirF16TypeGet(MlirContext ctx)79 MlirType mlirF16TypeGet(MlirContext ctx) {
80   return wrap(FloatType::getF16(unwrap(ctx)));
81 }
82 
mlirTypeIsAF32(MlirType type)83 bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
84 
mlirF32TypeGet(MlirContext ctx)85 MlirType mlirF32TypeGet(MlirContext ctx) {
86   return wrap(FloatType::getF32(unwrap(ctx)));
87 }
88 
mlirTypeIsAF64(MlirType type)89 bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
90 
mlirF64TypeGet(MlirContext ctx)91 MlirType mlirF64TypeGet(MlirContext ctx) {
92   return wrap(FloatType::getF64(unwrap(ctx)));
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // None type.
97 //===----------------------------------------------------------------------===//
98 
mlirTypeIsANone(MlirType type)99 bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
100 
mlirNoneTypeGet(MlirContext ctx)101 MlirType mlirNoneTypeGet(MlirContext ctx) {
102   return wrap(NoneType::get(unwrap(ctx)));
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Complex type.
107 //===----------------------------------------------------------------------===//
108 
mlirTypeIsAComplex(MlirType type)109 bool mlirTypeIsAComplex(MlirType type) {
110   return unwrap(type).isa<ComplexType>();
111 }
112 
mlirComplexTypeGet(MlirType elementType)113 MlirType mlirComplexTypeGet(MlirType elementType) {
114   return wrap(ComplexType::get(unwrap(elementType)));
115 }
116 
mlirComplexTypeGetElementType(MlirType type)117 MlirType mlirComplexTypeGetElementType(MlirType type) {
118   return wrap(unwrap(type).cast<ComplexType>().getElementType());
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // Shaped type.
123 //===----------------------------------------------------------------------===//
124 
mlirTypeIsAShaped(MlirType type)125 bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
126 
mlirShapedTypeGetElementType(MlirType type)127 MlirType mlirShapedTypeGetElementType(MlirType type) {
128   return wrap(unwrap(type).cast<ShapedType>().getElementType());
129 }
130 
mlirShapedTypeHasRank(MlirType type)131 bool mlirShapedTypeHasRank(MlirType type) {
132   return unwrap(type).cast<ShapedType>().hasRank();
133 }
134 
mlirShapedTypeGetRank(MlirType type)135 int64_t mlirShapedTypeGetRank(MlirType type) {
136   return unwrap(type).cast<ShapedType>().getRank();
137 }
138 
mlirShapedTypeHasStaticShape(MlirType type)139 bool mlirShapedTypeHasStaticShape(MlirType type) {
140   return unwrap(type).cast<ShapedType>().hasStaticShape();
141 }
142 
mlirShapedTypeIsDynamicDim(MlirType type,intptr_t dim)143 bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
144   return unwrap(type).cast<ShapedType>().isDynamicDim(
145       static_cast<unsigned>(dim));
146 }
147 
mlirShapedTypeGetDimSize(MlirType type,intptr_t dim)148 int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
149   return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
150 }
151 
mlirShapedTypeGetDynamicSize()152 int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamicSize; }
153 
mlirShapedTypeIsDynamicSize(int64_t size)154 bool mlirShapedTypeIsDynamicSize(int64_t size) {
155   return ShapedType::isDynamic(size);
156 }
157 
mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)158 bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
159   return ShapedType::isDynamicStrideOrOffset(val);
160 }
161 
mlirShapedTypeGetDynamicStrideOrOffset()162 int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
163   return ShapedType::kDynamicStrideOrOffset;
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // Vector type.
168 //===----------------------------------------------------------------------===//
169 
mlirTypeIsAVector(MlirType type)170 bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
171 
mlirVectorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType)172 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
173                            MlirType elementType) {
174   return wrap(
175       VectorType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
176                       unwrap(elementType)));
177 }
178 
mlirVectorTypeGetChecked(MlirLocation loc,intptr_t rank,const int64_t * shape,MlirType elementType)179 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
180                                   const int64_t *shape, MlirType elementType) {
181   return wrap(VectorType::getChecked(
182       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
183       unwrap(elementType)));
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // Ranked / Unranked tensor type.
188 //===----------------------------------------------------------------------===//
189 
mlirTypeIsATensor(MlirType type)190 bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
191 
mlirTypeIsARankedTensor(MlirType type)192 bool mlirTypeIsARankedTensor(MlirType type) {
193   return unwrap(type).isa<RankedTensorType>();
194 }
195 
mlirTypeIsAUnrankedTensor(MlirType type)196 bool mlirTypeIsAUnrankedTensor(MlirType type) {
197   return unwrap(type).isa<UnrankedTensorType>();
198 }
199 
mlirRankedTensorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType,MlirAttribute encoding)200 MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
201                                  MlirType elementType, MlirAttribute encoding) {
202   return wrap(RankedTensorType::get(
203       llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
204       unwrap(encoding)));
205 }
206 
mlirRankedTensorTypeGetChecked(MlirLocation loc,intptr_t rank,const int64_t * shape,MlirType elementType,MlirAttribute encoding)207 MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
208                                         const int64_t *shape,
209                                         MlirType elementType,
210                                         MlirAttribute encoding) {
211   return wrap(RankedTensorType::getChecked(
212       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
213       unwrap(elementType), unwrap(encoding)));
214 }
215 
mlirRankedTensorTypeGetEncoding(MlirType type)216 MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
217   return wrap(unwrap(type).cast<RankedTensorType>().getEncoding());
218 }
219 
mlirUnrankedTensorTypeGet(MlirType elementType)220 MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
221   return wrap(UnrankedTensorType::get(unwrap(elementType)));
222 }
223 
mlirUnrankedTensorTypeGetChecked(MlirLocation loc,MlirType elementType)224 MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
225                                           MlirType elementType) {
226   return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // Ranked / Unranked MemRef type.
231 //===----------------------------------------------------------------------===//
232 
mlirTypeIsAMemRef(MlirType type)233 bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
234 
mlirMemRefTypeGet(MlirType elementType,intptr_t rank,const int64_t * shape,MlirAttribute layout,MlirAttribute memorySpace)235 MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
236                            const int64_t *shape, MlirAttribute layout,
237                            MlirAttribute memorySpace) {
238   return wrap(MemRefType::get(
239       llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
240       mlirAttributeIsNull(layout)
241           ? MemRefLayoutAttrInterface()
242           : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
243       unwrap(memorySpace)));
244 }
245 
mlirMemRefTypeGetChecked(MlirLocation loc,MlirType elementType,intptr_t rank,const int64_t * shape,MlirAttribute layout,MlirAttribute memorySpace)246 MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
247                                   intptr_t rank, const int64_t *shape,
248                                   MlirAttribute layout,
249                                   MlirAttribute memorySpace) {
250   return wrap(MemRefType::getChecked(
251       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
252       unwrap(elementType),
253       mlirAttributeIsNull(layout)
254           ? MemRefLayoutAttrInterface()
255           : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
256       unwrap(memorySpace)));
257 }
258 
mlirMemRefTypeContiguousGet(MlirType elementType,intptr_t rank,const int64_t * shape,MlirAttribute memorySpace)259 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
260                                      const int64_t *shape,
261                                      MlirAttribute memorySpace) {
262   return wrap(MemRefType::get(
263       llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
264       MemRefLayoutAttrInterface(), unwrap(memorySpace)));
265 }
266 
mlirMemRefTypeContiguousGetChecked(MlirLocation loc,MlirType elementType,intptr_t rank,const int64_t * shape,MlirAttribute memorySpace)267 MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
268                                             MlirType elementType, intptr_t rank,
269                                             const int64_t *shape,
270                                             MlirAttribute memorySpace) {
271   return wrap(MemRefType::getChecked(
272       unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
273       unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace)));
274 }
275 
mlirMemRefTypeGetLayout(MlirType type)276 MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
277   return wrap(unwrap(type).cast<MemRefType>().getLayout());
278 }
279 
mlirMemRefTypeGetAffineMap(MlirType type)280 MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
281   return wrap(unwrap(type).cast<MemRefType>().getLayout().getAffineMap());
282 }
283 
mlirMemRefTypeGetMemorySpace(MlirType type)284 MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
285   return wrap(unwrap(type).cast<MemRefType>().getMemorySpace());
286 }
287 
mlirTypeIsAUnrankedMemRef(MlirType type)288 bool mlirTypeIsAUnrankedMemRef(MlirType type) {
289   return unwrap(type).isa<UnrankedMemRefType>();
290 }
291 
mlirUnrankedMemRefTypeGet(MlirType elementType,MlirAttribute memorySpace)292 MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
293                                    MlirAttribute memorySpace) {
294   return wrap(
295       UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace)));
296 }
297 
mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,MlirType elementType,MlirAttribute memorySpace)298 MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
299                                           MlirType elementType,
300                                           MlirAttribute memorySpace) {
301   return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
302                                              unwrap(memorySpace)));
303 }
304 
mlirUnrankedMemrefGetMemorySpace(MlirType type)305 MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
306   return wrap(unwrap(type).cast<UnrankedMemRefType>().getMemorySpace());
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // Tuple type.
311 //===----------------------------------------------------------------------===//
312 
mlirTypeIsATuple(MlirType type)313 bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
314 
mlirTupleTypeGet(MlirContext ctx,intptr_t numElements,MlirType const * elements)315 MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
316                           MlirType const *elements) {
317   SmallVector<Type, 4> types;
318   ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
319   return wrap(TupleType::get(unwrap(ctx), typeRef));
320 }
321 
mlirTupleTypeGetNumTypes(MlirType type)322 intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
323   return unwrap(type).cast<TupleType>().size();
324 }
325 
mlirTupleTypeGetType(MlirType type,intptr_t pos)326 MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
327   return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // Function type.
332 //===----------------------------------------------------------------------===//
333 
mlirTypeIsAFunction(MlirType type)334 bool mlirTypeIsAFunction(MlirType type) {
335   return unwrap(type).isa<FunctionType>();
336 }
337 
mlirFunctionTypeGet(MlirContext ctx,intptr_t numInputs,MlirType const * inputs,intptr_t numResults,MlirType const * results)338 MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
339                              MlirType const *inputs, intptr_t numResults,
340                              MlirType const *results) {
341   SmallVector<Type, 4> inputsList;
342   SmallVector<Type, 4> resultsList;
343   (void)unwrapList(numInputs, inputs, inputsList);
344   (void)unwrapList(numResults, results, resultsList);
345   return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
346 }
347 
mlirFunctionTypeGetNumInputs(MlirType type)348 intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
349   return unwrap(type).cast<FunctionType>().getNumInputs();
350 }
351 
mlirFunctionTypeGetNumResults(MlirType type)352 intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
353   return unwrap(type).cast<FunctionType>().getNumResults();
354 }
355 
mlirFunctionTypeGetInput(MlirType type,intptr_t pos)356 MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
357   assert(pos >= 0 && "pos in array must be positive");
358   return wrap(
359       unwrap(type).cast<FunctionType>().getInput(static_cast<unsigned>(pos)));
360 }
361 
mlirFunctionTypeGetResult(MlirType type,intptr_t pos)362 MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
363   assert(pos >= 0 && "pos in array must be positive");
364   return wrap(
365       unwrap(type).cast<FunctionType>().getResult(static_cast<unsigned>(pos)));
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // Opaque type.
370 //===----------------------------------------------------------------------===//
371 
mlirTypeIsAOpaque(MlirType type)372 bool mlirTypeIsAOpaque(MlirType type) { return unwrap(type).isa<OpaqueType>(); }
373 
mlirOpaqueTypeGet(MlirContext ctx,MlirStringRef dialectNamespace,MlirStringRef typeData)374 MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
375                            MlirStringRef typeData) {
376   return wrap(
377       OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
378                       unwrap(typeData)));
379 }
380 
mlirOpaqueTypeGetDialectNamespace(MlirType type)381 MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) {
382   return wrap(unwrap(type).cast<OpaqueType>().getDialectNamespace().strref());
383 }
384 
mlirOpaqueTypeGetData(MlirType type)385 MlirStringRef mlirOpaqueTypeGetData(MlirType type) {
386   return wrap(unwrap(type).cast<OpaqueType>().getTypeData());
387 }
388