13f71765aSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s
23f71765aSAlex Zinenko
33f71765aSAlex Zinenkofrom mlir.ir import *
43f71765aSAlex Zinenkofrom mlir.dialects import transform
53f71765aSAlex Zinenkofrom mlir.dialects import pdl
63f71765aSAlex Zinenkofrom mlir.dialects.transform import structured
73f71765aSAlex Zinenko
83f71765aSAlex Zinenko
93f71765aSAlex Zinenkodef run(f):
103f71765aSAlex Zinenko  with Context(), Location.unknown():
113f71765aSAlex Zinenko    module = Module.create()
123f71765aSAlex Zinenko    with InsertionPoint(module.body):
133f71765aSAlex Zinenko      print("\nTEST:", f.__name__)
143f71765aSAlex Zinenko      f()
153f71765aSAlex Zinenko    print(module)
163f71765aSAlex Zinenko  return f
173f71765aSAlex Zinenko
183f71765aSAlex Zinenko
193f71765aSAlex Zinenko@run
20ce2e198bSAlex Zinenkodef testDecompose():
21ce2e198bSAlex Zinenko  sequence = transform.SequenceOp()
22ce2e198bSAlex Zinenko  with InsertionPoint(sequence.body):
23ce2e198bSAlex Zinenko    structured.DecomposeOp(sequence.bodyTarget)
24ce2e198bSAlex Zinenko    transform.YieldOp()
25ce2e198bSAlex Zinenko  # CHECK-LABEL: TEST: testDecompose
26ce2e198bSAlex Zinenko  # CHECK: transform.sequence
27ce2e198bSAlex Zinenko  # CHECK: transform.structured.decompose
28ce2e198bSAlex Zinenko
29ce2e198bSAlex Zinenko
30ce2e198bSAlex Zinenko@run
31ce2e198bSAlex Zinenkodef testGeneralize():
32ce2e198bSAlex Zinenko  sequence = transform.SequenceOp()
33ce2e198bSAlex Zinenko  with InsertionPoint(sequence.body):
34ce2e198bSAlex Zinenko    structured.GeneralizeOp(sequence.bodyTarget)
35ce2e198bSAlex Zinenko    transform.YieldOp()
36ce2e198bSAlex Zinenko  # CHECK-LABEL: TEST: testGeneralize
37ce2e198bSAlex Zinenko  # CHECK: transform.sequence
38ce2e198bSAlex Zinenko  # CHECK: transform.structured.generalize
39ce2e198bSAlex Zinenko
40ce2e198bSAlex Zinenko
41ce2e198bSAlex Zinenko@run
423f71765aSAlex Zinenkodef testInterchange():
433f71765aSAlex Zinenko  sequence = transform.SequenceOp()
443f71765aSAlex Zinenko  with InsertionPoint(sequence.body):
453f71765aSAlex Zinenko    structured.InterchangeOp(
463f71765aSAlex Zinenko        sequence.bodyTarget,
473f71765aSAlex Zinenko        iterator_interchange=[
483f71765aSAlex Zinenko            IntegerAttr.get(IntegerType.get_signless(64), 1), 0
493f71765aSAlex Zinenko        ])
503f71765aSAlex Zinenko    transform.YieldOp()
513f71765aSAlex Zinenko  # CHECK-LABEL: TEST: testInterchange
523f71765aSAlex Zinenko  # CHECK: transform.sequence
533f71765aSAlex Zinenko  # CHECK: transform.structured.interchange
543f71765aSAlex Zinenko  # CHECK: iterator_interchange = [1, 0]
553f71765aSAlex Zinenko
563f71765aSAlex Zinenko
573f71765aSAlex Zinenko@run
58*3963b4d0SAlex Zinenkodef testMultitileSizes():
59*3963b4d0SAlex Zinenko  sequence = transform.SequenceOp()
60*3963b4d0SAlex Zinenko  with InsertionPoint(sequence.body):
61*3963b4d0SAlex Zinenko    structured.MultiTileSizesOp(
62*3963b4d0SAlex Zinenko        sequence.bodyTarget, dimension=1, target_size=42)
63*3963b4d0SAlex Zinenko    transform.YieldOp()
64*3963b4d0SAlex Zinenko  # CHECK-LABEL: TEST: testMultitileSizes
65*3963b4d0SAlex Zinenko  # CHECK: transform.sequence
66*3963b4d0SAlex Zinenko  # CHECK: transform.structured.multitile_sizes
67*3963b4d0SAlex Zinenko  # CHECK-DAG: dimension = 1
68*3963b4d0SAlex Zinenko  # CHECK-DAG: target_size = 42
69*3963b4d0SAlex Zinenko
70*3963b4d0SAlex Zinenko
71*3963b4d0SAlex Zinenko@run
723f71765aSAlex Zinenkodef testPad():
733f71765aSAlex Zinenko  sequence = transform.SequenceOp()
743f71765aSAlex Zinenko  with InsertionPoint(sequence.body):
753f71765aSAlex Zinenko    structured.PadOp(
763f71765aSAlex Zinenko        sequence.bodyTarget,
773f71765aSAlex Zinenko        padding_values=[FloatAttr.get_f32(42.0)],
783f71765aSAlex Zinenko        padding_dimensions=[1],
793f71765aSAlex Zinenko        transpose_paddings=[[1, 0]])
803f71765aSAlex Zinenko    transform.YieldOp()
813f71765aSAlex Zinenko  # CHECK-LABEL: TEST: testPad
823f71765aSAlex Zinenko  # CHECK: transform.sequence
833f71765aSAlex Zinenko  # CHECK: transform.structured.pad
843f71765aSAlex Zinenko  # CHECK-DAG: padding_values = [4.200000e+01 : f32]
853f71765aSAlex Zinenko  # CHECK-DAG: padding_dimensions = [1]
863f71765aSAlex Zinenko  # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
873f71765aSAlex Zinenko  # CHECK-DAG: hoist_paddings = []
883f71765aSAlex Zinenko  # CHECK-DAG: pack_paddings = []
893f71765aSAlex Zinenko
903f71765aSAlex Zinenko
913f71765aSAlex Zinenko@run
923f71765aSAlex Zinenkodef testScalarize():
933f71765aSAlex Zinenko  sequence = transform.SequenceOp()
943f71765aSAlex Zinenko  with InsertionPoint(sequence.body):
953f71765aSAlex Zinenko    structured.ScalarizeOp(sequence.bodyTarget)
963f71765aSAlex Zinenko    transform.YieldOp()
973f71765aSAlex Zinenko  # CHECK-LABEL: TEST: testScalarize
983f71765aSAlex Zinenko  # CHECK: transform.structured.scalarize
993f71765aSAlex Zinenko
1003f71765aSAlex Zinenko
1013f71765aSAlex Zinenko@run
102ff6e5508SAlex Zinenkodef testSplit():
103ff6e5508SAlex Zinenko  sequence = transform.SequenceOp()
104ff6e5508SAlex Zinenko  with InsertionPoint(sequence.body):
105ff6e5508SAlex Zinenko    split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
106ff6e5508SAlex Zinenko    structured.SplitOp(
107ff6e5508SAlex Zinenko        split.results[0], dimension=3, split_point=split.results[1])
108ff6e5508SAlex Zinenko    transform.YieldOp()
109ff6e5508SAlex Zinenko  # CHECK-LABEL: TEST: testSplit
110ff6e5508SAlex Zinenko  # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
111ff6e5508SAlex Zinenko  # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
112ff6e5508SAlex Zinenko
113ff6e5508SAlex Zinenko
114ff6e5508SAlex Zinenko@run
1153f71765aSAlex Zinenkodef testTileCompact():
1163f71765aSAlex Zinenko  sequence = transform.SequenceOp()
1173f71765aSAlex Zinenko  with InsertionPoint(sequence.body):
1183f71765aSAlex Zinenko    structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
1193f71765aSAlex Zinenko    transform.YieldOp()
1203f71765aSAlex Zinenko  # CHECK-LABEL: TEST: testTileCompact
1213f71765aSAlex Zinenko  # CHECK: transform.sequence
1224e4a4c05SAlex Zinenko  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
1234e4a4c05SAlex Zinenko  # CHECK: interchange = [0, 1]
1243f71765aSAlex Zinenko
1253f71765aSAlex Zinenko
1263f71765aSAlex Zinenko@run
1273f71765aSAlex Zinenkodef testTileAttributes():
1283f71765aSAlex Zinenko  sequence = transform.SequenceOp()
1293f71765aSAlex Zinenko  attr = ArrayAttr.get(
1303f71765aSAlex Zinenko      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
1313f71765aSAlex Zinenko  ichange = ArrayAttr.get(
1323f71765aSAlex Zinenko      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
1333f71765aSAlex Zinenko  with InsertionPoint(sequence.body):
1343f71765aSAlex Zinenko    structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
1353f71765aSAlex Zinenko    transform.YieldOp()
1363f71765aSAlex Zinenko  # CHECK-LABEL: TEST: testTileAttributes
1373f71765aSAlex Zinenko  # CHECK: transform.sequence
1384e4a4c05SAlex Zinenko  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
1394e4a4c05SAlex Zinenko  # CHECK: interchange = [0, 1]
1403f71765aSAlex Zinenko
1413f71765aSAlex Zinenko
1423f71765aSAlex Zinenko@run
1433f71765aSAlex Zinenkodef testTileZero():
1443f71765aSAlex Zinenko  sequence = transform.SequenceOp()
1453f71765aSAlex Zinenko  with InsertionPoint(sequence.body):
1463f71765aSAlex Zinenko    structured.TileOp(
1473f71765aSAlex Zinenko        sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
1483f71765aSAlex Zinenko    transform.YieldOp()
1493f71765aSAlex Zinenko  # CHECK-LABEL: TEST: testTileZero
1503f71765aSAlex Zinenko  # CHECK: transform.sequence
1514e4a4c05SAlex Zinenko  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
1524e4a4c05SAlex Zinenko  # CHECK: interchange = [0, 1, 2, 3]
1534e4a4c05SAlex Zinenko
1544e4a4c05SAlex Zinenko
1554e4a4c05SAlex Zinenko@run
1564e4a4c05SAlex Zinenkodef testTileDynamic():
1574e4a4c05SAlex Zinenko  with_pdl = transform.WithPDLPatternsOp()
1584e4a4c05SAlex Zinenko  with InsertionPoint(with_pdl.body):
1594e4a4c05SAlex Zinenko    sequence = transform.SequenceOp(with_pdl.bodyTarget)
1604e4a4c05SAlex Zinenko    with InsertionPoint(sequence.body):
1614e4a4c05SAlex Zinenko      m1 = transform.PDLMatchOp(sequence.bodyTarget, "first")
1624e4a4c05SAlex Zinenko      m2 = transform.PDLMatchOp(sequence.bodyTarget, "second")
1634e4a4c05SAlex Zinenko      structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
1644e4a4c05SAlex Zinenko      transform.YieldOp()
1654e4a4c05SAlex Zinenko  # CHECK-LABEL: TEST: testTileDynamic
1664e4a4c05SAlex Zinenko  # CHECK: %[[FIRST:.+]] = pdl_match
1674e4a4c05SAlex Zinenko  # CHECK: %[[SECOND:.+]] = pdl_match
1684e4a4c05SAlex Zinenko  # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
1693f71765aSAlex Zinenko
1703f71765aSAlex Zinenko
1713f71765aSAlex Zinenko@run
1723f71765aSAlex Zinenkodef testVectorize():
1733f71765aSAlex Zinenko  sequence = transform.SequenceOp()
1743f71765aSAlex Zinenko  with InsertionPoint(sequence.body):
1753f71765aSAlex Zinenko    structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
1763f71765aSAlex Zinenko    transform.YieldOp()
1773f71765aSAlex Zinenko  # CHECK-LABEL: TEST: testVectorize
1783f71765aSAlex Zinenko  # CHECK: transform.sequence
1793f71765aSAlex Zinenko  # CHECK: = transform.structured.vectorize
1803f71765aSAlex Zinenko  # CHECK: vectorize_padding = true
181