1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6
7def run(f):
8  print("\nTEST:", f.__name__)
9  f()
10  gc.collect()
11  assert Context._get_live_count() == 0
12  return f
13
14
15# CHECK-LABEL: TEST: testAffineMapCapsule
16@run
17def testAffineMapCapsule():
18  with Context() as ctx:
19    am1 = AffineMap.get_empty(ctx)
20  # CHECK: mlir.ir.AffineMap._CAPIPtr
21  affine_map_capsule = am1._CAPIPtr
22  print(affine_map_capsule)
23  am2 = AffineMap._CAPICreate(affine_map_capsule)
24  assert am2 == am1
25  assert am2.context is ctx
26
27
28# CHECK-LABEL: TEST: testAffineMapGet
29@run
30def testAffineMapGet():
31  with Context() as ctx:
32    d0 = AffineDimExpr.get(0)
33    d1 = AffineDimExpr.get(1)
34    c2 = AffineConstantExpr.get(2)
35
36    # CHECK: (d0, d1)[s0, s1, s2] -> ()
37    map0 = AffineMap.get(2, 3, [])
38    print(map0)
39
40    # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
41    map1 = AffineMap.get(2, 3, [d1, c2])
42    print(map1)
43
44    # CHECK: () -> (2)
45    map2 = AffineMap.get(0, 0, [c2])
46    print(map2)
47
48    # CHECK: (d0, d1) -> (d0, d1)
49    map3 = AffineMap.get(2, 0, [d0, d1])
50    print(map3)
51
52    # CHECK: (d0, d1) -> (d1)
53    map4 = AffineMap.get(2, 0, [d1])
54    print(map4)
55
56    # CHECK: (d0, d1, d2) -> (d2, d0, d1)
57    map5 = AffineMap.get_permutation([2, 0, 1])
58    print(map5)
59
60    assert map1 == AffineMap.get(2, 3, [d1, c2])
61    assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
62    assert map2 == AffineMap.get_constant(2)
63    assert map3 == AffineMap.get_identity(2)
64    assert map4 == AffineMap.get_minor_identity(2, 1)
65
66    try:
67      AffineMap.get(1, 1, [1])
68    except RuntimeError as e:
69      # CHECK: Invalid expression when attempting to create an AffineMap
70      print(e)
71
72    try:
73      AffineMap.get(1, 1, [None])
74    except RuntimeError as e:
75      # CHECK: Invalid expression (None?) when attempting to create an AffineMap
76      print(e)
77
78    try:
79      AffineMap.get_permutation([1, 0, 1])
80    except RuntimeError as e:
81      # CHECK: Invalid permutation when attempting to create an AffineMap
82      print(e)
83
84    try:
85      map3.get_submap([42])
86    except ValueError as e:
87      # CHECK: result position out of bounds
88      print(e)
89
90    try:
91      map3.get_minor_submap(42)
92    except ValueError as e:
93      # CHECK: number of results out of bounds
94      print(e)
95
96    try:
97      map3.get_major_submap(42)
98    except ValueError as e:
99      # CHECK: number of results out of bounds
100      print(e)
101
102
103# CHECK-LABEL: TEST: testAffineMapDerive
104@run
105def testAffineMapDerive():
106  with Context() as ctx:
107    map5 = AffineMap.get_identity(5)
108
109    # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
110    map123 = map5.get_submap([1, 2, 3])
111    print(map123)
112
113    # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
114    map01 = map5.get_major_submap(2)
115    print(map01)
116
117    # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
118    map34 = map5.get_minor_submap(2)
119    print(map34)
120
121
122# CHECK-LABEL: TEST: testAffineMapProperties
123@run
124def testAffineMapProperties():
125  with Context():
126    d0 = AffineDimExpr.get(0)
127    d1 = AffineDimExpr.get(1)
128    d2 = AffineDimExpr.get(2)
129    map1 = AffineMap.get(3, 0, [d2, d0])
130    map2 = AffineMap.get(3, 0, [d2, d0, d1])
131    map3 = AffineMap.get(3, 1, [d2, d0, d1])
132    # CHECK: False
133    print(map1.is_permutation)
134    # CHECK: True
135    print(map1.is_projected_permutation)
136    # CHECK: True
137    print(map2.is_permutation)
138    # CHECK: True
139    print(map2.is_projected_permutation)
140    # CHECK: False
141    print(map3.is_permutation)
142    # CHECK: False
143    print(map3.is_projected_permutation)
144
145
146# CHECK-LABEL: TEST: testAffineMapExprs
147@run
148def testAffineMapExprs():
149  with Context():
150    d0 = AffineDimExpr.get(0)
151    d1 = AffineDimExpr.get(1)
152    d2 = AffineDimExpr.get(2)
153    map3 = AffineMap.get(3, 1, [d2, d0, d1])
154
155    # CHECK: 3
156    print(map3.n_dims)
157    # CHECK: 4
158    print(map3.n_inputs)
159    # CHECK: 1
160    print(map3.n_symbols)
161    assert map3.n_inputs == map3.n_dims + map3.n_symbols
162
163    # CHECK: 3
164    print(len(map3.results))
165    for expr in map3.results:
166      # CHECK: d2
167      # CHECK: d0
168      # CHECK: d1
169      print(expr)
170    for expr in map3.results[-1:-4:-1]:
171      # CHECK: d1
172      # CHECK: d0
173      # CHECK: d2
174      print(expr)
175    assert list(map3.results) == [d2, d0, d1]
176
177
178# CHECK-LABEL: TEST: testCompressUnusedSymbols
179@run
180def testCompressUnusedSymbols():
181  with Context() as ctx:
182    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
183                  AffineDimExpr.get(2))
184    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
185                  AffineSymbolExpr.get(2))
186    maps = [
187        AffineMap.get(3, 3, [d2, d0, d1]),
188        AffineMap.get(3, 3, [d2, d0 + s2, d1]),
189        AffineMap.get(3, 3, [d1, d2, d0])
190    ]
191
192    compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
193
194    #      CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
195    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
196    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
197    print(maps)
198
199    #      CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
200    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
201    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
202    print(compressed_maps)
203
204
205# CHECK-LABEL: TEST: testReplace
206@run
207def testReplace():
208  with Context() as ctx:
209    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
210                  AffineDimExpr.get(2))
211    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
212                  AffineSymbolExpr.get(2))
213    map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
214
215    replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
216    replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
217    replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
218
219    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
220    print(replace0)
221
222    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
223    print(replace1)
224
225    # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
226    print(replace3)
227
228
229# CHECK-LABEL: TEST: testHash
230@run
231def testHash():
232  with Context():
233    d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
234    m1 = AffineMap.get(2, 0, [d0, d1])
235    m2 = AffineMap.get(2, 0, [d1, d0])
236    assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
237
238    dictionary = dict()
239    dictionary[m1] = 1
240    dictionary[m2] = 2
241    assert m1 in dictionary
242