1# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s
2
3from mlir.dialects.linalg.opdsl.lang import *
4
5
6# CHECK: ---
7# CHECK-LABEL: matmul
8# CHECK: args:
9# CHECK:     name: A
10# CHECK:     usage: InputOperand
11# CHECK:     type_var: T
12# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
13# CHECK:     name: B
14# CHECK:     usage: InputOperand
15# CHECK:     type_var: T
16# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
17# CHECK:     name: C
18# CHECK:     usage: OutputOperand
19# CHECK:     type_var: U
20# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
21@linalg_structured_op
22def matmul(
23    A=TensorDef(T, S.M, S.K),
24    B=TensorDef(T, S.K, S.N),
25    C=TensorDef(U, S.M, S.N, output=True)):
26  C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
27
28
29# CHECK: ---
30# CHECK-LABEL: fill
31# CHECK: args:
32# CHECK:     name: value
33# CHECK:     usage: InputOperand
34# CHECK-NOT: shape_map:
35# CHECK:     type_var: T
36@linalg_structured_op
37def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
38  O[D.m, D.n] = value
39
40
41# CHECK: ---
42# CHECK-LABEL: strided_copy
43# CHECK: args:
44# CHECK:     name: I
45# CHECK:     usage: InputOperand
46# CHECK:     type_var: T
47# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
48# CHECK:     name: O
49# CHECK:     usage: OutputOperand
50# CHECK:     type_var: T
51# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
52# CHECK:     name: strides
53# CHECK:     usage: IndexAttribute
54# CHECK:     type_var: I64
55# CHECK:     attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
56@linalg_structured_op
57def strided_copy(
58    I=TensorDef(T, S.IH, S.IW),
59    O=TensorDef(T, S.OH, S.OW, output=True),
60    strides=IndexAttrDef(S.SH, S.SW)):
61  O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
62