1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects import pdl 6from mlir.dialects.transform import structured 7 8 9def run(f): 10 with Context(), Location.unknown(): 11 module = Module.create() 12 with InsertionPoint(module.body): 13 print("\nTEST:", f.__name__) 14 f() 15 print(module) 16 return f 17 18 19@run 20def testDecompose(): 21 sequence = transform.SequenceOp() 22 with InsertionPoint(sequence.body): 23 structured.DecomposeOp(sequence.bodyTarget) 24 transform.YieldOp() 25 # CHECK-LABEL: TEST: testDecompose 26 # CHECK: transform.sequence 27 # CHECK: transform.structured.decompose 28 29 30@run 31def testGeneralize(): 32 sequence = transform.SequenceOp() 33 with InsertionPoint(sequence.body): 34 structured.GeneralizeOp(sequence.bodyTarget) 35 transform.YieldOp() 36 # CHECK-LABEL: TEST: testGeneralize 37 # CHECK: transform.sequence 38 # CHECK: transform.structured.generalize 39 40 41@run 42def testInterchange(): 43 sequence = transform.SequenceOp() 44 with InsertionPoint(sequence.body): 45 structured.InterchangeOp( 46 sequence.bodyTarget, 47 iterator_interchange=[ 48 IntegerAttr.get(IntegerType.get_signless(64), 1), 0 49 ]) 50 transform.YieldOp() 51 # CHECK-LABEL: TEST: testInterchange 52 # CHECK: transform.sequence 53 # CHECK: transform.structured.interchange 54 # CHECK: iterator_interchange = [1, 0] 55 56 57@run 58def testMultitileSizes(): 59 sequence = transform.SequenceOp() 60 with InsertionPoint(sequence.body): 61 structured.MultiTileSizesOp( 62 sequence.bodyTarget, dimension=1, target_size=42) 63 transform.YieldOp() 64 # CHECK-LABEL: TEST: testMultitileSizes 65 # CHECK: transform.sequence 66 # CHECK: transform.structured.multitile_sizes 67 # CHECK-DAG: dimension = 1 68 # CHECK-DAG: target_size = 42 69 70 71@run 72def testPad(): 73 sequence = transform.SequenceOp() 74 with InsertionPoint(sequence.body): 75 structured.PadOp( 76 sequence.bodyTarget, 77 padding_values=[FloatAttr.get_f32(42.0)], 78 padding_dimensions=[1], 79 transpose_paddings=[[1, 0]]) 80 transform.YieldOp() 81 # CHECK-LABEL: TEST: testPad 82 # CHECK: transform.sequence 83 # CHECK: transform.structured.pad 84 # CHECK-DAG: padding_values = [4.200000e+01 : f32] 85 # CHECK-DAG: padding_dimensions = [1] 86 # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] 87 # CHECK-DAG: hoist_paddings = [] 88 # CHECK-DAG: pack_paddings = [] 89 90 91@run 92def testScalarize(): 93 sequence = transform.SequenceOp() 94 with InsertionPoint(sequence.body): 95 structured.ScalarizeOp(sequence.bodyTarget) 96 transform.YieldOp() 97 # CHECK-LABEL: TEST: testScalarize 98 # CHECK: transform.structured.scalarize 99 100 101@run 102def testSplit(): 103 sequence = transform.SequenceOp() 104 with InsertionPoint(sequence.body): 105 split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42) 106 structured.SplitOp( 107 split.results[0], dimension=3, split_point=split.results[1]) 108 transform.YieldOp() 109 # CHECK-LABEL: TEST: testSplit 110 # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 111 # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 112 113 114@run 115def testTileCompact(): 116 sequence = transform.SequenceOp() 117 with InsertionPoint(sequence.body): 118 structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) 119 transform.YieldOp() 120 # CHECK-LABEL: TEST: testTileCompact 121 # CHECK: transform.sequence 122 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] 123 # CHECK: interchange = [0, 1] 124 125 126@run 127def testTileAttributes(): 128 sequence = transform.SequenceOp() 129 attr = ArrayAttr.get( 130 [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]]) 131 ichange = ArrayAttr.get( 132 [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]]) 133 with InsertionPoint(sequence.body): 134 structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) 135 transform.YieldOp() 136 # CHECK-LABEL: TEST: testTileAttributes 137 # CHECK: transform.sequence 138 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] 139 # CHECK: interchange = [0, 1] 140 141 142@run 143def testTileZero(): 144 sequence = transform.SequenceOp() 145 with InsertionPoint(sequence.body): 146 structured.TileOp( 147 sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) 148 transform.YieldOp() 149 # CHECK-LABEL: TEST: testTileZero 150 # CHECK: transform.sequence 151 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0] 152 # CHECK: interchange = [0, 1, 2, 3] 153 154 155@run 156def testTileDynamic(): 157 with_pdl = transform.WithPDLPatternsOp() 158 with InsertionPoint(with_pdl.body): 159 sequence = transform.SequenceOp(with_pdl.bodyTarget) 160 with InsertionPoint(sequence.body): 161 m1 = transform.PDLMatchOp(sequence.bodyTarget, "first") 162 m2 = transform.PDLMatchOp(sequence.bodyTarget, "second") 163 structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) 164 transform.YieldOp() 165 # CHECK-LABEL: TEST: testTileDynamic 166 # CHECK: %[[FIRST:.+]] = pdl_match 167 # CHECK: %[[SECOND:.+]] = pdl_match 168 # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0] 169 170 171@run 172def testVectorize(): 173 sequence = transform.SequenceOp() 174 with InsertionPoint(sequence.body): 175 structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) 176 transform.YieldOp() 177 # CHECK-LABEL: TEST: testVectorize 178 # CHECK: transform.sequence 179 # CHECK: = transform.structured.vectorize 180 # CHECK: vectorize_padding = true 181