1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6def run(f): 7 print("\nTEST:", f.__name__) 8 f() 9 gc.collect() 10 assert Context._get_live_count() == 0 11 return f 12 13 14# CHECK-LABEL: TEST: testParsePrint 15@run 16def testParsePrint(): 17 with Context() as ctx: 18 t = Attribute.parse('"hello"') 19 assert t.context is ctx 20 ctx = None 21 gc.collect() 22 # CHECK: "hello" 23 print(str(t)) 24 # CHECK: Attribute("hello") 25 print(repr(t)) 26 27 28# CHECK-LABEL: TEST: testParseError 29# TODO: Hook the diagnostic manager to capture a more meaningful error 30# message. 31@run 32def testParseError(): 33 with Context(): 34 try: 35 t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") 36 except ValueError as e: 37 # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' 38 print("testParseError:", e) 39 else: 40 print("Exception not produced") 41 42 43# CHECK-LABEL: TEST: testAttrEq 44@run 45def testAttrEq(): 46 with Context(): 47 a1 = Attribute.parse('"attr1"') 48 a2 = Attribute.parse('"attr2"') 49 a3 = Attribute.parse('"attr1"') 50 # CHECK: a1 == a1: True 51 print("a1 == a1:", a1 == a1) 52 # CHECK: a1 == a2: False 53 print("a1 == a2:", a1 == a2) 54 # CHECK: a1 == a3: True 55 print("a1 == a3:", a1 == a3) 56 # CHECK: a1 == None: False 57 print("a1 == None:", a1 == None) 58 59 60# CHECK-LABEL: TEST: testAttrHash 61@run 62def testAttrHash(): 63 with Context(): 64 a1 = Attribute.parse('"attr1"') 65 a2 = Attribute.parse('"attr2"') 66 a3 = Attribute.parse('"attr1"') 67 # CHECK: hash(a1) == hash(a3): True 68 print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__()) 69 70 s = set() 71 s.add(a1) 72 s.add(a2) 73 s.add(a3) 74 # CHECK: len(s): 2 75 print("len(s): ", len(s)) 76 77 78# CHECK-LABEL: TEST: testAttrCast 79@run 80def testAttrCast(): 81 with Context(): 82 a1 = Attribute.parse('"attr1"') 83 a2 = Attribute(a1) 84 # CHECK: a1 == a2: True 85 print("a1 == a2:", a1 == a2) 86 87 88# CHECK-LABEL: TEST: testAttrIsInstance 89@run 90def testAttrIsInstance(): 91 with Context(): 92 a1 = Attribute.parse("42") 93 a2 = Attribute.parse("[42]") 94 assert IntegerAttr.isinstance(a1) 95 assert not IntegerAttr.isinstance(a2) 96 assert not ArrayAttr.isinstance(a1) 97 assert ArrayAttr.isinstance(a2) 98 99 100# CHECK-LABEL: TEST: testAttrEqDoesNotRaise 101@run 102def testAttrEqDoesNotRaise(): 103 with Context(): 104 a1 = Attribute.parse('"attr1"') 105 not_an_attr = "foo" 106 # CHECK: False 107 print(a1 == not_an_attr) 108 # CHECK: False 109 print(a1 == None) 110 # CHECK: True 111 print(a1 != None) 112 113 114# CHECK-LABEL: TEST: testAttrCapsule 115@run 116def testAttrCapsule(): 117 with Context() as ctx: 118 a1 = Attribute.parse('"attr1"') 119 # CHECK: mlir.ir.Attribute._CAPIPtr 120 attr_capsule = a1._CAPIPtr 121 print(attr_capsule) 122 a2 = Attribute._CAPICreate(attr_capsule) 123 assert a2 == a1 124 assert a2.context is ctx 125 126 127# CHECK-LABEL: TEST: testStandardAttrCasts 128@run 129def testStandardAttrCasts(): 130 with Context(): 131 a1 = Attribute.parse('"attr1"') 132 astr = StringAttr(a1) 133 aself = StringAttr(astr) 134 # CHECK: Attribute("attr1") 135 print(repr(astr)) 136 try: 137 tillegal = StringAttr(Attribute.parse("1.0")) 138 except ValueError as e: 139 # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) 140 print("ValueError:", e) 141 else: 142 print("Exception not produced") 143 144 145# CHECK-LABEL: TEST: testAffineMapAttr 146@run 147def testAffineMapAttr(): 148 with Context() as ctx: 149 d0 = AffineDimExpr.get(0) 150 d1 = AffineDimExpr.get(1) 151 c2 = AffineConstantExpr.get(2) 152 map0 = AffineMap.get(2, 3, []) 153 154 # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()> 155 attr_built = AffineMapAttr.get(map0) 156 print(str(attr_built)) 157 158 attr_parsed = Attribute.parse(str(attr_built)) 159 assert attr_built == attr_parsed 160 161 162# CHECK-LABEL: TEST: testFloatAttr 163@run 164def testFloatAttr(): 165 with Context(), Location.unknown(): 166 fattr = FloatAttr(Attribute.parse("42.0 : f32")) 167 # CHECK: fattr value: 42.0 168 print("fattr value:", fattr.value) 169 170 # Test factory methods. 171 # CHECK: default_get: 4.200000e+01 : f32 172 print("default_get:", FloatAttr.get( 173 F32Type.get(), 42.0)) 174 # CHECK: f32_get: 4.200000e+01 : f32 175 print("f32_get:", FloatAttr.get_f32(42.0)) 176 # CHECK: f64_get: 4.200000e+01 : f64 177 print("f64_get:", FloatAttr.get_f64(42.0)) 178 try: 179 fattr_invalid = FloatAttr.get( 180 IntegerType.get_signless(32), 42) 181 except ValueError as e: 182 # CHECK: invalid 'Type(i32)' and expected floating point type. 183 print(e) 184 else: 185 print("Exception not produced") 186 187 188# CHECK-LABEL: TEST: testIntegerAttr 189@run 190def testIntegerAttr(): 191 with Context() as ctx: 192 i_attr = IntegerAttr(Attribute.parse("42")) 193 # CHECK: i_attr value: 42 194 print("i_attr value:", i_attr.value) 195 # CHECK: i_attr type: i64 196 print("i_attr type:", i_attr.type) 197 si_attr = IntegerAttr(Attribute.parse("-1 : si8")) 198 # CHECK: si_attr value: -1 199 print("si_attr value:", si_attr.value) 200 ui_attr = IntegerAttr(Attribute.parse("255 : ui8")) 201 # CHECK: ui_attr value: 255 202 print("ui_attr value:", ui_attr.value) 203 idx_attr = IntegerAttr(Attribute.parse("-1 : index")) 204 # CHECK: idx_attr value: -1 205 print("idx_attr value:", idx_attr.value) 206 207 # Test factory methods. 208 # CHECK: default_get: 42 : i32 209 print("default_get:", IntegerAttr.get( 210 IntegerType.get_signless(32), 42)) 211 212 213# CHECK-LABEL: TEST: testBoolAttr 214@run 215def testBoolAttr(): 216 with Context() as ctx: 217 battr = BoolAttr(Attribute.parse("true")) 218 # CHECK: iattr value: True 219 print("iattr value:", battr.value) 220 221 # Test factory methods. 222 # CHECK: default_get: true 223 print("default_get:", BoolAttr.get(True)) 224 225 226# CHECK-LABEL: TEST: testFlatSymbolRefAttr 227@run 228def testFlatSymbolRefAttr(): 229 with Context() as ctx: 230 sattr = FlatSymbolRefAttr(Attribute.parse('@symbol')) 231 # CHECK: symattr value: symbol 232 print("symattr value:", sattr.value) 233 234 # Test factory methods. 235 # CHECK: default_get: @foobar 236 print("default_get:", FlatSymbolRefAttr.get("foobar")) 237 238 239# CHECK-LABEL: TEST: testOpaqueAttr 240@run 241def testOpaqueAttr(): 242 with Context() as ctx: 243 ctx.allow_unregistered_dialects = True 244 oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>")) 245 # CHECK: oattr value: pytest_dummy 246 print("oattr value:", oattr.dialect_namespace) 247 # CHECK: oattr value: dummyattr<> 248 print("oattr value:", oattr.data) 249 250 # Test factory methods. 251 # CHECK: default_get: #foobar<123> 252 print( 253 "default_get:", 254 OpaqueAttr.get("foobar", bytes("123", "utf-8"), NoneType.get())) 255 256 257# CHECK-LABEL: TEST: testStringAttr 258@run 259def testStringAttr(): 260 with Context() as ctx: 261 sattr = StringAttr(Attribute.parse('"stringattr"')) 262 # CHECK: sattr value: stringattr 263 print("sattr value:", sattr.value) 264 265 # Test factory methods. 266 # CHECK: default_get: "foobar" 267 print("default_get:", StringAttr.get("foobar")) 268 # CHECK: typed_get: "12345" : i32 269 print("typed_get:", StringAttr.get_typed( 270 IntegerType.get_signless(32), "12345")) 271 272 273# CHECK-LABEL: TEST: testNamedAttr 274@run 275def testNamedAttr(): 276 with Context(): 277 a = Attribute.parse('"stringattr"') 278 named = a.get_named("foobar") # Note: under the small object threshold 279 # CHECK: attr: "stringattr" 280 print("attr:", named.attr) 281 # CHECK: name: foobar 282 print("name:", named.name) 283 # CHECK: named: NamedAttribute(foobar="stringattr") 284 print("named:", named) 285 286 287# CHECK-LABEL: TEST: testDenseIntAttr 288@run 289def testDenseIntAttr(): 290 with Context(): 291 raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") 292 # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> 293 print("attr:", raw) 294 295 a = DenseIntElementsAttr(raw) 296 assert len(a) == 6 297 298 # CHECK: 0 1 2 3 4 5 299 for value in a: 300 print(value, end=" ") 301 print() 302 303 # CHECK: i32 304 print(ShapedType(a.type).element_type) 305 306 raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") 307 # CHECK: attr: dense<[true, false, true, false]> 308 print("attr:", raw) 309 310 a = DenseIntElementsAttr(raw) 311 assert len(a) == 4 312 313 # CHECK: 1 0 1 0 314 for value in a: 315 print(value, end=" ") 316 print() 317 318 # CHECK: i1 319 print(ShapedType(a.type).element_type) 320 321 322# CHECK-LABEL: TEST: testDenseIntAttrGetItem 323@run 324def testDenseIntAttrGetItem(): 325 def print_item(attr_asm): 326 attr = DenseIntElementsAttr(Attribute.parse(attr_asm)) 327 dtype = ShapedType(attr.type).element_type 328 try: 329 item = attr[0] 330 print(f"{dtype}:", item) 331 except TypeError as e: 332 print(f"{dtype}:", e) 333 334 with Context(): 335 # CHECK: i1: 1 336 print_item("dense<true> : tensor<i1>") 337 # CHECK: i8: 123 338 print_item("dense<123> : tensor<i8>") 339 # CHECK: i16: 123 340 print_item("dense<123> : tensor<i16>") 341 # CHECK: i32: 123 342 print_item("dense<123> : tensor<i32>") 343 # CHECK: i64: 123 344 print_item("dense<123> : tensor<i64>") 345 # CHECK: ui8: 123 346 print_item("dense<123> : tensor<ui8>") 347 # CHECK: ui16: 123 348 print_item("dense<123> : tensor<ui16>") 349 # CHECK: ui32: 123 350 print_item("dense<123> : tensor<ui32>") 351 # CHECK: ui64: 123 352 print_item("dense<123> : tensor<ui64>") 353 # CHECK: si8: -123 354 print_item("dense<-123> : tensor<si8>") 355 # CHECK: si16: -123 356 print_item("dense<-123> : tensor<si16>") 357 # CHECK: si32: -123 358 print_item("dense<-123> : tensor<si32>") 359 # CHECK: si64: -123 360 print_item("dense<-123> : tensor<si64>") 361 362 # CHECK: i7: Unsupported integer type 363 print_item("dense<123> : tensor<i7>") 364 365 366# CHECK-LABEL: TEST: testDenseFPAttr 367@run 368def testDenseFPAttr(): 369 with Context(): 370 raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") 371 # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> 372 373 print("attr:", raw) 374 375 a = DenseFPElementsAttr(raw) 376 assert len(a) == 4 377 378 # CHECK: 0.0 1.0 2.0 3.0 379 for value in a: 380 print(value, end=" ") 381 print() 382 383 # CHECK: f32 384 print(ShapedType(a.type).element_type) 385 386 387# CHECK-LABEL: TEST: testDictAttr 388@run 389def testDictAttr(): 390 with Context(): 391 dict_attr = { 392 'stringattr': StringAttr.get('string'), 393 'integerattr' : IntegerAttr.get( 394 IntegerType.get_signless(32), 42) 395 } 396 397 a = DictAttr.get(dict_attr) 398 399 # CHECK attr: {integerattr = 42 : i32, stringattr = "string"} 400 print("attr:", a) 401 402 assert len(a) == 2 403 404 # CHECK: 42 : i32 405 print(a['integerattr']) 406 407 # CHECK: "string" 408 print(a['stringattr']) 409 410 # CHECK: True 411 print('stringattr' in a) 412 413 # CHECK: False 414 print('not_in_dict' in a) 415 416 # Check that exceptions are raised as expected. 417 try: 418 _ = a['does_not_exist'] 419 except KeyError: 420 pass 421 else: 422 assert False, "Exception not produced" 423 424 try: 425 _ = a[42] 426 except IndexError: 427 pass 428 else: 429 assert False, "expected IndexError on accessing an out-of-bounds attribute" 430 431 # CHECK "empty: {}" 432 print("empty: ", DictAttr.get()) 433 434 435# CHECK-LABEL: TEST: testTypeAttr 436@run 437def testTypeAttr(): 438 with Context(): 439 raw = Attribute.parse("vector<4xf32>") 440 # CHECK: attr: vector<4xf32> 441 print("attr:", raw) 442 type_attr = TypeAttr(raw) 443 # CHECK: f32 444 print(ShapedType(type_attr.value).element_type) 445 446 447# CHECK-LABEL: TEST: testArrayAttr 448@run 449def testArrayAttr(): 450 with Context(): 451 raw = Attribute.parse("[42, true, vector<4xf32>]") 452 # CHECK: attr: [42, true, vector<4xf32>] 453 print("raw attr:", raw) 454 # CHECK: - 42 455 # CHECK: - true 456 # CHECK: - vector<4xf32> 457 for attr in ArrayAttr(raw): 458 print("- ", attr) 459 460 with Context(): 461 intAttr = Attribute.parse("42") 462 vecAttr = Attribute.parse("vector<4xf32>") 463 boolAttr = BoolAttr.get(True) 464 raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) 465 # CHECK: attr: [vector<4xf32>, true, 42] 466 print("raw attr:", raw) 467 # CHECK: - vector<4xf32> 468 # CHECK: - true 469 # CHECK: - 42 470 arr = ArrayAttr(raw) 471 for attr in arr: 472 print("- ", attr) 473 # CHECK: attr[0]: vector<4xf32> 474 print("attr[0]:", arr[0]) 475 # CHECK: attr[1]: true 476 print("attr[1]:", arr[1]) 477 # CHECK: attr[2]: 42 478 print("attr[2]:", arr[2]) 479 try: 480 print("attr[3]:", arr[3]) 481 except IndexError as e: 482 # CHECK: Error: ArrayAttribute index out of range 483 print("Error: ", e) 484 with Context(): 485 try: 486 ArrayAttr.get([None]) 487 except RuntimeError as e: 488 # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute 489 print("Error: ", e) 490 try: 491 ArrayAttr.get([42]) 492 except RuntimeError as e: 493 # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute 494 print("Error: ", e) 495 496 with Context(): 497 array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")]) 498 array = array + [StringAttr.get("c")] 499 # CHECK: concat: ["a", "b", "c"] 500 print("concat: ", array) 501