1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects import pdl
6from mlir.dialects.transform import loop
7
8
9def run(f):
10  with Context(), Location.unknown():
11    module = Module.create()
12    with InsertionPoint(module.body):
13      print("\nTEST:", f.__name__)
14      f()
15    print(module)
16  return f
17
18
19@run
20def getParentLoop():
21  sequence = transform.SequenceOp()
22  with InsertionPoint(sequence.body):
23    loop.GetParentForOp(sequence.bodyTarget, num_loops=2)
24    transform.YieldOp()
25  # CHECK-LABEL: TEST: getParentLoop
26  # CHECK: = transform.loop.get_parent_for %
27  # CHECK: num_loops = 2
28
29
30@run
31def loopOutline():
32  sequence = transform.SequenceOp()
33  with InsertionPoint(sequence.body):
34    loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo")
35    transform.YieldOp()
36  # CHECK-LABEL: TEST: loopOutline
37  # CHECK: = transform.loop.outline %
38  # CHECK: func_name = "foo"
39
40
41@run
42def loopPeel():
43  sequence = transform.SequenceOp()
44  with InsertionPoint(sequence.body):
45    loop.LoopPeelOp(sequence.bodyTarget)
46    transform.YieldOp()
47  # CHECK-LABEL: TEST: loopPeel
48  # CHECK: = transform.loop.peel %
49
50
51@run
52def loopPipeline():
53  sequence = transform.SequenceOp()
54  with InsertionPoint(sequence.body):
55    loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3)
56    transform.YieldOp()
57  # CHECK-LABEL: TEST: loopPipeline
58  # CHECK: = transform.loop.pipeline %
59  # CHECK-DAG: iteration_interval = 3
60  # CHECK-DAG: read_latency = 10
61
62
63@run
64def loopUnroll():
65  sequence = transform.SequenceOp()
66  with InsertionPoint(sequence.body):
67    loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
68    transform.YieldOp()
69  # CHECK-LABEL: TEST: loopUnroll
70  # CHECK: transform.loop.unroll %
71  # CHECK: factor = 42
72