1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14# CHECK-LABEL: TEST: testParsePrint
15@run
16def testParsePrint():
17  with Context() as ctx:
18    t = Attribute.parse('"hello"')
19  assert t.context is ctx
20  ctx = None
21  gc.collect()
22  # CHECK: "hello"
23  print(str(t))
24  # CHECK: Attribute("hello")
25  print(repr(t))
26
27
28# CHECK-LABEL: TEST: testParseError
29# TODO: Hook the diagnostic manager to capture a more meaningful error
30# message.
31@run
32def testParseError():
33  with Context():
34    try:
35      t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
36    except ValueError as e:
37      # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
38      print("testParseError:", e)
39    else:
40      print("Exception not produced")
41
42
43# CHECK-LABEL: TEST: testAttrEq
44@run
45def testAttrEq():
46  with Context():
47    a1 = Attribute.parse('"attr1"')
48    a2 = Attribute.parse('"attr2"')
49    a3 = Attribute.parse('"attr1"')
50    # CHECK: a1 == a1: True
51    print("a1 == a1:", a1 == a1)
52    # CHECK: a1 == a2: False
53    print("a1 == a2:", a1 == a2)
54    # CHECK: a1 == a3: True
55    print("a1 == a3:", a1 == a3)
56    # CHECK: a1 == None: False
57    print("a1 == None:", a1 == None)
58
59
60# CHECK-LABEL: TEST: testAttrHash
61@run
62def testAttrHash():
63  with Context():
64    a1 = Attribute.parse('"attr1"')
65    a2 = Attribute.parse('"attr2"')
66    a3 = Attribute.parse('"attr1"')
67    # CHECK: hash(a1) == hash(a3): True
68    print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
69    # In general, hashes don't have to be unique. In this case, however, the
70    # hash is just the underlying pointer so it will be.
71    # CHECK: hash(a1) == hash(a2): False
72    print("hash(a1) == hash(a2):", a1.__hash__() == a2.__hash__())
73
74    s = set()
75    s.add(a1)
76    s.add(a2)
77    s.add(a3)
78    # CHECK: len(s): 2
79    print("len(s): ", len(s))
80
81
82# CHECK-LABEL: TEST: testAttrCast
83@run
84def testAttrCast():
85  with Context():
86    a1 = Attribute.parse('"attr1"')
87    a2 = Attribute(a1)
88    # CHECK: a1 == a2: True
89    print("a1 == a2:", a1 == a2)
90
91
92# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
93@run
94def testAttrEqDoesNotRaise():
95  with Context():
96    a1 = Attribute.parse('"attr1"')
97    not_an_attr = "foo"
98    # CHECK: False
99    print(a1 == not_an_attr)
100    # CHECK: False
101    print(a1 == None)
102    # CHECK: True
103    print(a1 != None)
104
105
106# CHECK-LABEL: TEST: testAttrCapsule
107@run
108def testAttrCapsule():
109  with Context() as ctx:
110    a1 = Attribute.parse('"attr1"')
111  # CHECK: mlir.ir.Attribute._CAPIPtr
112  attr_capsule = a1._CAPIPtr
113  print(attr_capsule)
114  a2 = Attribute._CAPICreate(attr_capsule)
115  assert a2 == a1
116  assert a2.context is ctx
117
118
119# CHECK-LABEL: TEST: testStandardAttrCasts
120@run
121def testStandardAttrCasts():
122  with Context():
123    a1 = Attribute.parse('"attr1"')
124    astr = StringAttr(a1)
125    aself = StringAttr(astr)
126    # CHECK: Attribute("attr1")
127    print(repr(astr))
128    try:
129      tillegal = StringAttr(Attribute.parse("1.0"))
130    except ValueError as e:
131      # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
132      print("ValueError:", e)
133    else:
134      print("Exception not produced")
135
136
137# CHECK-LABEL: TEST: testAffineMapAttr
138@run
139def testAffineMapAttr():
140  with Context() as ctx:
141    d0 = AffineDimExpr.get(0)
142    d1 = AffineDimExpr.get(1)
143    c2 = AffineConstantExpr.get(2)
144    map0 = AffineMap.get(2, 3, [])
145
146    # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
147    attr_built = AffineMapAttr.get(map0)
148    print(str(attr_built))
149
150    attr_parsed = Attribute.parse(str(attr_built))
151    assert attr_built == attr_parsed
152
153
154# CHECK-LABEL: TEST: testFloatAttr
155@run
156def testFloatAttr():
157  with Context(), Location.unknown():
158    fattr = FloatAttr(Attribute.parse("42.0 : f32"))
159    # CHECK: fattr value: 42.0
160    print("fattr value:", fattr.value)
161
162    # Test factory methods.
163    # CHECK: default_get: 4.200000e+01 : f32
164    print("default_get:", FloatAttr.get(
165        F32Type.get(), 42.0))
166    # CHECK: f32_get: 4.200000e+01 : f32
167    print("f32_get:", FloatAttr.get_f32(42.0))
168    # CHECK: f64_get: 4.200000e+01 : f64
169    print("f64_get:", FloatAttr.get_f64(42.0))
170    try:
171      fattr_invalid = FloatAttr.get(
172          IntegerType.get_signless(32), 42)
173    except ValueError as e:
174      # CHECK: invalid 'Type(i32)' and expected floating point type.
175      print(e)
176    else:
177      print("Exception not produced")
178
179
180# CHECK-LABEL: TEST: testIntegerAttr
181@run
182def testIntegerAttr():
183  with Context() as ctx:
184    iattr = IntegerAttr(Attribute.parse("42"))
185    # CHECK: iattr value: 42
186    print("iattr value:", iattr.value)
187    # CHECK: iattr type: i64
188    print("iattr type:", iattr.type)
189
190    # Test factory methods.
191    # CHECK: default_get: 42 : i32
192    print("default_get:", IntegerAttr.get(
193        IntegerType.get_signless(32), 42))
194
195
196# CHECK-LABEL: TEST: testBoolAttr
197@run
198def testBoolAttr():
199  with Context() as ctx:
200    battr = BoolAttr(Attribute.parse("true"))
201    # CHECK: iattr value: True
202    print("iattr value:", battr.value)
203
204    # Test factory methods.
205    # CHECK: default_get: true
206    print("default_get:", BoolAttr.get(True))
207
208
209# CHECK-LABEL: TEST: testFlatSymbolRefAttr
210@run
211def testFlatSymbolRefAttr():
212  with Context() as ctx:
213    sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
214    # CHECK: symattr value: symbol
215    print("symattr value:", sattr.value)
216
217    # Test factory methods.
218    # CHECK: default_get: @foobar
219    print("default_get:", FlatSymbolRefAttr.get("foobar"))
220
221
222# CHECK-LABEL: TEST: testStringAttr
223@run
224def testStringAttr():
225  with Context() as ctx:
226    sattr = StringAttr(Attribute.parse('"stringattr"'))
227    # CHECK: sattr value: stringattr
228    print("sattr value:", sattr.value)
229
230    # Test factory methods.
231    # CHECK: default_get: "foobar"
232    print("default_get:", StringAttr.get("foobar"))
233    # CHECK: typed_get: "12345" : i32
234    print("typed_get:", StringAttr.get_typed(
235        IntegerType.get_signless(32), "12345"))
236
237
238# CHECK-LABEL: TEST: testNamedAttr
239@run
240def testNamedAttr():
241  with Context():
242    a = Attribute.parse('"stringattr"')
243    named = a.get_named("foobar")  # Note: under the small object threshold
244    # CHECK: attr: "stringattr"
245    print("attr:", named.attr)
246    # CHECK: name: foobar
247    print("name:", named.name)
248    # CHECK: named: NamedAttribute(foobar="stringattr")
249    print("named:", named)
250
251
252# CHECK-LABEL: TEST: testDenseIntAttr
253@run
254def testDenseIntAttr():
255  with Context():
256    raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
257    # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
258    print("attr:", raw)
259
260    a = DenseIntElementsAttr(raw)
261    assert len(a) == 6
262
263    # CHECK: 0 1 2 3 4 5
264    for value in a:
265      print(value, end=" ")
266    print()
267
268    # CHECK: i32
269    print(ShapedType(a.type).element_type)
270
271    raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
272    # CHECK: attr: dense<[true, false, true, false]>
273    print("attr:", raw)
274
275    a = DenseIntElementsAttr(raw)
276    assert len(a) == 4
277
278    # CHECK: 1 0 1 0
279    for value in a:
280      print(value, end=" ")
281    print()
282
283    # CHECK: i1
284    print(ShapedType(a.type).element_type)
285
286
287# CHECK-LABEL: TEST: testDenseFPAttr
288@run
289def testDenseFPAttr():
290  with Context():
291    raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
292    # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
293
294    print("attr:", raw)
295
296    a = DenseFPElementsAttr(raw)
297    assert len(a) == 4
298
299    # CHECK: 0.0 1.0 2.0 3.0
300    for value in a:
301      print(value, end=" ")
302    print()
303
304    # CHECK: f32
305    print(ShapedType(a.type).element_type)
306
307
308# CHECK-LABEL: TEST: testDictAttr
309@run
310def testDictAttr():
311  with Context():
312    dict_attr = {
313      'stringattr':  StringAttr.get('string'),
314      'integerattr' : IntegerAttr.get(
315        IntegerType.get_signless(32), 42)
316    }
317
318    a = DictAttr.get(dict_attr)
319
320    # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
321    print("attr:", a)
322
323    assert len(a) == 2
324
325    # CHECK: 42 : i32
326    print(a['integerattr'])
327
328    # CHECK: "string"
329    print(a['stringattr'])
330
331    # Check that exceptions are raised as expected.
332    try:
333      _ = a['does_not_exist']
334    except KeyError:
335      pass
336    else:
337      assert False, "Exception not produced"
338
339    try:
340      _ = a[42]
341    except IndexError:
342      pass
343    else:
344      assert False, "expected IndexError on accessing an out-of-bounds attribute"
345
346    # CHECK "empty: {}"
347    print("empty: ", DictAttr.get())
348
349
350# CHECK-LABEL: TEST: testTypeAttr
351@run
352def testTypeAttr():
353  with Context():
354    raw = Attribute.parse("vector<4xf32>")
355    # CHECK: attr: vector<4xf32>
356    print("attr:", raw)
357    type_attr = TypeAttr(raw)
358    # CHECK: f32
359    print(ShapedType(type_attr.value).element_type)
360
361
362# CHECK-LABEL: TEST: testArrayAttr
363@run
364def testArrayAttr():
365  with Context():
366    raw = Attribute.parse("[42, true, vector<4xf32>]")
367  # CHECK: attr: [42, true, vector<4xf32>]
368  print("raw attr:", raw)
369  # CHECK: - 42
370  # CHECK: - true
371  # CHECK: - vector<4xf32>
372  for attr in ArrayAttr(raw):
373    print("- ", attr)
374
375  with Context():
376    intAttr = Attribute.parse("42")
377    vecAttr = Attribute.parse("vector<4xf32>")
378    boolAttr = BoolAttr.get(True)
379    raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
380  # CHECK: attr: [vector<4xf32>, true, 42]
381  print("raw attr:", raw)
382  # CHECK: - vector<4xf32>
383  # CHECK: - true
384  # CHECK: - 42
385  arr = ArrayAttr(raw)
386  for attr in arr:
387    print("- ", attr)
388  # CHECK: attr[0]: vector<4xf32>
389  print("attr[0]:", arr[0])
390  # CHECK: attr[1]: true
391  print("attr[1]:", arr[1])
392  # CHECK: attr[2]: 42
393  print("attr[2]:", arr[2])
394  try:
395    print("attr[3]:", arr[3])
396  except IndexError as e:
397    # CHECK: Error: ArrayAttribute index out of range
398    print("Error: ", e)
399  with Context():
400    try:
401      ArrayAttr.get([None])
402    except RuntimeError as e:
403      # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
404      print("Error: ", e)
405    try:
406      ArrayAttr.get([42])
407    except RuntimeError as e:
408      # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
409      print("Error: ", e)
410
411  with Context():
412    array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
413    array = array + [StringAttr.get("c")]
414    # CHECK: concat: ["a", "b", "c"]
415    print("concat: ", array)
416