1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14# CHECK-LABEL: TEST: testIntegerSetCapsule
15@run
16def testIntegerSetCapsule():
17  with Context() as ctx:
18    is1 = IntegerSet.get_empty(1, 1, ctx)
19  capsule = is1._CAPIPtr
20  # CHECK: mlir.ir.IntegerSet._CAPIPtr
21  print(capsule)
22  is2 = IntegerSet._CAPICreate(capsule)
23  assert is1 == is2
24  assert is2.context is ctx
25
26
27# CHECK-LABEL: TEST: testIntegerSetGet
28@run
29def testIntegerSetGet():
30  with Context():
31    d0 = AffineDimExpr.get(0)
32    d1 = AffineDimExpr.get(1)
33    s0 = AffineSymbolExpr.get(0)
34    c42 = AffineConstantExpr.get(42)
35
36    # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
37    set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
38    print(set0)
39
40    # CHECK: (d0)[s0] : (1 == 0)
41    set1 = IntegerSet.get_empty(1, 1)
42    print(set1)
43
44    # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
45    set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
46    print(set2)
47
48    try:
49      IntegerSet.get(2, 1, [], [])
50    except ValueError as e:
51      # CHECK: Expected non-empty list of constraints
52      print(e)
53
54    try:
55      IntegerSet.get(2, 1, [d0 - d1], [True, False])
56    except ValueError as e:
57      # CHECK: Expected the number of constraints to match that of equality flags
58      print(e)
59
60    try:
61      IntegerSet.get(2, 1, [0], [True])
62    except RuntimeError as e:
63      # CHECK: Invalid expression when attempting to create an IntegerSet
64      print(e)
65
66    try:
67      IntegerSet.get(2, 1, [None], [True])
68    except RuntimeError as e:
69      # CHECK: Invalid expression (None?) when attempting to create an IntegerSet
70      print(e)
71
72    try:
73      set0.get_replaced([d0], [s0], 1, 1)
74    except ValueError as e:
75      # CHECK: Expected the number of dimension replacement expressions to match that of dimensions
76      print(e)
77
78    try:
79      set0.get_replaced([d0, d1], [s0, s0], 1, 1)
80    except ValueError as e:
81      # CHECK: Expected the number of symbol replacement expressions to match that of symbols
82      print(e)
83
84    try:
85      set0.get_replaced([d0, 1], [s0], 1, 1)
86    except RuntimeError as e:
87      # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
88      print(e)
89
90    try:
91      set0.get_replaced([d0, d1], [None], 1, 1)
92    except RuntimeError as e:
93      # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
94      print(e)
95
96
97# CHECK-LABEL: TEST: testIntegerSetProperties
98@run
99def testIntegerSetProperties():
100  with Context():
101    d0 = AffineDimExpr.get(0)
102    d1 = AffineDimExpr.get(1)
103    s0 = AffineSymbolExpr.get(0)
104    c42 = AffineConstantExpr.get(42)
105
106    set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
107    # CHECK: 2
108    print(set0.n_dims)
109    # CHECK: 1
110    print(set0.n_symbols)
111    # CHECK: 3
112    print(set0.n_inputs)
113    # CHECK: 1
114    print(set0.n_equalities)
115    # CHECK: 2
116    print(set0.n_inequalities)
117
118    # CHECK: 3
119    print(len(set0.constraints))
120
121    # CHECK-DAG: d0 - d1 == 0
122    # CHECK-DAG: s0 - 42 >= 0
123    # CHECK-DAG: -d0 + s0 >= 0
124    for cstr in set0.constraints:
125      print(cstr.expr, end='')
126      print(" == 0" if cstr.is_eq else " >= 0")
127
128
129# CHECK_LABEL: TEST: testHash
130@run
131def testHash():
132  with Context():
133    d0 = AffineDimExpr.get(0)
134    d1 = AffineDimExpr.get(1)
135    set = IntegerSet.get(2, 0, [d0 + d1], [True])
136
137    assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True]))
138
139    dictionary = dict()
140    dictionary[set] = 42
141    assert set in dictionary
142