18c1b785cSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s
28c1b785cSAlex Zinenko
38c1b785cSAlex Zinenkofrom mlir.ir import *
4a54f4eaeSMogballfrom mlir.dialects import arith
523aa5a74SRiver Riddlefrom mlir.dialects import func
68c1b785cSAlex Zinenkofrom mlir.dialects import scf
78c1b785cSAlex Zinenkofrom mlir.dialects import builtin
88c1b785cSAlex Zinenko
98c1b785cSAlex Zinenko
10b164f23cSAlex Zinenkodef constructAndPrintInModule(f):
118c1b785cSAlex Zinenko  print("\nTEST:", f.__name__)
12b164f23cSAlex Zinenko  with Context(), Location.unknown():
13b164f23cSAlex Zinenko    module = Module.create()
14b164f23cSAlex Zinenko    with InsertionPoint(module.body):
158c1b785cSAlex Zinenko      f()
16b164f23cSAlex Zinenko    print(module)
178c1b785cSAlex Zinenko  return f
188c1b785cSAlex Zinenko
198c1b785cSAlex Zinenko
208c1b785cSAlex Zinenko# CHECK-LABEL: TEST: testSimpleLoop
21b164f23cSAlex Zinenko@constructAndPrintInModule
228c1b785cSAlex Zinenkodef testSimpleLoop():
238c1b785cSAlex Zinenko  index_type = IndexType.get()
248c1b785cSAlex Zinenko
25*36550692SRiver Riddle  @func.FuncOp.from_py_func(index_type, index_type, index_type)
268c1b785cSAlex Zinenko  def simple_loop(lb, ub, step):
278c1b785cSAlex Zinenko    loop = scf.ForOp(lb, ub, step, [lb, lb])
288c1b785cSAlex Zinenko    with InsertionPoint(loop.body):
298c1b785cSAlex Zinenko      scf.YieldOp(loop.inner_iter_args)
308c1b785cSAlex Zinenko    return
318c1b785cSAlex Zinenko
32b164f23cSAlex Zinenko
338c1b785cSAlex Zinenko# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
348c1b785cSAlex Zinenko# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
358c1b785cSAlex Zinenko# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
368c1b785cSAlex Zinenko# CHECK: scf.yield %[[I1]], %[[I2]]
378c1b785cSAlex Zinenko
388c1b785cSAlex Zinenko
398c1b785cSAlex Zinenko# CHECK-LABEL: TEST: testInductionVar
40b164f23cSAlex Zinenko@constructAndPrintInModule
418c1b785cSAlex Zinenkodef testInductionVar():
428c1b785cSAlex Zinenko  index_type = IndexType.get()
438c1b785cSAlex Zinenko
44*36550692SRiver Riddle  @func.FuncOp.from_py_func(index_type, index_type, index_type)
458c1b785cSAlex Zinenko  def induction_var(lb, ub, step):
468c1b785cSAlex Zinenko    loop = scf.ForOp(lb, ub, step, [lb])
478c1b785cSAlex Zinenko    with InsertionPoint(loop.body):
488c1b785cSAlex Zinenko      scf.YieldOp([loop.induction_variable])
498c1b785cSAlex Zinenko    return
508c1b785cSAlex Zinenko
51b164f23cSAlex Zinenko
528c1b785cSAlex Zinenko# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
538c1b785cSAlex Zinenko# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
548c1b785cSAlex Zinenko# CHECK: scf.yield %[[IV]]
55b164f23cSAlex Zinenko
56b164f23cSAlex Zinenko
57b164f23cSAlex Zinenko@constructAndPrintInModule
58b164f23cSAlex Zinenkodef testOpsAsArguments():
59b164f23cSAlex Zinenko  index_type = IndexType.get()
60*36550692SRiver Riddle  callee = func.FuncOp(
61b164f23cSAlex Zinenko      "callee", ([], [index_type, index_type]), visibility="private")
62*36550692SRiver Riddle  f = func.FuncOp("ops_as_arguments", ([], []))
6323aa5a74SRiver Riddle  with InsertionPoint(f.add_entry_block()):
64a54f4eaeSMogball    lb = arith.ConstantOp.create_index(0)
65a54f4eaeSMogball    ub = arith.ConstantOp.create_index(42)
66a54f4eaeSMogball    step = arith.ConstantOp.create_index(2)
6723aa5a74SRiver Riddle    iter_args = func.CallOp(callee, [])
68b164f23cSAlex Zinenko    loop = scf.ForOp(lb, ub, step, iter_args)
69b164f23cSAlex Zinenko    with InsertionPoint(loop.body):
70b164f23cSAlex Zinenko      scf.YieldOp(loop.inner_iter_args)
7123aa5a74SRiver Riddle    func.ReturnOp([])
72b164f23cSAlex Zinenko
73b164f23cSAlex Zinenko
74b164f23cSAlex Zinenko# CHECK-LABEL: TEST: testOpsAsArguments
75b164f23cSAlex Zinenko# CHECK: func private @callee() -> (index, index)
76b164f23cSAlex Zinenko# CHECK: func @ops_as_arguments() {
77a54f4eaeSMogball# CHECK:   %[[LB:.*]] = arith.constant 0
78a54f4eaeSMogball# CHECK:   %[[UB:.*]] = arith.constant 42
79a54f4eaeSMogball# CHECK:   %[[STEP:.*]] = arith.constant 2
80b164f23cSAlex Zinenko# CHECK:   %[[ARGS:.*]]:2 = call @callee()
81b164f23cSAlex Zinenko# CHECK:   scf.for %arg0 = %c0 to %c42 step %c2
82b164f23cSAlex Zinenko# CHECK:   iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
83b164f23cSAlex Zinenko# CHECK:     scf.yield %{{.*}}, %{{.*}}
84b164f23cSAlex Zinenko# CHECK:   return
85036088fdSchhzh123
86036088fdSchhzh123
87036088fdSchhzh123@constructAndPrintInModule
88036088fdSchhzh123def testIfWithoutElse():
89036088fdSchhzh123  bool = IntegerType.get_signless(1)
90036088fdSchhzh123  i32 = IntegerType.get_signless(32)
91036088fdSchhzh123
92*36550692SRiver Riddle  @func.FuncOp.from_py_func(bool)
93036088fdSchhzh123  def simple_if(cond):
94036088fdSchhzh123    if_op = scf.IfOp(cond)
95036088fdSchhzh123    with InsertionPoint(if_op.then_block):
96036088fdSchhzh123      one = arith.ConstantOp(i32, 1)
97036088fdSchhzh123      add = arith.AddIOp(one, one)
98036088fdSchhzh123      scf.YieldOp([])
99036088fdSchhzh123    return
100036088fdSchhzh123
101036088fdSchhzh123
102036088fdSchhzh123# CHECK: func @simple_if(%[[ARG0:.*]]: i1)
103036088fdSchhzh123# CHECK: scf.if %[[ARG0:.*]]
104036088fdSchhzh123# CHECK:   %[[ONE:.*]] = arith.constant 1
105036088fdSchhzh123# CHECK:   %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
106036088fdSchhzh123# CHECK: return
107036088fdSchhzh123
108036088fdSchhzh123
109036088fdSchhzh123@constructAndPrintInModule
110036088fdSchhzh123def testIfWithElse():
111036088fdSchhzh123  bool = IntegerType.get_signless(1)
112036088fdSchhzh123  i32 = IntegerType.get_signless(32)
113036088fdSchhzh123
114*36550692SRiver Riddle  @func.FuncOp.from_py_func(bool)
115036088fdSchhzh123  def simple_if_else(cond):
116036088fdSchhzh123    if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
117036088fdSchhzh123    with InsertionPoint(if_op.then_block):
118036088fdSchhzh123      x_true = arith.ConstantOp(i32, 0)
119036088fdSchhzh123      y_true = arith.ConstantOp(i32, 1)
120036088fdSchhzh123      scf.YieldOp([x_true, y_true])
121036088fdSchhzh123    with InsertionPoint(if_op.else_block):
122036088fdSchhzh123      x_false = arith.ConstantOp(i32, 2)
123036088fdSchhzh123      y_false = arith.ConstantOp(i32, 3)
124036088fdSchhzh123      scf.YieldOp([x_false, y_false])
125036088fdSchhzh123    add = arith.AddIOp(if_op.results[0], if_op.results[1])
126036088fdSchhzh123    return
127036088fdSchhzh123
128036088fdSchhzh123
129036088fdSchhzh123# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
130036088fdSchhzh123# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]]
131036088fdSchhzh123# CHECK:   %[[ZERO:.*]] = arith.constant 0
132036088fdSchhzh123# CHECK:   %[[ONE:.*]] = arith.constant 1
133036088fdSchhzh123# CHECK:   scf.yield %[[ZERO]], %[[ONE]]
134036088fdSchhzh123# CHECK: } else {
135036088fdSchhzh123# CHECK:   %[[TWO:.*]] = arith.constant 2
136036088fdSchhzh123# CHECK:   %[[THREE:.*]] = arith.constant 3
137036088fdSchhzh123# CHECK:   scf.yield %[[TWO]], %[[THREE]]
138036088fdSchhzh123# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
139036088fdSchhzh123# CHECK: return
140