1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14# CHECK-LABEL: TEST: testParsePrint
15@run
16def testParsePrint():
17  with Context() as ctx:
18    t = Attribute.parse('"hello"')
19  assert t.context is ctx
20  ctx = None
21  gc.collect()
22  # CHECK: "hello"
23  print(str(t))
24  # CHECK: Attribute("hello")
25  print(repr(t))
26
27
28# CHECK-LABEL: TEST: testParseError
29# TODO: Hook the diagnostic manager to capture a more meaningful error
30# message.
31@run
32def testParseError():
33  with Context():
34    try:
35      t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
36    except ValueError as e:
37      # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
38      print("testParseError:", e)
39    else:
40      print("Exception not produced")
41
42
43# CHECK-LABEL: TEST: testAttrEq
44@run
45def testAttrEq():
46  with Context():
47    a1 = Attribute.parse('"attr1"')
48    a2 = Attribute.parse('"attr2"')
49    a3 = Attribute.parse('"attr1"')
50    # CHECK: a1 == a1: True
51    print("a1 == a1:", a1 == a1)
52    # CHECK: a1 == a2: False
53    print("a1 == a2:", a1 == a2)
54    # CHECK: a1 == a3: True
55    print("a1 == a3:", a1 == a3)
56    # CHECK: a1 == None: False
57    print("a1 == None:", a1 == None)
58
59
60# CHECK-LABEL: TEST: testAttrHash
61@run
62def testAttrHash():
63  with Context():
64    a1 = Attribute.parse('"attr1"')
65    a2 = Attribute.parse('"attr2"')
66    a3 = Attribute.parse('"attr1"')
67    # CHECK: hash(a1) == hash(a3): True
68    print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
69
70    s = set()
71    s.add(a1)
72    s.add(a2)
73    s.add(a3)
74    # CHECK: len(s): 2
75    print("len(s): ", len(s))
76
77
78# CHECK-LABEL: TEST: testAttrCast
79@run
80def testAttrCast():
81  with Context():
82    a1 = Attribute.parse('"attr1"')
83    a2 = Attribute(a1)
84    # CHECK: a1 == a2: True
85    print("a1 == a2:", a1 == a2)
86
87
88# CHECK-LABEL: TEST: testAttrIsInstance
89@run
90def testAttrIsInstance():
91  with Context():
92    a1 = Attribute.parse("42")
93    a2 = Attribute.parse("[42]")
94    assert IntegerAttr.isinstance(a1)
95    assert not IntegerAttr.isinstance(a2)
96    assert not ArrayAttr.isinstance(a1)
97    assert ArrayAttr.isinstance(a2)
98
99
100# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
101@run
102def testAttrEqDoesNotRaise():
103  with Context():
104    a1 = Attribute.parse('"attr1"')
105    not_an_attr = "foo"
106    # CHECK: False
107    print(a1 == not_an_attr)
108    # CHECK: False
109    print(a1 == None)
110    # CHECK: True
111    print(a1 != None)
112
113
114# CHECK-LABEL: TEST: testAttrCapsule
115@run
116def testAttrCapsule():
117  with Context() as ctx:
118    a1 = Attribute.parse('"attr1"')
119  # CHECK: mlir.ir.Attribute._CAPIPtr
120  attr_capsule = a1._CAPIPtr
121  print(attr_capsule)
122  a2 = Attribute._CAPICreate(attr_capsule)
123  assert a2 == a1
124  assert a2.context is ctx
125
126
127# CHECK-LABEL: TEST: testStandardAttrCasts
128@run
129def testStandardAttrCasts():
130  with Context():
131    a1 = Attribute.parse('"attr1"')
132    astr = StringAttr(a1)
133    aself = StringAttr(astr)
134    # CHECK: Attribute("attr1")
135    print(repr(astr))
136    try:
137      tillegal = StringAttr(Attribute.parse("1.0"))
138    except ValueError as e:
139      # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
140      print("ValueError:", e)
141    else:
142      print("Exception not produced")
143
144
145# CHECK-LABEL: TEST: testAffineMapAttr
146@run
147def testAffineMapAttr():
148  with Context() as ctx:
149    d0 = AffineDimExpr.get(0)
150    d1 = AffineDimExpr.get(1)
151    c2 = AffineConstantExpr.get(2)
152    map0 = AffineMap.get(2, 3, [])
153
154    # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
155    attr_built = AffineMapAttr.get(map0)
156    print(str(attr_built))
157
158    attr_parsed = Attribute.parse(str(attr_built))
159    assert attr_built == attr_parsed
160
161
162# CHECK-LABEL: TEST: testFloatAttr
163@run
164def testFloatAttr():
165  with Context(), Location.unknown():
166    fattr = FloatAttr(Attribute.parse("42.0 : f32"))
167    # CHECK: fattr value: 42.0
168    print("fattr value:", fattr.value)
169
170    # Test factory methods.
171    # CHECK: default_get: 4.200000e+01 : f32
172    print("default_get:", FloatAttr.get(
173        F32Type.get(), 42.0))
174    # CHECK: f32_get: 4.200000e+01 : f32
175    print("f32_get:", FloatAttr.get_f32(42.0))
176    # CHECK: f64_get: 4.200000e+01 : f64
177    print("f64_get:", FloatAttr.get_f64(42.0))
178    try:
179      fattr_invalid = FloatAttr.get(
180          IntegerType.get_signless(32), 42)
181    except ValueError as e:
182      # CHECK: invalid 'Type(i32)' and expected floating point type.
183      print(e)
184    else:
185      print("Exception not produced")
186
187
188# CHECK-LABEL: TEST: testIntegerAttr
189@run
190def testIntegerAttr():
191  with Context() as ctx:
192    i_attr = IntegerAttr(Attribute.parse("42"))
193    # CHECK: i_attr value: 42
194    print("i_attr value:", i_attr.value)
195    # CHECK: i_attr type: i64
196    print("i_attr type:", i_attr.type)
197    si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
198    # CHECK: si_attr value: -1
199    print("si_attr value:", si_attr.value)
200    ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
201    # CHECK: ui_attr value: 255
202    print("ui_attr value:", ui_attr.value)
203    idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
204    # CHECK: idx_attr value: -1
205    print("idx_attr value:", idx_attr.value)
206
207    # Test factory methods.
208    # CHECK: default_get: 42 : i32
209    print("default_get:", IntegerAttr.get(
210        IntegerType.get_signless(32), 42))
211
212
213# CHECK-LABEL: TEST: testBoolAttr
214@run
215def testBoolAttr():
216  with Context() as ctx:
217    battr = BoolAttr(Attribute.parse("true"))
218    # CHECK: iattr value: True
219    print("iattr value:", battr.value)
220
221    # Test factory methods.
222    # CHECK: default_get: true
223    print("default_get:", BoolAttr.get(True))
224
225
226# CHECK-LABEL: TEST: testFlatSymbolRefAttr
227@run
228def testFlatSymbolRefAttr():
229  with Context() as ctx:
230    sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
231    # CHECK: symattr value: symbol
232    print("symattr value:", sattr.value)
233
234    # Test factory methods.
235    # CHECK: default_get: @foobar
236    print("default_get:", FlatSymbolRefAttr.get("foobar"))
237
238
239# CHECK-LABEL: TEST: testOpaqueAttr
240@run
241def testOpaqueAttr():
242  with Context() as ctx:
243    ctx.allow_unregistered_dialects = True
244    oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>"))
245    # CHECK: oattr value: pytest_dummy
246    print("oattr value:", oattr.dialect_namespace)
247    # CHECK: oattr value: dummyattr<>
248    print("oattr value:", oattr.data)
249
250    # Test factory methods.
251    # CHECK: default_get: #foobar<123>
252    print(
253        "default_get:",
254        OpaqueAttr.get("foobar", bytes("123", "utf-8"), NoneType.get()))
255
256
257# CHECK-LABEL: TEST: testStringAttr
258@run
259def testStringAttr():
260  with Context() as ctx:
261    sattr = StringAttr(Attribute.parse('"stringattr"'))
262    # CHECK: sattr value: stringattr
263    print("sattr value:", sattr.value)
264
265    # Test factory methods.
266    # CHECK: default_get: "foobar"
267    print("default_get:", StringAttr.get("foobar"))
268    # CHECK: typed_get: "12345" : i32
269    print("typed_get:", StringAttr.get_typed(
270        IntegerType.get_signless(32), "12345"))
271
272
273# CHECK-LABEL: TEST: testNamedAttr
274@run
275def testNamedAttr():
276  with Context():
277    a = Attribute.parse('"stringattr"')
278    named = a.get_named("foobar")  # Note: under the small object threshold
279    # CHECK: attr: "stringattr"
280    print("attr:", named.attr)
281    # CHECK: name: foobar
282    print("name:", named.name)
283    # CHECK: named: NamedAttribute(foobar="stringattr")
284    print("named:", named)
285
286
287# CHECK-LABEL: TEST: testDenseIntAttr
288@run
289def testDenseIntAttr():
290  with Context():
291    raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
292    # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
293    print("attr:", raw)
294
295    a = DenseIntElementsAttr(raw)
296    assert len(a) == 6
297
298    # CHECK: 0 1 2 3 4 5
299    for value in a:
300      print(value, end=" ")
301    print()
302
303    # CHECK: i32
304    print(ShapedType(a.type).element_type)
305
306    raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
307    # CHECK: attr: dense<[true, false, true, false]>
308    print("attr:", raw)
309
310    a = DenseIntElementsAttr(raw)
311    assert len(a) == 4
312
313    # CHECK: 1 0 1 0
314    for value in a:
315      print(value, end=" ")
316    print()
317
318    # CHECK: i1
319    print(ShapedType(a.type).element_type)
320
321
322# CHECK-LABEL: TEST: testDenseIntAttrGetItem
323@run
324def testDenseIntAttrGetItem():
325  def print_item(attr_asm):
326    attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
327    dtype = ShapedType(attr.type).element_type
328    try:
329      item = attr[0]
330      print(f"{dtype}:", item)
331    except TypeError as e:
332      print(f"{dtype}:", e)
333
334  with Context():
335    # CHECK: i1: 1
336    print_item("dense<true> : tensor<i1>")
337    # CHECK: i8: 123
338    print_item("dense<123> : tensor<i8>")
339    # CHECK: i16: 123
340    print_item("dense<123> : tensor<i16>")
341    # CHECK: i32: 123
342    print_item("dense<123> : tensor<i32>")
343    # CHECK: i64: 123
344    print_item("dense<123> : tensor<i64>")
345    # CHECK: ui8: 123
346    print_item("dense<123> : tensor<ui8>")
347    # CHECK: ui16: 123
348    print_item("dense<123> : tensor<ui16>")
349    # CHECK: ui32: 123
350    print_item("dense<123> : tensor<ui32>")
351    # CHECK: ui64: 123
352    print_item("dense<123> : tensor<ui64>")
353    # CHECK: si8: -123
354    print_item("dense<-123> : tensor<si8>")
355    # CHECK: si16: -123
356    print_item("dense<-123> : tensor<si16>")
357    # CHECK: si32: -123
358    print_item("dense<-123> : tensor<si32>")
359    # CHECK: si64: -123
360    print_item("dense<-123> : tensor<si64>")
361
362    # CHECK: i7: Unsupported integer type
363    print_item("dense<123> : tensor<i7>")
364
365
366# CHECK-LABEL: TEST: testDenseFPAttr
367@run
368def testDenseFPAttr():
369  with Context():
370    raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
371    # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
372
373    print("attr:", raw)
374
375    a = DenseFPElementsAttr(raw)
376    assert len(a) == 4
377
378    # CHECK: 0.0 1.0 2.0 3.0
379    for value in a:
380      print(value, end=" ")
381    print()
382
383    # CHECK: f32
384    print(ShapedType(a.type).element_type)
385
386
387# CHECK-LABEL: TEST: testDictAttr
388@run
389def testDictAttr():
390  with Context():
391    dict_attr = {
392      'stringattr':  StringAttr.get('string'),
393      'integerattr' : IntegerAttr.get(
394        IntegerType.get_signless(32), 42)
395    }
396
397    a = DictAttr.get(dict_attr)
398
399    # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
400    print("attr:", a)
401
402    assert len(a) == 2
403
404    # CHECK: 42 : i32
405    print(a['integerattr'])
406
407    # CHECK: "string"
408    print(a['stringattr'])
409
410    # CHECK: True
411    print('stringattr' in a)
412
413    # CHECK: False
414    print('not_in_dict' in a)
415
416    # Check that exceptions are raised as expected.
417    try:
418      _ = a['does_not_exist']
419    except KeyError:
420      pass
421    else:
422      assert False, "Exception not produced"
423
424    try:
425      _ = a[42]
426    except IndexError:
427      pass
428    else:
429      assert False, "expected IndexError on accessing an out-of-bounds attribute"
430
431    # CHECK "empty: {}"
432    print("empty: ", DictAttr.get())
433
434
435# CHECK-LABEL: TEST: testTypeAttr
436@run
437def testTypeAttr():
438  with Context():
439    raw = Attribute.parse("vector<4xf32>")
440    # CHECK: attr: vector<4xf32>
441    print("attr:", raw)
442    type_attr = TypeAttr(raw)
443    # CHECK: f32
444    print(ShapedType(type_attr.value).element_type)
445
446
447# CHECK-LABEL: TEST: testArrayAttr
448@run
449def testArrayAttr():
450  with Context():
451    raw = Attribute.parse("[42, true, vector<4xf32>]")
452  # CHECK: attr: [42, true, vector<4xf32>]
453  print("raw attr:", raw)
454  # CHECK: - 42
455  # CHECK: - true
456  # CHECK: - vector<4xf32>
457  for attr in ArrayAttr(raw):
458    print("- ", attr)
459
460  with Context():
461    intAttr = Attribute.parse("42")
462    vecAttr = Attribute.parse("vector<4xf32>")
463    boolAttr = BoolAttr.get(True)
464    raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
465  # CHECK: attr: [vector<4xf32>, true, 42]
466  print("raw attr:", raw)
467  # CHECK: - vector<4xf32>
468  # CHECK: - true
469  # CHECK: - 42
470  arr = ArrayAttr(raw)
471  for attr in arr:
472    print("- ", attr)
473  # CHECK: attr[0]: vector<4xf32>
474  print("attr[0]:", arr[0])
475  # CHECK: attr[1]: true
476  print("attr[1]:", arr[1])
477  # CHECK: attr[2]: 42
478  print("attr[2]:", arr[2])
479  try:
480    print("attr[3]:", arr[3])
481  except IndexError as e:
482    # CHECK: Error: ArrayAttribute index out of range
483    print("Error: ", e)
484  with Context():
485    try:
486      ArrayAttr.get([None])
487    except RuntimeError as e:
488      # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
489      print("Error: ", e)
490    try:
491      ArrayAttr.get([42])
492    except RuntimeError as e:
493      # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
494      print("Error: ", e)
495
496  with Context():
497    array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
498    array = array + [StringAttr.get("c")]
499    # CHECK: concat: ["a", "b", "c"]
500    print("concat: ", array)
501