1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14# CHECK-LABEL: TEST: testParsePrint
15@run
16def testParsePrint():
17  with Context() as ctx:
18    t = Attribute.parse('"hello"')
19  assert t.context is ctx
20  ctx = None
21  gc.collect()
22  # CHECK: "hello"
23  print(str(t))
24  # CHECK: Attribute("hello")
25  print(repr(t))
26
27
28# CHECK-LABEL: TEST: testParseError
29# TODO: Hook the diagnostic manager to capture a more meaningful error
30# message.
31@run
32def testParseError():
33  with Context():
34    try:
35      t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
36    except ValueError as e:
37      # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
38      print("testParseError:", e)
39    else:
40      print("Exception not produced")
41
42
43# CHECK-LABEL: TEST: testAttrEq
44@run
45def testAttrEq():
46  with Context():
47    a1 = Attribute.parse('"attr1"')
48    a2 = Attribute.parse('"attr2"')
49    a3 = Attribute.parse('"attr1"')
50    # CHECK: a1 == a1: True
51    print("a1 == a1:", a1 == a1)
52    # CHECK: a1 == a2: False
53    print("a1 == a2:", a1 == a2)
54    # CHECK: a1 == a3: True
55    print("a1 == a3:", a1 == a3)
56    # CHECK: a1 == None: False
57    print("a1 == None:", a1 == None)
58
59
60# CHECK-LABEL: TEST: testAttrCast
61@run
62def testAttrCast():
63  with Context():
64    a1 = Attribute.parse('"attr1"')
65    a2 = Attribute(a1)
66    # CHECK: a1 == a2: True
67    print("a1 == a2:", a1 == a2)
68
69
70# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
71@run
72def testAttrEqDoesNotRaise():
73  with Context():
74    a1 = Attribute.parse('"attr1"')
75    not_an_attr = "foo"
76    # CHECK: False
77    print(a1 == not_an_attr)
78    # CHECK: False
79    print(a1 == None)
80    # CHECK: True
81    print(a1 != None)
82
83
84# CHECK-LABEL: TEST: testAttrCapsule
85@run
86def testAttrCapsule():
87  with Context() as ctx:
88    a1 = Attribute.parse('"attr1"')
89  # CHECK: mlir.ir.Attribute._CAPIPtr
90  attr_capsule = a1._CAPIPtr
91  print(attr_capsule)
92  a2 = Attribute._CAPICreate(attr_capsule)
93  assert a2 == a1
94  assert a2.context is ctx
95
96
97# CHECK-LABEL: TEST: testStandardAttrCasts
98@run
99def testStandardAttrCasts():
100  with Context():
101    a1 = Attribute.parse('"attr1"')
102    astr = StringAttr(a1)
103    aself = StringAttr(astr)
104    # CHECK: Attribute("attr1")
105    print(repr(astr))
106    try:
107      tillegal = StringAttr(Attribute.parse("1.0"))
108    except ValueError as e:
109      # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
110      print("ValueError:", e)
111    else:
112      print("Exception not produced")
113
114
115# CHECK-LABEL: TEST: testAffineMapAttr
116@run
117def testAffineMapAttr():
118  with Context() as ctx:
119    d0 = AffineDimExpr.get(0)
120    d1 = AffineDimExpr.get(1)
121    c2 = AffineConstantExpr.get(2)
122    map0 = AffineMap.get(2, 3, [])
123
124    # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
125    attr_built = AffineMapAttr.get(map0)
126    print(str(attr_built))
127
128    attr_parsed = Attribute.parse(str(attr_built))
129    assert attr_built == attr_parsed
130
131
132# CHECK-LABEL: TEST: testFloatAttr
133@run
134def testFloatAttr():
135  with Context(), Location.unknown():
136    fattr = FloatAttr(Attribute.parse("42.0 : f32"))
137    # CHECK: fattr value: 42.0
138    print("fattr value:", fattr.value)
139
140    # Test factory methods.
141    # CHECK: default_get: 4.200000e+01 : f32
142    print("default_get:", FloatAttr.get(
143        F32Type.get(), 42.0))
144    # CHECK: f32_get: 4.200000e+01 : f32
145    print("f32_get:", FloatAttr.get_f32(42.0))
146    # CHECK: f64_get: 4.200000e+01 : f64
147    print("f64_get:", FloatAttr.get_f64(42.0))
148    try:
149      fattr_invalid = FloatAttr.get(
150          IntegerType.get_signless(32), 42)
151    except ValueError as e:
152      # CHECK: invalid 'Type(i32)' and expected floating point type.
153      print(e)
154    else:
155      print("Exception not produced")
156
157
158# CHECK-LABEL: TEST: testIntegerAttr
159@run
160def testIntegerAttr():
161  with Context() as ctx:
162    iattr = IntegerAttr(Attribute.parse("42"))
163    # CHECK: iattr value: 42
164    print("iattr value:", iattr.value)
165    # CHECK: iattr type: i64
166    print("iattr type:", iattr.type)
167
168    # Test factory methods.
169    # CHECK: default_get: 42 : i32
170    print("default_get:", IntegerAttr.get(
171        IntegerType.get_signless(32), 42))
172
173
174# CHECK-LABEL: TEST: testBoolAttr
175@run
176def testBoolAttr():
177  with Context() as ctx:
178    battr = BoolAttr(Attribute.parse("true"))
179    # CHECK: iattr value: True
180    print("iattr value:", battr.value)
181
182    # Test factory methods.
183    # CHECK: default_get: true
184    print("default_get:", BoolAttr.get(True))
185
186
187# CHECK-LABEL: TEST: testFlatSymbolRefAttr
188@run
189def testFlatSymbolRefAttr():
190  with Context() as ctx:
191    sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
192    # CHECK: symattr value: symbol
193    print("symattr value:", sattr.value)
194
195    # Test factory methods.
196    # CHECK: default_get: @foobar
197    print("default_get:", FlatSymbolRefAttr.get("foobar"))
198
199
200# CHECK-LABEL: TEST: testStringAttr
201@run
202def testStringAttr():
203  with Context() as ctx:
204    sattr = StringAttr(Attribute.parse('"stringattr"'))
205    # CHECK: sattr value: stringattr
206    print("sattr value:", sattr.value)
207
208    # Test factory methods.
209    # CHECK: default_get: "foobar"
210    print("default_get:", StringAttr.get("foobar"))
211    # CHECK: typed_get: "12345" : i32
212    print("typed_get:", StringAttr.get_typed(
213        IntegerType.get_signless(32), "12345"))
214
215
216# CHECK-LABEL: TEST: testNamedAttr
217@run
218def testNamedAttr():
219  with Context():
220    a = Attribute.parse('"stringattr"')
221    named = a.get_named("foobar")  # Note: under the small object threshold
222    # CHECK: attr: "stringattr"
223    print("attr:", named.attr)
224    # CHECK: name: foobar
225    print("name:", named.name)
226    # CHECK: named: NamedAttribute(foobar="stringattr")
227    print("named:", named)
228
229
230# CHECK-LABEL: TEST: testDenseIntAttr
231@run
232def testDenseIntAttr():
233  with Context():
234    raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
235    # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
236    print("attr:", raw)
237
238    a = DenseIntElementsAttr(raw)
239    assert len(a) == 6
240
241    # CHECK: 0 1 2 3 4 5
242    for value in a:
243      print(value, end=" ")
244    print()
245
246    # CHECK: i32
247    print(ShapedType(a.type).element_type)
248
249    raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
250    # CHECK: attr: dense<[true, false, true, false]>
251    print("attr:", raw)
252
253    a = DenseIntElementsAttr(raw)
254    assert len(a) == 4
255
256    # CHECK: 1 0 1 0
257    for value in a:
258      print(value, end=" ")
259    print()
260
261    # CHECK: i1
262    print(ShapedType(a.type).element_type)
263
264
265# CHECK-LABEL: TEST: testDenseFPAttr
266@run
267def testDenseFPAttr():
268  with Context():
269    raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
270    # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
271
272    print("attr:", raw)
273
274    a = DenseFPElementsAttr(raw)
275    assert len(a) == 4
276
277    # CHECK: 0.0 1.0 2.0 3.0
278    for value in a:
279      print(value, end=" ")
280    print()
281
282    # CHECK: f32
283    print(ShapedType(a.type).element_type)
284
285
286# CHECK-LABEL: TEST: testDictAttr
287@run
288def testDictAttr():
289  with Context():
290    dict_attr = {
291      'stringattr':  StringAttr.get('string'),
292      'integerattr' : IntegerAttr.get(
293        IntegerType.get_signless(32), 42)
294    }
295
296    a = DictAttr.get(dict_attr)
297
298    # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
299    print("attr:", a)
300
301    assert len(a) == 2
302
303    # CHECK: 42 : i32
304    print(a['integerattr'])
305
306    # CHECK: "string"
307    print(a['stringattr'])
308
309    # Check that exceptions are raised as expected.
310    try:
311      _ = a['does_not_exist']
312    except KeyError:
313      pass
314    else:
315      assert False, "Exception not produced"
316
317    try:
318      _ = a[42]
319    except IndexError:
320      pass
321    else:
322      assert False, "expected IndexError on accessing an out-of-bounds attribute"
323
324
325# CHECK-LABEL: TEST: testTypeAttr
326@run
327def testTypeAttr():
328  with Context():
329    raw = Attribute.parse("vector<4xf32>")
330    # CHECK: attr: vector<4xf32>
331    print("attr:", raw)
332    type_attr = TypeAttr(raw)
333    # CHECK: f32
334    print(ShapedType(type_attr.value).element_type)
335
336
337# CHECK-LABEL: TEST: testArrayAttr
338@run
339def testArrayAttr():
340  with Context():
341    raw = Attribute.parse("[42, true, vector<4xf32>]")
342  # CHECK: attr: [42, true, vector<4xf32>]
343  print("raw attr:", raw)
344  # CHECK: - 42
345  # CHECK: - true
346  # CHECK: - vector<4xf32>
347  for attr in ArrayAttr(raw):
348    print("- ", attr)
349
350  with Context():
351    intAttr = Attribute.parse("42")
352    vecAttr = Attribute.parse("vector<4xf32>")
353    boolAttr = BoolAttr.get(True)
354    raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
355  # CHECK: attr: [vector<4xf32>, true, 42]
356  print("raw attr:", raw)
357  # CHECK: - vector<4xf32>
358  # CHECK: - true
359  # CHECK: - 42
360  arr = ArrayAttr(raw)
361  for attr in arr:
362    print("- ", attr)
363  # CHECK: attr[0]: vector<4xf32>
364  print("attr[0]:", arr[0])
365  # CHECK: attr[1]: true
366  print("attr[1]:", arr[1])
367  # CHECK: attr[2]: 42
368  print("attr[2]:", arr[2])
369  try:
370    print("attr[3]:", arr[3])
371  except IndexError as e:
372    # CHECK: Error: ArrayAttribute index out of range
373    print("Error: ", e)
374  with Context():
375    try:
376      ArrayAttr.get([None])
377    except RuntimeError as e:
378      # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
379      print("Error: ", e)
380    try:
381      ArrayAttr.get([42])
382    except RuntimeError as e:
383      # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
384      print("Error: ", e)
385
386