1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2
3import gc, sys
4from mlir.ir import *
5from mlir.passmanager import *
6
7# Log everything to stderr and flush so that we have a unified stream to match
8# errors/info emitted by MLIR to stderr.
9def log(*args):
10  print(*args, file=sys.stderr)
11  sys.stderr.flush()
12
13def run(f):
14  log("\nTEST:", f.__name__)
15  f()
16  gc.collect()
17  assert Context._get_live_count() == 0
18
19# Verify capsule interop.
20# CHECK-LABEL: TEST: testCapsule
21def testCapsule():
22  with Context():
23    pm = PassManager()
24    pm_capsule = pm._CAPIPtr
25    assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
26    pm._testing_release()
27    pm1 = PassManager._CAPICreate(pm_capsule)
28    assert pm1 is not None  # And does not crash.
29run(testCapsule)
30
31
32# Verify successful round-trip.
33# CHECK-LABEL: TEST: testParseSuccess
34def testParseSuccess():
35  with Context():
36    # An unregistered pass should not parse.
37    try:
38      pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))")
39      # TODO: this error should be propagate to Python but the C API does not help right now.
40      # CHECK: error: 'not-existing-pass' does not refer to a registered pass or pass pipeline
41    except ValueError as e:
42      # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(not-existing-pass{json=false}))'.
43      log("ValueError exception:", e)
44    else:
45      log("Exception not produced")
46
47    # A registered pass should parse successfully.
48    pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
49    # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
50    log("Roundtrip: ", pm)
51run(testParseSuccess)
52
53# Verify failure on unregistered pass.
54# CHECK-LABEL: TEST: testParseFail
55def testParseFail():
56  with Context():
57    try:
58      pm = PassManager.parse("unknown-pass")
59    except ValueError as e:
60      # CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'.
61      log("ValueError exception:", e)
62    else:
63      log("Exception not produced")
64run(testParseFail)
65
66
67# Verify failure on incorrect level of nesting.
68# CHECK-LABEL: TEST: testInvalidNesting
69def testInvalidNesting():
70  with Context():
71    try:
72      pm = PassManager.parse("func.func(normalize-memrefs)")
73    except ValueError as e:
74      # CHECK: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
75      # CHECK: ValueError exception: invalid pass pipeline 'func.func(normalize-memrefs)'.
76      log("ValueError exception:", e)
77    else:
78      log("Exception not produced")
79run(testInvalidNesting)
80
81
82# Verify that a pass manager can execute on IR
83# CHECK-LABEL: TEST: testRun
84def testRunPipeline():
85  with Context():
86    pm = PassManager.parse("print-op-stats{json=false}")
87    module = Module.parse(r"""func.func @successfulParse() { return }""")
88    pm.run(module)
89# CHECK: Operations encountered:
90# CHECK: builtin.module    , 1
91# CHECK: func.func      , 1
92# CHECK: func.return        , 1
93run(testRunPipeline)
94