1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects import pdl 6 7 8def run(f): 9 with Context(), Location.unknown(): 10 module = Module.create() 11 with InsertionPoint(module.body): 12 print("\nTEST:", f.__name__) 13 f() 14 print(module) 15 return f 16 17 18@run 19def testSequenceOp(): 20 sequence = transform.SequenceOp([pdl.OperationType.get()]) 21 with InsertionPoint(sequence.body): 22 transform.YieldOp([sequence.bodyTarget]) 23 # CHECK-LABEL: TEST: testSequenceOp 24 # CHECK: = transform.sequence { 25 # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation): 26 # CHECK: yield %[[ARG0]] : !pdl.operation 27 # CHECK: } : !pdl.operation 28 29 30@run 31def testNestedSequenceOp(): 32 sequence = transform.SequenceOp() 33 with InsertionPoint(sequence.body): 34 nested = transform.SequenceOp(sequence.bodyTarget) 35 with InsertionPoint(nested.body): 36 doubly_nested = transform.SequenceOp([pdl.OperationType.get()], 37 nested.bodyTarget) 38 with InsertionPoint(doubly_nested.body): 39 transform.YieldOp([doubly_nested.bodyTarget]) 40 transform.YieldOp() 41 transform.YieldOp() 42 # CHECK-LABEL: TEST: testNestedSequenceOp 43 # CHECK: transform.sequence { 44 # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation): 45 # CHECK: sequence %[[ARG0]] { 46 # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): 47 # CHECK: = sequence %[[ARG1]] { 48 # CHECK: ^{{.*}}(%[[ARG2:.+]]: !pdl.operation): 49 # CHECK: yield %[[ARG2]] : !pdl.operation 50 # CHECK: } : !pdl.operation 51 # CHECK: } 52 # CHECK: } 53 54 55@run 56def testTransformPDLOps(): 57 withPdl = transform.WithPDLPatternsOp() 58 with InsertionPoint(withPdl.body): 59 sequence = transform.SequenceOp([pdl.OperationType.get()], 60 withPdl.bodyTarget) 61 with InsertionPoint(sequence.body): 62 match = transform.PDLMatchOp(sequence.bodyTarget, "pdl_matcher") 63 transform.YieldOp(match) 64 # CHECK-LABEL: TEST: testTransformPDLOps 65 # CHECK: transform.with_pdl_patterns { 66 # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation): 67 # CHECK: = sequence %[[ARG0]] { 68 # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): 69 # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] 70 # CHECK: yield %[[RES]] : !pdl.operation 71 # CHECK: } : !pdl.operation 72 # CHECK: } 73 74 75@run 76def testGetClosestIsolatedParentOp(): 77 sequence = transform.SequenceOp() 78 with InsertionPoint(sequence.body): 79 transform.GetClosestIsolatedParentOp(sequence.bodyTarget) 80 transform.YieldOp() 81 # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp 82 # CHECK: transform.sequence 83 # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): 84 # CHECK: = get_closest_isolated_parent %[[ARG1]] 85 86 87@run 88def testMergeHandlesOp(): 89 sequence = transform.SequenceOp() 90 with InsertionPoint(sequence.body): 91 transform.MergeHandlesOp([sequence.bodyTarget]) 92 transform.YieldOp() 93 # CHECK-LABEL: TEST: testMergeHandlesOp 94 # CHECK: transform.sequence 95 # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): 96 # CHECK: = merge_handles %[[ARG1]] 97 98 99@run 100def testReplicateOp(): 101 with_pdl = transform.WithPDLPatternsOp() 102 with InsertionPoint(with_pdl.body): 103 sequence = transform.SequenceOp(with_pdl.bodyTarget) 104 with InsertionPoint(sequence.body): 105 m1 = transform.PDLMatchOp(sequence.bodyTarget, "first") 106 m2 = transform.PDLMatchOp(sequence.bodyTarget, "second") 107 transform.ReplicateOp(m1, [m2]) 108 transform.YieldOp() 109 # CHECK-LABEL: TEST: testReplicateOp 110 # CHECK: %[[FIRST:.+]] = pdl_match 111 # CHECK: %[[SECOND:.+]] = pdl_match 112 # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] 113