1 //===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 using llvm::SmallVector;
21 using llvm::Twine;
22 
23 namespace {
24 
25 /// Checks whether the given type is an integer or float type.
mlirTypeIsAIntegerOrFloat(MlirType type)26 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
27   return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
28          mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
29 }
30 
31 class PyIntegerType : public PyConcreteType<PyIntegerType> {
32 public:
33   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
34   static constexpr const char *pyClassName = "IntegerType";
35   using PyConcreteType::PyConcreteType;
36 
bindDerived(ClassTy & c)37   static void bindDerived(ClassTy &c) {
38     c.def_static(
39         "get_signless",
40         [](unsigned width, DefaultingPyMlirContext context) {
41           MlirType t = mlirIntegerTypeGet(context->get(), width);
42           return PyIntegerType(context->getRef(), t);
43         },
44         py::arg("width"), py::arg("context") = py::none(),
45         "Create a signless integer type");
46     c.def_static(
47         "get_signed",
48         [](unsigned width, DefaultingPyMlirContext context) {
49           MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
50           return PyIntegerType(context->getRef(), t);
51         },
52         py::arg("width"), py::arg("context") = py::none(),
53         "Create a signed integer type");
54     c.def_static(
55         "get_unsigned",
56         [](unsigned width, DefaultingPyMlirContext context) {
57           MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
58           return PyIntegerType(context->getRef(), t);
59         },
60         py::arg("width"), py::arg("context") = py::none(),
61         "Create an unsigned integer type");
62     c.def_property_readonly(
63         "width",
64         [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
65         "Returns the width of the integer type");
66     c.def_property_readonly(
67         "is_signless",
68         [](PyIntegerType &self) -> bool {
69           return mlirIntegerTypeIsSignless(self);
70         },
71         "Returns whether this is a signless integer");
72     c.def_property_readonly(
73         "is_signed",
74         [](PyIntegerType &self) -> bool {
75           return mlirIntegerTypeIsSigned(self);
76         },
77         "Returns whether this is a signed integer");
78     c.def_property_readonly(
79         "is_unsigned",
80         [](PyIntegerType &self) -> bool {
81           return mlirIntegerTypeIsUnsigned(self);
82         },
83         "Returns whether this is an unsigned integer");
84   }
85 };
86 
87 /// Index Type subclass - IndexType.
88 class PyIndexType : public PyConcreteType<PyIndexType> {
89 public:
90   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
91   static constexpr const char *pyClassName = "IndexType";
92   using PyConcreteType::PyConcreteType;
93 
bindDerived(ClassTy & c)94   static void bindDerived(ClassTy &c) {
95     c.def_static(
96         "get",
97         [](DefaultingPyMlirContext context) {
98           MlirType t = mlirIndexTypeGet(context->get());
99           return PyIndexType(context->getRef(), t);
100         },
101         py::arg("context") = py::none(), "Create a index type.");
102   }
103 };
104 
105 /// Floating Point Type subclass - BF16Type.
106 class PyBF16Type : public PyConcreteType<PyBF16Type> {
107 public:
108   static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
109   static constexpr const char *pyClassName = "BF16Type";
110   using PyConcreteType::PyConcreteType;
111 
bindDerived(ClassTy & c)112   static void bindDerived(ClassTy &c) {
113     c.def_static(
114         "get",
115         [](DefaultingPyMlirContext context) {
116           MlirType t = mlirBF16TypeGet(context->get());
117           return PyBF16Type(context->getRef(), t);
118         },
119         py::arg("context") = py::none(), "Create a bf16 type.");
120   }
121 };
122 
123 /// Floating Point Type subclass - F16Type.
124 class PyF16Type : public PyConcreteType<PyF16Type> {
125 public:
126   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
127   static constexpr const char *pyClassName = "F16Type";
128   using PyConcreteType::PyConcreteType;
129 
bindDerived(ClassTy & c)130   static void bindDerived(ClassTy &c) {
131     c.def_static(
132         "get",
133         [](DefaultingPyMlirContext context) {
134           MlirType t = mlirF16TypeGet(context->get());
135           return PyF16Type(context->getRef(), t);
136         },
137         py::arg("context") = py::none(), "Create a f16 type.");
138   }
139 };
140 
141 /// Floating Point Type subclass - F32Type.
142 class PyF32Type : public PyConcreteType<PyF32Type> {
143 public:
144   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
145   static constexpr const char *pyClassName = "F32Type";
146   using PyConcreteType::PyConcreteType;
147 
bindDerived(ClassTy & c)148   static void bindDerived(ClassTy &c) {
149     c.def_static(
150         "get",
151         [](DefaultingPyMlirContext context) {
152           MlirType t = mlirF32TypeGet(context->get());
153           return PyF32Type(context->getRef(), t);
154         },
155         py::arg("context") = py::none(), "Create a f32 type.");
156   }
157 };
158 
159 /// Floating Point Type subclass - F64Type.
160 class PyF64Type : public PyConcreteType<PyF64Type> {
161 public:
162   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
163   static constexpr const char *pyClassName = "F64Type";
164   using PyConcreteType::PyConcreteType;
165 
bindDerived(ClassTy & c)166   static void bindDerived(ClassTy &c) {
167     c.def_static(
168         "get",
169         [](DefaultingPyMlirContext context) {
170           MlirType t = mlirF64TypeGet(context->get());
171           return PyF64Type(context->getRef(), t);
172         },
173         py::arg("context") = py::none(), "Create a f64 type.");
174   }
175 };
176 
177 /// None Type subclass - NoneType.
178 class PyNoneType : public PyConcreteType<PyNoneType> {
179 public:
180   static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
181   static constexpr const char *pyClassName = "NoneType";
182   using PyConcreteType::PyConcreteType;
183 
bindDerived(ClassTy & c)184   static void bindDerived(ClassTy &c) {
185     c.def_static(
186         "get",
187         [](DefaultingPyMlirContext context) {
188           MlirType t = mlirNoneTypeGet(context->get());
189           return PyNoneType(context->getRef(), t);
190         },
191         py::arg("context") = py::none(), "Create a none type.");
192   }
193 };
194 
195 /// Complex Type subclass - ComplexType.
196 class PyComplexType : public PyConcreteType<PyComplexType> {
197 public:
198   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
199   static constexpr const char *pyClassName = "ComplexType";
200   using PyConcreteType::PyConcreteType;
201 
bindDerived(ClassTy & c)202   static void bindDerived(ClassTy &c) {
203     c.def_static(
204         "get",
205         [](PyType &elementType) {
206           // The element must be a floating point or integer scalar type.
207           if (mlirTypeIsAIntegerOrFloat(elementType)) {
208             MlirType t = mlirComplexTypeGet(elementType);
209             return PyComplexType(elementType.getContext(), t);
210           }
211           throw SetPyError(
212               PyExc_ValueError,
213               Twine("invalid '") +
214                   py::repr(py::cast(elementType)).cast<std::string>() +
215                   "' and expected floating point or integer type.");
216         },
217         "Create a complex type");
218     c.def_property_readonly(
219         "element_type",
220         [](PyComplexType &self) -> PyType {
221           MlirType t = mlirComplexTypeGetElementType(self);
222           return PyType(self.getContext(), t);
223         },
224         "Returns element type.");
225   }
226 };
227 
228 class PyShapedType : public PyConcreteType<PyShapedType> {
229 public:
230   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
231   static constexpr const char *pyClassName = "ShapedType";
232   using PyConcreteType::PyConcreteType;
233 
bindDerived(ClassTy & c)234   static void bindDerived(ClassTy &c) {
235     c.def_property_readonly(
236         "element_type",
237         [](PyShapedType &self) {
238           MlirType t = mlirShapedTypeGetElementType(self);
239           return PyType(self.getContext(), t);
240         },
241         "Returns the element type of the shaped type.");
242     c.def_property_readonly(
243         "has_rank",
244         [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
245         "Returns whether the given shaped type is ranked.");
246     c.def_property_readonly(
247         "rank",
248         [](PyShapedType &self) {
249           self.requireHasRank();
250           return mlirShapedTypeGetRank(self);
251         },
252         "Returns the rank of the given ranked shaped type.");
253     c.def_property_readonly(
254         "has_static_shape",
255         [](PyShapedType &self) -> bool {
256           return mlirShapedTypeHasStaticShape(self);
257         },
258         "Returns whether the given shaped type has a static shape.");
259     c.def(
260         "is_dynamic_dim",
261         [](PyShapedType &self, intptr_t dim) -> bool {
262           self.requireHasRank();
263           return mlirShapedTypeIsDynamicDim(self, dim);
264         },
265         py::arg("dim"),
266         "Returns whether the dim-th dimension of the given shaped type is "
267         "dynamic.");
268     c.def(
269         "get_dim_size",
270         [](PyShapedType &self, intptr_t dim) {
271           self.requireHasRank();
272           return mlirShapedTypeGetDimSize(self, dim);
273         },
274         py::arg("dim"),
275         "Returns the dim-th dimension of the given ranked shaped type.");
276     c.def_static(
277         "is_dynamic_size",
278         [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
279         py::arg("dim_size"),
280         "Returns whether the given dimension size indicates a dynamic "
281         "dimension.");
282     c.def(
283         "is_dynamic_stride_or_offset",
284         [](PyShapedType &self, int64_t val) -> bool {
285           self.requireHasRank();
286           return mlirShapedTypeIsDynamicStrideOrOffset(val);
287         },
288         py::arg("dim_size"),
289         "Returns whether the given value is used as a placeholder for dynamic "
290         "strides and offsets in shaped types.");
291     c.def_property_readonly(
292         "shape",
293         [](PyShapedType &self) {
294           self.requireHasRank();
295 
296           std::vector<int64_t> shape;
297           int64_t rank = mlirShapedTypeGetRank(self);
298           shape.reserve(rank);
299           for (int64_t i = 0; i < rank; ++i)
300             shape.push_back(mlirShapedTypeGetDimSize(self, i));
301           return shape;
302         },
303         "Returns the shape of the ranked shaped type as a list of integers.");
304     c.def_static(
305         "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
306         "Returns the value used to indicate dynamic dimensions in shaped "
307         "types.");
308     c.def_static(
309         "_get_dynamic_stride_or_offset",
310         []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
311         "Returns the value used to indicate dynamic strides or offsets in "
312         "shaped types.");
313   }
314 
315 private:
requireHasRank()316   void requireHasRank() {
317     if (!mlirShapedTypeHasRank(*this)) {
318       throw SetPyError(
319           PyExc_ValueError,
320           "calling this method requires that the type has a rank.");
321     }
322   }
323 };
324 
325 /// Vector Type subclass - VectorType.
326 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
327 public:
328   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
329   static constexpr const char *pyClassName = "VectorType";
330   using PyConcreteType::PyConcreteType;
331 
bindDerived(ClassTy & c)332   static void bindDerived(ClassTy &c) {
333     c.def_static(
334         "get",
335         [](std::vector<int64_t> shape, PyType &elementType,
336            DefaultingPyLocation loc) {
337           MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
338                                                 elementType);
339           // TODO: Rework error reporting once diagnostic engine is exposed
340           // in C API.
341           if (mlirTypeIsNull(t)) {
342             throw SetPyError(
343                 PyExc_ValueError,
344                 Twine("invalid '") +
345                     py::repr(py::cast(elementType)).cast<std::string>() +
346                     "' and expected floating point or integer type.");
347           }
348           return PyVectorType(elementType.getContext(), t);
349         },
350         py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
351         "Create a vector type");
352   }
353 };
354 
355 /// Ranked Tensor Type subclass - RankedTensorType.
356 class PyRankedTensorType
357     : public PyConcreteType<PyRankedTensorType, PyShapedType> {
358 public:
359   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
360   static constexpr const char *pyClassName = "RankedTensorType";
361   using PyConcreteType::PyConcreteType;
362 
bindDerived(ClassTy & c)363   static void bindDerived(ClassTy &c) {
364     c.def_static(
365         "get",
366         [](std::vector<int64_t> shape, PyType &elementType,
367            llvm::Optional<PyAttribute> &encodingAttr,
368            DefaultingPyLocation loc) {
369           MlirType t = mlirRankedTensorTypeGetChecked(
370               loc, shape.size(), shape.data(), elementType,
371               encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
372           // TODO: Rework error reporting once diagnostic engine is exposed
373           // in C API.
374           if (mlirTypeIsNull(t)) {
375             throw SetPyError(
376                 PyExc_ValueError,
377                 Twine("invalid '") +
378                     py::repr(py::cast(elementType)).cast<std::string>() +
379                     "' and expected floating point, integer, vector or "
380                     "complex "
381                     "type.");
382           }
383           return PyRankedTensorType(elementType.getContext(), t);
384         },
385         py::arg("shape"), py::arg("element_type"),
386         py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
387         "Create a ranked tensor type");
388     c.def_property_readonly(
389         "encoding",
390         [](PyRankedTensorType &self) -> llvm::Optional<PyAttribute> {
391           MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
392           if (mlirAttributeIsNull(encoding))
393             return llvm::None;
394           return PyAttribute(self.getContext(), encoding);
395         });
396   }
397 };
398 
399 /// Unranked Tensor Type subclass - UnrankedTensorType.
400 class PyUnrankedTensorType
401     : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
402 public:
403   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
404   static constexpr const char *pyClassName = "UnrankedTensorType";
405   using PyConcreteType::PyConcreteType;
406 
bindDerived(ClassTy & c)407   static void bindDerived(ClassTy &c) {
408     c.def_static(
409         "get",
410         [](PyType &elementType, DefaultingPyLocation loc) {
411           MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
412           // TODO: Rework error reporting once diagnostic engine is exposed
413           // in C API.
414           if (mlirTypeIsNull(t)) {
415             throw SetPyError(
416                 PyExc_ValueError,
417                 Twine("invalid '") +
418                     py::repr(py::cast(elementType)).cast<std::string>() +
419                     "' and expected floating point, integer, vector or "
420                     "complex "
421                     "type.");
422           }
423           return PyUnrankedTensorType(elementType.getContext(), t);
424         },
425         py::arg("element_type"), py::arg("loc") = py::none(),
426         "Create a unranked tensor type");
427   }
428 };
429 
430 /// Ranked MemRef Type subclass - MemRefType.
431 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
432 public:
433   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
434   static constexpr const char *pyClassName = "MemRefType";
435   using PyConcreteType::PyConcreteType;
436 
bindDerived(ClassTy & c)437   static void bindDerived(ClassTy &c) {
438     c.def_static(
439          "get",
440          [](std::vector<int64_t> shape, PyType &elementType,
441             PyAttribute *layout, PyAttribute *memorySpace,
442             DefaultingPyLocation loc) {
443            MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
444            MlirAttribute memSpaceAttr =
445                memorySpace ? *memorySpace : mlirAttributeGetNull();
446            MlirType t =
447                mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
448                                         shape.data(), layoutAttr, memSpaceAttr);
449            // TODO: Rework error reporting once diagnostic engine is exposed
450            // in C API.
451            if (mlirTypeIsNull(t)) {
452              throw SetPyError(
453                  PyExc_ValueError,
454                  Twine("invalid '") +
455                      py::repr(py::cast(elementType)).cast<std::string>() +
456                      "' and expected floating point, integer, vector or "
457                      "complex "
458                      "type.");
459            }
460            return PyMemRefType(elementType.getContext(), t);
461          },
462          py::arg("shape"), py::arg("element_type"),
463          py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
464          py::arg("loc") = py::none(), "Create a memref type")
465         .def_property_readonly(
466             "layout",
467             [](PyMemRefType &self) -> PyAttribute {
468               MlirAttribute layout = mlirMemRefTypeGetLayout(self);
469               return PyAttribute(self.getContext(), layout);
470             },
471             "The layout of the MemRef type.")
472         .def_property_readonly(
473             "affine_map",
474             [](PyMemRefType &self) -> PyAffineMap {
475               MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
476               return PyAffineMap(self.getContext(), map);
477             },
478             "The layout of the MemRef type as an affine map.")
479         .def_property_readonly(
480             "memory_space",
481             [](PyMemRefType &self) -> PyAttribute {
482               MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
483               return PyAttribute(self.getContext(), a);
484             },
485             "Returns the memory space of the given MemRef type.");
486   }
487 };
488 
489 /// Unranked MemRef Type subclass - UnrankedMemRefType.
490 class PyUnrankedMemRefType
491     : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
492 public:
493   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
494   static constexpr const char *pyClassName = "UnrankedMemRefType";
495   using PyConcreteType::PyConcreteType;
496 
bindDerived(ClassTy & c)497   static void bindDerived(ClassTy &c) {
498     c.def_static(
499          "get",
500          [](PyType &elementType, PyAttribute *memorySpace,
501             DefaultingPyLocation loc) {
502            MlirAttribute memSpaceAttr = {};
503            if (memorySpace)
504              memSpaceAttr = *memorySpace;
505 
506            MlirType t =
507                mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
508            // TODO: Rework error reporting once diagnostic engine is exposed
509            // in C API.
510            if (mlirTypeIsNull(t)) {
511              throw SetPyError(
512                  PyExc_ValueError,
513                  Twine("invalid '") +
514                      py::repr(py::cast(elementType)).cast<std::string>() +
515                      "' and expected floating point, integer, vector or "
516                      "complex "
517                      "type.");
518            }
519            return PyUnrankedMemRefType(elementType.getContext(), t);
520          },
521          py::arg("element_type"), py::arg("memory_space"),
522          py::arg("loc") = py::none(), "Create a unranked memref type")
523         .def_property_readonly(
524             "memory_space",
525             [](PyUnrankedMemRefType &self) -> PyAttribute {
526               MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
527               return PyAttribute(self.getContext(), a);
528             },
529             "Returns the memory space of the given Unranked MemRef type.");
530   }
531 };
532 
533 /// Tuple Type subclass - TupleType.
534 class PyTupleType : public PyConcreteType<PyTupleType> {
535 public:
536   static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
537   static constexpr const char *pyClassName = "TupleType";
538   using PyConcreteType::PyConcreteType;
539 
bindDerived(ClassTy & c)540   static void bindDerived(ClassTy &c) {
541     c.def_static(
542         "get_tuple",
543         [](py::list elementList, DefaultingPyMlirContext context) {
544           intptr_t num = py::len(elementList);
545           // Mapping py::list to SmallVector.
546           SmallVector<MlirType, 4> elements;
547           for (auto element : elementList)
548             elements.push_back(element.cast<PyType>());
549           MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
550           return PyTupleType(context->getRef(), t);
551         },
552         py::arg("elements"), py::arg("context") = py::none(),
553         "Create a tuple type");
554     c.def(
555         "get_type",
556         [](PyTupleType &self, intptr_t pos) -> PyType {
557           MlirType t = mlirTupleTypeGetType(self, pos);
558           return PyType(self.getContext(), t);
559         },
560         py::arg("pos"), "Returns the pos-th type in the tuple type.");
561     c.def_property_readonly(
562         "num_types",
563         [](PyTupleType &self) -> intptr_t {
564           return mlirTupleTypeGetNumTypes(self);
565         },
566         "Returns the number of types contained in a tuple.");
567   }
568 };
569 
570 /// Function type.
571 class PyFunctionType : public PyConcreteType<PyFunctionType> {
572 public:
573   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
574   static constexpr const char *pyClassName = "FunctionType";
575   using PyConcreteType::PyConcreteType;
576 
bindDerived(ClassTy & c)577   static void bindDerived(ClassTy &c) {
578     c.def_static(
579         "get",
580         [](std::vector<PyType> inputs, std::vector<PyType> results,
581            DefaultingPyMlirContext context) {
582           SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
583           SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
584           MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
585                                            inputsRaw.data(), resultsRaw.size(),
586                                            resultsRaw.data());
587           return PyFunctionType(context->getRef(), t);
588         },
589         py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
590         "Gets a FunctionType from a list of input and result types");
591     c.def_property_readonly(
592         "inputs",
593         [](PyFunctionType &self) {
594           MlirType t = self;
595           auto contextRef = self.getContext();
596           py::list types;
597           for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
598                ++i) {
599             types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
600           }
601           return types;
602         },
603         "Returns the list of input types in the FunctionType.");
604     c.def_property_readonly(
605         "results",
606         [](PyFunctionType &self) {
607           auto contextRef = self.getContext();
608           py::list types;
609           for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
610                ++i) {
611             types.append(
612                 PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
613           }
614           return types;
615         },
616         "Returns the list of result types in the FunctionType.");
617   }
618 };
619 
toMlirStringRef(const std::string & s)620 static MlirStringRef toMlirStringRef(const std::string &s) {
621   return mlirStringRefCreate(s.data(), s.size());
622 }
623 
624 /// Opaque Type subclass - OpaqueType.
625 class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
626 public:
627   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
628   static constexpr const char *pyClassName = "OpaqueType";
629   using PyConcreteType::PyConcreteType;
630 
bindDerived(ClassTy & c)631   static void bindDerived(ClassTy &c) {
632     c.def_static(
633         "get",
634         [](std::string dialectNamespace, std::string typeData,
635            DefaultingPyMlirContext context) {
636           MlirType type = mlirOpaqueTypeGet(context->get(),
637                                             toMlirStringRef(dialectNamespace),
638                                             toMlirStringRef(typeData));
639           return PyOpaqueType(context->getRef(), type);
640         },
641         py::arg("dialect_namespace"), py::arg("buffer"),
642         py::arg("context") = py::none(),
643         "Create an unregistered (opaque) dialect type.");
644     c.def_property_readonly(
645         "dialect_namespace",
646         [](PyOpaqueType &self) {
647           MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
648           return py::str(stringRef.data, stringRef.length);
649         },
650         "Returns the dialect namespace for the Opaque type as a string.");
651     c.def_property_readonly(
652         "data",
653         [](PyOpaqueType &self) {
654           MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
655           return py::str(stringRef.data, stringRef.length);
656         },
657         "Returns the data for the Opaque type as a string.");
658   }
659 };
660 
661 } // namespace
662 
populateIRTypes(py::module & m)663 void mlir::python::populateIRTypes(py::module &m) {
664   PyIntegerType::bind(m);
665   PyIndexType::bind(m);
666   PyBF16Type::bind(m);
667   PyF16Type::bind(m);
668   PyF32Type::bind(m);
669   PyF64Type::bind(m);
670   PyNoneType::bind(m);
671   PyComplexType::bind(m);
672   PyShapedType::bind(m);
673   PyVectorType::bind(m);
674   PyRankedTensorType::bind(m);
675   PyUnrankedTensorType::bind(m);
676   PyMemRefType::bind(m);
677   PyUnrankedMemRefType::bind(m);
678   PyTupleType::bind(m);
679   PyFunctionType::bind(m);
680   PyOpaqueType::bind(m);
681 }
682