1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6 7def run(f): 8 print("\nTEST:", f.__name__) 9 f() 10 gc.collect() 11 assert Context._get_live_count() == 0 12 13 14# CHECK-LABEL: TEST: testAffineMapCapsule 15def testAffineMapCapsule(): 16 with Context() as ctx: 17 am1 = AffineMap.get_empty(ctx) 18 # CHECK: mlir.ir.AffineMap._CAPIPtr 19 affine_map_capsule = am1._CAPIPtr 20 print(affine_map_capsule) 21 am2 = AffineMap._CAPICreate(affine_map_capsule) 22 assert am2 == am1 23 assert am2.context is ctx 24 25 26run(testAffineMapCapsule) 27 28 29# CHECK-LABEL: TEST: testAffineMapGet 30def testAffineMapGet(): 31 with Context() as ctx: 32 d0 = AffineDimExpr.get(0) 33 d1 = AffineDimExpr.get(1) 34 c2 = AffineConstantExpr.get(2) 35 36 # CHECK: (d0, d1)[s0, s1, s2] -> () 37 map0 = AffineMap.get(2, 3, []) 38 print(map0) 39 40 # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2) 41 map1 = AffineMap.get(2, 3, [d1, c2]) 42 print(map1) 43 44 # CHECK: () -> (2) 45 map2 = AffineMap.get(0, 0, [c2]) 46 print(map2) 47 48 # CHECK: (d0, d1) -> (d0, d1) 49 map3 = AffineMap.get(2, 0, [d0, d1]) 50 print(map3) 51 52 # CHECK: (d0, d1) -> (d1) 53 map4 = AffineMap.get(2, 0, [d1]) 54 print(map4) 55 56 # CHECK: (d0, d1, d2) -> (d2, d0, d1) 57 map5 = AffineMap.get_permutation([2, 0, 1]) 58 print(map5) 59 60 assert map1 == AffineMap.get(2, 3, [d1, c2]) 61 assert AffineMap.get(0, 0, []) == AffineMap.get_empty() 62 assert map2 == AffineMap.get_constant(2) 63 assert map3 == AffineMap.get_identity(2) 64 assert map4 == AffineMap.get_minor_identity(2, 1) 65 66 try: 67 AffineMap.get(1, 1, [1]) 68 except RuntimeError as e: 69 # CHECK: Invalid expression when attempting to create an AffineMap 70 print(e) 71 72 try: 73 AffineMap.get(1, 1, [None]) 74 except RuntimeError as e: 75 # CHECK: Invalid expression (None?) when attempting to create an AffineMap 76 print(e) 77 78 try: 79 AffineMap.get_permutation([1, 0, 1]) 80 except RuntimeError as e: 81 # CHECK: Invalid permutation when attempting to create an AffineMap 82 print(e) 83 84 try: 85 map3.get_submap([42]) 86 except ValueError as e: 87 # CHECK: result position out of bounds 88 print(e) 89 90 try: 91 map3.get_minor_submap(42) 92 except ValueError as e: 93 # CHECK: number of results out of bounds 94 print(e) 95 96 try: 97 map3.get_major_submap(42) 98 except ValueError as e: 99 # CHECK: number of results out of bounds 100 print(e) 101 102 103run(testAffineMapGet) 104 105 106# CHECK-LABEL: TEST: testAffineMapDerive 107def testAffineMapDerive(): 108 with Context() as ctx: 109 map5 = AffineMap.get_identity(5) 110 111 # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3) 112 map123 = map5.get_submap([1, 2, 3]) 113 print(map123) 114 115 # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1) 116 map01 = map5.get_major_submap(2) 117 print(map01) 118 119 # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4) 120 map34 = map5.get_minor_submap(2) 121 print(map34) 122 123 124run(testAffineMapDerive) 125 126 127# CHECK-LABEL: TEST: testAffineMapProperties 128def testAffineMapProperties(): 129 with Context(): 130 d0 = AffineDimExpr.get(0) 131 d1 = AffineDimExpr.get(1) 132 d2 = AffineDimExpr.get(2) 133 map1 = AffineMap.get(3, 0, [d2, d0]) 134 map2 = AffineMap.get(3, 0, [d2, d0, d1]) 135 map3 = AffineMap.get(3, 1, [d2, d0, d1]) 136 # CHECK: False 137 print(map1.is_permutation) 138 # CHECK: True 139 print(map1.is_projected_permutation) 140 # CHECK: True 141 print(map2.is_permutation) 142 # CHECK: True 143 print(map2.is_projected_permutation) 144 # CHECK: False 145 print(map3.is_permutation) 146 # CHECK: False 147 print(map3.is_projected_permutation) 148 149 150run(testAffineMapProperties) 151 152 153# CHECK-LABEL: TEST: testAffineMapExprs 154def testAffineMapExprs(): 155 with Context(): 156 d0 = AffineDimExpr.get(0) 157 d1 = AffineDimExpr.get(1) 158 d2 = AffineDimExpr.get(2) 159 map3 = AffineMap.get(3, 1, [d2, d0, d1]) 160 161 # CHECK: 3 162 print(map3.n_dims) 163 # CHECK: 4 164 print(map3.n_inputs) 165 # CHECK: 1 166 print(map3.n_symbols) 167 assert map3.n_inputs == map3.n_dims + map3.n_symbols 168 169 # CHECK: 3 170 print(len(map3.results)) 171 for expr in map3.results: 172 # CHECK: d2 173 # CHECK: d0 174 # CHECK: d1 175 print(expr) 176 for expr in map3.results[-1:-4:-1]: 177 # CHECK: d1 178 # CHECK: d0 179 # CHECK: d2 180 print(expr) 181 assert list(map3.results) == [d2, d0, d1] 182 183 184run(testAffineMapExprs) 185 186 187# CHECK-LABEL: TEST: testCompressUnusedSymbols 188def testCompressUnusedSymbols(): 189 with Context() as ctx: 190 d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), 191 AffineDimExpr.get(2)) 192 s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1), 193 AffineSymbolExpr.get(2)) 194 maps = [ 195 AffineMap.get(3, 3, [d2, d0, d1]), 196 AffineMap.get(3, 3, [d2, d0 + s2, d1]), 197 AffineMap.get(3, 3, [d1, d2, d0]) 198 ] 199 200 compressed_maps = AffineMap.compress_unused_symbols(maps, ctx) 201 202 # CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1)) 203 # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1)) 204 # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0)) 205 print(maps) 206 207 # CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1)) 208 # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1)) 209 # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0)) 210 print(compressed_maps) 211 212 213run(testCompressUnusedSymbols) 214 215 216# CHECK-LABEL: TEST: testReplace 217def testReplace(): 218 with Context() as ctx: 219 d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), 220 AffineDimExpr.get(2)) 221 s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1), 222 AffineSymbolExpr.get(2)) 223 map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0]) 224 225 replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3) 226 replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3) 227 replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2) 228 229 # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42) 230 print(replace0) 231 232 # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0) 233 print(replace1) 234 235 # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0) 236 print(replace3) 237 238 239run(testReplace) 240