1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import arith
5from mlir.dialects import builtin
6from mlir.dialects import func
7from mlir.dialects import linalg
8
9from mlir.dialects.linalg.opdsl.lang import *
10
11
12def run(f):
13  print("\nTEST:", f.__name__)
14  f()
15  return f
16
17
18# CHECK-LABEL: TEST: testInitTensor
19@run
20def testInitTensor():
21  with Context() as ctx, Location.unknown():
22    module = Module.create()
23    f32 = F32Type.get()
24    with InsertionPoint(module.body):
25      # CHECK-LABEL: func @static_sizes
26      # CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32>
27      @func.FuncOp.from_py_func()
28      def static_sizes():
29        return linalg.InitTensorOp([3, 4], f32)
30
31      # CHECK-LABEL: func @dynamic_sizes
32      # CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
33      @func.FuncOp.from_py_func(IndexType.get(), IndexType.get())
34      def dynamic_sizes(d0, d1):
35        return linalg.InitTensorOp([d0, d1], f32)
36
37      # CHECK-LABEL: func @zero_d
38      # CHECK: %0 = linalg.init_tensor [] : tensor<f32>
39      @func.FuncOp.from_py_func()
40      def zero_d():
41        return linalg.InitTensorOp([], f32)
42
43  print(module)
44
45
46# CHECK-LABEL: TEST: testInitTensorStaticSizesAttribute
47@run
48def testInitTensorStaticSizesAttribute():
49  with Context() as ctx, Location.unknown():
50    module = Module.create()
51    f32 = F32Type.get()
52    with InsertionPoint(module.body):
53      op = linalg.InitTensorOp([3, 4], f32)
54      # CHECK: [3, 4]
55      print(op.attributes["static_sizes"])
56
57
58# CHECK-LABEL: TEST: testFill
59@run
60def testFill():
61  with Context() as ctx, Location.unknown():
62    module = Module.create()
63    f32 = F32Type.get()
64    with InsertionPoint(module.body):
65      # CHECK-LABEL: func @fill_tensor
66      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
67      #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
68      #  CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
69      #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
70      @func.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32))
71      def fill_tensor(out):
72        zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
73        return linalg.fill(zero, outs=[out])
74
75      # CHECK-LABEL: func @fill_buffer
76      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
77      #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
78      #  CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
79      #  CHECK-NEXT: return
80      @func.FuncOp.from_py_func(MemRefType.get((12, -1), f32))
81      def fill_buffer(out):
82        zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
83        linalg.fill(zero, outs=[out])
84
85  print(module)
86
87
88# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
89@run
90def testNamedStructuredOpCustomForm():
91  with Context() as ctx, Location.unknown():
92    module = Module.create()
93    f32 = F32Type.get()
94    with InsertionPoint(module.body):
95
96      @func.FuncOp.from_py_func(
97          RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32))
98      def named_form(lhs, rhs):
99        init_result = linalg.InitTensorOp([4, 8], f32)
100        # Check for the named form with custom format
101        #      CHECK: linalg.elemwise_unary
102        # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
103        # CHECK-SAME:    fun = #linalg.unary_fn<exp>
104        # CHECK-SAME:    ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
105        unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
106        #      CHECK: linalg.elemwise_binary
107        # CHECK-SAME:    cast = #linalg.type_fn<cast_unsigned>
108        # CHECK-SAME:    fun = #linalg.binary_fn<mul>
109        # CHECK-SAME:    ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
110        #      CHECK: return
111        binary_result = linalg.elemwise_binary(
112            lhs,
113            rhs,
114            outs=[init_result.result],
115            fun=BinaryFn.mul,
116            cast=TypeFn.cast_unsigned)
117        return unary_result, binary_result
118
119  print(module)
120
121
122# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
123@run
124def testNamedStructuredOpGenericForm():
125  with Context() as ctx, Location.unknown():
126    module = Module.create()
127    f32 = F32Type.get()
128    with InsertionPoint(module.body):
129
130      @func.FuncOp.from_py_func(
131          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
132                                                                   f32))
133      def named_form(lhs, rhs):
134        init_result = linalg.InitTensorOp([4, 8], f32)
135        #      CHECK: "linalg.matmul"(%{{.*}})
136        # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
137        # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
138        # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
139        # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
140        # CHECK-NEXT:    cast = #linalg.type_fn<cast_signed>
141        # CHECK-SAME:    operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
142        # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
143        return linalg.matmul(lhs, rhs, outs=[init_result.result])
144
145  module.operation.print(print_generic_op_form=True)
146
147
148# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
149@run
150def testNamedStructuredAsGenericOp():
151  with Context() as ctx, Location.unknown():
152    module = Module.create()
153    f32 = F32Type.get()
154    with InsertionPoint(module.body):
155
156      @func.FuncOp.from_py_func(
157          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
158                                                                   f32))
159      def generic_form(lhs, rhs):
160        init_result = linalg.InitTensorOp([4, 8], f32)
161        # CHECK: linalg.generic
162        return linalg.matmul(
163            lhs, rhs, outs=[init_result.result], emit_generic=True)
164
165  print(module)
166
167
168# CHECK-LABEL: TEST: testOpResultFromOtherOp
169@run
170def testOpResultFromOtherOp():
171  with Context(), Location.unknown():
172    module = Module.create()
173    f32 = F32Type.get()
174    with InsertionPoint(module.body):
175
176      @func.FuncOp.from_py_func(
177          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
178                                                                   f32))
179      def pass_an_op_directly(arg0, arg1):
180        one = arith.ConstantOp(F32Type.get(), 1.0)
181        # CHECK: %[[LHS:.*]] = linalg.fill
182        lhs = linalg.fill(one, outs=[arg0])
183        # CHECK: %[[RHS:.*]] = linalg.fill
184        rhs = linalg.fill(one, outs=[arg1])
185        # CHECK: %[[INIT:.*]] = linalg.init_tensor
186        init = linalg.InitTensorOp([4, 8], f32)
187        # CHECK: linalg.matmul
188        # CHECK: ins(%[[LHS]], %[[RHS]]
189        # CHECK: outs(%[[INIT]]
190        return linalg.matmul(lhs, rhs, outs=init)
191
192  print(module)
193