1# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s 2 3from mlir.dialects.linalg.opdsl.lang import * 4 5 6# CHECK: --- 7# CHECK-LABEL: matmul 8# CHECK: args: 9# CHECK: name: A 10# CHECK: kind: input_tensor 11# CHECK: type_var: T 12# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> 13# CHECK: name: B 14# CHECK: kind: input_tensor 15# CHECK: type_var: T 16# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> 17# CHECK: name: C 18# CHECK: kind: output_tensor 19# CHECK: type_var: U 20# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> 21# CHECK: name: bfn 22# CHECK: kind: binary_fn_attr 23# CHECK: default_fn: mul 24# CHECK: name: ufn 25# CHECK: kind: unary_fn_attr 26# CHECK: default_fn: exp 27# CHECK: name: cast 28# CHECK: kind: type_fn_attr 29# CHECK: default_fn: cast_signed 30@linalg_structured_op 31def matmul( 32 A=TensorDef(T, S.M, S.K), 33 B=TensorDef(T, S.K, S.N), 34 C=TensorDef(U, S.M, S.N, output=True), 35 bfn=BinaryFnAttrDef(default=BinaryFn.mul), 36 ufn=UnaryFnAttrDef(default=UnaryFn.exp), 37 cast=TypeFnAttrDef(default=TypeFn.cast_signed)): 38 C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) 39 40 41# CHECK: --- 42# CHECK-LABEL: fill 43# CHECK: args: 44# CHECK: name: value 45# CHECK: kind: scalar 46# CHECK-NOT: shape_map: 47# CHECK: type_var: T 48@linalg_structured_op 49def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): 50 O[D.m, D.n] = value 51 52 53# CHECK: --- 54# CHECK-LABEL: strided_copy 55# CHECK: args: 56# CHECK: name: I 57# CHECK: kind: input_tensor 58# CHECK: type_var: T 59# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)> 60# CHECK: name: O 61# CHECK: kind: output_tensor 62# CHECK: type_var: T 63# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)> 64# CHECK: name: strides 65# CHECK: kind: index_attr 66# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)> 67# CHECK: default_indices: 68# CHECK: - 1 69# CHECK: - 2 70@linalg_structured_op 71def strided_copy( 72 I=TensorDef(T, S.IH, S.IW), 73 O=TensorDef(T, S.OH, S.OW, output=True), 74 strides=IndexAttrDef(S.SH, S.SW, default=[1, 2])): 75 O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW] 76