1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6
7def run(f):
8  print("\nTEST:", f.__name__)
9  f()
10  gc.collect()
11  assert Context._get_live_count() == 0
12
13
14# CHECK-LABEL: TEST: testAffineMapCapsule
15def testAffineMapCapsule():
16  with Context() as ctx:
17    am1 = AffineMap.get_empty(ctx)
18  # CHECK: mlir.ir.AffineMap._CAPIPtr
19  affine_map_capsule = am1._CAPIPtr
20  print(affine_map_capsule)
21  am2 = AffineMap._CAPICreate(affine_map_capsule)
22  assert am2 == am1
23  assert am2.context is ctx
24
25
26run(testAffineMapCapsule)
27
28
29# CHECK-LABEL: TEST: testAffineMapGet
30def testAffineMapGet():
31  with Context() as ctx:
32    d0 = AffineDimExpr.get(0)
33    d1 = AffineDimExpr.get(1)
34    c2 = AffineConstantExpr.get(2)
35
36    # CHECK: (d0, d1)[s0, s1, s2] -> ()
37    map0 = AffineMap.get(2, 3, [])
38    print(map0)
39
40    # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
41    map1 = AffineMap.get(2, 3, [d1, c2])
42    print(map1)
43
44    # CHECK: () -> (2)
45    map2 = AffineMap.get(0, 0, [c2])
46    print(map2)
47
48    # CHECK: (d0, d1) -> (d0, d1)
49    map3 = AffineMap.get(2, 0, [d0, d1])
50    print(map3)
51
52    # CHECK: (d0, d1) -> (d1)
53    map4 = AffineMap.get(2, 0, [d1])
54    print(map4)
55
56    # CHECK: (d0, d1, d2) -> (d2, d0, d1)
57    map5 = AffineMap.get_permutation([2, 0, 1])
58    print(map5)
59
60    assert map1 == AffineMap.get(2, 3, [d1, c2])
61    assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
62    assert map2 == AffineMap.get_constant(2)
63    assert map3 == AffineMap.get_identity(2)
64    assert map4 == AffineMap.get_minor_identity(2, 1)
65
66    try:
67      AffineMap.get(1, 1, [1])
68    except RuntimeError as e:
69      # CHECK: Invalid expression when attempting to create an AffineMap
70      print(e)
71
72    try:
73      AffineMap.get(1, 1, [None])
74    except RuntimeError as e:
75      # CHECK: Invalid expression (None?) when attempting to create an AffineMap
76      print(e)
77
78    try:
79      AffineMap.get_permutation([1, 0, 1])
80    except RuntimeError as e:
81      # CHECK: Invalid permutation when attempting to create an AffineMap
82      print(e)
83
84    try:
85      map3.get_submap([42])
86    except ValueError as e:
87      # CHECK: result position out of bounds
88      print(e)
89
90    try:
91      map3.get_minor_submap(42)
92    except ValueError as e:
93      # CHECK: number of results out of bounds
94      print(e)
95
96    try:
97      map3.get_major_submap(42)
98    except ValueError as e:
99      # CHECK: number of results out of bounds
100      print(e)
101
102
103run(testAffineMapGet)
104
105
106# CHECK-LABEL: TEST: testAffineMapDerive
107def testAffineMapDerive():
108  with Context() as ctx:
109    map5 = AffineMap.get_identity(5)
110
111    # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
112    map123 = map5.get_submap([1, 2, 3])
113    print(map123)
114
115    # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
116    map01 = map5.get_major_submap(2)
117    print(map01)
118
119    # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
120    map34 = map5.get_minor_submap(2)
121    print(map34)
122
123
124run(testAffineMapDerive)
125
126
127# CHECK-LABEL: TEST: testAffineMapProperties
128def testAffineMapProperties():
129  with Context():
130    d0 = AffineDimExpr.get(0)
131    d1 = AffineDimExpr.get(1)
132    d2 = AffineDimExpr.get(2)
133    map1 = AffineMap.get(3, 0, [d2, d0])
134    map2 = AffineMap.get(3, 0, [d2, d0, d1])
135    map3 = AffineMap.get(3, 1, [d2, d0, d1])
136    # CHECK: False
137    print(map1.is_permutation)
138    # CHECK: True
139    print(map1.is_projected_permutation)
140    # CHECK: True
141    print(map2.is_permutation)
142    # CHECK: True
143    print(map2.is_projected_permutation)
144    # CHECK: False
145    print(map3.is_permutation)
146    # CHECK: False
147    print(map3.is_projected_permutation)
148
149
150run(testAffineMapProperties)
151
152
153# CHECK-LABEL: TEST: testAffineMapExprs
154def testAffineMapExprs():
155  with Context():
156    d0 = AffineDimExpr.get(0)
157    d1 = AffineDimExpr.get(1)
158    d2 = AffineDimExpr.get(2)
159    map3 = AffineMap.get(3, 1, [d2, d0, d1])
160
161    # CHECK: 3
162    print(map3.n_dims)
163    # CHECK: 4
164    print(map3.n_inputs)
165    # CHECK: 1
166    print(map3.n_symbols)
167    assert map3.n_inputs == map3.n_dims + map3.n_symbols
168
169    # CHECK: 3
170    print(len(map3.results))
171    for expr in map3.results:
172      # CHECK: d2
173      # CHECK: d0
174      # CHECK: d1
175      print(expr)
176    for expr in map3.results[-1:-4:-1]:
177      # CHECK: d1
178      # CHECK: d0
179      # CHECK: d2
180      print(expr)
181    assert list(map3.results) == [d2, d0, d1]
182
183
184run(testAffineMapExprs)
185
186
187# CHECK-LABEL: TEST: testCompressUnusedSymbols
188def testCompressUnusedSymbols():
189  with Context() as ctx:
190    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
191                  AffineDimExpr.get(2))
192    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
193                  AffineSymbolExpr.get(2))
194    maps = [
195        AffineMap.get(3, 3, [d2, d0, d1]),
196        AffineMap.get(3, 3, [d2, d0 + s2, d1]),
197        AffineMap.get(3, 3, [d1, d2, d0])
198    ]
199
200    compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
201
202    #      CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
203    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
204    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
205    print(maps)
206
207    #      CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
208    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
209    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
210    print(compressed_maps)
211
212
213run(testCompressUnusedSymbols)
214
215
216# CHECK-LABEL: TEST: testReplace
217def testReplace():
218  with Context() as ctx:
219    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
220                  AffineDimExpr.get(2))
221    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
222                  AffineSymbolExpr.get(2))
223    map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
224
225    replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
226    replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
227    replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
228
229    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
230    print(replace0)
231
232    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
233    print(replace1)
234
235    # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
236    print(replace3)
237
238
239run(testReplace)
240