1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import builtin
5from mlir.dialects import linalg
6from mlir.dialects import std
7from mlir.dialects import arith
8
9
10def run(f):
11  print("\nTEST:", f.__name__)
12  f()
13  return f
14
15
16# CHECK-LABEL: TEST: testInitTensor
17@run
18def testInitTensor():
19  with Context() as ctx, Location.unknown():
20    module = Module.create()
21    f32 = F32Type.get()
22    with InsertionPoint(module.body):
23      # CHECK-LABEL: func @static_sizes
24      # CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32>
25      @builtin.FuncOp.from_py_func()
26      def static_sizes():
27        return linalg.InitTensorOp([3, 4], f32)
28
29      # CHECK-LABEL: func @dynamic_sizes
30      # CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
31      @builtin.FuncOp.from_py_func(IndexType.get(), IndexType.get())
32      def dynamic_sizes(d0, d1):
33        return linalg.InitTensorOp([d0, d1], f32)
34
35      # CHECK-LABEL: func @zero_d
36      # CHECK: %0 = linalg.init_tensor [] : tensor<f32>
37      @builtin.FuncOp.from_py_func()
38      def zero_d():
39        return linalg.InitTensorOp([], f32)
40
41  print(module)
42
43
44# CHECK-LABEL: TEST: testInitTensorStaticSizesAttribute
45@run
46def testInitTensorStaticSizesAttribute():
47  with Context() as ctx, Location.unknown():
48    module = Module.create()
49    f32 = F32Type.get()
50    with InsertionPoint(module.body):
51      op = linalg.InitTensorOp([3, 4], f32)
52      # CHECK: [3, 4]
53      print(op.attributes["static_sizes"])
54
55
56# CHECK-LABEL: TEST: testFill
57@run
58def testFill():
59  with Context() as ctx, Location.unknown():
60    module = Module.create()
61    f32 = F32Type.get()
62    with InsertionPoint(module.body):
63      # CHECK-LABEL: func @fill_tensor
64      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
65      #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
66      #  CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[CST]], %[[OUT]]) : f32, tensor<12x?xf32> -> tensor<12x?xf32>
67      #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
68      @builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32))
69      def fill_tensor(out):
70        zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
71        return linalg.FillOp(output=out, value=zero).result
72
73      # CHECK-LABEL: func @fill_buffer
74      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
75      #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
76      #  CHECK-NEXT: linalg.fill(%[[CST]], %[[OUT]]) : f32, memref<12x?xf32>
77      #  CHECK-NEXT: return
78      @builtin.FuncOp.from_py_func(MemRefType.get((12, -1), f32))
79      def fill_buffer(out):
80        zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
81        linalg.FillOp(output=out, value=zero)
82
83  print(module)
84
85
86# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
87@run
88def testNamedStructuredOpCustomForm():
89  with Context() as ctx, Location.unknown():
90    module = Module.create()
91    f32 = F32Type.get()
92    with InsertionPoint(module.body):
93
94      @builtin.FuncOp.from_py_func(
95          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
96                                                                   f32))
97      def named_form(lhs, rhs):
98        init_result = linalg.InitTensorOp([4, 8], f32)
99        # First check the named form with custom format
100        #      CHECK: linalg.matmul
101        #  CHECK-NOT: linalg.memoized_indexing_maps
102        # CHECK-SAME:    ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
103        # CHECK-SAME:   outs(%{{.*}} : tensor<4x8xf32>)
104        # CHECK-SAME:   -> tensor<4x8xf32>
105        # CHECK-NEXT: return
106        return linalg.matmul(lhs, rhs, outs=[init_result.result])
107
108  print(module)
109
110
111# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
112@run
113def testNamedStructuredOpGenericForm():
114  with Context() as ctx, Location.unknown():
115    module = Module.create()
116    f32 = F32Type.get()
117    with InsertionPoint(module.body):
118
119      @builtin.FuncOp.from_py_func(
120          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
121                                                                   f32))
122      def named_form(lhs, rhs):
123        init_result = linalg.InitTensorOp([4, 8], f32)
124        #      CHECK: "linalg.matmul"(%{{.*}})
125        # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
126        # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
127        # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
128        # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
129        # CHECK-NEXT:    {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
130        # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
131        return linalg.matmul(lhs, rhs, outs=[init_result.result])
132
133  module.operation.print(print_generic_op_form=True)
134
135
136# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
137@run
138def testNamedStructuredAsGenericOp():
139  with Context() as ctx, Location.unknown():
140    module = Module.create()
141    f32 = F32Type.get()
142    with InsertionPoint(module.body):
143
144      @builtin.FuncOp.from_py_func(
145          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
146                                                                   f32))
147      def generic_form(lhs, rhs):
148        init_result = linalg.InitTensorOp([4, 8], f32)
149        # CHECK: linalg.generic
150        return linalg.matmul(
151            lhs, rhs, outs=[init_result.result], emit_generic=True)
152
153  print(module)
154
155
156# CHECK-LABEL: TEST: testOpResultFromOtherOp
157@run
158def testOpResultFromOtherOp():
159  with Context(), Location.unknown():
160    module = Module.create()
161    f32 = F32Type.get()
162    with InsertionPoint(module.body):
163
164      @builtin.FuncOp.from_py_func(
165          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
166                                                                   f32))
167      def pass_an_op_directly(arg0, arg1):
168        one = arith.ConstantOp(F32Type.get(), 1.0)
169        # CHECK: %[[LHS:.*]] = linalg.fill
170        lhs = linalg.FillOp(arg0, one)
171        # CHECK: %[[RHS:.*]] = linalg.fill
172        rhs = linalg.FillOp(arg1, one)
173        # CHECK: %[[INIT:.*]] = linalg.init_tensor
174        init = linalg.InitTensorOp([4, 8], f32)
175        # CHECK: linalg.matmul
176        # CHECK: ins(%[[LHS]], %[[RHS]]
177        # CHECK: outs(%[[INIT]]
178        return linalg.matmul(lhs, rhs, outs=init)
179
180  print(module)
181