1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects import pdl
6from mlir.dialects.transform import structured
7
8
9def run(f):
10  with Context(), Location.unknown():
11    module = Module.create()
12    with InsertionPoint(module.body):
13      print("\nTEST:", f.__name__)
14      f()
15    print(module)
16  return f
17
18
19@run
20def testDecompose():
21  sequence = transform.SequenceOp()
22  with InsertionPoint(sequence.body):
23    structured.DecomposeOp(sequence.bodyTarget)
24    transform.YieldOp()
25  # CHECK-LABEL: TEST: testDecompose
26  # CHECK: transform.sequence
27  # CHECK: transform.structured.decompose
28
29
30@run
31def testGeneralize():
32  sequence = transform.SequenceOp()
33  with InsertionPoint(sequence.body):
34    structured.GeneralizeOp(sequence.bodyTarget)
35    transform.YieldOp()
36  # CHECK-LABEL: TEST: testGeneralize
37  # CHECK: transform.sequence
38  # CHECK: transform.structured.generalize
39
40
41@run
42def testInterchange():
43  sequence = transform.SequenceOp()
44  with InsertionPoint(sequence.body):
45    structured.InterchangeOp(
46        sequence.bodyTarget,
47        iterator_interchange=[
48            IntegerAttr.get(IntegerType.get_signless(64), 1), 0
49        ])
50    transform.YieldOp()
51  # CHECK-LABEL: TEST: testInterchange
52  # CHECK: transform.sequence
53  # CHECK: transform.structured.interchange
54  # CHECK: iterator_interchange = [1, 0]
55
56
57@run
58def testPad():
59  sequence = transform.SequenceOp()
60  with InsertionPoint(sequence.body):
61    structured.PadOp(
62        sequence.bodyTarget,
63        padding_values=[FloatAttr.get_f32(42.0)],
64        padding_dimensions=[1],
65        transpose_paddings=[[1, 0]])
66    transform.YieldOp()
67  # CHECK-LABEL: TEST: testPad
68  # CHECK: transform.sequence
69  # CHECK: transform.structured.pad
70  # CHECK-DAG: padding_values = [4.200000e+01 : f32]
71  # CHECK-DAG: padding_dimensions = [1]
72  # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
73  # CHECK-DAG: hoist_paddings = []
74  # CHECK-DAG: pack_paddings = []
75
76
77@run
78def testScalarize():
79  sequence = transform.SequenceOp()
80  with InsertionPoint(sequence.body):
81    structured.ScalarizeOp(sequence.bodyTarget)
82    transform.YieldOp()
83  # CHECK-LABEL: TEST: testScalarize
84  # CHECK: transform.structured.scalarize
85
86
87@run
88def testSplit():
89  sequence = transform.SequenceOp()
90  with InsertionPoint(sequence.body):
91    split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
92    structured.SplitOp(
93        split.results[0], dimension=3, split_point=split.results[1])
94    transform.YieldOp()
95  # CHECK-LABEL: TEST: testSplit
96  # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
97  # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
98
99
100@run
101def testTileCompact():
102  sequence = transform.SequenceOp()
103  with InsertionPoint(sequence.body):
104    structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
105    transform.YieldOp()
106  # CHECK-LABEL: TEST: testTileCompact
107  # CHECK: transform.sequence
108  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
109  # CHECK: interchange = [0, 1]
110
111
112@run
113def testTileAttributes():
114  sequence = transform.SequenceOp()
115  attr = ArrayAttr.get(
116      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
117  ichange = ArrayAttr.get(
118      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
119  with InsertionPoint(sequence.body):
120    structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
121    transform.YieldOp()
122  # CHECK-LABEL: TEST: testTileAttributes
123  # CHECK: transform.sequence
124  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
125  # CHECK: interchange = [0, 1]
126
127
128@run
129def testTileZero():
130  sequence = transform.SequenceOp()
131  with InsertionPoint(sequence.body):
132    structured.TileOp(
133        sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
134    transform.YieldOp()
135  # CHECK-LABEL: TEST: testTileZero
136  # CHECK: transform.sequence
137  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
138  # CHECK: interchange = [0, 1, 2, 3]
139
140
141@run
142def testTileDynamic():
143  with_pdl = transform.WithPDLPatternsOp()
144  with InsertionPoint(with_pdl.body):
145    sequence = transform.SequenceOp(with_pdl.bodyTarget)
146    with InsertionPoint(sequence.body):
147      m1 = transform.PDLMatchOp(sequence.bodyTarget, "first")
148      m2 = transform.PDLMatchOp(sequence.bodyTarget, "second")
149      structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
150      transform.YieldOp()
151  # CHECK-LABEL: TEST: testTileDynamic
152  # CHECK: %[[FIRST:.+]] = pdl_match
153  # CHECK: %[[SECOND:.+]] = pdl_match
154  # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
155
156
157@run
158def testVectorize():
159  sequence = transform.SequenceOp()
160  with InsertionPoint(sequence.body):
161    structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
162    transform.YieldOp()
163  # CHECK-LABEL: TEST: testVectorize
164  # CHECK: transform.sequence
165  # CHECK: = transform.structured.vectorize
166  # CHECK: vectorize_padding = true
167