18c1b785cSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s 28c1b785cSAlex Zinenko 38c1b785cSAlex Zinenkofrom mlir.ir import * 4a54f4eaeSMogballfrom mlir.dialects import arith 523aa5a74SRiver Riddlefrom mlir.dialects import func 68c1b785cSAlex Zinenkofrom mlir.dialects import scf 78c1b785cSAlex Zinenkofrom mlir.dialects import builtin 88c1b785cSAlex Zinenko 98c1b785cSAlex Zinenko 10b164f23cSAlex Zinenkodef constructAndPrintInModule(f): 118c1b785cSAlex Zinenko print("\nTEST:", f.__name__) 12b164f23cSAlex Zinenko with Context(), Location.unknown(): 13b164f23cSAlex Zinenko module = Module.create() 14b164f23cSAlex Zinenko with InsertionPoint(module.body): 158c1b785cSAlex Zinenko f() 16b164f23cSAlex Zinenko print(module) 178c1b785cSAlex Zinenko return f 188c1b785cSAlex Zinenko 198c1b785cSAlex Zinenko 208c1b785cSAlex Zinenko# CHECK-LABEL: TEST: testSimpleLoop 21b164f23cSAlex Zinenko@constructAndPrintInModule 228c1b785cSAlex Zinenkodef testSimpleLoop(): 238c1b785cSAlex Zinenko index_type = IndexType.get() 248c1b785cSAlex Zinenko 25*36550692SRiver Riddle @func.FuncOp.from_py_func(index_type, index_type, index_type) 268c1b785cSAlex Zinenko def simple_loop(lb, ub, step): 278c1b785cSAlex Zinenko loop = scf.ForOp(lb, ub, step, [lb, lb]) 288c1b785cSAlex Zinenko with InsertionPoint(loop.body): 298c1b785cSAlex Zinenko scf.YieldOp(loop.inner_iter_args) 308c1b785cSAlex Zinenko return 318c1b785cSAlex Zinenko 32b164f23cSAlex Zinenko 338c1b785cSAlex Zinenko# CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) 348c1b785cSAlex Zinenko# CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] 358c1b785cSAlex Zinenko# CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]]) 368c1b785cSAlex Zinenko# CHECK: scf.yield %[[I1]], %[[I2]] 378c1b785cSAlex Zinenko 388c1b785cSAlex Zinenko 398c1b785cSAlex Zinenko# CHECK-LABEL: TEST: testInductionVar 40b164f23cSAlex Zinenko@constructAndPrintInModule 418c1b785cSAlex Zinenkodef testInductionVar(): 428c1b785cSAlex Zinenko index_type = IndexType.get() 438c1b785cSAlex Zinenko 44*36550692SRiver Riddle @func.FuncOp.from_py_func(index_type, index_type, index_type) 458c1b785cSAlex Zinenko def induction_var(lb, ub, step): 468c1b785cSAlex Zinenko loop = scf.ForOp(lb, ub, step, [lb]) 478c1b785cSAlex Zinenko with InsertionPoint(loop.body): 488c1b785cSAlex Zinenko scf.YieldOp([loop.induction_variable]) 498c1b785cSAlex Zinenko return 508c1b785cSAlex Zinenko 51b164f23cSAlex Zinenko 528c1b785cSAlex Zinenko# CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) 538c1b785cSAlex Zinenko# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] 548c1b785cSAlex Zinenko# CHECK: scf.yield %[[IV]] 55b164f23cSAlex Zinenko 56b164f23cSAlex Zinenko 57b164f23cSAlex Zinenko@constructAndPrintInModule 58b164f23cSAlex Zinenkodef testOpsAsArguments(): 59b164f23cSAlex Zinenko index_type = IndexType.get() 60*36550692SRiver Riddle callee = func.FuncOp( 61b164f23cSAlex Zinenko "callee", ([], [index_type, index_type]), visibility="private") 62*36550692SRiver Riddle f = func.FuncOp("ops_as_arguments", ([], [])) 6323aa5a74SRiver Riddle with InsertionPoint(f.add_entry_block()): 64a54f4eaeSMogball lb = arith.ConstantOp.create_index(0) 65a54f4eaeSMogball ub = arith.ConstantOp.create_index(42) 66a54f4eaeSMogball step = arith.ConstantOp.create_index(2) 6723aa5a74SRiver Riddle iter_args = func.CallOp(callee, []) 68b164f23cSAlex Zinenko loop = scf.ForOp(lb, ub, step, iter_args) 69b164f23cSAlex Zinenko with InsertionPoint(loop.body): 70b164f23cSAlex Zinenko scf.YieldOp(loop.inner_iter_args) 7123aa5a74SRiver Riddle func.ReturnOp([]) 72b164f23cSAlex Zinenko 73b164f23cSAlex Zinenko 74b164f23cSAlex Zinenko# CHECK-LABEL: TEST: testOpsAsArguments 75b164f23cSAlex Zinenko# CHECK: func private @callee() -> (index, index) 76b164f23cSAlex Zinenko# CHECK: func @ops_as_arguments() { 77a54f4eaeSMogball# CHECK: %[[LB:.*]] = arith.constant 0 78a54f4eaeSMogball# CHECK: %[[UB:.*]] = arith.constant 42 79a54f4eaeSMogball# CHECK: %[[STEP:.*]] = arith.constant 2 80b164f23cSAlex Zinenko# CHECK: %[[ARGS:.*]]:2 = call @callee() 81b164f23cSAlex Zinenko# CHECK: scf.for %arg0 = %c0 to %c42 step %c2 82b164f23cSAlex Zinenko# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1) 83b164f23cSAlex Zinenko# CHECK: scf.yield %{{.*}}, %{{.*}} 84b164f23cSAlex Zinenko# CHECK: return 85036088fdSchhzh123 86036088fdSchhzh123 87036088fdSchhzh123@constructAndPrintInModule 88036088fdSchhzh123def testIfWithoutElse(): 89036088fdSchhzh123 bool = IntegerType.get_signless(1) 90036088fdSchhzh123 i32 = IntegerType.get_signless(32) 91036088fdSchhzh123 92*36550692SRiver Riddle @func.FuncOp.from_py_func(bool) 93036088fdSchhzh123 def simple_if(cond): 94036088fdSchhzh123 if_op = scf.IfOp(cond) 95036088fdSchhzh123 with InsertionPoint(if_op.then_block): 96036088fdSchhzh123 one = arith.ConstantOp(i32, 1) 97036088fdSchhzh123 add = arith.AddIOp(one, one) 98036088fdSchhzh123 scf.YieldOp([]) 99036088fdSchhzh123 return 100036088fdSchhzh123 101036088fdSchhzh123 102036088fdSchhzh123# CHECK: func @simple_if(%[[ARG0:.*]]: i1) 103036088fdSchhzh123# CHECK: scf.if %[[ARG0:.*]] 104036088fdSchhzh123# CHECK: %[[ONE:.*]] = arith.constant 1 105036088fdSchhzh123# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]] 106036088fdSchhzh123# CHECK: return 107036088fdSchhzh123 108036088fdSchhzh123 109036088fdSchhzh123@constructAndPrintInModule 110036088fdSchhzh123def testIfWithElse(): 111036088fdSchhzh123 bool = IntegerType.get_signless(1) 112036088fdSchhzh123 i32 = IntegerType.get_signless(32) 113036088fdSchhzh123 114*36550692SRiver Riddle @func.FuncOp.from_py_func(bool) 115036088fdSchhzh123 def simple_if_else(cond): 116036088fdSchhzh123 if_op = scf.IfOp(cond, [i32, i32], hasElse=True) 117036088fdSchhzh123 with InsertionPoint(if_op.then_block): 118036088fdSchhzh123 x_true = arith.ConstantOp(i32, 0) 119036088fdSchhzh123 y_true = arith.ConstantOp(i32, 1) 120036088fdSchhzh123 scf.YieldOp([x_true, y_true]) 121036088fdSchhzh123 with InsertionPoint(if_op.else_block): 122036088fdSchhzh123 x_false = arith.ConstantOp(i32, 2) 123036088fdSchhzh123 y_false = arith.ConstantOp(i32, 3) 124036088fdSchhzh123 scf.YieldOp([x_false, y_false]) 125036088fdSchhzh123 add = arith.AddIOp(if_op.results[0], if_op.results[1]) 126036088fdSchhzh123 return 127036088fdSchhzh123 128036088fdSchhzh123 129036088fdSchhzh123# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1) 130036088fdSchhzh123# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]] 131036088fdSchhzh123# CHECK: %[[ZERO:.*]] = arith.constant 0 132036088fdSchhzh123# CHECK: %[[ONE:.*]] = arith.constant 1 133036088fdSchhzh123# CHECK: scf.yield %[[ZERO]], %[[ONE]] 134036088fdSchhzh123# CHECK: } else { 135036088fdSchhzh123# CHECK: %[[TWO:.*]] = arith.constant 2 136036088fdSchhzh123# CHECK: %[[THREE:.*]] = arith.constant 3 137036088fdSchhzh123# CHECK: scf.yield %[[TWO]], %[[THREE]] 138036088fdSchhzh123# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 139036088fdSchhzh123# CHECK: return 140