1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6def run(f): 7 print("\nTEST:", f.__name__) 8 f() 9 gc.collect() 10 assert Context._get_live_count() == 0 11 return f 12 13 14# CHECK-LABEL: TEST: testParsePrint 15@run 16def testParsePrint(): 17 with Context() as ctx: 18 t = Attribute.parse('"hello"') 19 assert t.context is ctx 20 ctx = None 21 gc.collect() 22 # CHECK: "hello" 23 print(str(t)) 24 # CHECK: Attribute("hello") 25 print(repr(t)) 26 27 28# CHECK-LABEL: TEST: testParseError 29# TODO: Hook the diagnostic manager to capture a more meaningful error 30# message. 31@run 32def testParseError(): 33 with Context(): 34 try: 35 t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") 36 except ValueError as e: 37 # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' 38 print("testParseError:", e) 39 else: 40 print("Exception not produced") 41 42 43# CHECK-LABEL: TEST: testAttrEq 44@run 45def testAttrEq(): 46 with Context(): 47 a1 = Attribute.parse('"attr1"') 48 a2 = Attribute.parse('"attr2"') 49 a3 = Attribute.parse('"attr1"') 50 # CHECK: a1 == a1: True 51 print("a1 == a1:", a1 == a1) 52 # CHECK: a1 == a2: False 53 print("a1 == a2:", a1 == a2) 54 # CHECK: a1 == a3: True 55 print("a1 == a3:", a1 == a3) 56 # CHECK: a1 == None: False 57 print("a1 == None:", a1 == None) 58 59 60# CHECK-LABEL: TEST: testAttrHash 61@run 62def testAttrHash(): 63 with Context(): 64 a1 = Attribute.parse('"attr1"') 65 a2 = Attribute.parse('"attr2"') 66 a3 = Attribute.parse('"attr1"') 67 # CHECK: hash(a1) == hash(a3): True 68 print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__()) 69 70 s = set() 71 s.add(a1) 72 s.add(a2) 73 s.add(a3) 74 # CHECK: len(s): 2 75 print("len(s): ", len(s)) 76 77 78# CHECK-LABEL: TEST: testAttrCast 79@run 80def testAttrCast(): 81 with Context(): 82 a1 = Attribute.parse('"attr1"') 83 a2 = Attribute(a1) 84 # CHECK: a1 == a2: True 85 print("a1 == a2:", a1 == a2) 86 87 88# CHECK-LABEL: TEST: testAttrIsInstance 89@run 90def testAttrIsInstance(): 91 with Context(): 92 a1 = Attribute.parse("42") 93 a2 = Attribute.parse("[42]") 94 assert IntegerAttr.isinstance(a1) 95 assert not IntegerAttr.isinstance(a2) 96 assert not ArrayAttr.isinstance(a1) 97 assert ArrayAttr.isinstance(a2) 98 99 100# CHECK-LABEL: TEST: testAttrEqDoesNotRaise 101@run 102def testAttrEqDoesNotRaise(): 103 with Context(): 104 a1 = Attribute.parse('"attr1"') 105 not_an_attr = "foo" 106 # CHECK: False 107 print(a1 == not_an_attr) 108 # CHECK: False 109 print(a1 == None) 110 # CHECK: True 111 print(a1 != None) 112 113 114# CHECK-LABEL: TEST: testAttrCapsule 115@run 116def testAttrCapsule(): 117 with Context() as ctx: 118 a1 = Attribute.parse('"attr1"') 119 # CHECK: mlir.ir.Attribute._CAPIPtr 120 attr_capsule = a1._CAPIPtr 121 print(attr_capsule) 122 a2 = Attribute._CAPICreate(attr_capsule) 123 assert a2 == a1 124 assert a2.context is ctx 125 126 127# CHECK-LABEL: TEST: testStandardAttrCasts 128@run 129def testStandardAttrCasts(): 130 with Context(): 131 a1 = Attribute.parse('"attr1"') 132 astr = StringAttr(a1) 133 aself = StringAttr(astr) 134 # CHECK: Attribute("attr1") 135 print(repr(astr)) 136 try: 137 tillegal = StringAttr(Attribute.parse("1.0")) 138 except ValueError as e: 139 # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) 140 print("ValueError:", e) 141 else: 142 print("Exception not produced") 143 144 145# CHECK-LABEL: TEST: testAffineMapAttr 146@run 147def testAffineMapAttr(): 148 with Context() as ctx: 149 d0 = AffineDimExpr.get(0) 150 d1 = AffineDimExpr.get(1) 151 c2 = AffineConstantExpr.get(2) 152 map0 = AffineMap.get(2, 3, []) 153 154 # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()> 155 attr_built = AffineMapAttr.get(map0) 156 print(str(attr_built)) 157 158 attr_parsed = Attribute.parse(str(attr_built)) 159 assert attr_built == attr_parsed 160 161 162# CHECK-LABEL: TEST: testFloatAttr 163@run 164def testFloatAttr(): 165 with Context(), Location.unknown(): 166 fattr = FloatAttr(Attribute.parse("42.0 : f32")) 167 # CHECK: fattr value: 42.0 168 print("fattr value:", fattr.value) 169 170 # Test factory methods. 171 # CHECK: default_get: 4.200000e+01 : f32 172 print("default_get:", FloatAttr.get( 173 F32Type.get(), 42.0)) 174 # CHECK: f32_get: 4.200000e+01 : f32 175 print("f32_get:", FloatAttr.get_f32(42.0)) 176 # CHECK: f64_get: 4.200000e+01 : f64 177 print("f64_get:", FloatAttr.get_f64(42.0)) 178 try: 179 fattr_invalid = FloatAttr.get( 180 IntegerType.get_signless(32), 42) 181 except ValueError as e: 182 # CHECK: invalid 'Type(i32)' and expected floating point type. 183 print(e) 184 else: 185 print("Exception not produced") 186 187 188# CHECK-LABEL: TEST: testIntegerAttr 189@run 190def testIntegerAttr(): 191 with Context() as ctx: 192 iattr = IntegerAttr(Attribute.parse("42")) 193 # CHECK: iattr value: 42 194 print("iattr value:", iattr.value) 195 # CHECK: iattr type: i64 196 print("iattr type:", iattr.type) 197 198 # Test factory methods. 199 # CHECK: default_get: 42 : i32 200 print("default_get:", IntegerAttr.get( 201 IntegerType.get_signless(32), 42)) 202 203 204# CHECK-LABEL: TEST: testBoolAttr 205@run 206def testBoolAttr(): 207 with Context() as ctx: 208 battr = BoolAttr(Attribute.parse("true")) 209 # CHECK: iattr value: True 210 print("iattr value:", battr.value) 211 212 # Test factory methods. 213 # CHECK: default_get: true 214 print("default_get:", BoolAttr.get(True)) 215 216 217# CHECK-LABEL: TEST: testFlatSymbolRefAttr 218@run 219def testFlatSymbolRefAttr(): 220 with Context() as ctx: 221 sattr = FlatSymbolRefAttr(Attribute.parse('@symbol')) 222 # CHECK: symattr value: symbol 223 print("symattr value:", sattr.value) 224 225 # Test factory methods. 226 # CHECK: default_get: @foobar 227 print("default_get:", FlatSymbolRefAttr.get("foobar")) 228 229 230# CHECK-LABEL: TEST: testStringAttr 231@run 232def testStringAttr(): 233 with Context() as ctx: 234 sattr = StringAttr(Attribute.parse('"stringattr"')) 235 # CHECK: sattr value: stringattr 236 print("sattr value:", sattr.value) 237 238 # Test factory methods. 239 # CHECK: default_get: "foobar" 240 print("default_get:", StringAttr.get("foobar")) 241 # CHECK: typed_get: "12345" : i32 242 print("typed_get:", StringAttr.get_typed( 243 IntegerType.get_signless(32), "12345")) 244 245 246# CHECK-LABEL: TEST: testNamedAttr 247@run 248def testNamedAttr(): 249 with Context(): 250 a = Attribute.parse('"stringattr"') 251 named = a.get_named("foobar") # Note: under the small object threshold 252 # CHECK: attr: "stringattr" 253 print("attr:", named.attr) 254 # CHECK: name: foobar 255 print("name:", named.name) 256 # CHECK: named: NamedAttribute(foobar="stringattr") 257 print("named:", named) 258 259 260# CHECK-LABEL: TEST: testDenseIntAttr 261@run 262def testDenseIntAttr(): 263 with Context(): 264 raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") 265 # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> 266 print("attr:", raw) 267 268 a = DenseIntElementsAttr(raw) 269 assert len(a) == 6 270 271 # CHECK: 0 1 2 3 4 5 272 for value in a: 273 print(value, end=" ") 274 print() 275 276 # CHECK: i32 277 print(ShapedType(a.type).element_type) 278 279 raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") 280 # CHECK: attr: dense<[true, false, true, false]> 281 print("attr:", raw) 282 283 a = DenseIntElementsAttr(raw) 284 assert len(a) == 4 285 286 # CHECK: 1 0 1 0 287 for value in a: 288 print(value, end=" ") 289 print() 290 291 # CHECK: i1 292 print(ShapedType(a.type).element_type) 293 294 295# CHECK-LABEL: TEST: testDenseFPAttr 296@run 297def testDenseFPAttr(): 298 with Context(): 299 raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") 300 # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> 301 302 print("attr:", raw) 303 304 a = DenseFPElementsAttr(raw) 305 assert len(a) == 4 306 307 # CHECK: 0.0 1.0 2.0 3.0 308 for value in a: 309 print(value, end=" ") 310 print() 311 312 # CHECK: f32 313 print(ShapedType(a.type).element_type) 314 315 316# CHECK-LABEL: TEST: testDictAttr 317@run 318def testDictAttr(): 319 with Context(): 320 dict_attr = { 321 'stringattr': StringAttr.get('string'), 322 'integerattr' : IntegerAttr.get( 323 IntegerType.get_signless(32), 42) 324 } 325 326 a = DictAttr.get(dict_attr) 327 328 # CHECK attr: {integerattr = 42 : i32, stringattr = "string"} 329 print("attr:", a) 330 331 assert len(a) == 2 332 333 # CHECK: 42 : i32 334 print(a['integerattr']) 335 336 # CHECK: "string" 337 print(a['stringattr']) 338 339 # CHECK: True 340 print('stringattr' in a) 341 342 # CHECK: False 343 print('not_in_dict' in a) 344 345 # Check that exceptions are raised as expected. 346 try: 347 _ = a['does_not_exist'] 348 except KeyError: 349 pass 350 else: 351 assert False, "Exception not produced" 352 353 try: 354 _ = a[42] 355 except IndexError: 356 pass 357 else: 358 assert False, "expected IndexError on accessing an out-of-bounds attribute" 359 360 # CHECK "empty: {}" 361 print("empty: ", DictAttr.get()) 362 363 364# CHECK-LABEL: TEST: testTypeAttr 365@run 366def testTypeAttr(): 367 with Context(): 368 raw = Attribute.parse("vector<4xf32>") 369 # CHECK: attr: vector<4xf32> 370 print("attr:", raw) 371 type_attr = TypeAttr(raw) 372 # CHECK: f32 373 print(ShapedType(type_attr.value).element_type) 374 375 376# CHECK-LABEL: TEST: testArrayAttr 377@run 378def testArrayAttr(): 379 with Context(): 380 raw = Attribute.parse("[42, true, vector<4xf32>]") 381 # CHECK: attr: [42, true, vector<4xf32>] 382 print("raw attr:", raw) 383 # CHECK: - 42 384 # CHECK: - true 385 # CHECK: - vector<4xf32> 386 for attr in ArrayAttr(raw): 387 print("- ", attr) 388 389 with Context(): 390 intAttr = Attribute.parse("42") 391 vecAttr = Attribute.parse("vector<4xf32>") 392 boolAttr = BoolAttr.get(True) 393 raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) 394 # CHECK: attr: [vector<4xf32>, true, 42] 395 print("raw attr:", raw) 396 # CHECK: - vector<4xf32> 397 # CHECK: - true 398 # CHECK: - 42 399 arr = ArrayAttr(raw) 400 for attr in arr: 401 print("- ", attr) 402 # CHECK: attr[0]: vector<4xf32> 403 print("attr[0]:", arr[0]) 404 # CHECK: attr[1]: true 405 print("attr[1]:", arr[1]) 406 # CHECK: attr[2]: 42 407 print("attr[2]:", arr[2]) 408 try: 409 print("attr[3]:", arr[3]) 410 except IndexError as e: 411 # CHECK: Error: ArrayAttribute index out of range 412 print("Error: ", e) 413 with Context(): 414 try: 415 ArrayAttr.get([None]) 416 except RuntimeError as e: 417 # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute 418 print("Error: ", e) 419 try: 420 ArrayAttr.get([42]) 421 except RuntimeError as e: 422 # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute 423 print("Error: ", e) 424 425 with Context(): 426 array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")]) 427 array = array + [StringAttr.get("c")] 428 # CHECK: concat: ["a", "b", "c"] 429 print("concat: ", array) 430