1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects import pdl 6from mlir.dialects.transform import loop 7 8 9def run(f): 10 with Context(), Location.unknown(): 11 module = Module.create() 12 with InsertionPoint(module.body): 13 print("\nTEST:", f.__name__) 14 f() 15 print(module) 16 return f 17 18 19@run 20def getParentLoop(): 21 sequence = transform.SequenceOp() 22 with InsertionPoint(sequence.body): 23 loop.GetParentForOp(sequence.bodyTarget, num_loops=2) 24 transform.YieldOp() 25 # CHECK-LABEL: TEST: getParentLoop 26 # CHECK: = transform.loop.get_parent_for % 27 # CHECK: num_loops = 2 28 29 30@run 31def loopOutline(): 32 sequence = transform.SequenceOp() 33 with InsertionPoint(sequence.body): 34 loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo") 35 transform.YieldOp() 36 # CHECK-LABEL: TEST: loopOutline 37 # CHECK: = transform.loop.outline % 38 # CHECK: func_name = "foo" 39 40 41@run 42def loopPeel(): 43 sequence = transform.SequenceOp() 44 with InsertionPoint(sequence.body): 45 loop.LoopPeelOp(sequence.bodyTarget) 46 transform.YieldOp() 47 # CHECK-LABEL: TEST: loopPeel 48 # CHECK: = transform.loop.peel % 49 50 51@run 52def loopPipeline(): 53 sequence = transform.SequenceOp() 54 with InsertionPoint(sequence.body): 55 loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3) 56 transform.YieldOp() 57 # CHECK-LABEL: TEST: loopPipeline 58 # CHECK: = transform.loop.pipeline % 59 # CHECK-DAG: iteration_interval = 3 60 # CHECK-DAG: read_latency = 10 61 62 63@run 64def loopUnroll(): 65 sequence = transform.SequenceOp() 66 with InsertionPoint(sequence.body): 67 loop.LoopUnrollOp(sequence.bodyTarget, factor=42) 68 transform.YieldOp() 69 # CHECK-LABEL: TEST: loopUnroll 70 # CHECK: transform.loop.unroll % 71 # CHECK: factor = 42 72