1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14# CHECK-LABEL: TEST: testAffineExprCapsule
15@run
16def testAffineExprCapsule():
17  with Context() as ctx:
18    affine_expr = AffineExpr.get_constant(42)
19
20  affine_expr_capsule = affine_expr._CAPIPtr
21  # CHECK: capsule object
22  # CHECK: mlir.ir.AffineExpr._CAPIPtr
23  print(affine_expr_capsule)
24
25  affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
26  assert affine_expr == affine_expr_2
27  assert affine_expr_2.context == ctx
28
29
30# CHECK-LABEL: TEST: testAffineExprEq
31@run
32def testAffineExprEq():
33  with Context():
34    a1 = AffineExpr.get_constant(42)
35    a2 = AffineExpr.get_constant(42)
36    a3 = AffineExpr.get_constant(43)
37    # CHECK: True
38    print(a1 == a1)
39    # CHECK: True
40    print(a1 == a2)
41    # CHECK: False
42    print(a1 == a3)
43    # CHECK: False
44    print(a1 == None)
45    # CHECK: False
46    print(a1 == "foo")
47
48
49# CHECK-LABEL: TEST: testAffineExprContext
50@run
51def testAffineExprContext():
52  with Context():
53    a1 = AffineExpr.get_constant(42)
54  with Context():
55    a2 = AffineExpr.get_constant(42)
56
57  # CHECK: False
58  print(a1 == a2)
59
60run(testAffineExprContext)
61
62
63# CHECK-LABEL: TEST: testAffineExprConstant
64@run
65def testAffineExprConstant():
66  with Context():
67    a1 = AffineExpr.get_constant(42)
68    # CHECK: 42
69    print(a1.value)
70    # CHECK: 42
71    print(a1)
72
73    a2 = AffineConstantExpr.get(42)
74    # CHECK: 42
75    print(a2.value)
76    # CHECK: 42
77    print(a2)
78
79    assert a1 == a2
80
81
82# CHECK-LABEL: TEST: testAffineExprDim
83@run
84def testAffineExprDim():
85  with Context():
86    d1 = AffineExpr.get_dim(1)
87    d11 = AffineDimExpr.get(1)
88    d2 = AffineDimExpr.get(2)
89
90    # CHECK: 1
91    print(d1.position)
92    # CHECK: d1
93    print(d1)
94
95    # CHECK: 2
96    print(d2.position)
97    # CHECK: d2
98    print(d2)
99
100    assert d1 == d11
101    assert d1 != d2
102
103
104# CHECK-LABEL: TEST: testAffineExprSymbol
105@run
106def testAffineExprSymbol():
107  with Context():
108    s1 = AffineExpr.get_symbol(1)
109    s11 = AffineSymbolExpr.get(1)
110    s2 = AffineSymbolExpr.get(2)
111
112    # CHECK: 1
113    print(s1.position)
114    # CHECK: s1
115    print(s1)
116
117    # CHECK: 2
118    print(s2.position)
119    # CHEKC: s2
120    print(s2)
121
122    assert s1 == s11
123    assert s1 != s2
124
125
126# CHECK-LABEL: TEST: testAffineAddExpr
127@run
128def testAffineAddExpr():
129  with Context():
130    d1 = AffineDimExpr.get(1)
131    d2 = AffineDimExpr.get(2)
132    d12 = AffineExpr.get_add(d1, d2)
133    # CHECK: d1 + d2
134    print(d12)
135
136    d12op = d1 + d2
137    # CHECK: d1 + d2
138    print(d12op)
139
140    d1cst_op = d1 + 2
141    # CHECK: d1 + 2
142    print(d1cst_op)
143
144    d1cst_op2 = 2 + d1
145    # CHECK: d1 + 2
146    print(d1cst_op2)
147
148    assert d12 == d12op
149    assert d12.lhs == d1
150    assert d12.rhs == d2
151
152
153# CHECK-LABEL: TEST: testAffineMulExpr
154@run
155def testAffineMulExpr():
156  with Context():
157    d1 = AffineDimExpr.get(1)
158    c2 = AffineConstantExpr.get(2)
159    expr = AffineExpr.get_mul(d1, c2)
160    # CHECK: d1 * 2
161    print(expr)
162
163    # CHECK: d1 * 2
164    op = d1 * c2
165    print(op)
166
167    # CHECK: d1 * 2
168    op_cst = d1 * 2
169    print(op_cst)
170
171    # CHECK: d1 * 2
172    op_cst2 = 2 * d1
173    print(op_cst2)
174
175    assert expr == op
176    assert expr == op_cst
177    assert expr.lhs == d1
178    assert expr.rhs == c2
179
180
181# CHECK-LABEL: TEST: testAffineModExpr
182@run
183def testAffineModExpr():
184  with Context():
185    d1 = AffineDimExpr.get(1)
186    c2 = AffineConstantExpr.get(2)
187    expr = AffineExpr.get_mod(d1, c2)
188    # CHECK: d1 mod 2
189    print(expr)
190
191    # CHECK: d1 mod 2
192    op = d1 % c2
193    print(op)
194
195    # CHECK: d1 mod 2
196    op_cst = d1 % 2
197    print(op_cst)
198
199    # CHECK: 2 mod d1
200    print(2 % d1)
201
202    assert expr == op
203    assert expr == op_cst
204    assert expr.lhs == d1
205    assert expr.rhs == c2
206
207    expr2 = AffineExpr.get_mod(c2, d1)
208    expr3 = AffineExpr.get_mod(2, d1)
209    expr4 = AffineExpr.get_mod(d1, 2)
210
211    # CHECK: 2 mod d1
212    print(expr2)
213    # CHECK: 2 mod d1
214    print(expr3)
215    # CHECK: d1 mod 2
216    print(expr4)
217
218    assert expr2 == expr3
219    assert expr4 == expr
220
221
222# CHECK-LABEL: TEST: testAffineFloorDivExpr
223@run
224def testAffineFloorDivExpr():
225  with Context():
226    d1 = AffineDimExpr.get(1)
227    c2 = AffineConstantExpr.get(2)
228    expr = AffineExpr.get_floor_div(d1, c2)
229    # CHECK: d1 floordiv 2
230    print(expr)
231
232    assert expr.lhs == d1
233    assert expr.rhs == c2
234
235    expr2 = AffineExpr.get_floor_div(c2, d1)
236    expr3 = AffineExpr.get_floor_div(2, d1)
237    expr4 = AffineExpr.get_floor_div(d1, 2)
238
239    # CHECK: 2 floordiv d1
240    print(expr2)
241    # CHECK: 2 floordiv d1
242    print(expr3)
243    # CHECK: d1 floordiv 2
244    print(expr4)
245
246    assert expr2 == expr3
247    assert expr4 == expr
248
249
250# CHECK-LABEL: TEST: testAffineCeilDivExpr
251@run
252def testAffineCeilDivExpr():
253  with Context():
254    d1 = AffineDimExpr.get(1)
255    c2 = AffineConstantExpr.get(2)
256    expr = AffineExpr.get_ceil_div(d1, c2)
257    # CHECK: d1 ceildiv 2
258    print(expr)
259
260    assert expr.lhs == d1
261    assert expr.rhs == c2
262
263    expr2 = AffineExpr.get_ceil_div(c2, d1)
264    expr3 = AffineExpr.get_ceil_div(2, d1)
265    expr4 = AffineExpr.get_ceil_div(d1, 2)
266
267    # CHECK: 2 ceildiv d1
268    print(expr2)
269    # CHECK: 2 ceildiv d1
270    print(expr3)
271    # CHECK: d1 ceildiv 2
272    print(expr4)
273
274    assert expr2 == expr3
275    assert expr4 == expr
276
277
278# CHECK-LABEL: TEST: testAffineExprSub
279@run
280def testAffineExprSub():
281  with Context():
282    d1 = AffineDimExpr.get(1)
283    d2 = AffineDimExpr.get(2)
284    expr = d1 - d2
285    # CHECK: d1 - d2
286    print(expr)
287
288    assert expr.lhs == d1
289    rhs = AffineMulExpr(expr.rhs)
290    # CHECK: d2
291    print(rhs.lhs)
292    # CHECK: -1
293    print(rhs.rhs)
294
295    # CHECK: d1 - 42
296    print(d1 - 42)
297    # CHECK: -d1 + 42
298    print(42 - d1)
299
300    c42 = AffineConstantExpr.get(42)
301    assert d1 - 42 == d1 - c42
302    assert 42 - d1 == c42 - d1
303
304# CHECK-LABEL: TEST: testClassHierarchy
305@run
306def testClassHierarchy():
307  with Context():
308    d1 = AffineDimExpr.get(1)
309    c2 = AffineConstantExpr.get(2)
310    add = AffineAddExpr.get(d1, c2)
311    mul = AffineMulExpr.get(d1, c2)
312    mod = AffineModExpr.get(d1, c2)
313    floor_div = AffineFloorDivExpr.get(d1, c2)
314    ceil_div = AffineCeilDivExpr.get(d1, c2)
315
316    # CHECK: False
317    print(isinstance(d1, AffineBinaryExpr))
318    # CHECK: False
319    print(isinstance(c2, AffineBinaryExpr))
320    # CHECK: True
321    print(isinstance(add, AffineBinaryExpr))
322    # CHECK: True
323    print(isinstance(mul, AffineBinaryExpr))
324    # CHECK: True
325    print(isinstance(mod, AffineBinaryExpr))
326    # CHECK: True
327    print(isinstance(floor_div, AffineBinaryExpr))
328    # CHECK: True
329    print(isinstance(ceil_div, AffineBinaryExpr))
330
331    try:
332      AffineBinaryExpr(d1)
333    except ValueError as e:
334      # CHECK: Cannot cast affine expression to AffineBinaryExpr
335      print(e)
336
337    try:
338      AffineBinaryExpr(c2)
339    except ValueError as e:
340      # CHECK: Cannot cast affine expression to AffineBinaryExpr
341      print(e)
342
343# CHECK-LABEL: TEST: testIsInstance
344@run
345def testIsInstance():
346  with Context():
347    d1 = AffineDimExpr.get(1)
348    c2 = AffineConstantExpr.get(2)
349    add = AffineAddExpr.get(d1, c2)
350    mul = AffineMulExpr.get(d1, c2)
351
352    # CHECK: True
353    print(AffineDimExpr.isinstance(d1))
354    # CHECK: False
355    print(AffineConstantExpr.isinstance(d1))
356    # CHECK: True
357    print(AffineConstantExpr.isinstance(c2))
358    # CHECK: False
359    print(AffineMulExpr.isinstance(c2))
360    # CHECK: True
361    print(AffineAddExpr.isinstance(add))
362    # CHECK: False
363    print(AffineMulExpr.isinstance(add))
364    # CHECK: True
365    print(AffineMulExpr.isinstance(mul))
366    # CHECK: False
367    print(AffineAddExpr.isinstance(mul))
368
369
370# CHECK-LABEL: TEST: testCompose
371@run
372def testCompose():
373  with Context():
374    # d0 + d2.
375    expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2))
376
377    # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)
378    map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1))
379    map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0))
380    map3 = AffineAddExpr.get(
381        AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)),
382        AffineDimExpr.get(2))
383    map = AffineMap.get(3, 2, [map1, map2, map3])
384
385    # CHECK: d0 + s1 + d0 + d1 + d2
386    print(expr.compose(map))
387
388
389# CHECK-LABEL: TEST: testHash
390@run
391def testHash():
392  with Context():
393    d0 = AffineDimExpr.get(0)
394    s1 = AffineSymbolExpr.get(1)
395    assert hash(d0) == hash(AffineDimExpr.get(0))
396    assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1))
397
398    dictionary = dict()
399    dictionary[d0] = 0
400    dictionary[s1] = 1
401    assert d0 in dictionary
402    assert s1 in dictionary
403