1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6def run(f): 7 print("\nTEST:", f.__name__) 8 f() 9 gc.collect() 10 assert Context._get_live_count() == 0 11 return f 12 13 14# CHECK-LABEL: TEST: testParsePrint 15@run 16def testParsePrint(): 17 with Context() as ctx: 18 t = Attribute.parse('"hello"') 19 assert t.context is ctx 20 ctx = None 21 gc.collect() 22 # CHECK: "hello" 23 print(str(t)) 24 # CHECK: Attribute("hello") 25 print(repr(t)) 26 27 28# CHECK-LABEL: TEST: testParseError 29# TODO: Hook the diagnostic manager to capture a more meaningful error 30# message. 31@run 32def testParseError(): 33 with Context(): 34 try: 35 t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") 36 except ValueError as e: 37 # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' 38 print("testParseError:", e) 39 else: 40 print("Exception not produced") 41 42 43# CHECK-LABEL: TEST: testAttrEq 44@run 45def testAttrEq(): 46 with Context(): 47 a1 = Attribute.parse('"attr1"') 48 a2 = Attribute.parse('"attr2"') 49 a3 = Attribute.parse('"attr1"') 50 # CHECK: a1 == a1: True 51 print("a1 == a1:", a1 == a1) 52 # CHECK: a1 == a2: False 53 print("a1 == a2:", a1 == a2) 54 # CHECK: a1 == a3: True 55 print("a1 == a3:", a1 == a3) 56 # CHECK: a1 == None: False 57 print("a1 == None:", a1 == None) 58 59 60# CHECK-LABEL: TEST: testAttrHash 61@run 62def testAttrHash(): 63 with Context(): 64 a1 = Attribute.parse('"attr1"') 65 a2 = Attribute.parse('"attr2"') 66 a3 = Attribute.parse('"attr1"') 67 # CHECK: hash(a1) == hash(a3): True 68 print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__()) 69 # In general, hashes don't have to be unique. In this case, however, the 70 # hash is just the underlying pointer so it will be. 71 # CHECK: hash(a1) == hash(a2): False 72 print("hash(a1) == hash(a2):", a1.__hash__() == a2.__hash__()) 73 74 s = set() 75 s.add(a1) 76 s.add(a2) 77 s.add(a3) 78 # CHECK: len(s): 2 79 print("len(s): ", len(s)) 80 81 82# CHECK-LABEL: TEST: testAttrCast 83@run 84def testAttrCast(): 85 with Context(): 86 a1 = Attribute.parse('"attr1"') 87 a2 = Attribute(a1) 88 # CHECK: a1 == a2: True 89 print("a1 == a2:", a1 == a2) 90 91 92# CHECK-LABEL: TEST: testAttrEqDoesNotRaise 93@run 94def testAttrEqDoesNotRaise(): 95 with Context(): 96 a1 = Attribute.parse('"attr1"') 97 not_an_attr = "foo" 98 # CHECK: False 99 print(a1 == not_an_attr) 100 # CHECK: False 101 print(a1 == None) 102 # CHECK: True 103 print(a1 != None) 104 105 106# CHECK-LABEL: TEST: testAttrCapsule 107@run 108def testAttrCapsule(): 109 with Context() as ctx: 110 a1 = Attribute.parse('"attr1"') 111 # CHECK: mlir.ir.Attribute._CAPIPtr 112 attr_capsule = a1._CAPIPtr 113 print(attr_capsule) 114 a2 = Attribute._CAPICreate(attr_capsule) 115 assert a2 == a1 116 assert a2.context is ctx 117 118 119# CHECK-LABEL: TEST: testStandardAttrCasts 120@run 121def testStandardAttrCasts(): 122 with Context(): 123 a1 = Attribute.parse('"attr1"') 124 astr = StringAttr(a1) 125 aself = StringAttr(astr) 126 # CHECK: Attribute("attr1") 127 print(repr(astr)) 128 try: 129 tillegal = StringAttr(Attribute.parse("1.0")) 130 except ValueError as e: 131 # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) 132 print("ValueError:", e) 133 else: 134 print("Exception not produced") 135 136 137# CHECK-LABEL: TEST: testAffineMapAttr 138@run 139def testAffineMapAttr(): 140 with Context() as ctx: 141 d0 = AffineDimExpr.get(0) 142 d1 = AffineDimExpr.get(1) 143 c2 = AffineConstantExpr.get(2) 144 map0 = AffineMap.get(2, 3, []) 145 146 # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()> 147 attr_built = AffineMapAttr.get(map0) 148 print(str(attr_built)) 149 150 attr_parsed = Attribute.parse(str(attr_built)) 151 assert attr_built == attr_parsed 152 153 154# CHECK-LABEL: TEST: testFloatAttr 155@run 156def testFloatAttr(): 157 with Context(), Location.unknown(): 158 fattr = FloatAttr(Attribute.parse("42.0 : f32")) 159 # CHECK: fattr value: 42.0 160 print("fattr value:", fattr.value) 161 162 # Test factory methods. 163 # CHECK: default_get: 4.200000e+01 : f32 164 print("default_get:", FloatAttr.get( 165 F32Type.get(), 42.0)) 166 # CHECK: f32_get: 4.200000e+01 : f32 167 print("f32_get:", FloatAttr.get_f32(42.0)) 168 # CHECK: f64_get: 4.200000e+01 : f64 169 print("f64_get:", FloatAttr.get_f64(42.0)) 170 try: 171 fattr_invalid = FloatAttr.get( 172 IntegerType.get_signless(32), 42) 173 except ValueError as e: 174 # CHECK: invalid 'Type(i32)' and expected floating point type. 175 print(e) 176 else: 177 print("Exception not produced") 178 179 180# CHECK-LABEL: TEST: testIntegerAttr 181@run 182def testIntegerAttr(): 183 with Context() as ctx: 184 iattr = IntegerAttr(Attribute.parse("42")) 185 # CHECK: iattr value: 42 186 print("iattr value:", iattr.value) 187 # CHECK: iattr type: i64 188 print("iattr type:", iattr.type) 189 190 # Test factory methods. 191 # CHECK: default_get: 42 : i32 192 print("default_get:", IntegerAttr.get( 193 IntegerType.get_signless(32), 42)) 194 195 196# CHECK-LABEL: TEST: testBoolAttr 197@run 198def testBoolAttr(): 199 with Context() as ctx: 200 battr = BoolAttr(Attribute.parse("true")) 201 # CHECK: iattr value: True 202 print("iattr value:", battr.value) 203 204 # Test factory methods. 205 # CHECK: default_get: true 206 print("default_get:", BoolAttr.get(True)) 207 208 209# CHECK-LABEL: TEST: testFlatSymbolRefAttr 210@run 211def testFlatSymbolRefAttr(): 212 with Context() as ctx: 213 sattr = FlatSymbolRefAttr(Attribute.parse('@symbol')) 214 # CHECK: symattr value: symbol 215 print("symattr value:", sattr.value) 216 217 # Test factory methods. 218 # CHECK: default_get: @foobar 219 print("default_get:", FlatSymbolRefAttr.get("foobar")) 220 221 222# CHECK-LABEL: TEST: testStringAttr 223@run 224def testStringAttr(): 225 with Context() as ctx: 226 sattr = StringAttr(Attribute.parse('"stringattr"')) 227 # CHECK: sattr value: stringattr 228 print("sattr value:", sattr.value) 229 230 # Test factory methods. 231 # CHECK: default_get: "foobar" 232 print("default_get:", StringAttr.get("foobar")) 233 # CHECK: typed_get: "12345" : i32 234 print("typed_get:", StringAttr.get_typed( 235 IntegerType.get_signless(32), "12345")) 236 237 238# CHECK-LABEL: TEST: testNamedAttr 239@run 240def testNamedAttr(): 241 with Context(): 242 a = Attribute.parse('"stringattr"') 243 named = a.get_named("foobar") # Note: under the small object threshold 244 # CHECK: attr: "stringattr" 245 print("attr:", named.attr) 246 # CHECK: name: foobar 247 print("name:", named.name) 248 # CHECK: named: NamedAttribute(foobar="stringattr") 249 print("named:", named) 250 251 252# CHECK-LABEL: TEST: testDenseIntAttr 253@run 254def testDenseIntAttr(): 255 with Context(): 256 raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") 257 # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> 258 print("attr:", raw) 259 260 a = DenseIntElementsAttr(raw) 261 assert len(a) == 6 262 263 # CHECK: 0 1 2 3 4 5 264 for value in a: 265 print(value, end=" ") 266 print() 267 268 # CHECK: i32 269 print(ShapedType(a.type).element_type) 270 271 raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") 272 # CHECK: attr: dense<[true, false, true, false]> 273 print("attr:", raw) 274 275 a = DenseIntElementsAttr(raw) 276 assert len(a) == 4 277 278 # CHECK: 1 0 1 0 279 for value in a: 280 print(value, end=" ") 281 print() 282 283 # CHECK: i1 284 print(ShapedType(a.type).element_type) 285 286 287# CHECK-LABEL: TEST: testDenseFPAttr 288@run 289def testDenseFPAttr(): 290 with Context(): 291 raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") 292 # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> 293 294 print("attr:", raw) 295 296 a = DenseFPElementsAttr(raw) 297 assert len(a) == 4 298 299 # CHECK: 0.0 1.0 2.0 3.0 300 for value in a: 301 print(value, end=" ") 302 print() 303 304 # CHECK: f32 305 print(ShapedType(a.type).element_type) 306 307 308# CHECK-LABEL: TEST: testDictAttr 309@run 310def testDictAttr(): 311 with Context(): 312 dict_attr = { 313 'stringattr': StringAttr.get('string'), 314 'integerattr' : IntegerAttr.get( 315 IntegerType.get_signless(32), 42) 316 } 317 318 a = DictAttr.get(dict_attr) 319 320 # CHECK attr: {integerattr = 42 : i32, stringattr = "string"} 321 print("attr:", a) 322 323 assert len(a) == 2 324 325 # CHECK: 42 : i32 326 print(a['integerattr']) 327 328 # CHECK: "string" 329 print(a['stringattr']) 330 331 # Check that exceptions are raised as expected. 332 try: 333 _ = a['does_not_exist'] 334 except KeyError: 335 pass 336 else: 337 assert False, "Exception not produced" 338 339 try: 340 _ = a[42] 341 except IndexError: 342 pass 343 else: 344 assert False, "expected IndexError on accessing an out-of-bounds attribute" 345 346 # CHECK "empty: {}" 347 print("empty: ", DictAttr.get()) 348 349 350# CHECK-LABEL: TEST: testTypeAttr 351@run 352def testTypeAttr(): 353 with Context(): 354 raw = Attribute.parse("vector<4xf32>") 355 # CHECK: attr: vector<4xf32> 356 print("attr:", raw) 357 type_attr = TypeAttr(raw) 358 # CHECK: f32 359 print(ShapedType(type_attr.value).element_type) 360 361 362# CHECK-LABEL: TEST: testArrayAttr 363@run 364def testArrayAttr(): 365 with Context(): 366 raw = Attribute.parse("[42, true, vector<4xf32>]") 367 # CHECK: attr: [42, true, vector<4xf32>] 368 print("raw attr:", raw) 369 # CHECK: - 42 370 # CHECK: - true 371 # CHECK: - vector<4xf32> 372 for attr in ArrayAttr(raw): 373 print("- ", attr) 374 375 with Context(): 376 intAttr = Attribute.parse("42") 377 vecAttr = Attribute.parse("vector<4xf32>") 378 boolAttr = BoolAttr.get(True) 379 raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) 380 # CHECK: attr: [vector<4xf32>, true, 42] 381 print("raw attr:", raw) 382 # CHECK: - vector<4xf32> 383 # CHECK: - true 384 # CHECK: - 42 385 arr = ArrayAttr(raw) 386 for attr in arr: 387 print("- ", attr) 388 # CHECK: attr[0]: vector<4xf32> 389 print("attr[0]:", arr[0]) 390 # CHECK: attr[1]: true 391 print("attr[1]:", arr[1]) 392 # CHECK: attr[2]: 42 393 print("attr[2]:", arr[2]) 394 try: 395 print("attr[3]:", arr[3]) 396 except IndexError as e: 397 # CHECK: Error: ArrayAttribute index out of range 398 print("Error: ", e) 399 with Context(): 400 try: 401 ArrayAttr.get([None]) 402 except RuntimeError as e: 403 # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute 404 print("Error: ", e) 405 try: 406 ArrayAttr.get([42]) 407 except RuntimeError as e: 408 # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute 409 print("Error: ", e) 410 411 with Context(): 412 array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")]) 413 array = array + [StringAttr.get("c")] 414 # CHECK: concat: ["a", "b", "c"] 415 print("concat: ", array) 416