1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects import pdl
6from mlir.dialects.transform import structured
7
8
9def run(f):
10  with Context(), Location.unknown():
11    module = Module.create()
12    with InsertionPoint(module.body):
13      print("\nTEST:", f.__name__)
14      f()
15    print(module)
16  return f
17
18
19@run
20def testInterchange():
21  sequence = transform.SequenceOp()
22  with InsertionPoint(sequence.body):
23    structured.InterchangeOp(
24        sequence.bodyTarget,
25        iterator_interchange=[
26            IntegerAttr.get(IntegerType.get_signless(64), 1), 0
27        ])
28    transform.YieldOp()
29  # CHECK-LABEL: TEST: testInterchange
30  # CHECK: transform.sequence
31  # CHECK: transform.structured.interchange
32  # CHECK: iterator_interchange = [1, 0]
33
34
35@run
36def testPad():
37  sequence = transform.SequenceOp()
38  with InsertionPoint(sequence.body):
39    structured.PadOp(
40        sequence.bodyTarget,
41        padding_values=[FloatAttr.get_f32(42.0)],
42        padding_dimensions=[1],
43        transpose_paddings=[[1, 0]])
44    transform.YieldOp()
45  # CHECK-LABEL: TEST: testPad
46  # CHECK: transform.sequence
47  # CHECK: transform.structured.pad
48  # CHECK-DAG: padding_values = [4.200000e+01 : f32]
49  # CHECK-DAG: padding_dimensions = [1]
50  # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
51  # CHECK-DAG: hoist_paddings = []
52  # CHECK-DAG: pack_paddings = []
53
54
55@run
56def testScalarize():
57  sequence = transform.SequenceOp()
58  with InsertionPoint(sequence.body):
59    structured.ScalarizeOp(sequence.bodyTarget)
60    transform.YieldOp()
61  # CHECK-LABEL: TEST: testScalarize
62  # CHECK: transform.structured.scalarize
63
64
65@run
66def testTileCompact():
67  sequence = transform.SequenceOp()
68  with InsertionPoint(sequence.body):
69    structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
70    transform.YieldOp()
71  # CHECK-LABEL: TEST: testTileCompact
72  # CHECK: transform.sequence
73  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
74  # CHECK-DAG: interchange = [0, 1]
75  # CHECK-DAG: sizes = [4, 8]
76
77
78@run
79def testTileAttributes():
80  sequence = transform.SequenceOp()
81  attr = ArrayAttr.get(
82      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
83  ichange = ArrayAttr.get(
84      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
85  with InsertionPoint(sequence.body):
86    structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
87    transform.YieldOp()
88  # CHECK-LABEL: TEST: testTileAttributes
89  # CHECK: transform.sequence
90  # CHECK: structured.tile
91  # CHECK-DAG: interchange = [0, 1]
92  # CHECK-DAG: sizes = [4, 8]
93
94
95@run
96def testTileZero():
97  sequence = transform.SequenceOp()
98  with InsertionPoint(sequence.body):
99    structured.TileOp(
100        sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
101    transform.YieldOp()
102  # CHECK-LABEL: TEST: testTileZero
103  # CHECK: transform.sequence
104  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
105  # CHECK-DAG: interchange = [0, 1, 2, 3]
106  # CHECK-DAG: sizes = [4, 0, 2, 0]
107
108
109@run
110def testVectorize():
111  sequence = transform.SequenceOp()
112  with InsertionPoint(sequence.body):
113    structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
114    transform.YieldOp()
115  # CHECK-LABEL: TEST: testVectorize
116  # CHECK: transform.sequence
117  # CHECK: = transform.structured.vectorize
118  # CHECK: vectorize_padding = true
119