1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14# CHECK-LABEL: TEST: testParsePrint
15@run
16def testParsePrint():
17  with Context() as ctx:
18    t = Attribute.parse('"hello"')
19  assert t.context is ctx
20  ctx = None
21  gc.collect()
22  # CHECK: "hello"
23  print(str(t))
24  # CHECK: Attribute("hello")
25  print(repr(t))
26
27
28# CHECK-LABEL: TEST: testParseError
29# TODO: Hook the diagnostic manager to capture a more meaningful error
30# message.
31@run
32def testParseError():
33  with Context():
34    try:
35      t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
36    except ValueError as e:
37      # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
38      print("testParseError:", e)
39    else:
40      print("Exception not produced")
41
42
43# CHECK-LABEL: TEST: testAttrEq
44@run
45def testAttrEq():
46  with Context():
47    a1 = Attribute.parse('"attr1"')
48    a2 = Attribute.parse('"attr2"')
49    a3 = Attribute.parse('"attr1"')
50    # CHECK: a1 == a1: True
51    print("a1 == a1:", a1 == a1)
52    # CHECK: a1 == a2: False
53    print("a1 == a2:", a1 == a2)
54    # CHECK: a1 == a3: True
55    print("a1 == a3:", a1 == a3)
56    # CHECK: a1 == None: False
57    print("a1 == None:", a1 == None)
58
59
60# CHECK-LABEL: TEST: testAttrHash
61@run
62def testAttrHash():
63  with Context():
64    a1 = Attribute.parse('"attr1"')
65    a2 = Attribute.parse('"attr2"')
66    a3 = Attribute.parse('"attr1"')
67    # CHECK: hash(a1) == hash(a3): True
68    print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
69    # In general, hashes don't have to be unique. In this case, however, the
70    # hash is just the underlying pointer so it will be.
71    # CHECK: hash(a1) == hash(a2): False
72    print("hash(a1) == hash(a2):", a1.__hash__() == a2.__hash__())
73
74    s = set()
75    s.add(a1)
76    s.add(a2)
77    s.add(a3)
78    # CHECK: len(s): 2
79    print("len(s): ", len(s))
80
81
82# CHECK-LABEL: TEST: testAttrCast
83@run
84def testAttrCast():
85  with Context():
86    a1 = Attribute.parse('"attr1"')
87    a2 = Attribute(a1)
88    # CHECK: a1 == a2: True
89    print("a1 == a2:", a1 == a2)
90
91
92# CHECK-LABEL: TEST: testAttrIsInstance
93@run
94def testAttrIsInstance():
95  with Context():
96    a1 = Attribute.parse("42")
97    a2 = Attribute.parse("[42]")
98    assert IntegerAttr.isinstance(a1)
99    assert not IntegerAttr.isinstance(a2)
100    assert not ArrayAttr.isinstance(a1)
101    assert ArrayAttr.isinstance(a2)
102
103
104# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
105@run
106def testAttrEqDoesNotRaise():
107  with Context():
108    a1 = Attribute.parse('"attr1"')
109    not_an_attr = "foo"
110    # CHECK: False
111    print(a1 == not_an_attr)
112    # CHECK: False
113    print(a1 == None)
114    # CHECK: True
115    print(a1 != None)
116
117
118# CHECK-LABEL: TEST: testAttrCapsule
119@run
120def testAttrCapsule():
121  with Context() as ctx:
122    a1 = Attribute.parse('"attr1"')
123  # CHECK: mlir.ir.Attribute._CAPIPtr
124  attr_capsule = a1._CAPIPtr
125  print(attr_capsule)
126  a2 = Attribute._CAPICreate(attr_capsule)
127  assert a2 == a1
128  assert a2.context is ctx
129
130
131# CHECK-LABEL: TEST: testStandardAttrCasts
132@run
133def testStandardAttrCasts():
134  with Context():
135    a1 = Attribute.parse('"attr1"')
136    astr = StringAttr(a1)
137    aself = StringAttr(astr)
138    # CHECK: Attribute("attr1")
139    print(repr(astr))
140    try:
141      tillegal = StringAttr(Attribute.parse("1.0"))
142    except ValueError as e:
143      # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
144      print("ValueError:", e)
145    else:
146      print("Exception not produced")
147
148
149# CHECK-LABEL: TEST: testAffineMapAttr
150@run
151def testAffineMapAttr():
152  with Context() as ctx:
153    d0 = AffineDimExpr.get(0)
154    d1 = AffineDimExpr.get(1)
155    c2 = AffineConstantExpr.get(2)
156    map0 = AffineMap.get(2, 3, [])
157
158    # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
159    attr_built = AffineMapAttr.get(map0)
160    print(str(attr_built))
161
162    attr_parsed = Attribute.parse(str(attr_built))
163    assert attr_built == attr_parsed
164
165
166# CHECK-LABEL: TEST: testFloatAttr
167@run
168def testFloatAttr():
169  with Context(), Location.unknown():
170    fattr = FloatAttr(Attribute.parse("42.0 : f32"))
171    # CHECK: fattr value: 42.0
172    print("fattr value:", fattr.value)
173
174    # Test factory methods.
175    # CHECK: default_get: 4.200000e+01 : f32
176    print("default_get:", FloatAttr.get(
177        F32Type.get(), 42.0))
178    # CHECK: f32_get: 4.200000e+01 : f32
179    print("f32_get:", FloatAttr.get_f32(42.0))
180    # CHECK: f64_get: 4.200000e+01 : f64
181    print("f64_get:", FloatAttr.get_f64(42.0))
182    try:
183      fattr_invalid = FloatAttr.get(
184          IntegerType.get_signless(32), 42)
185    except ValueError as e:
186      # CHECK: invalid 'Type(i32)' and expected floating point type.
187      print(e)
188    else:
189      print("Exception not produced")
190
191
192# CHECK-LABEL: TEST: testIntegerAttr
193@run
194def testIntegerAttr():
195  with Context() as ctx:
196    iattr = IntegerAttr(Attribute.parse("42"))
197    # CHECK: iattr value: 42
198    print("iattr value:", iattr.value)
199    # CHECK: iattr type: i64
200    print("iattr type:", iattr.type)
201
202    # Test factory methods.
203    # CHECK: default_get: 42 : i32
204    print("default_get:", IntegerAttr.get(
205        IntegerType.get_signless(32), 42))
206
207
208# CHECK-LABEL: TEST: testBoolAttr
209@run
210def testBoolAttr():
211  with Context() as ctx:
212    battr = BoolAttr(Attribute.parse("true"))
213    # CHECK: iattr value: True
214    print("iattr value:", battr.value)
215
216    # Test factory methods.
217    # CHECK: default_get: true
218    print("default_get:", BoolAttr.get(True))
219
220
221# CHECK-LABEL: TEST: testFlatSymbolRefAttr
222@run
223def testFlatSymbolRefAttr():
224  with Context() as ctx:
225    sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
226    # CHECK: symattr value: symbol
227    print("symattr value:", sattr.value)
228
229    # Test factory methods.
230    # CHECK: default_get: @foobar
231    print("default_get:", FlatSymbolRefAttr.get("foobar"))
232
233
234# CHECK-LABEL: TEST: testStringAttr
235@run
236def testStringAttr():
237  with Context() as ctx:
238    sattr = StringAttr(Attribute.parse('"stringattr"'))
239    # CHECK: sattr value: stringattr
240    print("sattr value:", sattr.value)
241
242    # Test factory methods.
243    # CHECK: default_get: "foobar"
244    print("default_get:", StringAttr.get("foobar"))
245    # CHECK: typed_get: "12345" : i32
246    print("typed_get:", StringAttr.get_typed(
247        IntegerType.get_signless(32), "12345"))
248
249
250# CHECK-LABEL: TEST: testNamedAttr
251@run
252def testNamedAttr():
253  with Context():
254    a = Attribute.parse('"stringattr"')
255    named = a.get_named("foobar")  # Note: under the small object threshold
256    # CHECK: attr: "stringattr"
257    print("attr:", named.attr)
258    # CHECK: name: foobar
259    print("name:", named.name)
260    # CHECK: named: NamedAttribute(foobar="stringattr")
261    print("named:", named)
262
263
264# CHECK-LABEL: TEST: testDenseIntAttr
265@run
266def testDenseIntAttr():
267  with Context():
268    raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
269    # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
270    print("attr:", raw)
271
272    a = DenseIntElementsAttr(raw)
273    assert len(a) == 6
274
275    # CHECK: 0 1 2 3 4 5
276    for value in a:
277      print(value, end=" ")
278    print()
279
280    # CHECK: i32
281    print(ShapedType(a.type).element_type)
282
283    raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
284    # CHECK: attr: dense<[true, false, true, false]>
285    print("attr:", raw)
286
287    a = DenseIntElementsAttr(raw)
288    assert len(a) == 4
289
290    # CHECK: 1 0 1 0
291    for value in a:
292      print(value, end=" ")
293    print()
294
295    # CHECK: i1
296    print(ShapedType(a.type).element_type)
297
298
299# CHECK-LABEL: TEST: testDenseFPAttr
300@run
301def testDenseFPAttr():
302  with Context():
303    raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
304    # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
305
306    print("attr:", raw)
307
308    a = DenseFPElementsAttr(raw)
309    assert len(a) == 4
310
311    # CHECK: 0.0 1.0 2.0 3.0
312    for value in a:
313      print(value, end=" ")
314    print()
315
316    # CHECK: f32
317    print(ShapedType(a.type).element_type)
318
319
320# CHECK-LABEL: TEST: testDictAttr
321@run
322def testDictAttr():
323  with Context():
324    dict_attr = {
325      'stringattr':  StringAttr.get('string'),
326      'integerattr' : IntegerAttr.get(
327        IntegerType.get_signless(32), 42)
328    }
329
330    a = DictAttr.get(dict_attr)
331
332    # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
333    print("attr:", a)
334
335    assert len(a) == 2
336
337    # CHECK: 42 : i32
338    print(a['integerattr'])
339
340    # CHECK: "string"
341    print(a['stringattr'])
342
343    # Check that exceptions are raised as expected.
344    try:
345      _ = a['does_not_exist']
346    except KeyError:
347      pass
348    else:
349      assert False, "Exception not produced"
350
351    try:
352      _ = a[42]
353    except IndexError:
354      pass
355    else:
356      assert False, "expected IndexError on accessing an out-of-bounds attribute"
357
358    # CHECK "empty: {}"
359    print("empty: ", DictAttr.get())
360
361
362# CHECK-LABEL: TEST: testTypeAttr
363@run
364def testTypeAttr():
365  with Context():
366    raw = Attribute.parse("vector<4xf32>")
367    # CHECK: attr: vector<4xf32>
368    print("attr:", raw)
369    type_attr = TypeAttr(raw)
370    # CHECK: f32
371    print(ShapedType(type_attr.value).element_type)
372
373
374# CHECK-LABEL: TEST: testArrayAttr
375@run
376def testArrayAttr():
377  with Context():
378    raw = Attribute.parse("[42, true, vector<4xf32>]")
379  # CHECK: attr: [42, true, vector<4xf32>]
380  print("raw attr:", raw)
381  # CHECK: - 42
382  # CHECK: - true
383  # CHECK: - vector<4xf32>
384  for attr in ArrayAttr(raw):
385    print("- ", attr)
386
387  with Context():
388    intAttr = Attribute.parse("42")
389    vecAttr = Attribute.parse("vector<4xf32>")
390    boolAttr = BoolAttr.get(True)
391    raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
392  # CHECK: attr: [vector<4xf32>, true, 42]
393  print("raw attr:", raw)
394  # CHECK: - vector<4xf32>
395  # CHECK: - true
396  # CHECK: - 42
397  arr = ArrayAttr(raw)
398  for attr in arr:
399    print("- ", attr)
400  # CHECK: attr[0]: vector<4xf32>
401  print("attr[0]:", arr[0])
402  # CHECK: attr[1]: true
403  print("attr[1]:", arr[1])
404  # CHECK: attr[2]: 42
405  print("attr[2]:", arr[2])
406  try:
407    print("attr[3]:", arr[3])
408  except IndexError as e:
409    # CHECK: Error: ArrayAttribute index out of range
410    print("Error: ", e)
411  with Context():
412    try:
413      ArrayAttr.get([None])
414    except RuntimeError as e:
415      # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
416      print("Error: ", e)
417    try:
418      ArrayAttr.get([42])
419    except RuntimeError as e:
420      # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
421      print("Error: ", e)
422
423  with Context():
424    array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
425    array = array + [StringAttr.get("c")]
426    # CHECK: concat: ["a", "b", "c"]
427    print("concat: ", array)
428