1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11
12
13# CHECK-LABEL: TEST: testAffineMapCapsule
14def testAffineMapCapsule():
15  with Context() as ctx:
16    am1 = AffineMap.get_empty(ctx)
17  # CHECK: mlir.ir.AffineMap._CAPIPtr
18  affine_map_capsule = am1._CAPIPtr
19  print(affine_map_capsule)
20  am2 = AffineMap._CAPICreate(affine_map_capsule)
21  assert am2 == am1
22  assert am2.context is ctx
23
24run(testAffineMapCapsule)
25
26
27# CHECK-LABEL: TEST: testAffineMapGet
28def testAffineMapGet():
29  with Context() as ctx:
30    d0 = AffineDimExpr.get(0)
31    d1 = AffineDimExpr.get(1)
32    c2 = AffineConstantExpr.get(2)
33
34    # CHECK: (d0, d1)[s0, s1, s2] -> ()
35    map0 = AffineMap.get(2, 3, [])
36    print(map0)
37
38    # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
39    map1 = AffineMap.get(2, 3, [d1, c2])
40    print(map1)
41
42    # CHECK: () -> (2)
43    map2 = AffineMap.get(0, 0, [c2])
44    print(map2)
45
46    # CHECK: (d0, d1) -> (d0, d1)
47    map3 = AffineMap.get(2, 0, [d0, d1])
48    print(map3)
49
50    # CHECK: (d0, d1) -> (d1)
51    map4 = AffineMap.get(2, 0, [d1])
52    print(map4)
53
54    # CHECK: (d0, d1, d2) -> (d2, d0, d1)
55    map5 = AffineMap.get_permutation([2, 0, 1])
56    print(map5)
57
58    assert map1 == AffineMap.get(2, 3, [d1, c2])
59    assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
60    assert map2 == AffineMap.get_constant(2)
61    assert map3 == AffineMap.get_identity(2)
62    assert map4 == AffineMap.get_minor_identity(2, 1)
63
64    try:
65      AffineMap.get(1, 1, [1])
66    except RuntimeError as e:
67      # CHECK: Invalid expression when attempting to create an AffineMap
68      print(e)
69
70    try:
71      AffineMap.get(1, 1, [None])
72    except RuntimeError as e:
73      # CHECK: Invalid expression (None?) when attempting to create an AffineMap
74      print(e)
75
76    try:
77      AffineMap.get_permutation([1, 0, 1])
78    except RuntimeError as e:
79      # CHECK: Invalid permutation when attempting to create an AffineMap
80      print(e)
81
82    try:
83      map3.get_submap([42])
84    except ValueError as e:
85      # CHECK: result position out of bounds
86      print(e)
87
88    try:
89      map3.get_minor_submap(42)
90    except ValueError as e:
91      # CHECK: number of results out of bounds
92      print(e)
93
94    try:
95      map3.get_major_submap(42)
96    except ValueError as e:
97      # CHECK: number of results out of bounds
98      print(e)
99
100run(testAffineMapGet)
101
102
103# CHECK-LABEL: TEST: testAffineMapDerive
104def testAffineMapDerive():
105  with Context() as ctx:
106    map5 = AffineMap.get_identity(5)
107
108    # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
109    map123 = map5.get_submap([1, 2, 3])
110    print(map123)
111
112    # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
113    map01 = map5.get_major_submap(2)
114    print(map01)
115
116    # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
117    map34 = map5.get_minor_submap(2)
118    print(map34)
119
120run(testAffineMapDerive)
121
122
123# CHECK-LABEL: TEST: testAffineMapProperties
124def testAffineMapProperties():
125  with Context():
126    d0 = AffineDimExpr.get(0)
127    d1 = AffineDimExpr.get(1)
128    d2 = AffineDimExpr.get(2)
129    map1 = AffineMap.get(3, 0, [d2, d0])
130    map2 = AffineMap.get(3, 0, [d2, d0, d1])
131    map3 = AffineMap.get(3, 1, [d2, d0, d1])
132    # CHECK: False
133    print(map1.is_permutation)
134    # CHECK: True
135    print(map1.is_projected_permutation)
136    # CHECK: True
137    print(map2.is_permutation)
138    # CHECK: True
139    print(map2.is_projected_permutation)
140    # CHECK: False
141    print(map3.is_permutation)
142    # CHECK: False
143    print(map3.is_projected_permutation)
144
145run(testAffineMapProperties)
146
147
148# CHECK-LABEL: TEST: testAffineMapExprs
149def testAffineMapExprs():
150  with Context():
151    d0 = AffineDimExpr.get(0)
152    d1 = AffineDimExpr.get(1)
153    d2 = AffineDimExpr.get(2)
154    map3 = AffineMap.get(3, 1, [d2, d0, d1])
155
156    # CHECK: 3
157    print(map3.n_dims)
158    # CHECK: 4
159    print(map3.n_inputs)
160    # CHECK: 1
161    print(map3.n_symbols)
162    assert map3.n_inputs == map3.n_dims + map3.n_symbols
163
164    # CHECK: 3
165    print(len(map3.results))
166    for expr in map3.results:
167      # CHECK: d2
168      # CHECK: d0
169      # CHECK: d1
170      print(expr)
171    for expr in map3.results[-1:-4:-1]:
172      # CHECK: d1
173      # CHECK: d0
174      # CHECK: d2
175      print(expr)
176    assert list(map3.results) == [d2, d0, d1]
177
178run(testAffineMapExprs)
179
180# CHECK-LABEL: TEST: testCompressUnusedSymbols
181def testCompressUnusedSymbols():
182  with Context() as ctx:
183    d0, d1, d2 = (
184      AffineDimExpr.get(0),
185      AffineDimExpr.get(1),
186      AffineDimExpr.get(2))
187    s0, s1, s2 = (
188      AffineSymbolExpr.get(0),
189      AffineSymbolExpr.get(1),
190      AffineSymbolExpr.get(2))
191    maps = [
192        AffineMap.get(3, 3, [d2, d0, d1]),
193        AffineMap.get(3, 3, [d2, d0 + s2, d1]),
194        AffineMap.get(3, 3, [d1, d2, d0])]
195
196    compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
197
198    #      CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
199    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
200    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
201    print(maps)
202
203    #      CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
204    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
205    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
206    print(compressed_maps)
207
208
209run(testCompressUnusedSymbols)
210