1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects import pdl
6from mlir.dialects.transform import structured
7
8
9def run(f):
10  with Context(), Location.unknown():
11    module = Module.create()
12    with InsertionPoint(module.body):
13      print("\nTEST:", f.__name__)
14      f()
15    print(module)
16  return f
17
18
19@run
20def testDecompose():
21  sequence = transform.SequenceOp()
22  with InsertionPoint(sequence.body):
23    structured.DecomposeOp(sequence.bodyTarget)
24    transform.YieldOp()
25  # CHECK-LABEL: TEST: testDecompose
26  # CHECK: transform.sequence
27  # CHECK: transform.structured.decompose
28
29
30@run
31def testGeneralize():
32  sequence = transform.SequenceOp()
33  with InsertionPoint(sequence.body):
34    structured.GeneralizeOp(sequence.bodyTarget)
35    transform.YieldOp()
36  # CHECK-LABEL: TEST: testGeneralize
37  # CHECK: transform.sequence
38  # CHECK: transform.structured.generalize
39
40
41@run
42def testInterchange():
43  sequence = transform.SequenceOp()
44  with InsertionPoint(sequence.body):
45    structured.InterchangeOp(
46        sequence.bodyTarget,
47        iterator_interchange=[
48            IntegerAttr.get(IntegerType.get_signless(64), 1), 0
49        ])
50    transform.YieldOp()
51  # CHECK-LABEL: TEST: testInterchange
52  # CHECK: transform.sequence
53  # CHECK: transform.structured.interchange
54  # CHECK: iterator_interchange = [1, 0]
55
56
57@run
58def testMultitileSizes():
59  sequence = transform.SequenceOp()
60  with InsertionPoint(sequence.body):
61    structured.MultiTileSizesOp(
62        sequence.bodyTarget, dimension=1, target_size=42)
63    transform.YieldOp()
64  # CHECK-LABEL: TEST: testMultitileSizes
65  # CHECK: transform.sequence
66  # CHECK: transform.structured.multitile_sizes
67  # CHECK-DAG: dimension = 1
68  # CHECK-DAG: target_size = 42
69
70
71@run
72def testPad():
73  sequence = transform.SequenceOp()
74  with InsertionPoint(sequence.body):
75    structured.PadOp(
76        sequence.bodyTarget,
77        padding_values=[FloatAttr.get_f32(42.0)],
78        padding_dimensions=[1],
79        transpose_paddings=[[1, 0]])
80    transform.YieldOp()
81  # CHECK-LABEL: TEST: testPad
82  # CHECK: transform.sequence
83  # CHECK: transform.structured.pad
84  # CHECK-DAG: padding_values = [4.200000e+01 : f32]
85  # CHECK-DAG: padding_dimensions = [1]
86  # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
87  # CHECK-DAG: hoist_paddings = []
88  # CHECK-DAG: pack_paddings = []
89
90
91@run
92def testScalarize():
93  sequence = transform.SequenceOp()
94  with InsertionPoint(sequence.body):
95    structured.ScalarizeOp(sequence.bodyTarget)
96    transform.YieldOp()
97  # CHECK-LABEL: TEST: testScalarize
98  # CHECK: transform.structured.scalarize
99
100
101@run
102def testSplit():
103  sequence = transform.SequenceOp()
104  with InsertionPoint(sequence.body):
105    split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
106    structured.SplitOp(
107        split.results[0], dimension=3, split_point=split.results[1])
108    transform.YieldOp()
109  # CHECK-LABEL: TEST: testSplit
110  # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
111  # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
112
113
114@run
115def testTileCompact():
116  sequence = transform.SequenceOp()
117  with InsertionPoint(sequence.body):
118    structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
119    transform.YieldOp()
120  # CHECK-LABEL: TEST: testTileCompact
121  # CHECK: transform.sequence
122  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
123  # CHECK: interchange = [0, 1]
124
125
126@run
127def testTileAttributes():
128  sequence = transform.SequenceOp()
129  attr = ArrayAttr.get(
130      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
131  ichange = ArrayAttr.get(
132      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
133  with InsertionPoint(sequence.body):
134    structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
135    transform.YieldOp()
136  # CHECK-LABEL: TEST: testTileAttributes
137  # CHECK: transform.sequence
138  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
139  # CHECK: interchange = [0, 1]
140
141
142@run
143def testTileZero():
144  sequence = transform.SequenceOp()
145  with InsertionPoint(sequence.body):
146    structured.TileOp(
147        sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
148    transform.YieldOp()
149  # CHECK-LABEL: TEST: testTileZero
150  # CHECK: transform.sequence
151  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
152  # CHECK: interchange = [0, 1, 2, 3]
153
154
155@run
156def testTileDynamic():
157  with_pdl = transform.WithPDLPatternsOp()
158  with InsertionPoint(with_pdl.body):
159    sequence = transform.SequenceOp(with_pdl.bodyTarget)
160    with InsertionPoint(sequence.body):
161      m1 = transform.PDLMatchOp(sequence.bodyTarget, "first")
162      m2 = transform.PDLMatchOp(sequence.bodyTarget, "second")
163      structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
164      transform.YieldOp()
165  # CHECK-LABEL: TEST: testTileDynamic
166  # CHECK: %[[FIRST:.+]] = pdl_match
167  # CHECK: %[[SECOND:.+]] = pdl_match
168  # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
169
170
171@run
172def testVectorize():
173  sequence = transform.SequenceOp()
174  with InsertionPoint(sequence.body):
175    structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
176    transform.YieldOp()
177  # CHECK-LABEL: TEST: testVectorize
178  # CHECK: transform.sequence
179  # CHECK: = transform.structured.vectorize
180  # CHECK: vectorize_padding = true
181