1 //===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/BuiltinTypes.h"
10 #include "mlir-c/AffineMap.h"
11 #include "mlir-c/IR.h"
12 #include "mlir/CAPI/AffineMap.h"
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Types.h"
18
19 using namespace mlir;
20
21 //===----------------------------------------------------------------------===//
22 // Integer types.
23 //===----------------------------------------------------------------------===//
24
mlirTypeIsAInteger(MlirType type)25 bool mlirTypeIsAInteger(MlirType type) {
26 return unwrap(type).isa<IntegerType>();
27 }
28
mlirIntegerTypeGet(MlirContext ctx,unsigned bitwidth)29 MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
30 return wrap(IntegerType::get(unwrap(ctx), bitwidth));
31 }
32
mlirIntegerTypeSignedGet(MlirContext ctx,unsigned bitwidth)33 MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
34 return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
35 }
36
mlirIntegerTypeUnsignedGet(MlirContext ctx,unsigned bitwidth)37 MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
38 return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
39 }
40
mlirIntegerTypeGetWidth(MlirType type)41 unsigned mlirIntegerTypeGetWidth(MlirType type) {
42 return unwrap(type).cast<IntegerType>().getWidth();
43 }
44
mlirIntegerTypeIsSignless(MlirType type)45 bool mlirIntegerTypeIsSignless(MlirType type) {
46 return unwrap(type).cast<IntegerType>().isSignless();
47 }
48
mlirIntegerTypeIsSigned(MlirType type)49 bool mlirIntegerTypeIsSigned(MlirType type) {
50 return unwrap(type).cast<IntegerType>().isSigned();
51 }
52
mlirIntegerTypeIsUnsigned(MlirType type)53 bool mlirIntegerTypeIsUnsigned(MlirType type) {
54 return unwrap(type).cast<IntegerType>().isUnsigned();
55 }
56
57 //===----------------------------------------------------------------------===//
58 // Index type.
59 //===----------------------------------------------------------------------===//
60
mlirTypeIsAIndex(MlirType type)61 bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
62
mlirIndexTypeGet(MlirContext ctx)63 MlirType mlirIndexTypeGet(MlirContext ctx) {
64 return wrap(IndexType::get(unwrap(ctx)));
65 }
66
67 //===----------------------------------------------------------------------===//
68 // Floating-point types.
69 //===----------------------------------------------------------------------===//
70
mlirTypeIsABF16(MlirType type)71 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
72
mlirBF16TypeGet(MlirContext ctx)73 MlirType mlirBF16TypeGet(MlirContext ctx) {
74 return wrap(FloatType::getBF16(unwrap(ctx)));
75 }
76
mlirTypeIsAF16(MlirType type)77 bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
78
mlirF16TypeGet(MlirContext ctx)79 MlirType mlirF16TypeGet(MlirContext ctx) {
80 return wrap(FloatType::getF16(unwrap(ctx)));
81 }
82
mlirTypeIsAF32(MlirType type)83 bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
84
mlirF32TypeGet(MlirContext ctx)85 MlirType mlirF32TypeGet(MlirContext ctx) {
86 return wrap(FloatType::getF32(unwrap(ctx)));
87 }
88
mlirTypeIsAF64(MlirType type)89 bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
90
mlirF64TypeGet(MlirContext ctx)91 MlirType mlirF64TypeGet(MlirContext ctx) {
92 return wrap(FloatType::getF64(unwrap(ctx)));
93 }
94
95 //===----------------------------------------------------------------------===//
96 // None type.
97 //===----------------------------------------------------------------------===//
98
mlirTypeIsANone(MlirType type)99 bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
100
mlirNoneTypeGet(MlirContext ctx)101 MlirType mlirNoneTypeGet(MlirContext ctx) {
102 return wrap(NoneType::get(unwrap(ctx)));
103 }
104
105 //===----------------------------------------------------------------------===//
106 // Complex type.
107 //===----------------------------------------------------------------------===//
108
mlirTypeIsAComplex(MlirType type)109 bool mlirTypeIsAComplex(MlirType type) {
110 return unwrap(type).isa<ComplexType>();
111 }
112
mlirComplexTypeGet(MlirType elementType)113 MlirType mlirComplexTypeGet(MlirType elementType) {
114 return wrap(ComplexType::get(unwrap(elementType)));
115 }
116
mlirComplexTypeGetElementType(MlirType type)117 MlirType mlirComplexTypeGetElementType(MlirType type) {
118 return wrap(unwrap(type).cast<ComplexType>().getElementType());
119 }
120
121 //===----------------------------------------------------------------------===//
122 // Shaped type.
123 //===----------------------------------------------------------------------===//
124
mlirTypeIsAShaped(MlirType type)125 bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
126
mlirShapedTypeGetElementType(MlirType type)127 MlirType mlirShapedTypeGetElementType(MlirType type) {
128 return wrap(unwrap(type).cast<ShapedType>().getElementType());
129 }
130
mlirShapedTypeHasRank(MlirType type)131 bool mlirShapedTypeHasRank(MlirType type) {
132 return unwrap(type).cast<ShapedType>().hasRank();
133 }
134
mlirShapedTypeGetRank(MlirType type)135 int64_t mlirShapedTypeGetRank(MlirType type) {
136 return unwrap(type).cast<ShapedType>().getRank();
137 }
138
mlirShapedTypeHasStaticShape(MlirType type)139 bool mlirShapedTypeHasStaticShape(MlirType type) {
140 return unwrap(type).cast<ShapedType>().hasStaticShape();
141 }
142
mlirShapedTypeIsDynamicDim(MlirType type,intptr_t dim)143 bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
144 return unwrap(type).cast<ShapedType>().isDynamicDim(
145 static_cast<unsigned>(dim));
146 }
147
mlirShapedTypeGetDimSize(MlirType type,intptr_t dim)148 int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
149 return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
150 }
151
mlirShapedTypeGetDynamicSize()152 int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamicSize; }
153
mlirShapedTypeIsDynamicSize(int64_t size)154 bool mlirShapedTypeIsDynamicSize(int64_t size) {
155 return ShapedType::isDynamic(size);
156 }
157
mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)158 bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
159 return ShapedType::isDynamicStrideOrOffset(val);
160 }
161
mlirShapedTypeGetDynamicStrideOrOffset()162 int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
163 return ShapedType::kDynamicStrideOrOffset;
164 }
165
166 //===----------------------------------------------------------------------===//
167 // Vector type.
168 //===----------------------------------------------------------------------===//
169
mlirTypeIsAVector(MlirType type)170 bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
171
mlirVectorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType)172 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
173 MlirType elementType) {
174 return wrap(
175 VectorType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
176 unwrap(elementType)));
177 }
178
mlirVectorTypeGetChecked(MlirLocation loc,intptr_t rank,const int64_t * shape,MlirType elementType)179 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
180 const int64_t *shape, MlirType elementType) {
181 return wrap(VectorType::getChecked(
182 unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
183 unwrap(elementType)));
184 }
185
186 //===----------------------------------------------------------------------===//
187 // Ranked / Unranked tensor type.
188 //===----------------------------------------------------------------------===//
189
mlirTypeIsATensor(MlirType type)190 bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
191
mlirTypeIsARankedTensor(MlirType type)192 bool mlirTypeIsARankedTensor(MlirType type) {
193 return unwrap(type).isa<RankedTensorType>();
194 }
195
mlirTypeIsAUnrankedTensor(MlirType type)196 bool mlirTypeIsAUnrankedTensor(MlirType type) {
197 return unwrap(type).isa<UnrankedTensorType>();
198 }
199
mlirRankedTensorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType,MlirAttribute encoding)200 MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
201 MlirType elementType, MlirAttribute encoding) {
202 return wrap(RankedTensorType::get(
203 llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
204 unwrap(encoding)));
205 }
206
mlirRankedTensorTypeGetChecked(MlirLocation loc,intptr_t rank,const int64_t * shape,MlirType elementType,MlirAttribute encoding)207 MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
208 const int64_t *shape,
209 MlirType elementType,
210 MlirAttribute encoding) {
211 return wrap(RankedTensorType::getChecked(
212 unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
213 unwrap(elementType), unwrap(encoding)));
214 }
215
mlirRankedTensorTypeGetEncoding(MlirType type)216 MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
217 return wrap(unwrap(type).cast<RankedTensorType>().getEncoding());
218 }
219
mlirUnrankedTensorTypeGet(MlirType elementType)220 MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
221 return wrap(UnrankedTensorType::get(unwrap(elementType)));
222 }
223
mlirUnrankedTensorTypeGetChecked(MlirLocation loc,MlirType elementType)224 MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
225 MlirType elementType) {
226 return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
227 }
228
229 //===----------------------------------------------------------------------===//
230 // Ranked / Unranked MemRef type.
231 //===----------------------------------------------------------------------===//
232
mlirTypeIsAMemRef(MlirType type)233 bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
234
mlirMemRefTypeGet(MlirType elementType,intptr_t rank,const int64_t * shape,MlirAttribute layout,MlirAttribute memorySpace)235 MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
236 const int64_t *shape, MlirAttribute layout,
237 MlirAttribute memorySpace) {
238 return wrap(MemRefType::get(
239 llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
240 mlirAttributeIsNull(layout)
241 ? MemRefLayoutAttrInterface()
242 : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
243 unwrap(memorySpace)));
244 }
245
mlirMemRefTypeGetChecked(MlirLocation loc,MlirType elementType,intptr_t rank,const int64_t * shape,MlirAttribute layout,MlirAttribute memorySpace)246 MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
247 intptr_t rank, const int64_t *shape,
248 MlirAttribute layout,
249 MlirAttribute memorySpace) {
250 return wrap(MemRefType::getChecked(
251 unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
252 unwrap(elementType),
253 mlirAttributeIsNull(layout)
254 ? MemRefLayoutAttrInterface()
255 : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
256 unwrap(memorySpace)));
257 }
258
mlirMemRefTypeContiguousGet(MlirType elementType,intptr_t rank,const int64_t * shape,MlirAttribute memorySpace)259 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
260 const int64_t *shape,
261 MlirAttribute memorySpace) {
262 return wrap(MemRefType::get(
263 llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
264 MemRefLayoutAttrInterface(), unwrap(memorySpace)));
265 }
266
mlirMemRefTypeContiguousGetChecked(MlirLocation loc,MlirType elementType,intptr_t rank,const int64_t * shape,MlirAttribute memorySpace)267 MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
268 MlirType elementType, intptr_t rank,
269 const int64_t *shape,
270 MlirAttribute memorySpace) {
271 return wrap(MemRefType::getChecked(
272 unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
273 unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace)));
274 }
275
mlirMemRefTypeGetLayout(MlirType type)276 MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
277 return wrap(unwrap(type).cast<MemRefType>().getLayout());
278 }
279
mlirMemRefTypeGetAffineMap(MlirType type)280 MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
281 return wrap(unwrap(type).cast<MemRefType>().getLayout().getAffineMap());
282 }
283
mlirMemRefTypeGetMemorySpace(MlirType type)284 MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
285 return wrap(unwrap(type).cast<MemRefType>().getMemorySpace());
286 }
287
mlirTypeIsAUnrankedMemRef(MlirType type)288 bool mlirTypeIsAUnrankedMemRef(MlirType type) {
289 return unwrap(type).isa<UnrankedMemRefType>();
290 }
291
mlirUnrankedMemRefTypeGet(MlirType elementType,MlirAttribute memorySpace)292 MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
293 MlirAttribute memorySpace) {
294 return wrap(
295 UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace)));
296 }
297
mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,MlirType elementType,MlirAttribute memorySpace)298 MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
299 MlirType elementType,
300 MlirAttribute memorySpace) {
301 return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
302 unwrap(memorySpace)));
303 }
304
mlirUnrankedMemrefGetMemorySpace(MlirType type)305 MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
306 return wrap(unwrap(type).cast<UnrankedMemRefType>().getMemorySpace());
307 }
308
309 //===----------------------------------------------------------------------===//
310 // Tuple type.
311 //===----------------------------------------------------------------------===//
312
mlirTypeIsATuple(MlirType type)313 bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
314
mlirTupleTypeGet(MlirContext ctx,intptr_t numElements,MlirType const * elements)315 MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
316 MlirType const *elements) {
317 SmallVector<Type, 4> types;
318 ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
319 return wrap(TupleType::get(unwrap(ctx), typeRef));
320 }
321
mlirTupleTypeGetNumTypes(MlirType type)322 intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
323 return unwrap(type).cast<TupleType>().size();
324 }
325
mlirTupleTypeGetType(MlirType type,intptr_t pos)326 MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
327 return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
328 }
329
330 //===----------------------------------------------------------------------===//
331 // Function type.
332 //===----------------------------------------------------------------------===//
333
mlirTypeIsAFunction(MlirType type)334 bool mlirTypeIsAFunction(MlirType type) {
335 return unwrap(type).isa<FunctionType>();
336 }
337
mlirFunctionTypeGet(MlirContext ctx,intptr_t numInputs,MlirType const * inputs,intptr_t numResults,MlirType const * results)338 MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
339 MlirType const *inputs, intptr_t numResults,
340 MlirType const *results) {
341 SmallVector<Type, 4> inputsList;
342 SmallVector<Type, 4> resultsList;
343 (void)unwrapList(numInputs, inputs, inputsList);
344 (void)unwrapList(numResults, results, resultsList);
345 return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
346 }
347
mlirFunctionTypeGetNumInputs(MlirType type)348 intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
349 return unwrap(type).cast<FunctionType>().getNumInputs();
350 }
351
mlirFunctionTypeGetNumResults(MlirType type)352 intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
353 return unwrap(type).cast<FunctionType>().getNumResults();
354 }
355
mlirFunctionTypeGetInput(MlirType type,intptr_t pos)356 MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
357 assert(pos >= 0 && "pos in array must be positive");
358 return wrap(
359 unwrap(type).cast<FunctionType>().getInput(static_cast<unsigned>(pos)));
360 }
361
mlirFunctionTypeGetResult(MlirType type,intptr_t pos)362 MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
363 assert(pos >= 0 && "pos in array must be positive");
364 return wrap(
365 unwrap(type).cast<FunctionType>().getResult(static_cast<unsigned>(pos)));
366 }
367
368 //===----------------------------------------------------------------------===//
369 // Opaque type.
370 //===----------------------------------------------------------------------===//
371
mlirTypeIsAOpaque(MlirType type)372 bool mlirTypeIsAOpaque(MlirType type) { return unwrap(type).isa<OpaqueType>(); }
373
mlirOpaqueTypeGet(MlirContext ctx,MlirStringRef dialectNamespace,MlirStringRef typeData)374 MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
375 MlirStringRef typeData) {
376 return wrap(
377 OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
378 unwrap(typeData)));
379 }
380
mlirOpaqueTypeGetDialectNamespace(MlirType type)381 MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) {
382 return wrap(unwrap(type).cast<OpaqueType>().getDialectNamespace().strref());
383 }
384
mlirOpaqueTypeGetData(MlirType type)385 MlirStringRef mlirOpaqueTypeGetData(MlirType type) {
386 return wrap(unwrap(type).cast<OpaqueType>().getTypeData());
387 }
388