1// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
2// RUN: mlir-opt %s --sparse-tensor-conversion --cse | FileCheck %s --check-prefix=CHECK-CONV
3
4#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
5#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
6
7//
8// roundtrip:
9//
10// CHECK-ROUND-LABEL: func.func @sparse_expand(
11// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
12//      CHECK-ROUND:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
13//      CHECK-ROUND:  return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
14//
15// conversion:
16//
17// CHECK-CONV-LABEL: func.func @sparse_expand(
18// CHECK-CONV-DAG:  %[[C0:.*]] = arith.constant 0 : index
19// CHECK-CONV-DAG:  %[[C1:.*]] = arith.constant 1 : index
20// CHECK-CONV-DAG:  %[[C10:.*]] = arith.constant 10 : index
21// CHECK-CONV-DAG:  call @newSparseTensor
22// CHECK-CONV-DAG:  call @newSparseTensor
23// CHECK-CONV:      scf.while : () -> () {
24// CHECK-CONV:        call @getNextF64
25// CHECK-CONV:        scf.condition(%13)
26// CHECK-CONV:      } do {
27// CHECK-CONV:        %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xindex>
28// CHECK-CONV:        %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index
29// CHECK-CONV:        memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<?xindex>
30// CHECK-CONV:        %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index
31// CHECK-CONV:        memref.store %[[R]], %{{.*}}[%[[C1]]] : memref<?xindex>
32// CHECK-CONV:        call @addEltF64
33// CHECK-CONV:        scf.yield
34// CHECK-CONV:      }
35// CHECK-CONV:      %[[N:.*]] = call @newSparseTensor
36// CHECK-CONV:      call @delSparseTensorCOOF64
37// CHECK-CONV:      call @delSparseTensorCOOF64
38// CHECK-CONV:      return %[[N]] : !llvm.ptr<i8>
39//
40func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
41  %0 = tensor.expand_shape %arg0 [[0, 1]] :
42    tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
43  return %0 : tensor<10x10xf64, #SparseMatrix>
44}
45
46//
47// roundtrip:
48//
49// CHECK-ROUND-LABEL: func.func @sparse_collapse(
50// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
51//      CHECK-ROUND:  %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
52//      CHECK-ROUND:  return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
53//
54// conversion:
55//
56// CHECK-CONV-LABEL: func.func @sparse_collapse(
57// CHECK-CONV-DAG:  %[[C0:.*]] = arith.constant 0 : index
58// CHECK-CONV-DAG:  %[[C1:.*]] = arith.constant 1 : index
59// CHECK-CONV-DAG:  %[[C10:.*]] = arith.constant 10 : index
60// CHECK-CONV-DAG:  call @newSparseTensor
61// CHECK-CONV-DAG:  call @newSparseTensor
62// CHECK-CONV:      scf.while : () -> () {
63// CHECK-CONV:        call @getNextF64
64// CHECK-CONV:        scf.condition(%13)
65// CHECK-CONV:      } do {
66// CHECK-CONV:        %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xindex>
67// CHECK-CONV:        %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index
68// CHECK-CONV:        %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xindex>
69// CHECK-CONV:        %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index
70// CHECK-CONV:        memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<?xindex>
71// CHECK-CONV:        call @addEltF64
72// CHECK-CONV:        scf.yield
73// CHECK-CONV:      }
74// CHECK-CONV:      %[[N:.*]] = call @newSparseTensor
75// CHECK-CONV:      call @delSparseTensorCOOF64
76// CHECK-CONV:      call @delSparseTensorCOOF64
77// CHECK-CONV:      return %[[N]] : !llvm.ptr<i8>
78//
79func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
80  %0 = tensor.collapse_shape %arg0 [[0, 1]] :
81    tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector>
82  return %0 : tensor<100xf64, #SparseVector>
83}
84