1// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
2
3// The script is designed to make adding checks to
4// a test case fast, it is *not* designed to be authoritative
5// about what constitutes a good test! The CHECK should be
6// minimized and named to reflect the test intent.
7
8// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=8" -canonicalize | \
9// RUN:   FileCheck %s
10
11#SparseVector = #sparse_tensor.encoding<{
12  dimLevelType = ["compressed"]
13}>
14
15#trait_1d = {
16  indexing_maps = [
17    affine_map<(i) -> (i)>,  // a
18    affine_map<(i) -> (i)>   // x (out)
19  ],
20  iterator_types = ["parallel"],
21  doc = "X(i) = a(i) op i"
22}
23
24// CHECK-LABEL:   func @sparse_index_1d_conj(
25// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<8xi64> {
26// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<0> : vector<8xi64>
27// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<0> : vector<8xindex>
28// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 8 : index
29// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
30// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : i64
31// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
32// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
33// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
34// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
35// CHECK-DAG:       %[[VAL_10:.*]] = memref.alloc() : memref<8xi64>
36// CHECK-DAG:       linalg.fill ins(%[[VAL_5]] : i64) outs(%[[VAL_10]] : memref<8xi64>)
37// CHECK-DAG:       %[[VAL_11:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
38// CHECK-DAG:       %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
39// CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_12]] step %[[VAL_3]] {
40// CHECK:             %[[VAL_14:.*]] = affine.min #map0(%[[VAL_13]]){{\[}}%[[VAL_12]]]
41// CHECK:             %[[VAL_15:.*]] = vector.create_mask %[[VAL_14]] : vector<8xi1>
42// CHECK:             %[[VAL_16:.*]] = vector.maskedload %[[VAL_8]]{{\[}}%[[VAL_13]]], %[[VAL_15]], %[[VAL_2]] : memref<?xindex>, vector<8xi1>, vector<8xindex> into vector<8xindex>
43// CHECK:             %[[VAL_17:.*]] = vector.maskedload %[[VAL_9]]{{\[}}%[[VAL_13]]], %[[VAL_15]], %[[VAL_1]] : memref<?xi64>, vector<8xi1>, vector<8xi64> into vector<8xi64>
44// CHECK:             %[[VAL_18:.*]] = arith.index_cast %[[VAL_16]] : vector<8xindex> to vector<8xi64>
45// CHECK:             %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_18]] : vector<8xi64>
46// CHECK:             vector.scatter %[[VAL_10]]{{\[}}%[[VAL_6]]] {{\[}}%[[VAL_16]]], %[[VAL_15]], %[[VAL_19]] : memref<8xi64>, vector<8xindex>, vector<8xi1>, vector<8xi64>
47// CHECK:           }
48// CHECK:           %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<8xi64>
49// CHECK:           return %[[VAL_20]] : tensor<8xi64>
50// CHECK:         }
51func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> {
52  %init = linalg.init_tensor [8] : tensor<8xi64>
53  %r = linalg.generic #trait_1d
54      ins(%arga: tensor<8xi64, #SparseVector>)
55     outs(%init: tensor<8xi64>) {
56      ^bb(%a: i64, %x: i64):
57        %i = linalg.index 0 : index
58        %ii = arith.index_cast %i : index to i64
59        %m1 = arith.muli %a, %ii : i64
60        linalg.yield %m1 : i64
61  } -> tensor<8xi64>
62  return %r : tensor<8xi64>
63}
64
65// CHECK-LABEL:   func @sparse_index_1d_disj(
66// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<8xi64> {
67// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
68// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
69// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : i64
70// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
71// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
72// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
73// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
74// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
75// CHECK-DAG:       %[[VAL_9:.*]] = memref.alloc() : memref<8xi64>
76// CHECK-DAG:       linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_9]] : memref<8xi64>)
77// CHECK-DAG:       %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
78// CHECK-DAG:       %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_2]]] : memref<?xindex>
79// CHECK:           %[[VAL_12:.*]]:2 = scf.while (%[[VAL_13:.*]] = %[[VAL_10]], %[[VAL_14:.*]] = %[[VAL_5]]) : (index, index) -> (index, index) {
80// CHECK:             %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_11]] : index
81// CHECK:             scf.condition(%[[VAL_15]]) %[[VAL_13]], %[[VAL_14]] : index, index
82// CHECK:           } do {
83// CHECK:           ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index):
84// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
85// CHECK:             %[[VAL_19:.*]] = arith.cmpi eq, %[[VAL_18]], %[[VAL_17]] : index
86// CHECK:             scf.if %[[VAL_19]] {
87// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xi64>
88// CHECK:               %[[VAL_21:.*]] = arith.index_cast %[[VAL_17]] : index to i64
89// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : i64
90// CHECK:               memref.store %[[VAL_22]], %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<8xi64>
91// CHECK:             } else {
92// CHECK:               %[[VAL_23:.*]] = arith.index_cast %[[VAL_17]] : index to i64
93// CHECK:               memref.store %[[VAL_23]], %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<8xi64>
94// CHECK:             }
95// CHECK:             %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_18]], %[[VAL_17]] : index
96// CHECK:             %[[VAL_25:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index
97// CHECK:             %[[VAL_26:.*]] = arith.select %[[VAL_24]], %[[VAL_25]], %[[VAL_16]] : index
98// CHECK:             %[[VAL_27:.*]] = arith.addi %[[VAL_17]], %[[VAL_2]] : index
99// CHECK:             scf.yield %[[VAL_26]], %[[VAL_27]] : index, index
100// CHECK:           }
101// CHECK:           scf.for %[[VAL_28:.*]] = %[[VAL_29:.*]]#1 to %[[VAL_4]] step %[[VAL_4]] {
102// CHECK:             %[[VAL_30:.*]] = affine.min #map1(%[[VAL_28]])
103// CHECK:             %[[VAL_31:.*]] = vector.create_mask %[[VAL_30]] : vector<8xi1>
104// CHECK:             %[[VAL_32:.*]] = vector.broadcast %[[VAL_28]] : index to vector<8xindex>
105// CHECK:             %[[VAL_33:.*]] = arith.addi %[[VAL_32]], %[[VAL_1]] : vector<8xindex>
106// CHECK:             %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : vector<8xindex> to vector<8xi64>
107// CHECK:             vector.maskedstore %[[VAL_9]]{{\[}}%[[VAL_28]]], %[[VAL_31]], %[[VAL_34]] : memref<8xi64>, vector<8xi1>, vector<8xi64>
108// CHECK:           }
109// CHECK:           %[[VAL_35:.*]] = bufferization.to_tensor %[[VAL_9]] : memref<8xi64>
110// CHECK:           return %[[VAL_35]] : tensor<8xi64>
111// CHECK:         }
112func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> {
113  %init = linalg.init_tensor [8] : tensor<8xi64>
114  %r = linalg.generic #trait_1d
115      ins(%arga: tensor<8xi64, #SparseVector>)
116     outs(%init: tensor<8xi64>) {
117      ^bb(%a: i64, %x: i64):
118        %i = linalg.index 0 : index
119        %ii = arith.index_cast %i : index to i64
120        %m1 = arith.addi %a, %ii : i64
121        linalg.yield %m1 : i64
122  } -> tensor<8xi64>
123  return %r : tensor<8xi64>
124}
125