1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects import pdl 6from mlir.dialects.transform import structured 7 8 9def run(f): 10 with Context(), Location.unknown(): 11 module = Module.create() 12 with InsertionPoint(module.body): 13 print("\nTEST:", f.__name__) 14 f() 15 print(module) 16 return f 17 18 19@run 20def testInterchange(): 21 sequence = transform.SequenceOp() 22 with InsertionPoint(sequence.body): 23 structured.InterchangeOp( 24 sequence.bodyTarget, 25 iterator_interchange=[ 26 IntegerAttr.get(IntegerType.get_signless(64), 1), 0 27 ]) 28 transform.YieldOp() 29 # CHECK-LABEL: TEST: testInterchange 30 # CHECK: transform.sequence 31 # CHECK: transform.structured.interchange 32 # CHECK: iterator_interchange = [1, 0] 33 34 35@run 36def testPad(): 37 sequence = transform.SequenceOp() 38 with InsertionPoint(sequence.body): 39 structured.PadOp( 40 sequence.bodyTarget, 41 padding_values=[FloatAttr.get_f32(42.0)], 42 padding_dimensions=[1], 43 transpose_paddings=[[1, 0]]) 44 transform.YieldOp() 45 # CHECK-LABEL: TEST: testPad 46 # CHECK: transform.sequence 47 # CHECK: transform.structured.pad 48 # CHECK-DAG: padding_values = [4.200000e+01 : f32] 49 # CHECK-DAG: padding_dimensions = [1] 50 # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] 51 # CHECK-DAG: hoist_paddings = [] 52 # CHECK-DAG: pack_paddings = [] 53 54 55@run 56def testScalarize(): 57 sequence = transform.SequenceOp() 58 with InsertionPoint(sequence.body): 59 structured.ScalarizeOp(sequence.bodyTarget) 60 transform.YieldOp() 61 # CHECK-LABEL: TEST: testScalarize 62 # CHECK: transform.structured.scalarize 63 64 65@run 66def testTileCompact(): 67 sequence = transform.SequenceOp() 68 with InsertionPoint(sequence.body): 69 structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) 70 transform.YieldOp() 71 # CHECK-LABEL: TEST: testTileCompact 72 # CHECK: transform.sequence 73 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile 74 # CHECK-DAG: interchange = [0, 1] 75 # CHECK-DAG: sizes = [4, 8] 76 77 78@run 79def testTileAttributes(): 80 sequence = transform.SequenceOp() 81 attr = ArrayAttr.get( 82 [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]]) 83 ichange = ArrayAttr.get( 84 [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]]) 85 with InsertionPoint(sequence.body): 86 structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) 87 transform.YieldOp() 88 # CHECK-LABEL: TEST: testTileAttributes 89 # CHECK: transform.sequence 90 # CHECK: structured.tile 91 # CHECK-DAG: interchange = [0, 1] 92 # CHECK-DAG: sizes = [4, 8] 93 94 95@run 96def testTileZero(): 97 sequence = transform.SequenceOp() 98 with InsertionPoint(sequence.body): 99 structured.TileOp( 100 sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) 101 transform.YieldOp() 102 # CHECK-LABEL: TEST: testTileZero 103 # CHECK: transform.sequence 104 # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile 105 # CHECK-DAG: interchange = [0, 1, 2, 3] 106 # CHECK-DAG: sizes = [4, 0, 2, 0] 107 108 109@run 110def testVectorize(): 111 sequence = transform.SequenceOp() 112 with InsertionPoint(sequence.body): 113 structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) 114 transform.YieldOp() 115 # CHECK-LABEL: TEST: testVectorize 116 # CHECK: transform.sequence 117 # CHECK: = transform.structured.vectorize 118 # CHECK: vectorize_padding = true 119