1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import builtin 5from mlir.dialects import linalg 6from mlir.dialects import std 7from mlir.dialects import arith 8 9 10def run(f): 11 print("\nTEST:", f.__name__) 12 f() 13 return f 14 15 16# CHECK-LABEL: TEST: testInitTensor 17@run 18def testInitTensor(): 19 with Context() as ctx, Location.unknown(): 20 module = Module.create() 21 f32 = F32Type.get() 22 with InsertionPoint(module.body): 23 # CHECK-LABEL: func @static_sizes 24 # CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32> 25 @builtin.FuncOp.from_py_func() 26 def static_sizes(): 27 return linalg.InitTensorOp([3, 4], f32) 28 29 # CHECK-LABEL: func @dynamic_sizes 30 # CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32> 31 @builtin.FuncOp.from_py_func(IndexType.get(), IndexType.get()) 32 def dynamic_sizes(d0, d1): 33 return linalg.InitTensorOp([d0, d1], f32) 34 35 # CHECK-LABEL: func @zero_d 36 # CHECK: %0 = linalg.init_tensor [] : tensor<f32> 37 @builtin.FuncOp.from_py_func() 38 def zero_d(): 39 return linalg.InitTensorOp([], f32) 40 41 print(module) 42 43 44# CHECK-LABEL: TEST: testInitTensorStaticSizesAttribute 45@run 46def testInitTensorStaticSizesAttribute(): 47 with Context() as ctx, Location.unknown(): 48 module = Module.create() 49 f32 = F32Type.get() 50 with InsertionPoint(module.body): 51 op = linalg.InitTensorOp([3, 4], f32) 52 # CHECK: [3, 4] 53 print(op.attributes["static_sizes"]) 54 55 56# CHECK-LABEL: TEST: testFill 57@run 58def testFill(): 59 with Context() as ctx, Location.unknown(): 60 module = Module.create() 61 f32 = F32Type.get() 62 with InsertionPoint(module.body): 63 # CHECK-LABEL: func @fill_tensor 64 # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32> 65 # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 66 # CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[CST]], %[[OUT]]) : f32, tensor<12x?xf32> -> tensor<12x?xf32> 67 # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> 68 @builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32)) 69 def fill_tensor(out): 70 zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result 71 return linalg.FillOp(output=out, value=zero).result 72 73 # CHECK-LABEL: func @fill_buffer 74 # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> 75 # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 76 # CHECK-NEXT: linalg.fill(%[[CST]], %[[OUT]]) : f32, memref<12x?xf32> 77 # CHECK-NEXT: return 78 @builtin.FuncOp.from_py_func(MemRefType.get((12, -1), f32)) 79 def fill_buffer(out): 80 zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result 81 linalg.FillOp(output=out, value=zero) 82 83 print(module) 84 85 86# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm 87@run 88def testNamedStructuredOpCustomForm(): 89 with Context() as ctx, Location.unknown(): 90 module = Module.create() 91 f32 = F32Type.get() 92 with InsertionPoint(module.body): 93 94 @builtin.FuncOp.from_py_func( 95 RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), 96 f32)) 97 def named_form(lhs, rhs): 98 init_result = linalg.InitTensorOp([4, 8], f32) 99 # First check the named form with custom format 100 # CHECK: linalg.matmul 101 # CHECK-NOT: linalg.memoized_indexing_maps 102 # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>) 103 # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>) 104 # CHECK-SAME: -> tensor<4x8xf32> 105 # CHECK-NEXT: return 106 return linalg.matmul(lhs, rhs, outs=[init_result.result]) 107 108 print(module) 109 110 111# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm 112@run 113def testNamedStructuredOpGenericForm(): 114 with Context() as ctx, Location.unknown(): 115 module = Module.create() 116 f32 = F32Type.get() 117 with InsertionPoint(module.body): 118 119 @builtin.FuncOp.from_py_func( 120 RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), 121 f32)) 122 def named_form(lhs, rhs): 123 init_result = linalg.InitTensorOp([4, 8], f32) 124 # CHECK: "linalg.matmul"(%{{.*}}) 125 # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): 126 # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 127 # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 128 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () 129 # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : 130 # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> 131 return linalg.matmul(lhs, rhs, outs=[init_result.result]) 132 133 module.operation.print(print_generic_op_form=True) 134 135 136# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp 137@run 138def testNamedStructuredAsGenericOp(): 139 with Context() as ctx, Location.unknown(): 140 module = Module.create() 141 f32 = F32Type.get() 142 with InsertionPoint(module.body): 143 144 @builtin.FuncOp.from_py_func( 145 RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), 146 f32)) 147 def generic_form(lhs, rhs): 148 init_result = linalg.InitTensorOp([4, 8], f32) 149 # CHECK: linalg.generic 150 return linalg.matmul( 151 lhs, rhs, outs=[init_result.result], emit_generic=True) 152 153 print(module) 154 155 156# CHECK-LABEL: TEST: testOpResultFromOtherOp 157@run 158def testOpResultFromOtherOp(): 159 with Context(), Location.unknown(): 160 module = Module.create() 161 f32 = F32Type.get() 162 with InsertionPoint(module.body): 163 164 @builtin.FuncOp.from_py_func( 165 RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), 166 f32)) 167 def pass_an_op_directly(arg0, arg1): 168 one = arith.ConstantOp(F32Type.get(), 1.0) 169 # CHECK: %[[LHS:.*]] = linalg.fill 170 lhs = linalg.FillOp(arg0, one) 171 # CHECK: %[[RHS:.*]] = linalg.fill 172 rhs = linalg.FillOp(arg1, one) 173 # CHECK: %[[INIT:.*]] = linalg.init_tensor 174 init = linalg.InitTensorOp([4, 8], f32) 175 # CHECK: linalg.matmul 176 # CHECK: ins(%[[LHS]], %[[RHS]] 177 # CHECK: outs(%[[INIT]] 178 return linalg.matmul(lhs, rhs, outs=init) 179 180 print(module) 181