1 //===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "IRModule.h"
10
11 #include "PybindUtils.h"
12
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19
20 using llvm::SmallVector;
21 using llvm::Twine;
22
23 namespace {
24
25 /// Checks whether the given type is an integer or float type.
mlirTypeIsAIntegerOrFloat(MlirType type)26 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
27 return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
28 mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
29 }
30
31 class PyIntegerType : public PyConcreteType<PyIntegerType> {
32 public:
33 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
34 static constexpr const char *pyClassName = "IntegerType";
35 using PyConcreteType::PyConcreteType;
36
bindDerived(ClassTy & c)37 static void bindDerived(ClassTy &c) {
38 c.def_static(
39 "get_signless",
40 [](unsigned width, DefaultingPyMlirContext context) {
41 MlirType t = mlirIntegerTypeGet(context->get(), width);
42 return PyIntegerType(context->getRef(), t);
43 },
44 py::arg("width"), py::arg("context") = py::none(),
45 "Create a signless integer type");
46 c.def_static(
47 "get_signed",
48 [](unsigned width, DefaultingPyMlirContext context) {
49 MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
50 return PyIntegerType(context->getRef(), t);
51 },
52 py::arg("width"), py::arg("context") = py::none(),
53 "Create a signed integer type");
54 c.def_static(
55 "get_unsigned",
56 [](unsigned width, DefaultingPyMlirContext context) {
57 MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
58 return PyIntegerType(context->getRef(), t);
59 },
60 py::arg("width"), py::arg("context") = py::none(),
61 "Create an unsigned integer type");
62 c.def_property_readonly(
63 "width",
64 [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
65 "Returns the width of the integer type");
66 c.def_property_readonly(
67 "is_signless",
68 [](PyIntegerType &self) -> bool {
69 return mlirIntegerTypeIsSignless(self);
70 },
71 "Returns whether this is a signless integer");
72 c.def_property_readonly(
73 "is_signed",
74 [](PyIntegerType &self) -> bool {
75 return mlirIntegerTypeIsSigned(self);
76 },
77 "Returns whether this is a signed integer");
78 c.def_property_readonly(
79 "is_unsigned",
80 [](PyIntegerType &self) -> bool {
81 return mlirIntegerTypeIsUnsigned(self);
82 },
83 "Returns whether this is an unsigned integer");
84 }
85 };
86
87 /// Index Type subclass - IndexType.
88 class PyIndexType : public PyConcreteType<PyIndexType> {
89 public:
90 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
91 static constexpr const char *pyClassName = "IndexType";
92 using PyConcreteType::PyConcreteType;
93
bindDerived(ClassTy & c)94 static void bindDerived(ClassTy &c) {
95 c.def_static(
96 "get",
97 [](DefaultingPyMlirContext context) {
98 MlirType t = mlirIndexTypeGet(context->get());
99 return PyIndexType(context->getRef(), t);
100 },
101 py::arg("context") = py::none(), "Create a index type.");
102 }
103 };
104
105 /// Floating Point Type subclass - BF16Type.
106 class PyBF16Type : public PyConcreteType<PyBF16Type> {
107 public:
108 static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
109 static constexpr const char *pyClassName = "BF16Type";
110 using PyConcreteType::PyConcreteType;
111
bindDerived(ClassTy & c)112 static void bindDerived(ClassTy &c) {
113 c.def_static(
114 "get",
115 [](DefaultingPyMlirContext context) {
116 MlirType t = mlirBF16TypeGet(context->get());
117 return PyBF16Type(context->getRef(), t);
118 },
119 py::arg("context") = py::none(), "Create a bf16 type.");
120 }
121 };
122
123 /// Floating Point Type subclass - F16Type.
124 class PyF16Type : public PyConcreteType<PyF16Type> {
125 public:
126 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
127 static constexpr const char *pyClassName = "F16Type";
128 using PyConcreteType::PyConcreteType;
129
bindDerived(ClassTy & c)130 static void bindDerived(ClassTy &c) {
131 c.def_static(
132 "get",
133 [](DefaultingPyMlirContext context) {
134 MlirType t = mlirF16TypeGet(context->get());
135 return PyF16Type(context->getRef(), t);
136 },
137 py::arg("context") = py::none(), "Create a f16 type.");
138 }
139 };
140
141 /// Floating Point Type subclass - F32Type.
142 class PyF32Type : public PyConcreteType<PyF32Type> {
143 public:
144 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
145 static constexpr const char *pyClassName = "F32Type";
146 using PyConcreteType::PyConcreteType;
147
bindDerived(ClassTy & c)148 static void bindDerived(ClassTy &c) {
149 c.def_static(
150 "get",
151 [](DefaultingPyMlirContext context) {
152 MlirType t = mlirF32TypeGet(context->get());
153 return PyF32Type(context->getRef(), t);
154 },
155 py::arg("context") = py::none(), "Create a f32 type.");
156 }
157 };
158
159 /// Floating Point Type subclass - F64Type.
160 class PyF64Type : public PyConcreteType<PyF64Type> {
161 public:
162 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
163 static constexpr const char *pyClassName = "F64Type";
164 using PyConcreteType::PyConcreteType;
165
bindDerived(ClassTy & c)166 static void bindDerived(ClassTy &c) {
167 c.def_static(
168 "get",
169 [](DefaultingPyMlirContext context) {
170 MlirType t = mlirF64TypeGet(context->get());
171 return PyF64Type(context->getRef(), t);
172 },
173 py::arg("context") = py::none(), "Create a f64 type.");
174 }
175 };
176
177 /// None Type subclass - NoneType.
178 class PyNoneType : public PyConcreteType<PyNoneType> {
179 public:
180 static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
181 static constexpr const char *pyClassName = "NoneType";
182 using PyConcreteType::PyConcreteType;
183
bindDerived(ClassTy & c)184 static void bindDerived(ClassTy &c) {
185 c.def_static(
186 "get",
187 [](DefaultingPyMlirContext context) {
188 MlirType t = mlirNoneTypeGet(context->get());
189 return PyNoneType(context->getRef(), t);
190 },
191 py::arg("context") = py::none(), "Create a none type.");
192 }
193 };
194
195 /// Complex Type subclass - ComplexType.
196 class PyComplexType : public PyConcreteType<PyComplexType> {
197 public:
198 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
199 static constexpr const char *pyClassName = "ComplexType";
200 using PyConcreteType::PyConcreteType;
201
bindDerived(ClassTy & c)202 static void bindDerived(ClassTy &c) {
203 c.def_static(
204 "get",
205 [](PyType &elementType) {
206 // The element must be a floating point or integer scalar type.
207 if (mlirTypeIsAIntegerOrFloat(elementType)) {
208 MlirType t = mlirComplexTypeGet(elementType);
209 return PyComplexType(elementType.getContext(), t);
210 }
211 throw SetPyError(
212 PyExc_ValueError,
213 Twine("invalid '") +
214 py::repr(py::cast(elementType)).cast<std::string>() +
215 "' and expected floating point or integer type.");
216 },
217 "Create a complex type");
218 c.def_property_readonly(
219 "element_type",
220 [](PyComplexType &self) -> PyType {
221 MlirType t = mlirComplexTypeGetElementType(self);
222 return PyType(self.getContext(), t);
223 },
224 "Returns element type.");
225 }
226 };
227
228 class PyShapedType : public PyConcreteType<PyShapedType> {
229 public:
230 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
231 static constexpr const char *pyClassName = "ShapedType";
232 using PyConcreteType::PyConcreteType;
233
bindDerived(ClassTy & c)234 static void bindDerived(ClassTy &c) {
235 c.def_property_readonly(
236 "element_type",
237 [](PyShapedType &self) {
238 MlirType t = mlirShapedTypeGetElementType(self);
239 return PyType(self.getContext(), t);
240 },
241 "Returns the element type of the shaped type.");
242 c.def_property_readonly(
243 "has_rank",
244 [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
245 "Returns whether the given shaped type is ranked.");
246 c.def_property_readonly(
247 "rank",
248 [](PyShapedType &self) {
249 self.requireHasRank();
250 return mlirShapedTypeGetRank(self);
251 },
252 "Returns the rank of the given ranked shaped type.");
253 c.def_property_readonly(
254 "has_static_shape",
255 [](PyShapedType &self) -> bool {
256 return mlirShapedTypeHasStaticShape(self);
257 },
258 "Returns whether the given shaped type has a static shape.");
259 c.def(
260 "is_dynamic_dim",
261 [](PyShapedType &self, intptr_t dim) -> bool {
262 self.requireHasRank();
263 return mlirShapedTypeIsDynamicDim(self, dim);
264 },
265 py::arg("dim"),
266 "Returns whether the dim-th dimension of the given shaped type is "
267 "dynamic.");
268 c.def(
269 "get_dim_size",
270 [](PyShapedType &self, intptr_t dim) {
271 self.requireHasRank();
272 return mlirShapedTypeGetDimSize(self, dim);
273 },
274 py::arg("dim"),
275 "Returns the dim-th dimension of the given ranked shaped type.");
276 c.def_static(
277 "is_dynamic_size",
278 [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
279 py::arg("dim_size"),
280 "Returns whether the given dimension size indicates a dynamic "
281 "dimension.");
282 c.def(
283 "is_dynamic_stride_or_offset",
284 [](PyShapedType &self, int64_t val) -> bool {
285 self.requireHasRank();
286 return mlirShapedTypeIsDynamicStrideOrOffset(val);
287 },
288 py::arg("dim_size"),
289 "Returns whether the given value is used as a placeholder for dynamic "
290 "strides and offsets in shaped types.");
291 c.def_property_readonly(
292 "shape",
293 [](PyShapedType &self) {
294 self.requireHasRank();
295
296 std::vector<int64_t> shape;
297 int64_t rank = mlirShapedTypeGetRank(self);
298 shape.reserve(rank);
299 for (int64_t i = 0; i < rank; ++i)
300 shape.push_back(mlirShapedTypeGetDimSize(self, i));
301 return shape;
302 },
303 "Returns the shape of the ranked shaped type as a list of integers.");
304 c.def_static(
305 "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
306 "Returns the value used to indicate dynamic dimensions in shaped "
307 "types.");
308 c.def_static(
309 "_get_dynamic_stride_or_offset",
310 []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
311 "Returns the value used to indicate dynamic strides or offsets in "
312 "shaped types.");
313 }
314
315 private:
requireHasRank()316 void requireHasRank() {
317 if (!mlirShapedTypeHasRank(*this)) {
318 throw SetPyError(
319 PyExc_ValueError,
320 "calling this method requires that the type has a rank.");
321 }
322 }
323 };
324
325 /// Vector Type subclass - VectorType.
326 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
327 public:
328 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
329 static constexpr const char *pyClassName = "VectorType";
330 using PyConcreteType::PyConcreteType;
331
bindDerived(ClassTy & c)332 static void bindDerived(ClassTy &c) {
333 c.def_static(
334 "get",
335 [](std::vector<int64_t> shape, PyType &elementType,
336 DefaultingPyLocation loc) {
337 MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
338 elementType);
339 // TODO: Rework error reporting once diagnostic engine is exposed
340 // in C API.
341 if (mlirTypeIsNull(t)) {
342 throw SetPyError(
343 PyExc_ValueError,
344 Twine("invalid '") +
345 py::repr(py::cast(elementType)).cast<std::string>() +
346 "' and expected floating point or integer type.");
347 }
348 return PyVectorType(elementType.getContext(), t);
349 },
350 py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
351 "Create a vector type");
352 }
353 };
354
355 /// Ranked Tensor Type subclass - RankedTensorType.
356 class PyRankedTensorType
357 : public PyConcreteType<PyRankedTensorType, PyShapedType> {
358 public:
359 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
360 static constexpr const char *pyClassName = "RankedTensorType";
361 using PyConcreteType::PyConcreteType;
362
bindDerived(ClassTy & c)363 static void bindDerived(ClassTy &c) {
364 c.def_static(
365 "get",
366 [](std::vector<int64_t> shape, PyType &elementType,
367 llvm::Optional<PyAttribute> &encodingAttr,
368 DefaultingPyLocation loc) {
369 MlirType t = mlirRankedTensorTypeGetChecked(
370 loc, shape.size(), shape.data(), elementType,
371 encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
372 // TODO: Rework error reporting once diagnostic engine is exposed
373 // in C API.
374 if (mlirTypeIsNull(t)) {
375 throw SetPyError(
376 PyExc_ValueError,
377 Twine("invalid '") +
378 py::repr(py::cast(elementType)).cast<std::string>() +
379 "' and expected floating point, integer, vector or "
380 "complex "
381 "type.");
382 }
383 return PyRankedTensorType(elementType.getContext(), t);
384 },
385 py::arg("shape"), py::arg("element_type"),
386 py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
387 "Create a ranked tensor type");
388 c.def_property_readonly(
389 "encoding",
390 [](PyRankedTensorType &self) -> llvm::Optional<PyAttribute> {
391 MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
392 if (mlirAttributeIsNull(encoding))
393 return llvm::None;
394 return PyAttribute(self.getContext(), encoding);
395 });
396 }
397 };
398
399 /// Unranked Tensor Type subclass - UnrankedTensorType.
400 class PyUnrankedTensorType
401 : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
402 public:
403 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
404 static constexpr const char *pyClassName = "UnrankedTensorType";
405 using PyConcreteType::PyConcreteType;
406
bindDerived(ClassTy & c)407 static void bindDerived(ClassTy &c) {
408 c.def_static(
409 "get",
410 [](PyType &elementType, DefaultingPyLocation loc) {
411 MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
412 // TODO: Rework error reporting once diagnostic engine is exposed
413 // in C API.
414 if (mlirTypeIsNull(t)) {
415 throw SetPyError(
416 PyExc_ValueError,
417 Twine("invalid '") +
418 py::repr(py::cast(elementType)).cast<std::string>() +
419 "' and expected floating point, integer, vector or "
420 "complex "
421 "type.");
422 }
423 return PyUnrankedTensorType(elementType.getContext(), t);
424 },
425 py::arg("element_type"), py::arg("loc") = py::none(),
426 "Create a unranked tensor type");
427 }
428 };
429
430 /// Ranked MemRef Type subclass - MemRefType.
431 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
432 public:
433 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
434 static constexpr const char *pyClassName = "MemRefType";
435 using PyConcreteType::PyConcreteType;
436
bindDerived(ClassTy & c)437 static void bindDerived(ClassTy &c) {
438 c.def_static(
439 "get",
440 [](std::vector<int64_t> shape, PyType &elementType,
441 PyAttribute *layout, PyAttribute *memorySpace,
442 DefaultingPyLocation loc) {
443 MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
444 MlirAttribute memSpaceAttr =
445 memorySpace ? *memorySpace : mlirAttributeGetNull();
446 MlirType t =
447 mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
448 shape.data(), layoutAttr, memSpaceAttr);
449 // TODO: Rework error reporting once diagnostic engine is exposed
450 // in C API.
451 if (mlirTypeIsNull(t)) {
452 throw SetPyError(
453 PyExc_ValueError,
454 Twine("invalid '") +
455 py::repr(py::cast(elementType)).cast<std::string>() +
456 "' and expected floating point, integer, vector or "
457 "complex "
458 "type.");
459 }
460 return PyMemRefType(elementType.getContext(), t);
461 },
462 py::arg("shape"), py::arg("element_type"),
463 py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
464 py::arg("loc") = py::none(), "Create a memref type")
465 .def_property_readonly(
466 "layout",
467 [](PyMemRefType &self) -> PyAttribute {
468 MlirAttribute layout = mlirMemRefTypeGetLayout(self);
469 return PyAttribute(self.getContext(), layout);
470 },
471 "The layout of the MemRef type.")
472 .def_property_readonly(
473 "affine_map",
474 [](PyMemRefType &self) -> PyAffineMap {
475 MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
476 return PyAffineMap(self.getContext(), map);
477 },
478 "The layout of the MemRef type as an affine map.")
479 .def_property_readonly(
480 "memory_space",
481 [](PyMemRefType &self) -> PyAttribute {
482 MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
483 return PyAttribute(self.getContext(), a);
484 },
485 "Returns the memory space of the given MemRef type.");
486 }
487 };
488
489 /// Unranked MemRef Type subclass - UnrankedMemRefType.
490 class PyUnrankedMemRefType
491 : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
492 public:
493 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
494 static constexpr const char *pyClassName = "UnrankedMemRefType";
495 using PyConcreteType::PyConcreteType;
496
bindDerived(ClassTy & c)497 static void bindDerived(ClassTy &c) {
498 c.def_static(
499 "get",
500 [](PyType &elementType, PyAttribute *memorySpace,
501 DefaultingPyLocation loc) {
502 MlirAttribute memSpaceAttr = {};
503 if (memorySpace)
504 memSpaceAttr = *memorySpace;
505
506 MlirType t =
507 mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
508 // TODO: Rework error reporting once diagnostic engine is exposed
509 // in C API.
510 if (mlirTypeIsNull(t)) {
511 throw SetPyError(
512 PyExc_ValueError,
513 Twine("invalid '") +
514 py::repr(py::cast(elementType)).cast<std::string>() +
515 "' and expected floating point, integer, vector or "
516 "complex "
517 "type.");
518 }
519 return PyUnrankedMemRefType(elementType.getContext(), t);
520 },
521 py::arg("element_type"), py::arg("memory_space"),
522 py::arg("loc") = py::none(), "Create a unranked memref type")
523 .def_property_readonly(
524 "memory_space",
525 [](PyUnrankedMemRefType &self) -> PyAttribute {
526 MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
527 return PyAttribute(self.getContext(), a);
528 },
529 "Returns the memory space of the given Unranked MemRef type.");
530 }
531 };
532
533 /// Tuple Type subclass - TupleType.
534 class PyTupleType : public PyConcreteType<PyTupleType> {
535 public:
536 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
537 static constexpr const char *pyClassName = "TupleType";
538 using PyConcreteType::PyConcreteType;
539
bindDerived(ClassTy & c)540 static void bindDerived(ClassTy &c) {
541 c.def_static(
542 "get_tuple",
543 [](py::list elementList, DefaultingPyMlirContext context) {
544 intptr_t num = py::len(elementList);
545 // Mapping py::list to SmallVector.
546 SmallVector<MlirType, 4> elements;
547 for (auto element : elementList)
548 elements.push_back(element.cast<PyType>());
549 MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
550 return PyTupleType(context->getRef(), t);
551 },
552 py::arg("elements"), py::arg("context") = py::none(),
553 "Create a tuple type");
554 c.def(
555 "get_type",
556 [](PyTupleType &self, intptr_t pos) -> PyType {
557 MlirType t = mlirTupleTypeGetType(self, pos);
558 return PyType(self.getContext(), t);
559 },
560 py::arg("pos"), "Returns the pos-th type in the tuple type.");
561 c.def_property_readonly(
562 "num_types",
563 [](PyTupleType &self) -> intptr_t {
564 return mlirTupleTypeGetNumTypes(self);
565 },
566 "Returns the number of types contained in a tuple.");
567 }
568 };
569
570 /// Function type.
571 class PyFunctionType : public PyConcreteType<PyFunctionType> {
572 public:
573 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
574 static constexpr const char *pyClassName = "FunctionType";
575 using PyConcreteType::PyConcreteType;
576
bindDerived(ClassTy & c)577 static void bindDerived(ClassTy &c) {
578 c.def_static(
579 "get",
580 [](std::vector<PyType> inputs, std::vector<PyType> results,
581 DefaultingPyMlirContext context) {
582 SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
583 SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
584 MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
585 inputsRaw.data(), resultsRaw.size(),
586 resultsRaw.data());
587 return PyFunctionType(context->getRef(), t);
588 },
589 py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
590 "Gets a FunctionType from a list of input and result types");
591 c.def_property_readonly(
592 "inputs",
593 [](PyFunctionType &self) {
594 MlirType t = self;
595 auto contextRef = self.getContext();
596 py::list types;
597 for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
598 ++i) {
599 types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
600 }
601 return types;
602 },
603 "Returns the list of input types in the FunctionType.");
604 c.def_property_readonly(
605 "results",
606 [](PyFunctionType &self) {
607 auto contextRef = self.getContext();
608 py::list types;
609 for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
610 ++i) {
611 types.append(
612 PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
613 }
614 return types;
615 },
616 "Returns the list of result types in the FunctionType.");
617 }
618 };
619
toMlirStringRef(const std::string & s)620 static MlirStringRef toMlirStringRef(const std::string &s) {
621 return mlirStringRefCreate(s.data(), s.size());
622 }
623
624 /// Opaque Type subclass - OpaqueType.
625 class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
626 public:
627 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
628 static constexpr const char *pyClassName = "OpaqueType";
629 using PyConcreteType::PyConcreteType;
630
bindDerived(ClassTy & c)631 static void bindDerived(ClassTy &c) {
632 c.def_static(
633 "get",
634 [](std::string dialectNamespace, std::string typeData,
635 DefaultingPyMlirContext context) {
636 MlirType type = mlirOpaqueTypeGet(context->get(),
637 toMlirStringRef(dialectNamespace),
638 toMlirStringRef(typeData));
639 return PyOpaqueType(context->getRef(), type);
640 },
641 py::arg("dialect_namespace"), py::arg("buffer"),
642 py::arg("context") = py::none(),
643 "Create an unregistered (opaque) dialect type.");
644 c.def_property_readonly(
645 "dialect_namespace",
646 [](PyOpaqueType &self) {
647 MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
648 return py::str(stringRef.data, stringRef.length);
649 },
650 "Returns the dialect namespace for the Opaque type as a string.");
651 c.def_property_readonly(
652 "data",
653 [](PyOpaqueType &self) {
654 MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
655 return py::str(stringRef.data, stringRef.length);
656 },
657 "Returns the data for the Opaque type as a string.");
658 }
659 };
660
661 } // namespace
662
populateIRTypes(py::module & m)663 void mlir::python::populateIRTypes(py::module &m) {
664 PyIntegerType::bind(m);
665 PyIndexType::bind(m);
666 PyBF16Type::bind(m);
667 PyF16Type::bind(m);
668 PyF32Type::bind(m);
669 PyF64Type::bind(m);
670 PyNoneType::bind(m);
671 PyComplexType::bind(m);
672 PyShapedType::bind(m);
673 PyVectorType::bind(m);
674 PyRankedTensorType::bind(m);
675 PyUnrankedTensorType::bind(m);
676 PyMemRefType::bind(m);
677 PyUnrankedMemRefType::bind(m);
678 PyTupleType::bind(m);
679 PyFunctionType::bind(m);
680 PyOpaqueType::bind(m);
681 }
682