1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6def run(f): 7 print("\nTEST:", f.__name__) 8 f() 9 gc.collect() 10 assert Context._get_live_count() == 0 11 return f 12 13 14# CHECK-LABEL: TEST: testParsePrint 15@run 16def testParsePrint(): 17 ctx = Context() 18 t = Type.parse("i32", ctx) 19 assert t.context is ctx 20 ctx = None 21 gc.collect() 22 # CHECK: i32 23 print(str(t)) 24 # CHECK: Type(i32) 25 print(repr(t)) 26 27 28# CHECK-LABEL: TEST: testParseError 29# TODO: Hook the diagnostic manager to capture a more meaningful error 30# message. 31@run 32def testParseError(): 33 ctx = Context() 34 try: 35 t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx) 36 except ValueError as e: 37 # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST' 38 print("testParseError:", e) 39 else: 40 print("Exception not produced") 41 42 43# CHECK-LABEL: TEST: testTypeEq 44@run 45def testTypeEq(): 46 ctx = Context() 47 t1 = Type.parse("i32", ctx) 48 t2 = Type.parse("f32", ctx) 49 t3 = Type.parse("i32", ctx) 50 # CHECK: t1 == t1: True 51 print("t1 == t1:", t1 == t1) 52 # CHECK: t1 == t2: False 53 print("t1 == t2:", t1 == t2) 54 # CHECK: t1 == t3: True 55 print("t1 == t3:", t1 == t3) 56 # CHECK: t1 == None: False 57 print("t1 == None:", t1 == None) 58 59 60# CHECK-LABEL: TEST: testTypeHash 61@run 62def testTypeHash(): 63 ctx = Context() 64 t1 = Type.parse("i32", ctx) 65 t2 = Type.parse("f32", ctx) 66 t3 = Type.parse("i32", ctx) 67 68 # CHECK: hash(t1) == hash(t3): True 69 print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__()) 70 71 s = set() 72 s.add(t1) 73 s.add(t2) 74 s.add(t3) 75 # CHECK: len(s): 2 76 print("len(s): ", len(s)) 77 78# CHECK-LABEL: TEST: testTypeCast 79@run 80def testTypeCast(): 81 ctx = Context() 82 t1 = Type.parse("i32", ctx) 83 t2 = Type(t1) 84 # CHECK: t1 == t2: True 85 print("t1 == t2:", t1 == t2) 86 87 88# CHECK-LABEL: TEST: testTypeIsInstance 89@run 90def testTypeIsInstance(): 91 ctx = Context() 92 t1 = Type.parse("i32", ctx) 93 t2 = Type.parse("f32", ctx) 94 # CHECK: True 95 print(IntegerType.isinstance(t1)) 96 # CHECK: False 97 print(F32Type.isinstance(t1)) 98 # CHECK: True 99 print(F32Type.isinstance(t2)) 100 101 102# CHECK-LABEL: TEST: testTypeEqDoesNotRaise 103@run 104def testTypeEqDoesNotRaise(): 105 ctx = Context() 106 t1 = Type.parse("i32", ctx) 107 not_a_type = "foo" 108 # CHECK: False 109 print(t1 == not_a_type) 110 # CHECK: False 111 print(t1 == None) 112 # CHECK: True 113 print(t1 != None) 114 115 116# CHECK-LABEL: TEST: testTypeCapsule 117@run 118def testTypeCapsule(): 119 with Context() as ctx: 120 t1 = Type.parse("i32", ctx) 121 # CHECK: mlir.ir.Type._CAPIPtr 122 type_capsule = t1._CAPIPtr 123 print(type_capsule) 124 t2 = Type._CAPICreate(type_capsule) 125 assert t2 == t1 126 assert t2.context is ctx 127 128 129# CHECK-LABEL: TEST: testStandardTypeCasts 130@run 131def testStandardTypeCasts(): 132 ctx = Context() 133 t1 = Type.parse("i32", ctx) 134 tint = IntegerType(t1) 135 tself = IntegerType(tint) 136 # CHECK: Type(i32) 137 print(repr(tint)) 138 try: 139 tillegal = IntegerType(Type.parse("f32", ctx)) 140 except ValueError as e: 141 # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32)) 142 print("ValueError:", e) 143 else: 144 print("Exception not produced") 145 146 147# CHECK-LABEL: TEST: testIntegerType 148@run 149def testIntegerType(): 150 with Context() as ctx: 151 i32 = IntegerType(Type.parse("i32")) 152 # CHECK: i32 width: 32 153 print("i32 width:", i32.width) 154 # CHECK: i32 signless: True 155 print("i32 signless:", i32.is_signless) 156 # CHECK: i32 signed: False 157 print("i32 signed:", i32.is_signed) 158 # CHECK: i32 unsigned: False 159 print("i32 unsigned:", i32.is_unsigned) 160 161 s32 = IntegerType(Type.parse("si32")) 162 # CHECK: s32 signless: False 163 print("s32 signless:", s32.is_signless) 164 # CHECK: s32 signed: True 165 print("s32 signed:", s32.is_signed) 166 # CHECK: s32 unsigned: False 167 print("s32 unsigned:", s32.is_unsigned) 168 169 u32 = IntegerType(Type.parse("ui32")) 170 # CHECK: u32 signless: False 171 print("u32 signless:", u32.is_signless) 172 # CHECK: u32 signed: False 173 print("u32 signed:", u32.is_signed) 174 # CHECK: u32 unsigned: True 175 print("u32 unsigned:", u32.is_unsigned) 176 177 # CHECK: signless: i16 178 print("signless:", IntegerType.get_signless(16)) 179 # CHECK: signed: si8 180 print("signed:", IntegerType.get_signed(8)) 181 # CHECK: unsigned: ui64 182 print("unsigned:", IntegerType.get_unsigned(64)) 183 184# CHECK-LABEL: TEST: testIndexType 185@run 186def testIndexType(): 187 with Context() as ctx: 188 # CHECK: index type: index 189 print("index type:", IndexType.get()) 190 191 192# CHECK-LABEL: TEST: testFloatType 193@run 194def testFloatType(): 195 with Context(): 196 # CHECK: float: bf16 197 print("float:", BF16Type.get()) 198 # CHECK: float: f16 199 print("float:", F16Type.get()) 200 # CHECK: float: f32 201 print("float:", F32Type.get()) 202 # CHECK: float: f64 203 print("float:", F64Type.get()) 204 205 206# CHECK-LABEL: TEST: testNoneType 207@run 208def testNoneType(): 209 with Context(): 210 # CHECK: none type: none 211 print("none type:", NoneType.get()) 212 213 214# CHECK-LABEL: TEST: testComplexType 215@run 216def testComplexType(): 217 with Context() as ctx: 218 complex_i32 = ComplexType(Type.parse("complex<i32>")) 219 # CHECK: complex type element: i32 220 print("complex type element:", complex_i32.element_type) 221 222 f32 = F32Type.get() 223 # CHECK: complex type: complex<f32> 224 print("complex type:", ComplexType.get(f32)) 225 226 index = IndexType.get() 227 try: 228 complex_invalid = ComplexType.get(index) 229 except ValueError as e: 230 # CHECK: invalid 'Type(index)' and expected floating point or integer type. 231 print(e) 232 else: 233 print("Exception not produced") 234 235 236# CHECK-LABEL: TEST: testConcreteShapedType 237# Shaped type is not a kind of builtin types, it is the base class for vectors, 238# memrefs and tensors, so this test case uses an instance of vector to test the 239# shaped type. The class hierarchy is preserved on the python side. 240@run 241def testConcreteShapedType(): 242 with Context() as ctx: 243 vector = VectorType(Type.parse("vector<2x3xf32>")) 244 # CHECK: element type: f32 245 print("element type:", vector.element_type) 246 # CHECK: whether the given shaped type is ranked: True 247 print("whether the given shaped type is ranked:", vector.has_rank) 248 # CHECK: rank: 2 249 print("rank:", vector.rank) 250 # CHECK: whether the shaped type has a static shape: True 251 print("whether the shaped type has a static shape:", vector.has_static_shape) 252 # CHECK: whether the dim-th dimension is dynamic: False 253 print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) 254 # CHECK: dim size: 3 255 print("dim size:", vector.get_dim_size(1)) 256 # CHECK: is_dynamic_size: False 257 print("is_dynamic_size:", vector.is_dynamic_size(3)) 258 # CHECK: is_dynamic_stride_or_offset: False 259 print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) 260 # CHECK: isinstance(ShapedType): True 261 print("isinstance(ShapedType):", isinstance(vector, ShapedType)) 262 263 264# CHECK-LABEL: TEST: testAbstractShapedType 265# Tests that ShapedType operates as an abstract base class of a concrete 266# shaped type (using vector as an example). 267@run 268def testAbstractShapedType(): 269 ctx = Context() 270 vector = ShapedType(Type.parse("vector<2x3xf32>", ctx)) 271 # CHECK: element type: f32 272 print("element type:", vector.element_type) 273 274 275# CHECK-LABEL: TEST: testVectorType 276@run 277def testVectorType(): 278 with Context(), Location.unknown(): 279 f32 = F32Type.get() 280 shape = [2, 3] 281 # CHECK: vector type: vector<2x3xf32> 282 print("vector type:", VectorType.get(shape, f32)) 283 284 none = NoneType.get() 285 try: 286 vector_invalid = VectorType.get(shape, none) 287 except ValueError as e: 288 # CHECK: invalid 'Type(none)' and expected floating point or integer type. 289 print(e) 290 else: 291 print("Exception not produced") 292 293 294# CHECK-LABEL: TEST: testRankedTensorType 295@run 296def testRankedTensorType(): 297 with Context(), Location.unknown(): 298 f32 = F32Type.get() 299 shape = [2, 3] 300 loc = Location.unknown() 301 # CHECK: ranked tensor type: tensor<2x3xf32> 302 print("ranked tensor type:", 303 RankedTensorType.get(shape, f32)) 304 305 none = NoneType.get() 306 try: 307 tensor_invalid = RankedTensorType.get(shape, none) 308 except ValueError as e: 309 # CHECK: invalid 'Type(none)' and expected floating point, integer, vector 310 # CHECK: or complex type. 311 print(e) 312 else: 313 print("Exception not produced") 314 315 # Encoding should be None. 316 assert RankedTensorType.get(shape, f32).encoding is None 317 318 tensor = RankedTensorType.get(shape, f32) 319 assert tensor.shape == shape 320 321 322# CHECK-LABEL: TEST: testUnrankedTensorType 323@run 324def testUnrankedTensorType(): 325 with Context(), Location.unknown(): 326 f32 = F32Type.get() 327 loc = Location.unknown() 328 unranked_tensor = UnrankedTensorType.get(f32) 329 # CHECK: unranked tensor type: tensor<*xf32> 330 print("unranked tensor type:", unranked_tensor) 331 try: 332 invalid_rank = unranked_tensor.rank 333 except ValueError as e: 334 # CHECK: calling this method requires that the type has a rank. 335 print(e) 336 else: 337 print("Exception not produced") 338 try: 339 invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) 340 except ValueError as e: 341 # CHECK: calling this method requires that the type has a rank. 342 print(e) 343 else: 344 print("Exception not produced") 345 try: 346 invalid_get_dim_size = unranked_tensor.get_dim_size(1) 347 except ValueError as e: 348 # CHECK: calling this method requires that the type has a rank. 349 print(e) 350 else: 351 print("Exception not produced") 352 353 none = NoneType.get() 354 try: 355 tensor_invalid = UnrankedTensorType.get(none) 356 except ValueError as e: 357 # CHECK: invalid 'Type(none)' and expected floating point, integer, vector 358 # CHECK: or complex type. 359 print(e) 360 else: 361 print("Exception not produced") 362 363 364# CHECK-LABEL: TEST: testMemRefType 365@run 366def testMemRefType(): 367 with Context(), Location.unknown(): 368 f32 = F32Type.get() 369 shape = [2, 3] 370 loc = Location.unknown() 371 memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) 372 # CHECK: memref type: memref<2x3xf32, 2> 373 print("memref type:", memref) 374 # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)> 375 print("memref layout:", memref.layout) 376 # CHECK: memref affine map: (d0, d1) -> (d0, d1) 377 print("memref affine map:", memref.affine_map) 378 # CHECK: memory space: 2 379 print("memory space:", memref.memory_space) 380 381 layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) 382 memref_layout = MemRefType.get(shape, f32, layout=layout) 383 # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> 384 print("memref type:", memref_layout) 385 # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)> 386 print("memref layout:", memref_layout.layout) 387 # CHECK: memref affine map: (d0, d1) -> (d1, d0) 388 print("memref affine map:", memref_layout.affine_map) 389 # CHECK: memory space: <<NULL ATTRIBUTE>> 390 print("memory space:", memref_layout.memory_space) 391 392 none = NoneType.get() 393 try: 394 memref_invalid = MemRefType.get(shape, none) 395 except ValueError as e: 396 # CHECK: invalid 'Type(none)' and expected floating point, integer, vector 397 # CHECK: or complex type. 398 print(e) 399 else: 400 print("Exception not produced") 401 402 assert memref.shape == shape 403 404 405# CHECK-LABEL: TEST: testUnrankedMemRefType 406@run 407def testUnrankedMemRefType(): 408 with Context(), Location.unknown(): 409 f32 = F32Type.get() 410 loc = Location.unknown() 411 unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) 412 # CHECK: unranked memref type: memref<*xf32, 2> 413 print("unranked memref type:", unranked_memref) 414 try: 415 invalid_rank = unranked_memref.rank 416 except ValueError as e: 417 # CHECK: calling this method requires that the type has a rank. 418 print(e) 419 else: 420 print("Exception not produced") 421 try: 422 invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) 423 except ValueError as e: 424 # CHECK: calling this method requires that the type has a rank. 425 print(e) 426 else: 427 print("Exception not produced") 428 try: 429 invalid_get_dim_size = unranked_memref.get_dim_size(1) 430 except ValueError as e: 431 # CHECK: calling this method requires that the type has a rank. 432 print(e) 433 else: 434 print("Exception not produced") 435 436 none = NoneType.get() 437 try: 438 memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) 439 except ValueError as e: 440 # CHECK: invalid 'Type(none)' and expected floating point, integer, vector 441 # CHECK: or complex type. 442 print(e) 443 else: 444 print("Exception not produced") 445 446 447# CHECK-LABEL: TEST: testTupleType 448@run 449def testTupleType(): 450 with Context() as ctx: 451 i32 = IntegerType(Type.parse("i32")) 452 f32 = F32Type.get() 453 vector = VectorType(Type.parse("vector<2x3xf32>")) 454 l = [i32, f32, vector] 455 tuple_type = TupleType.get_tuple(l) 456 # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>> 457 print("tuple type:", tuple_type) 458 # CHECK: number of types: 3 459 print("number of types:", tuple_type.num_types) 460 # CHECK: pos-th type in the tuple type: f32 461 print("pos-th type in the tuple type:", tuple_type.get_type(1)) 462 463 464# CHECK-LABEL: TEST: testFunctionType 465@run 466def testFunctionType(): 467 with Context() as ctx: 468 input_types = [IntegerType.get_signless(32), 469 IntegerType.get_signless(16)] 470 result_types = [IndexType.get()] 471 func = FunctionType.get(input_types, result_types) 472 # CHECK: INPUTS: [Type(i32), Type(i16)] 473 print("INPUTS:", func.inputs) 474 # CHECK: RESULTS: [Type(index)] 475 print("RESULTS:", func.results) 476 477 478# CHECK-LABEL: TEST: testOpaqueType 479@run 480def testOpaqueType(): 481 with Context() as ctx: 482 ctx.allow_unregistered_dialects = True 483 opaque = OpaqueType.get("dialect", "type") 484 # CHECK: opaque type: !dialect.type 485 print("opaque type:", opaque) 486 # CHECK: dialect namespace: dialect 487 print("dialect namespace:", opaque.dialect_namespace) 488 # CHECK: data: type 489 print("data:", opaque.data) 490