1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects import pdl 6from mlir.dialects.transform import structured 7 8 9def run(f): 10 with Context(), Location.unknown(): 11 module = Module.create() 12 with InsertionPoint(module.body): 13 print("\nTEST:", f.__name__) 14 f() 15 print(module) 16 return f 17 18 19@run 20def testDecompose(): 21 sequence = transform.SequenceOp() 22 with InsertionPoint(sequence.body): 23 structured.DecomposeOp(sequence.bodyTarget) 24 transform.YieldOp() 25 # CHECK-LABEL: TEST: testDecompose 26 # CHECK: transform.sequence 27 # CHECK: transform.structured.decompose 28 29 30@run 31def testGeneralize(): 32 sequence = transform.SequenceOp() 33 with InsertionPoint(sequence.body): 34 structured.GeneralizeOp(sequence.bodyTarget) 35 transform.YieldOp() 36 # CHECK-LABEL: TEST: testGeneralize 37 # CHECK: transform.sequence 38 # CHECK: transform.structured.generalize 39 40 41@run 42def testInterchange(): 43 sequence = transform.SequenceOp() 44 with InsertionPoint(sequence.body): 45 structured.InterchangeOp( 46 sequence.bodyTarget, 47 iterator_interchange=[ 48 IntegerAttr.get(IntegerType.get_signless(64), 1), 0 49 ]) 50 transform.YieldOp() 51 # CHECK-LABEL: TEST: testInterchange 52 # CHECK: transform.sequence 53 # CHECK: transform.structured.interchange 54 # CHECK: iterator_interchange = [1, 0] 55 56 57@run 58def testPad(): 59 sequence = transform.SequenceOp() 60 with InsertionPoint(sequence.body): 61 structured.PadOp( 62 sequence.bodyTarget, 63 padding_values=[FloatAttr.get_f32(42.0)], 64 padding_dimensions=[1], 65 transpose_paddings=[[1, 0]]) 66 transform.YieldOp() 67 # CHECK-LABEL: TEST: testPad 68 # CHECK: transform.sequence 69 # CHECK: transform.structured.pad 70 # CHECK-DAG: padding_values = [4.200000e+01 : f32] 71 # CHECK-DAG: padding_dimensions = [1] 72 # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] 73 # CHECK-DAG: hoist_paddings = [] 74 # CHECK-DAG: pack_paddings = [] 75 76 77@run 78def testScalarize(): 79 sequence = transform.SequenceOp() 80 with InsertionPoint(sequence.body): 81 structured.ScalarizeOp(sequence.bodyTarget) 82 transform.YieldOp() 83 # CHECK-LABEL: TEST: testScalarize 84 # CHECK: transform.structured.scalarize 85 86 87@run 88def testSplit(): 89 sequence = transform.SequenceOp() 90 with InsertionPoint(sequence.body): 91 split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42) 92 structured.SplitOp( 93 split.results[0], dimension=3, split_point=split.results[1]) 94 transform.YieldOp() 95 # CHECK-LABEL: TEST: testSplit 96 # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 97 # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 98 99 100@run 101def testTileCompact(): 102 sequence = transform.SequenceOp() 103 with InsertionPoint(sequence.body): 104 structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) 105 transform.YieldOp() 106 # CHECK-LABEL: TEST: testTileCompact 107 # CHECK: transform.sequence 108 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] 109 # CHECK: interchange = [0, 1] 110 111 112@run 113def testTileAttributes(): 114 sequence = transform.SequenceOp() 115 attr = ArrayAttr.get( 116 [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]]) 117 ichange = ArrayAttr.get( 118 [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]]) 119 with InsertionPoint(sequence.body): 120 structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) 121 transform.YieldOp() 122 # CHECK-LABEL: TEST: testTileAttributes 123 # CHECK: transform.sequence 124 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] 125 # CHECK: interchange = [0, 1] 126 127 128@run 129def testTileZero(): 130 sequence = transform.SequenceOp() 131 with InsertionPoint(sequence.body): 132 structured.TileOp( 133 sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) 134 transform.YieldOp() 135 # CHECK-LABEL: TEST: testTileZero 136 # CHECK: transform.sequence 137 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0] 138 # CHECK: interchange = [0, 1, 2, 3] 139 140 141@run 142def testTileDynamic(): 143 with_pdl = transform.WithPDLPatternsOp() 144 with InsertionPoint(with_pdl.body): 145 sequence = transform.SequenceOp(with_pdl.bodyTarget) 146 with InsertionPoint(sequence.body): 147 m1 = transform.PDLMatchOp(sequence.bodyTarget, "first") 148 m2 = transform.PDLMatchOp(sequence.bodyTarget, "second") 149 structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) 150 transform.YieldOp() 151 # CHECK-LABEL: TEST: testTileDynamic 152 # CHECK: %[[FIRST:.+]] = pdl_match 153 # CHECK: %[[SECOND:.+]] = pdl_match 154 # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0] 155 156 157@run 158def testVectorize(): 159 sequence = transform.SequenceOp() 160 with InsertionPoint(sequence.body): 161 structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) 162 transform.YieldOp() 163 # CHECK-LABEL: TEST: testVectorize 164 # CHECK: transform.sequence 165 # CHECK: = transform.structured.vectorize 166 # CHECK: vectorize_padding = true 167