1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6def run(f): 7 print("\nTEST:", f.__name__) 8 f() 9 gc.collect() 10 assert Context._get_live_count() == 0 11 return f 12 13 14# CHECK-LABEL: TEST: testParsePrint 15@run 16def testParsePrint(): 17 with Context() as ctx: 18 t = Attribute.parse('"hello"') 19 assert t.context is ctx 20 ctx = None 21 gc.collect() 22 # CHECK: "hello" 23 print(str(t)) 24 # CHECK: Attribute("hello") 25 print(repr(t)) 26 27 28# CHECK-LABEL: TEST: testParseError 29# TODO: Hook the diagnostic manager to capture a more meaningful error 30# message. 31@run 32def testParseError(): 33 with Context(): 34 try: 35 t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") 36 except ValueError as e: 37 # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' 38 print("testParseError:", e) 39 else: 40 print("Exception not produced") 41 42 43# CHECK-LABEL: TEST: testAttrEq 44@run 45def testAttrEq(): 46 with Context(): 47 a1 = Attribute.parse('"attr1"') 48 a2 = Attribute.parse('"attr2"') 49 a3 = Attribute.parse('"attr1"') 50 # CHECK: a1 == a1: True 51 print("a1 == a1:", a1 == a1) 52 # CHECK: a1 == a2: False 53 print("a1 == a2:", a1 == a2) 54 # CHECK: a1 == a3: True 55 print("a1 == a3:", a1 == a3) 56 # CHECK: a1 == None: False 57 print("a1 == None:", a1 == None) 58 59 60# CHECK-LABEL: TEST: testAttrCast 61@run 62def testAttrCast(): 63 with Context(): 64 a1 = Attribute.parse('"attr1"') 65 a2 = Attribute(a1) 66 # CHECK: a1 == a2: True 67 print("a1 == a2:", a1 == a2) 68 69 70# CHECK-LABEL: TEST: testAttrEqDoesNotRaise 71@run 72def testAttrEqDoesNotRaise(): 73 with Context(): 74 a1 = Attribute.parse('"attr1"') 75 not_an_attr = "foo" 76 # CHECK: False 77 print(a1 == not_an_attr) 78 # CHECK: False 79 print(a1 == None) 80 # CHECK: True 81 print(a1 != None) 82 83 84# CHECK-LABEL: TEST: testAttrCapsule 85@run 86def testAttrCapsule(): 87 with Context() as ctx: 88 a1 = Attribute.parse('"attr1"') 89 # CHECK: mlir.ir.Attribute._CAPIPtr 90 attr_capsule = a1._CAPIPtr 91 print(attr_capsule) 92 a2 = Attribute._CAPICreate(attr_capsule) 93 assert a2 == a1 94 assert a2.context is ctx 95 96 97# CHECK-LABEL: TEST: testStandardAttrCasts 98@run 99def testStandardAttrCasts(): 100 with Context(): 101 a1 = Attribute.parse('"attr1"') 102 astr = StringAttr(a1) 103 aself = StringAttr(astr) 104 # CHECK: Attribute("attr1") 105 print(repr(astr)) 106 try: 107 tillegal = StringAttr(Attribute.parse("1.0")) 108 except ValueError as e: 109 # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) 110 print("ValueError:", e) 111 else: 112 print("Exception not produced") 113 114 115# CHECK-LABEL: TEST: testAffineMapAttr 116@run 117def testAffineMapAttr(): 118 with Context() as ctx: 119 d0 = AffineDimExpr.get(0) 120 d1 = AffineDimExpr.get(1) 121 c2 = AffineConstantExpr.get(2) 122 map0 = AffineMap.get(2, 3, []) 123 124 # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()> 125 attr_built = AffineMapAttr.get(map0) 126 print(str(attr_built)) 127 128 attr_parsed = Attribute.parse(str(attr_built)) 129 assert attr_built == attr_parsed 130 131 132# CHECK-LABEL: TEST: testFloatAttr 133@run 134def testFloatAttr(): 135 with Context(), Location.unknown(): 136 fattr = FloatAttr(Attribute.parse("42.0 : f32")) 137 # CHECK: fattr value: 42.0 138 print("fattr value:", fattr.value) 139 140 # Test factory methods. 141 # CHECK: default_get: 4.200000e+01 : f32 142 print("default_get:", FloatAttr.get( 143 F32Type.get(), 42.0)) 144 # CHECK: f32_get: 4.200000e+01 : f32 145 print("f32_get:", FloatAttr.get_f32(42.0)) 146 # CHECK: f64_get: 4.200000e+01 : f64 147 print("f64_get:", FloatAttr.get_f64(42.0)) 148 try: 149 fattr_invalid = FloatAttr.get( 150 IntegerType.get_signless(32), 42) 151 except ValueError as e: 152 # CHECK: invalid 'Type(i32)' and expected floating point type. 153 print(e) 154 else: 155 print("Exception not produced") 156 157 158# CHECK-LABEL: TEST: testIntegerAttr 159@run 160def testIntegerAttr(): 161 with Context() as ctx: 162 iattr = IntegerAttr(Attribute.parse("42")) 163 # CHECK: iattr value: 42 164 print("iattr value:", iattr.value) 165 # CHECK: iattr type: i64 166 print("iattr type:", iattr.type) 167 168 # Test factory methods. 169 # CHECK: default_get: 42 : i32 170 print("default_get:", IntegerAttr.get( 171 IntegerType.get_signless(32), 42)) 172 173 174# CHECK-LABEL: TEST: testBoolAttr 175@run 176def testBoolAttr(): 177 with Context() as ctx: 178 battr = BoolAttr(Attribute.parse("true")) 179 # CHECK: iattr value: True 180 print("iattr value:", battr.value) 181 182 # Test factory methods. 183 # CHECK: default_get: true 184 print("default_get:", BoolAttr.get(True)) 185 186 187# CHECK-LABEL: TEST: testFlatSymbolRefAttr 188@run 189def testFlatSymbolRefAttr(): 190 with Context() as ctx: 191 sattr = FlatSymbolRefAttr(Attribute.parse('@symbol')) 192 # CHECK: symattr value: symbol 193 print("symattr value:", sattr.value) 194 195 # Test factory methods. 196 # CHECK: default_get: @foobar 197 print("default_get:", FlatSymbolRefAttr.get("foobar")) 198 199 200# CHECK-LABEL: TEST: testStringAttr 201@run 202def testStringAttr(): 203 with Context() as ctx: 204 sattr = StringAttr(Attribute.parse('"stringattr"')) 205 # CHECK: sattr value: stringattr 206 print("sattr value:", sattr.value) 207 208 # Test factory methods. 209 # CHECK: default_get: "foobar" 210 print("default_get:", StringAttr.get("foobar")) 211 # CHECK: typed_get: "12345" : i32 212 print("typed_get:", StringAttr.get_typed( 213 IntegerType.get_signless(32), "12345")) 214 215 216# CHECK-LABEL: TEST: testNamedAttr 217@run 218def testNamedAttr(): 219 with Context(): 220 a = Attribute.parse('"stringattr"') 221 named = a.get_named("foobar") # Note: under the small object threshold 222 # CHECK: attr: "stringattr" 223 print("attr:", named.attr) 224 # CHECK: name: foobar 225 print("name:", named.name) 226 # CHECK: named: NamedAttribute(foobar="stringattr") 227 print("named:", named) 228 229 230# CHECK-LABEL: TEST: testDenseIntAttr 231@run 232def testDenseIntAttr(): 233 with Context(): 234 raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") 235 # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> 236 print("attr:", raw) 237 238 a = DenseIntElementsAttr(raw) 239 assert len(a) == 6 240 241 # CHECK: 0 1 2 3 4 5 242 for value in a: 243 print(value, end=" ") 244 print() 245 246 # CHECK: i32 247 print(ShapedType(a.type).element_type) 248 249 raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") 250 # CHECK: attr: dense<[true, false, true, false]> 251 print("attr:", raw) 252 253 a = DenseIntElementsAttr(raw) 254 assert len(a) == 4 255 256 # CHECK: 1 0 1 0 257 for value in a: 258 print(value, end=" ") 259 print() 260 261 # CHECK: i1 262 print(ShapedType(a.type).element_type) 263 264 265# CHECK-LABEL: TEST: testDenseFPAttr 266@run 267def testDenseFPAttr(): 268 with Context(): 269 raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") 270 # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> 271 272 print("attr:", raw) 273 274 a = DenseFPElementsAttr(raw) 275 assert len(a) == 4 276 277 # CHECK: 0.0 1.0 2.0 3.0 278 for value in a: 279 print(value, end=" ") 280 print() 281 282 # CHECK: f32 283 print(ShapedType(a.type).element_type) 284 285 286# CHECK-LABEL: TEST: testDictAttr 287@run 288def testDictAttr(): 289 with Context(): 290 dict_attr = { 291 'stringattr': StringAttr.get('string'), 292 'integerattr' : IntegerAttr.get( 293 IntegerType.get_signless(32), 42) 294 } 295 296 a = DictAttr.get(dict_attr) 297 298 # CHECK attr: {integerattr = 42 : i32, stringattr = "string"} 299 print("attr:", a) 300 301 assert len(a) == 2 302 303 # CHECK: 42 : i32 304 print(a['integerattr']) 305 306 # CHECK: "string" 307 print(a['stringattr']) 308 309 # Check that exceptions are raised as expected. 310 try: 311 _ = a['does_not_exist'] 312 except KeyError: 313 pass 314 else: 315 assert False, "Exception not produced" 316 317 try: 318 _ = a[42] 319 except IndexError: 320 pass 321 else: 322 assert False, "expected IndexError on accessing an out-of-bounds attribute" 323 324 325# CHECK-LABEL: TEST: testTypeAttr 326@run 327def testTypeAttr(): 328 with Context(): 329 raw = Attribute.parse("vector<4xf32>") 330 # CHECK: attr: vector<4xf32> 331 print("attr:", raw) 332 type_attr = TypeAttr(raw) 333 # CHECK: f32 334 print(ShapedType(type_attr.value).element_type) 335 336 337# CHECK-LABEL: TEST: testArrayAttr 338@run 339def testArrayAttr(): 340 with Context(): 341 raw = Attribute.parse("[42, true, vector<4xf32>]") 342 # CHECK: attr: [42, true, vector<4xf32>] 343 print("raw attr:", raw) 344 # CHECK: - 42 345 # CHECK: - true 346 # CHECK: - vector<4xf32> 347 for attr in ArrayAttr(raw): 348 print("- ", attr) 349 350 with Context(): 351 intAttr = Attribute.parse("42") 352 vecAttr = Attribute.parse("vector<4xf32>") 353 boolAttr = BoolAttr.get(True) 354 raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) 355 # CHECK: attr: [vector<4xf32>, true, 42] 356 print("raw attr:", raw) 357 # CHECK: - vector<4xf32> 358 # CHECK: - true 359 # CHECK: - 42 360 arr = ArrayAttr(raw) 361 for attr in arr: 362 print("- ", attr) 363 # CHECK: attr[0]: vector<4xf32> 364 print("attr[0]:", arr[0]) 365 # CHECK: attr[1]: true 366 print("attr[1]:", arr[1]) 367 # CHECK: attr[2]: 42 368 print("attr[2]:", arr[2]) 369 try: 370 print("attr[3]:", arr[3]) 371 except IndexError as e: 372 # CHECK: Error: ArrayAttribute index out of range 373 print("Error: ", e) 374 with Context(): 375 try: 376 ArrayAttr.get([None]) 377 except RuntimeError as e: 378 # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute 379 print("Error: ", e) 380 try: 381 ArrayAttr.get([42]) 382 except RuntimeError as e: 383 # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute 384 print("Error: ", e) 385 386