1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import arith 5from mlir.dialects import builtin 6from mlir.dialects import func 7from mlir.dialects import linalg 8 9from mlir.dialects.linalg.opdsl.lang import * 10 11 12def run(f): 13 print("\nTEST:", f.__name__) 14 f() 15 return f 16 17 18# CHECK-LABEL: TEST: testInitTensor 19@run 20def testInitTensor(): 21 with Context() as ctx, Location.unknown(): 22 module = Module.create() 23 f32 = F32Type.get() 24 with InsertionPoint(module.body): 25 # CHECK-LABEL: func @static_sizes 26 # CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32> 27 @func.FuncOp.from_py_func() 28 def static_sizes(): 29 return linalg.InitTensorOp([3, 4], f32) 30 31 # CHECK-LABEL: func @dynamic_sizes 32 # CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32> 33 @func.FuncOp.from_py_func(IndexType.get(), IndexType.get()) 34 def dynamic_sizes(d0, d1): 35 return linalg.InitTensorOp([d0, d1], f32) 36 37 # CHECK-LABEL: func @zero_d 38 # CHECK: %0 = linalg.init_tensor [] : tensor<f32> 39 @func.FuncOp.from_py_func() 40 def zero_d(): 41 return linalg.InitTensorOp([], f32) 42 43 print(module) 44 45 46# CHECK-LABEL: TEST: testInitTensorStaticSizesAttribute 47@run 48def testInitTensorStaticSizesAttribute(): 49 with Context() as ctx, Location.unknown(): 50 module = Module.create() 51 f32 = F32Type.get() 52 with InsertionPoint(module.body): 53 op = linalg.InitTensorOp([3, 4], f32) 54 # CHECK: [3, 4] 55 print(op.attributes["static_sizes"]) 56 57 58# CHECK-LABEL: TEST: testFill 59@run 60def testFill(): 61 with Context() as ctx, Location.unknown(): 62 module = Module.create() 63 f32 = F32Type.get() 64 with InsertionPoint(module.body): 65 # CHECK-LABEL: func @fill_tensor 66 # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32> 67 # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 68 # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32> 69 # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> 70 @func.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32)) 71 def fill_tensor(out): 72 zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result 73 return linalg.fill(zero, outs=[out]) 74 75 # CHECK-LABEL: func @fill_buffer 76 # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> 77 # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 78 # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>) 79 # CHECK-NEXT: return 80 @func.FuncOp.from_py_func(MemRefType.get((12, -1), f32)) 81 def fill_buffer(out): 82 zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result 83 linalg.fill(zero, outs=[out]) 84 85 print(module) 86 87 88# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm 89@run 90def testNamedStructuredOpCustomForm(): 91 with Context() as ctx, Location.unknown(): 92 module = Module.create() 93 f32 = F32Type.get() 94 with InsertionPoint(module.body): 95 96 @func.FuncOp.from_py_func( 97 RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)) 98 def named_form(lhs, rhs): 99 init_result = linalg.InitTensorOp([4, 8], f32) 100 # Check for the named form with custom format 101 # CHECK: linalg.elemwise_unary 102 # CHECK-SAME: cast = #linalg.type_fn<cast_signed> 103 # CHECK-SAME: fun = #linalg.unary_fn<exp> 104 # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) 105 unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result]) 106 # CHECK: linalg.elemwise_binary 107 # CHECK-SAME: cast = #linalg.type_fn<cast_unsigned> 108 # CHECK-SAME: fun = #linalg.binary_fn<mul> 109 # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) 110 # CHECK: return 111 binary_result = linalg.elemwise_binary( 112 lhs, 113 rhs, 114 outs=[init_result.result], 115 fun=BinaryFn.mul, 116 cast=TypeFn.cast_unsigned) 117 return unary_result, binary_result 118 119 print(module) 120 121 122# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm 123@run 124def testNamedStructuredOpGenericForm(): 125 with Context() as ctx, Location.unknown(): 126 module = Module.create() 127 f32 = F32Type.get() 128 with InsertionPoint(module.body): 129 130 @func.FuncOp.from_py_func( 131 RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), 132 f32)) 133 def named_form(lhs, rhs): 134 init_result = linalg.InitTensorOp([4, 8], f32) 135 # CHECK: "linalg.matmul"(%{{.*}}) 136 # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): 137 # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 138 # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 139 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () 140 # CHECK-NEXT: cast = #linalg.type_fn<cast_signed> 141 # CHECK-SAME: operand_segment_sizes = dense<[2, 1]> : vector<2xi32> 142 # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> 143 return linalg.matmul(lhs, rhs, outs=[init_result.result]) 144 145 module.operation.print(print_generic_op_form=True) 146 147 148# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp 149@run 150def testNamedStructuredAsGenericOp(): 151 with Context() as ctx, Location.unknown(): 152 module = Module.create() 153 f32 = F32Type.get() 154 with InsertionPoint(module.body): 155 156 @func.FuncOp.from_py_func( 157 RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), 158 f32)) 159 def generic_form(lhs, rhs): 160 init_result = linalg.InitTensorOp([4, 8], f32) 161 # CHECK: linalg.generic 162 return linalg.matmul( 163 lhs, rhs, outs=[init_result.result], emit_generic=True) 164 165 print(module) 166 167 168# CHECK-LABEL: TEST: testOpResultFromOtherOp 169@run 170def testOpResultFromOtherOp(): 171 with Context(), Location.unknown(): 172 module = Module.create() 173 f32 = F32Type.get() 174 with InsertionPoint(module.body): 175 176 @func.FuncOp.from_py_func( 177 RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), 178 f32)) 179 def pass_an_op_directly(arg0, arg1): 180 one = arith.ConstantOp(F32Type.get(), 1.0) 181 # CHECK: %[[LHS:.*]] = linalg.fill 182 lhs = linalg.fill(one, outs=[arg0]) 183 # CHECK: %[[RHS:.*]] = linalg.fill 184 rhs = linalg.fill(one, outs=[arg1]) 185 # CHECK: %[[INIT:.*]] = linalg.init_tensor 186 init = linalg.InitTensorOp([4, 8], f32) 187 # CHECK: linalg.matmul 188 # CHECK: ins(%[[LHS]], %[[RHS]] 189 # CHECK: outs(%[[INIT]] 190 return linalg.matmul(lhs, rhs, outs=init) 191 192 print(module) 193