1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6
7def run(f):
8  print("\nTEST:", f.__name__)
9  f()
10  gc.collect()
11  assert Context._get_live_count() == 0
12  return f
13
14
15# CHECK-LABEL: TEST: testDialectDescriptor
16@run
17def testDialectDescriptor():
18  ctx = Context()
19  d = ctx.get_dialect_descriptor("std")
20  # CHECK: <DialectDescriptor std>
21  print(d)
22  # CHECK: std
23  print(d.namespace)
24  try:
25    _ = ctx.get_dialect_descriptor("not_existing")
26  except ValueError:
27    pass
28  else:
29    assert False, "Expected exception"
30
31
32# CHECK-LABEL: TEST: testUserDialectClass
33@run
34def testUserDialectClass():
35  ctx = Context()
36  # Access using attribute.
37  d = ctx.dialects.std
38  # Note that the standard dialect namespace prints as ''. Others will print
39  # as "<Dialect %namespace (..."
40  # CHECK: <Dialect (class mlir.dialects._std_ops_gen._Dialect)>
41  print(d)
42  try:
43    _ = ctx.dialects.not_existing
44  except AttributeError:
45    pass
46  else:
47    assert False, "Expected exception"
48
49  # Access using index.
50  d = ctx.dialects["std"]
51  # CHECK: <Dialect (class mlir.dialects._std_ops_gen._Dialect)>
52  print(d)
53  try:
54    _ = ctx.dialects["not_existing"]
55  except IndexError:
56    pass
57  else:
58    assert False, "Expected exception"
59
60  # Using the 'd' alias.
61  d = ctx.d["std"]
62  # CHECK: <Dialect (class mlir.dialects._std_ops_gen._Dialect)>
63  print(d)
64
65
66# CHECK-LABEL: TEST: testCustomOpView
67# This test uses the standard dialect AddFOp as an example of a user op.
68# TODO: Op creation and access is still quite verbose: simplify this test as
69# additional capabilities come online.
70@run
71def testCustomOpView():
72
73  def createInput():
74    op = Operation.create("pytest_dummy.intinput", results=[f32])
75    # TODO: Auto result cast from operation
76    return op.results[0]
77
78  with Context() as ctx, Location.unknown():
79    ctx.allow_unregistered_dialects = True
80    m = Module.create()
81
82    with InsertionPoint(m.body):
83      f32 = F32Type.get()
84      # Create via dialects context collection.
85      input1 = createInput()
86      input2 = createInput()
87      op1 = ctx.dialects.std.AddFOp(input1.type, input1, input2)
88
89      # Create via an import
90      from mlir.dialects.std import AddFOp
91      AddFOp(input1.type, input1, op1.result)
92
93  # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
94  # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
95  # CHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
96  # CHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
97  m.operation.print()
98
99
100# CHECK-LABEL: TEST: testIsRegisteredOperation
101@run
102def testIsRegisteredOperation():
103  ctx = Context()
104
105  # CHECK: std.cond_br: True
106  print(f"std.cond_br: {ctx.is_registered_operation('std.cond_br')}")
107  # CHECK: std.not_existing: False
108  print(f"std.not_existing: {ctx.is_registered_operation('std.not_existing')}")
109