1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects import pdl
6
7
8def run(f):
9  with Context(), Location.unknown():
10    module = Module.create()
11    with InsertionPoint(module.body):
12      print("\nTEST:", f.__name__)
13      f()
14    print(module)
15  return f
16
17
18@run
19def testSequenceOp():
20  sequence = transform.SequenceOp([pdl.OperationType.get()])
21  with InsertionPoint(sequence.body):
22    transform.YieldOp([sequence.bodyTarget])
23  # CHECK-LABEL: TEST: testSequenceOp
24  # CHECK: = transform.sequence {
25  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
26  # CHECK:   yield %[[ARG0]] : !pdl.operation
27  # CHECK: } : !pdl.operation
28
29
30@run
31def testNestedSequenceOp():
32  sequence = transform.SequenceOp()
33  with InsertionPoint(sequence.body):
34    nested = transform.SequenceOp(sequence.bodyTarget)
35    with InsertionPoint(nested.body):
36      doubly_nested = transform.SequenceOp([pdl.OperationType.get()],
37                                           nested.bodyTarget)
38      with InsertionPoint(doubly_nested.body):
39        transform.YieldOp([doubly_nested.bodyTarget])
40      transform.YieldOp()
41    transform.YieldOp()
42  # CHECK-LABEL: TEST: testNestedSequenceOp
43  # CHECK: transform.sequence {
44  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
45  # CHECK:   sequence %[[ARG0]] {
46  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
47  # CHECK:     = sequence %[[ARG1]] {
48  # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !pdl.operation):
49  # CHECK:       yield %[[ARG2]] : !pdl.operation
50  # CHECK:     } : !pdl.operation
51  # CHECK:   }
52  # CHECK: }
53
54
55@run
56def testTransformPDLOps():
57  withPdl = transform.WithPDLPatternsOp()
58  with InsertionPoint(withPdl.body):
59    sequence = transform.SequenceOp([pdl.OperationType.get()],
60                                    withPdl.bodyTarget)
61    with InsertionPoint(sequence.body):
62      match = transform.PDLMatchOp(sequence.bodyTarget, "pdl_matcher")
63      transform.YieldOp(match)
64  # CHECK-LABEL: TEST: testTransformPDLOps
65  # CHECK: transform.with_pdl_patterns {
66  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
67  # CHECK:   = sequence %[[ARG0]] {
68  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
69  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
70  # CHECK:     yield %[[RES]] : !pdl.operation
71  # CHECK:   } : !pdl.operation
72  # CHECK: }
73
74
75@run
76def testGetClosestIsolatedParentOp():
77  sequence = transform.SequenceOp()
78  with InsertionPoint(sequence.body):
79    transform.GetClosestIsolatedParentOp(sequence.bodyTarget)
80    transform.YieldOp()
81  # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
82  # CHECK: transform.sequence
83  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
84  # CHECK:   = get_closest_isolated_parent %[[ARG1]]
85
86
87@run
88def testMergeHandlesOp():
89  sequence = transform.SequenceOp()
90  with InsertionPoint(sequence.body):
91    transform.MergeHandlesOp([sequence.bodyTarget])
92    transform.YieldOp()
93  # CHECK-LABEL: TEST: testMergeHandlesOp
94  # CHECK: transform.sequence
95  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
96  # CHECK:   = merge_handles %[[ARG1]]
97
98
99@run
100def testReplicateOp():
101  with_pdl = transform.WithPDLPatternsOp()
102  with InsertionPoint(with_pdl.body):
103    sequence = transform.SequenceOp(with_pdl.bodyTarget)
104    with InsertionPoint(sequence.body):
105      m1 = transform.PDLMatchOp(sequence.bodyTarget, "first")
106      m2 = transform.PDLMatchOp(sequence.bodyTarget, "second")
107      transform.ReplicateOp(m1, [m2])
108      transform.YieldOp()
109  # CHECK-LABEL: TEST: testReplicateOp
110  # CHECK: %[[FIRST:.+]] = pdl_match
111  # CHECK: %[[SECOND:.+]] = pdl_match
112  # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
113