1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6def run(f): 7 print("\nTEST:", f.__name__) 8 f() 9 gc.collect() 10 assert Context._get_live_count() == 0 11 return f 12 13 14# CHECK-LABEL: TEST: testParsePrint 15@run 16def testParsePrint(): 17 with Context() as ctx: 18 t = Attribute.parse('"hello"') 19 assert t.context is ctx 20 ctx = None 21 gc.collect() 22 # CHECK: "hello" 23 print(str(t)) 24 # CHECK: Attribute("hello") 25 print(repr(t)) 26 27 28# CHECK-LABEL: TEST: testParseError 29# TODO: Hook the diagnostic manager to capture a more meaningful error 30# message. 31@run 32def testParseError(): 33 with Context(): 34 try: 35 t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") 36 except ValueError as e: 37 # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' 38 print("testParseError:", e) 39 else: 40 print("Exception not produced") 41 42 43# CHECK-LABEL: TEST: testAttrEq 44@run 45def testAttrEq(): 46 with Context(): 47 a1 = Attribute.parse('"attr1"') 48 a2 = Attribute.parse('"attr2"') 49 a3 = Attribute.parse('"attr1"') 50 # CHECK: a1 == a1: True 51 print("a1 == a1:", a1 == a1) 52 # CHECK: a1 == a2: False 53 print("a1 == a2:", a1 == a2) 54 # CHECK: a1 == a3: True 55 print("a1 == a3:", a1 == a3) 56 # CHECK: a1 == None: False 57 print("a1 == None:", a1 == None) 58 59 60# CHECK-LABEL: TEST: testAttrHash 61@run 62def testAttrHash(): 63 with Context(): 64 a1 = Attribute.parse('"attr1"') 65 a2 = Attribute.parse('"attr2"') 66 a3 = Attribute.parse('"attr1"') 67 # CHECK: hash(a1) == hash(a3): True 68 print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__()) 69 # In general, hashes don't have to be unique. In this case, however, the 70 # hash is just the underlying pointer so it will be. 71 # CHECK: hash(a1) == hash(a2): False 72 print("hash(a1) == hash(a2):", a1.__hash__() == a2.__hash__()) 73 74 s = set() 75 s.add(a1) 76 s.add(a2) 77 s.add(a3) 78 # CHECK: len(s): 2 79 print("len(s): ", len(s)) 80 81 82# CHECK-LABEL: TEST: testAttrCast 83@run 84def testAttrCast(): 85 with Context(): 86 a1 = Attribute.parse('"attr1"') 87 a2 = Attribute(a1) 88 # CHECK: a1 == a2: True 89 print("a1 == a2:", a1 == a2) 90 91 92# CHECK-LABEL: TEST: testAttrIsInstance 93@run 94def testAttrIsInstance(): 95 with Context(): 96 a1 = Attribute.parse("42") 97 a2 = Attribute.parse("[42]") 98 assert IntegerAttr.isinstance(a1) 99 assert not IntegerAttr.isinstance(a2) 100 assert not ArrayAttr.isinstance(a1) 101 assert ArrayAttr.isinstance(a2) 102 103 104# CHECK-LABEL: TEST: testAttrEqDoesNotRaise 105@run 106def testAttrEqDoesNotRaise(): 107 with Context(): 108 a1 = Attribute.parse('"attr1"') 109 not_an_attr = "foo" 110 # CHECK: False 111 print(a1 == not_an_attr) 112 # CHECK: False 113 print(a1 == None) 114 # CHECK: True 115 print(a1 != None) 116 117 118# CHECK-LABEL: TEST: testAttrCapsule 119@run 120def testAttrCapsule(): 121 with Context() as ctx: 122 a1 = Attribute.parse('"attr1"') 123 # CHECK: mlir.ir.Attribute._CAPIPtr 124 attr_capsule = a1._CAPIPtr 125 print(attr_capsule) 126 a2 = Attribute._CAPICreate(attr_capsule) 127 assert a2 == a1 128 assert a2.context is ctx 129 130 131# CHECK-LABEL: TEST: testStandardAttrCasts 132@run 133def testStandardAttrCasts(): 134 with Context(): 135 a1 = Attribute.parse('"attr1"') 136 astr = StringAttr(a1) 137 aself = StringAttr(astr) 138 # CHECK: Attribute("attr1") 139 print(repr(astr)) 140 try: 141 tillegal = StringAttr(Attribute.parse("1.0")) 142 except ValueError as e: 143 # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) 144 print("ValueError:", e) 145 else: 146 print("Exception not produced") 147 148 149# CHECK-LABEL: TEST: testAffineMapAttr 150@run 151def testAffineMapAttr(): 152 with Context() as ctx: 153 d0 = AffineDimExpr.get(0) 154 d1 = AffineDimExpr.get(1) 155 c2 = AffineConstantExpr.get(2) 156 map0 = AffineMap.get(2, 3, []) 157 158 # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()> 159 attr_built = AffineMapAttr.get(map0) 160 print(str(attr_built)) 161 162 attr_parsed = Attribute.parse(str(attr_built)) 163 assert attr_built == attr_parsed 164 165 166# CHECK-LABEL: TEST: testFloatAttr 167@run 168def testFloatAttr(): 169 with Context(), Location.unknown(): 170 fattr = FloatAttr(Attribute.parse("42.0 : f32")) 171 # CHECK: fattr value: 42.0 172 print("fattr value:", fattr.value) 173 174 # Test factory methods. 175 # CHECK: default_get: 4.200000e+01 : f32 176 print("default_get:", FloatAttr.get( 177 F32Type.get(), 42.0)) 178 # CHECK: f32_get: 4.200000e+01 : f32 179 print("f32_get:", FloatAttr.get_f32(42.0)) 180 # CHECK: f64_get: 4.200000e+01 : f64 181 print("f64_get:", FloatAttr.get_f64(42.0)) 182 try: 183 fattr_invalid = FloatAttr.get( 184 IntegerType.get_signless(32), 42) 185 except ValueError as e: 186 # CHECK: invalid 'Type(i32)' and expected floating point type. 187 print(e) 188 else: 189 print("Exception not produced") 190 191 192# CHECK-LABEL: TEST: testIntegerAttr 193@run 194def testIntegerAttr(): 195 with Context() as ctx: 196 iattr = IntegerAttr(Attribute.parse("42")) 197 # CHECK: iattr value: 42 198 print("iattr value:", iattr.value) 199 # CHECK: iattr type: i64 200 print("iattr type:", iattr.type) 201 202 # Test factory methods. 203 # CHECK: default_get: 42 : i32 204 print("default_get:", IntegerAttr.get( 205 IntegerType.get_signless(32), 42)) 206 207 208# CHECK-LABEL: TEST: testBoolAttr 209@run 210def testBoolAttr(): 211 with Context() as ctx: 212 battr = BoolAttr(Attribute.parse("true")) 213 # CHECK: iattr value: True 214 print("iattr value:", battr.value) 215 216 # Test factory methods. 217 # CHECK: default_get: true 218 print("default_get:", BoolAttr.get(True)) 219 220 221# CHECK-LABEL: TEST: testFlatSymbolRefAttr 222@run 223def testFlatSymbolRefAttr(): 224 with Context() as ctx: 225 sattr = FlatSymbolRefAttr(Attribute.parse('@symbol')) 226 # CHECK: symattr value: symbol 227 print("symattr value:", sattr.value) 228 229 # Test factory methods. 230 # CHECK: default_get: @foobar 231 print("default_get:", FlatSymbolRefAttr.get("foobar")) 232 233 234# CHECK-LABEL: TEST: testStringAttr 235@run 236def testStringAttr(): 237 with Context() as ctx: 238 sattr = StringAttr(Attribute.parse('"stringattr"')) 239 # CHECK: sattr value: stringattr 240 print("sattr value:", sattr.value) 241 242 # Test factory methods. 243 # CHECK: default_get: "foobar" 244 print("default_get:", StringAttr.get("foobar")) 245 # CHECK: typed_get: "12345" : i32 246 print("typed_get:", StringAttr.get_typed( 247 IntegerType.get_signless(32), "12345")) 248 249 250# CHECK-LABEL: TEST: testNamedAttr 251@run 252def testNamedAttr(): 253 with Context(): 254 a = Attribute.parse('"stringattr"') 255 named = a.get_named("foobar") # Note: under the small object threshold 256 # CHECK: attr: "stringattr" 257 print("attr:", named.attr) 258 # CHECK: name: foobar 259 print("name:", named.name) 260 # CHECK: named: NamedAttribute(foobar="stringattr") 261 print("named:", named) 262 263 264# CHECK-LABEL: TEST: testDenseIntAttr 265@run 266def testDenseIntAttr(): 267 with Context(): 268 raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") 269 # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> 270 print("attr:", raw) 271 272 a = DenseIntElementsAttr(raw) 273 assert len(a) == 6 274 275 # CHECK: 0 1 2 3 4 5 276 for value in a: 277 print(value, end=" ") 278 print() 279 280 # CHECK: i32 281 print(ShapedType(a.type).element_type) 282 283 raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") 284 # CHECK: attr: dense<[true, false, true, false]> 285 print("attr:", raw) 286 287 a = DenseIntElementsAttr(raw) 288 assert len(a) == 4 289 290 # CHECK: 1 0 1 0 291 for value in a: 292 print(value, end=" ") 293 print() 294 295 # CHECK: i1 296 print(ShapedType(a.type).element_type) 297 298 299# CHECK-LABEL: TEST: testDenseFPAttr 300@run 301def testDenseFPAttr(): 302 with Context(): 303 raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") 304 # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> 305 306 print("attr:", raw) 307 308 a = DenseFPElementsAttr(raw) 309 assert len(a) == 4 310 311 # CHECK: 0.0 1.0 2.0 3.0 312 for value in a: 313 print(value, end=" ") 314 print() 315 316 # CHECK: f32 317 print(ShapedType(a.type).element_type) 318 319 320# CHECK-LABEL: TEST: testDictAttr 321@run 322def testDictAttr(): 323 with Context(): 324 dict_attr = { 325 'stringattr': StringAttr.get('string'), 326 'integerattr' : IntegerAttr.get( 327 IntegerType.get_signless(32), 42) 328 } 329 330 a = DictAttr.get(dict_attr) 331 332 # CHECK attr: {integerattr = 42 : i32, stringattr = "string"} 333 print("attr:", a) 334 335 assert len(a) == 2 336 337 # CHECK: 42 : i32 338 print(a['integerattr']) 339 340 # CHECK: "string" 341 print(a['stringattr']) 342 343 # Check that exceptions are raised as expected. 344 try: 345 _ = a['does_not_exist'] 346 except KeyError: 347 pass 348 else: 349 assert False, "Exception not produced" 350 351 try: 352 _ = a[42] 353 except IndexError: 354 pass 355 else: 356 assert False, "expected IndexError on accessing an out-of-bounds attribute" 357 358 # CHECK "empty: {}" 359 print("empty: ", DictAttr.get()) 360 361 362# CHECK-LABEL: TEST: testTypeAttr 363@run 364def testTypeAttr(): 365 with Context(): 366 raw = Attribute.parse("vector<4xf32>") 367 # CHECK: attr: vector<4xf32> 368 print("attr:", raw) 369 type_attr = TypeAttr(raw) 370 # CHECK: f32 371 print(ShapedType(type_attr.value).element_type) 372 373 374# CHECK-LABEL: TEST: testArrayAttr 375@run 376def testArrayAttr(): 377 with Context(): 378 raw = Attribute.parse("[42, true, vector<4xf32>]") 379 # CHECK: attr: [42, true, vector<4xf32>] 380 print("raw attr:", raw) 381 # CHECK: - 42 382 # CHECK: - true 383 # CHECK: - vector<4xf32> 384 for attr in ArrayAttr(raw): 385 print("- ", attr) 386 387 with Context(): 388 intAttr = Attribute.parse("42") 389 vecAttr = Attribute.parse("vector<4xf32>") 390 boolAttr = BoolAttr.get(True) 391 raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) 392 # CHECK: attr: [vector<4xf32>, true, 42] 393 print("raw attr:", raw) 394 # CHECK: - vector<4xf32> 395 # CHECK: - true 396 # CHECK: - 42 397 arr = ArrayAttr(raw) 398 for attr in arr: 399 print("- ", attr) 400 # CHECK: attr[0]: vector<4xf32> 401 print("attr[0]:", arr[0]) 402 # CHECK: attr[1]: true 403 print("attr[1]:", arr[1]) 404 # CHECK: attr[2]: 42 405 print("attr[2]:", arr[2]) 406 try: 407 print("attr[3]:", arr[3]) 408 except IndexError as e: 409 # CHECK: Error: ArrayAttribute index out of range 410 print("Error: ", e) 411 with Context(): 412 try: 413 ArrayAttr.get([None]) 414 except RuntimeError as e: 415 # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute 416 print("Error: ", e) 417 try: 418 ArrayAttr.get([42]) 419 except RuntimeError as e: 420 # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute 421 print("Error: ", e) 422 423 with Context(): 424 array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")]) 425 array = array + [StringAttr.get("c")] 426 # CHECK: concat: ["a", "b", "c"] 427 print("concat: ", array) 428