1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects import pdl 6from mlir.dialects.transform import structured 7 8 9def run(f): 10 with Context(), Location.unknown(): 11 module = Module.create() 12 with InsertionPoint(module.body): 13 print("\nTEST:", f.__name__) 14 f() 15 print(module) 16 return f 17 18 19@run 20def testDecompose(): 21 sequence = transform.SequenceOp() 22 with InsertionPoint(sequence.body): 23 structured.DecomposeOp(sequence.bodyTarget) 24 transform.YieldOp() 25 # CHECK-LABEL: TEST: testDecompose 26 # CHECK: transform.sequence 27 # CHECK: transform.structured.decompose 28 29 30@run 31def testGeneralize(): 32 sequence = transform.SequenceOp() 33 with InsertionPoint(sequence.body): 34 structured.GeneralizeOp(sequence.bodyTarget) 35 transform.YieldOp() 36 # CHECK-LABEL: TEST: testGeneralize 37 # CHECK: transform.sequence 38 # CHECK: transform.structured.generalize 39 40 41@run 42def testInterchange(): 43 sequence = transform.SequenceOp() 44 with InsertionPoint(sequence.body): 45 structured.InterchangeOp( 46 sequence.bodyTarget, 47 iterator_interchange=[ 48 IntegerAttr.get(IntegerType.get_signless(64), 1), 0 49 ]) 50 transform.YieldOp() 51 # CHECK-LABEL: TEST: testInterchange 52 # CHECK: transform.sequence 53 # CHECK: transform.structured.interchange 54 # CHECK: iterator_interchange = [1, 0] 55 56 57@run 58def testPad(): 59 sequence = transform.SequenceOp() 60 with InsertionPoint(sequence.body): 61 structured.PadOp( 62 sequence.bodyTarget, 63 padding_values=[FloatAttr.get_f32(42.0)], 64 padding_dimensions=[1], 65 transpose_paddings=[[1, 0]]) 66 transform.YieldOp() 67 # CHECK-LABEL: TEST: testPad 68 # CHECK: transform.sequence 69 # CHECK: transform.structured.pad 70 # CHECK-DAG: padding_values = [4.200000e+01 : f32] 71 # CHECK-DAG: padding_dimensions = [1] 72 # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] 73 # CHECK-DAG: hoist_paddings = [] 74 # CHECK-DAG: pack_paddings = [] 75 76 77@run 78def testScalarize(): 79 sequence = transform.SequenceOp() 80 with InsertionPoint(sequence.body): 81 structured.ScalarizeOp(sequence.bodyTarget) 82 transform.YieldOp() 83 # CHECK-LABEL: TEST: testScalarize 84 # CHECK: transform.structured.scalarize 85 86 87@run 88def testTileCompact(): 89 sequence = transform.SequenceOp() 90 with InsertionPoint(sequence.body): 91 structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) 92 transform.YieldOp() 93 # CHECK-LABEL: TEST: testTileCompact 94 # CHECK: transform.sequence 95 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile 96 # CHECK-DAG: interchange = [0, 1] 97 # CHECK-DAG: sizes = [4, 8] 98 99 100@run 101def testTileAttributes(): 102 sequence = transform.SequenceOp() 103 attr = ArrayAttr.get( 104 [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]]) 105 ichange = ArrayAttr.get( 106 [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]]) 107 with InsertionPoint(sequence.body): 108 structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) 109 transform.YieldOp() 110 # CHECK-LABEL: TEST: testTileAttributes 111 # CHECK: transform.sequence 112 # CHECK: structured.tile 113 # CHECK-DAG: interchange = [0, 1] 114 # CHECK-DAG: sizes = [4, 8] 115 116 117@run 118def testTileZero(): 119 sequence = transform.SequenceOp() 120 with InsertionPoint(sequence.body): 121 structured.TileOp( 122 sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) 123 transform.YieldOp() 124 # CHECK-LABEL: TEST: testTileZero 125 # CHECK: transform.sequence 126 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile 127 # CHECK-DAG: interchange = [0, 1, 2, 3] 128 # CHECK-DAG: sizes = [4, 0, 2, 0] 129 130 131@run 132def testVectorize(): 133 sequence = transform.SequenceOp() 134 with InsertionPoint(sequence.body): 135 structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) 136 transform.YieldOp() 137 # CHECK-LABEL: TEST: testVectorize 138 # CHECK: transform.sequence 139 # CHECK: = transform.structured.vectorize 140 # CHECK: vectorize_padding = true 141