1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects import pdl
6from mlir.dialects.transform import structured
7
8
9def run(f):
10  with Context(), Location.unknown():
11    module = Module.create()
12    with InsertionPoint(module.body):
13      print("\nTEST:", f.__name__)
14      f()
15    print(module)
16  return f
17
18
19@run
20def testDecompose():
21  sequence = transform.SequenceOp()
22  with InsertionPoint(sequence.body):
23    structured.DecomposeOp(sequence.bodyTarget)
24    transform.YieldOp()
25  # CHECK-LABEL: TEST: testDecompose
26  # CHECK: transform.sequence
27  # CHECK: transform.structured.decompose
28
29
30@run
31def testGeneralize():
32  sequence = transform.SequenceOp()
33  with InsertionPoint(sequence.body):
34    structured.GeneralizeOp(sequence.bodyTarget)
35    transform.YieldOp()
36  # CHECK-LABEL: TEST: testGeneralize
37  # CHECK: transform.sequence
38  # CHECK: transform.structured.generalize
39
40
41@run
42def testInterchange():
43  sequence = transform.SequenceOp()
44  with InsertionPoint(sequence.body):
45    structured.InterchangeOp(
46        sequence.bodyTarget,
47        iterator_interchange=[
48            IntegerAttr.get(IntegerType.get_signless(64), 1), 0
49        ])
50    transform.YieldOp()
51  # CHECK-LABEL: TEST: testInterchange
52  # CHECK: transform.sequence
53  # CHECK: transform.structured.interchange
54  # CHECK: iterator_interchange = [1, 0]
55
56
57@run
58def testPad():
59  sequence = transform.SequenceOp()
60  with InsertionPoint(sequence.body):
61    structured.PadOp(
62        sequence.bodyTarget,
63        padding_values=[FloatAttr.get_f32(42.0)],
64        padding_dimensions=[1],
65        transpose_paddings=[[1, 0]])
66    transform.YieldOp()
67  # CHECK-LABEL: TEST: testPad
68  # CHECK: transform.sequence
69  # CHECK: transform.structured.pad
70  # CHECK-DAG: padding_values = [4.200000e+01 : f32]
71  # CHECK-DAG: padding_dimensions = [1]
72  # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
73  # CHECK-DAG: hoist_paddings = []
74  # CHECK-DAG: pack_paddings = []
75
76
77@run
78def testScalarize():
79  sequence = transform.SequenceOp()
80  with InsertionPoint(sequence.body):
81    structured.ScalarizeOp(sequence.bodyTarget)
82    transform.YieldOp()
83  # CHECK-LABEL: TEST: testScalarize
84  # CHECK: transform.structured.scalarize
85
86
87@run
88def testSplit():
89  sequence = transform.SequenceOp()
90  with InsertionPoint(sequence.body):
91    split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
92    structured.SplitOp(
93        split.results[0], dimension=3, split_point=split.results[1])
94    transform.YieldOp()
95  # CHECK-LABEL: TEST: testSplit
96  # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
97  # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
98
99
100@run
101def testTileCompact():
102  sequence = transform.SequenceOp()
103  with InsertionPoint(sequence.body):
104    structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
105    transform.YieldOp()
106  # CHECK-LABEL: TEST: testTileCompact
107  # CHECK: transform.sequence
108  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
109  # CHECK-DAG: interchange = [0, 1]
110  # CHECK-DAG: sizes = [4, 8]
111
112
113@run
114def testTileAttributes():
115  sequence = transform.SequenceOp()
116  attr = ArrayAttr.get(
117      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
118  ichange = ArrayAttr.get(
119      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
120  with InsertionPoint(sequence.body):
121    structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
122    transform.YieldOp()
123  # CHECK-LABEL: TEST: testTileAttributes
124  # CHECK: transform.sequence
125  # CHECK: structured.tile
126  # CHECK-DAG: interchange = [0, 1]
127  # CHECK-DAG: sizes = [4, 8]
128
129
130@run
131def testTileZero():
132  sequence = transform.SequenceOp()
133  with InsertionPoint(sequence.body):
134    structured.TileOp(
135        sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
136    transform.YieldOp()
137  # CHECK-LABEL: TEST: testTileZero
138  # CHECK: transform.sequence
139  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
140  # CHECK-DAG: interchange = [0, 1, 2, 3]
141  # CHECK-DAG: sizes = [4, 0, 2, 0]
142
143
144@run
145def testVectorize():
146  sequence = transform.SequenceOp()
147  with InsertionPoint(sequence.body):
148    structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
149    transform.YieldOp()
150  # CHECK-LABEL: TEST: testVectorize
151  # CHECK: transform.sequence
152  # CHECK: = transform.structured.vectorize
153  # CHECK: vectorize_padding = true
154