1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import scf
5from mlir.dialects import builtin
6
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  return f
12
13
14# CHECK-LABEL: TEST: testSimpleLoop
15@run
16def testSimpleLoop():
17  with Context(), Location.unknown():
18    module = Module.create()
19    index_type = IndexType.get()
20    with InsertionPoint(module.body):
21
22      @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
23      def simple_loop(lb, ub, step):
24        loop = scf.ForOp(lb, ub, step, [lb, lb])
25        with InsertionPoint(loop.body):
26          scf.YieldOp(loop.inner_iter_args)
27        return
28
29  # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
30  # CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
31  # CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
32  # CHECK: scf.yield %[[I1]], %[[I2]]
33  print(module)
34
35
36# CHECK-LABEL: TEST: testInductionVar
37@run
38def testInductionVar():
39  with Context(), Location.unknown():
40    module = Module.create()
41    index_type = IndexType.get()
42    with InsertionPoint(module.body):
43
44      @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
45      def induction_var(lb, ub, step):
46        loop = scf.ForOp(lb, ub, step, [lb])
47        with InsertionPoint(loop.body):
48          scf.YieldOp([loop.induction_variable])
49        return
50
51  # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
52  # CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
53  # CHECK: scf.yield %[[IV]]
54  print(module)
55