1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14# CHECK-LABEL: TEST: testParsePrint
15@run
16def testParsePrint():
17  ctx = Context()
18  t = Type.parse("i32", ctx)
19  assert t.context is ctx
20  ctx = None
21  gc.collect()
22  # CHECK: i32
23  print(str(t))
24  # CHECK: Type(i32)
25  print(repr(t))
26
27
28# CHECK-LABEL: TEST: testParseError
29# TODO: Hook the diagnostic manager to capture a more meaningful error
30# message.
31@run
32def testParseError():
33  ctx = Context()
34  try:
35    t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
36  except ValueError as e:
37    # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST'
38    print("testParseError:", e)
39  else:
40    print("Exception not produced")
41
42
43# CHECK-LABEL: TEST: testTypeEq
44@run
45def testTypeEq():
46  ctx = Context()
47  t1 = Type.parse("i32", ctx)
48  t2 = Type.parse("f32", ctx)
49  t3 = Type.parse("i32", ctx)
50  # CHECK: t1 == t1: True
51  print("t1 == t1:", t1 == t1)
52  # CHECK: t1 == t2: False
53  print("t1 == t2:", t1 == t2)
54  # CHECK: t1 == t3: True
55  print("t1 == t3:", t1 == t3)
56  # CHECK: t1 == None: False
57  print("t1 == None:", t1 == None)
58
59
60# CHECK-LABEL: TEST: testTypeHash
61@run
62def testTypeHash():
63  ctx = Context()
64  t1 = Type.parse("i32", ctx)
65  t2 = Type.parse("f32", ctx)
66  t3 = Type.parse("i32", ctx)
67
68  # CHECK: hash(t1) == hash(t3): True
69  print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
70
71  s = set()
72  s.add(t1)
73  s.add(t2)
74  s.add(t3)
75  # CHECK: len(s): 2
76  print("len(s): ", len(s))
77
78# CHECK-LABEL: TEST: testTypeCast
79@run
80def testTypeCast():
81  ctx = Context()
82  t1 = Type.parse("i32", ctx)
83  t2 = Type(t1)
84  # CHECK: t1 == t2: True
85  print("t1 == t2:", t1 == t2)
86
87
88# CHECK-LABEL: TEST: testTypeIsInstance
89@run
90def testTypeIsInstance():
91  ctx = Context()
92  t1 = Type.parse("i32", ctx)
93  t2 = Type.parse("f32", ctx)
94  # CHECK: True
95  print(IntegerType.isinstance(t1))
96  # CHECK: False
97  print(F32Type.isinstance(t1))
98  # CHECK: True
99  print(F32Type.isinstance(t2))
100
101
102# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
103@run
104def testTypeEqDoesNotRaise():
105  ctx = Context()
106  t1 = Type.parse("i32", ctx)
107  not_a_type = "foo"
108  # CHECK: False
109  print(t1 == not_a_type)
110  # CHECK: False
111  print(t1 == None)
112  # CHECK: True
113  print(t1 != None)
114
115
116# CHECK-LABEL: TEST: testTypeCapsule
117@run
118def testTypeCapsule():
119  with Context() as ctx:
120    t1 = Type.parse("i32", ctx)
121  # CHECK: mlir.ir.Type._CAPIPtr
122  type_capsule = t1._CAPIPtr
123  print(type_capsule)
124  t2 = Type._CAPICreate(type_capsule)
125  assert t2 == t1
126  assert t2.context is ctx
127
128
129# CHECK-LABEL: TEST: testStandardTypeCasts
130@run
131def testStandardTypeCasts():
132  ctx = Context()
133  t1 = Type.parse("i32", ctx)
134  tint = IntegerType(t1)
135  tself = IntegerType(tint)
136  # CHECK: Type(i32)
137  print(repr(tint))
138  try:
139    tillegal = IntegerType(Type.parse("f32", ctx))
140  except ValueError as e:
141    # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
142    print("ValueError:", e)
143  else:
144    print("Exception not produced")
145
146
147# CHECK-LABEL: TEST: testIntegerType
148@run
149def testIntegerType():
150  with Context() as ctx:
151    i32 = IntegerType(Type.parse("i32"))
152    # CHECK: i32 width: 32
153    print("i32 width:", i32.width)
154    # CHECK: i32 signless: True
155    print("i32 signless:", i32.is_signless)
156    # CHECK: i32 signed: False
157    print("i32 signed:", i32.is_signed)
158    # CHECK: i32 unsigned: False
159    print("i32 unsigned:", i32.is_unsigned)
160
161    s32 = IntegerType(Type.parse("si32"))
162    # CHECK: s32 signless: False
163    print("s32 signless:", s32.is_signless)
164    # CHECK: s32 signed: True
165    print("s32 signed:", s32.is_signed)
166    # CHECK: s32 unsigned: False
167    print("s32 unsigned:", s32.is_unsigned)
168
169    u32 = IntegerType(Type.parse("ui32"))
170    # CHECK: u32 signless: False
171    print("u32 signless:", u32.is_signless)
172    # CHECK: u32 signed: False
173    print("u32 signed:", u32.is_signed)
174    # CHECK: u32 unsigned: True
175    print("u32 unsigned:", u32.is_unsigned)
176
177    # CHECK: signless: i16
178    print("signless:", IntegerType.get_signless(16))
179    # CHECK: signed: si8
180    print("signed:", IntegerType.get_signed(8))
181    # CHECK: unsigned: ui64
182    print("unsigned:", IntegerType.get_unsigned(64))
183
184# CHECK-LABEL: TEST: testIndexType
185@run
186def testIndexType():
187  with Context() as ctx:
188    # CHECK: index type: index
189    print("index type:", IndexType.get())
190
191
192# CHECK-LABEL: TEST: testFloatType
193@run
194def testFloatType():
195  with Context():
196    # CHECK: float: bf16
197    print("float:", BF16Type.get())
198    # CHECK: float: f16
199    print("float:", F16Type.get())
200    # CHECK: float: f32
201    print("float:", F32Type.get())
202    # CHECK: float: f64
203    print("float:", F64Type.get())
204
205
206# CHECK-LABEL: TEST: testNoneType
207@run
208def testNoneType():
209  with Context():
210    # CHECK: none type: none
211    print("none type:", NoneType.get())
212
213
214# CHECK-LABEL: TEST: testComplexType
215@run
216def testComplexType():
217  with Context() as ctx:
218    complex_i32 = ComplexType(Type.parse("complex<i32>"))
219    # CHECK: complex type element: i32
220    print("complex type element:", complex_i32.element_type)
221
222    f32 = F32Type.get()
223    # CHECK: complex type: complex<f32>
224    print("complex type:", ComplexType.get(f32))
225
226    index = IndexType.get()
227    try:
228      complex_invalid = ComplexType.get(index)
229    except ValueError as e:
230      # CHECK: invalid 'Type(index)' and expected floating point or integer type.
231      print(e)
232    else:
233      print("Exception not produced")
234
235
236# CHECK-LABEL: TEST: testConcreteShapedType
237# Shaped type is not a kind of builtin types, it is the base class for vectors,
238# memrefs and tensors, so this test case uses an instance of vector to test the
239# shaped type. The class hierarchy is preserved on the python side.
240@run
241def testConcreteShapedType():
242  with Context() as ctx:
243    vector = VectorType(Type.parse("vector<2x3xf32>"))
244    # CHECK: element type: f32
245    print("element type:", vector.element_type)
246    # CHECK: whether the given shaped type is ranked: True
247    print("whether the given shaped type is ranked:", vector.has_rank)
248    # CHECK: rank: 2
249    print("rank:", vector.rank)
250    # CHECK: whether the shaped type has a static shape: True
251    print("whether the shaped type has a static shape:", vector.has_static_shape)
252    # CHECK: whether the dim-th dimension is dynamic: False
253    print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
254    # CHECK: dim size: 3
255    print("dim size:", vector.get_dim_size(1))
256    # CHECK: is_dynamic_size: False
257    print("is_dynamic_size:", vector.is_dynamic_size(3))
258    # CHECK: is_dynamic_stride_or_offset: False
259    print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
260    # CHECK: isinstance(ShapedType): True
261    print("isinstance(ShapedType):", isinstance(vector, ShapedType))
262
263
264# CHECK-LABEL: TEST: testAbstractShapedType
265# Tests that ShapedType operates as an abstract base class of a concrete
266# shaped type (using vector as an example).
267@run
268def testAbstractShapedType():
269  ctx = Context()
270  vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
271  # CHECK: element type: f32
272  print("element type:", vector.element_type)
273
274
275# CHECK-LABEL: TEST: testVectorType
276@run
277def testVectorType():
278  with Context(), Location.unknown():
279    f32 = F32Type.get()
280    shape = [2, 3]
281    # CHECK: vector type: vector<2x3xf32>
282    print("vector type:", VectorType.get(shape, f32))
283
284    none = NoneType.get()
285    try:
286      vector_invalid = VectorType.get(shape, none)
287    except ValueError as e:
288      # CHECK: invalid 'Type(none)' and expected floating point or integer type.
289      print(e)
290    else:
291      print("Exception not produced")
292
293
294# CHECK-LABEL: TEST: testRankedTensorType
295@run
296def testRankedTensorType():
297  with Context(), Location.unknown():
298    f32 = F32Type.get()
299    shape = [2, 3]
300    loc = Location.unknown()
301    # CHECK: ranked tensor type: tensor<2x3xf32>
302    print("ranked tensor type:",
303          RankedTensorType.get(shape, f32))
304
305    none = NoneType.get()
306    try:
307      tensor_invalid = RankedTensorType.get(shape, none)
308    except ValueError as e:
309      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
310      # CHECK: or complex type.
311      print(e)
312    else:
313      print("Exception not produced")
314
315    # Encoding should be None.
316    assert RankedTensorType.get(shape, f32).encoding is None
317
318    tensor = RankedTensorType.get(shape, f32)
319    assert tensor.shape == shape
320
321
322# CHECK-LABEL: TEST: testUnrankedTensorType
323@run
324def testUnrankedTensorType():
325  with Context(), Location.unknown():
326    f32 = F32Type.get()
327    loc = Location.unknown()
328    unranked_tensor = UnrankedTensorType.get(f32)
329    # CHECK: unranked tensor type: tensor<*xf32>
330    print("unranked tensor type:", unranked_tensor)
331    try:
332      invalid_rank = unranked_tensor.rank
333    except ValueError as e:
334      # CHECK: calling this method requires that the type has a rank.
335      print(e)
336    else:
337      print("Exception not produced")
338    try:
339      invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
340    except ValueError as e:
341      # CHECK: calling this method requires that the type has a rank.
342      print(e)
343    else:
344      print("Exception not produced")
345    try:
346      invalid_get_dim_size = unranked_tensor.get_dim_size(1)
347    except ValueError as e:
348      # CHECK: calling this method requires that the type has a rank.
349      print(e)
350    else:
351      print("Exception not produced")
352
353    none = NoneType.get()
354    try:
355      tensor_invalid = UnrankedTensorType.get(none)
356    except ValueError as e:
357      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
358      # CHECK: or complex type.
359      print(e)
360    else:
361      print("Exception not produced")
362
363
364# CHECK-LABEL: TEST: testMemRefType
365@run
366def testMemRefType():
367  with Context(), Location.unknown():
368    f32 = F32Type.get()
369    shape = [2, 3]
370    loc = Location.unknown()
371    memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
372    # CHECK: memref type: memref<2x3xf32, 2>
373    print("memref type:", memref)
374    # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
375    print("memref layout:", memref.layout)
376    # CHECK: memref affine map: (d0, d1) -> (d0, d1)
377    print("memref affine map:", memref.affine_map)
378    # CHECK: memory space: 2
379    print("memory space:", memref.memory_space)
380
381    layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
382    memref_layout = MemRefType.get(shape, f32, layout=layout)
383    # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
384    print("memref type:", memref_layout)
385    # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)>
386    print("memref layout:", memref_layout.layout)
387    # CHECK: memref affine map: (d0, d1) -> (d1, d0)
388    print("memref affine map:", memref_layout.affine_map)
389    # CHECK: memory space: <<NULL ATTRIBUTE>>
390    print("memory space:", memref_layout.memory_space)
391
392    none = NoneType.get()
393    try:
394      memref_invalid = MemRefType.get(shape, none)
395    except ValueError as e:
396      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
397      # CHECK: or complex type.
398      print(e)
399    else:
400      print("Exception not produced")
401
402    assert memref.shape == shape
403
404
405# CHECK-LABEL: TEST: testUnrankedMemRefType
406@run
407def testUnrankedMemRefType():
408  with Context(), Location.unknown():
409    f32 = F32Type.get()
410    loc = Location.unknown()
411    unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2"))
412    # CHECK: unranked memref type: memref<*xf32, 2>
413    print("unranked memref type:", unranked_memref)
414    try:
415      invalid_rank = unranked_memref.rank
416    except ValueError as e:
417      # CHECK: calling this method requires that the type has a rank.
418      print(e)
419    else:
420      print("Exception not produced")
421    try:
422      invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
423    except ValueError as e:
424      # CHECK: calling this method requires that the type has a rank.
425      print(e)
426    else:
427      print("Exception not produced")
428    try:
429      invalid_get_dim_size = unranked_memref.get_dim_size(1)
430    except ValueError as e:
431      # CHECK: calling this method requires that the type has a rank.
432      print(e)
433    else:
434      print("Exception not produced")
435
436    none = NoneType.get()
437    try:
438      memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
439    except ValueError as e:
440      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
441      # CHECK: or complex type.
442      print(e)
443    else:
444      print("Exception not produced")
445
446
447# CHECK-LABEL: TEST: testTupleType
448@run
449def testTupleType():
450  with Context() as ctx:
451    i32 = IntegerType(Type.parse("i32"))
452    f32 = F32Type.get()
453    vector = VectorType(Type.parse("vector<2x3xf32>"))
454    l = [i32, f32, vector]
455    tuple_type = TupleType.get_tuple(l)
456    # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
457    print("tuple type:", tuple_type)
458    # CHECK: number of types: 3
459    print("number of types:", tuple_type.num_types)
460    # CHECK: pos-th type in the tuple type: f32
461    print("pos-th type in the tuple type:", tuple_type.get_type(1))
462
463
464# CHECK-LABEL: TEST: testFunctionType
465@run
466def testFunctionType():
467  with Context() as ctx:
468    input_types = [IntegerType.get_signless(32),
469                  IntegerType.get_signless(16)]
470    result_types = [IndexType.get()]
471    func = FunctionType.get(input_types, result_types)
472    # CHECK: INPUTS: [Type(i32), Type(i16)]
473    print("INPUTS:", func.inputs)
474    # CHECK: RESULTS: [Type(index)]
475    print("RESULTS:", func.results)
476
477
478# CHECK-LABEL: TEST: testOpaqueType
479@run
480def testOpaqueType():
481  with Context() as ctx:
482    ctx.allow_unregistered_dialects = True
483    opaque = OpaqueType.get("dialect", "type")
484    # CHECK: opaque type: !dialect.type
485    print("opaque type:", opaque)
486    # CHECK: dialect namespace: dialect
487    print("dialect namespace:", opaque.dialect_namespace)
488    # CHECK: data: type
489    print("data:", opaque.data)
490