1// Force this file to use the kDirect method for sparse2sparse.
2// RUN: mlir-opt %s --sparse-compiler="s2s-strategy=2" | \
3// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
4// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
5// RUN: FileCheck %s
6
7#Tensor1 = #sparse_tensor.encoding<{
8  dimLevelType = [ "dense", "dense", "compressed" ]
9}>
10
11// NOTE: dense after compressed is not currently supported for the target
12// of direct-sparse2sparse conversion.  (It's fine for the source though.)
13#Tensor2 = #sparse_tensor.encoding<{
14  dimLevelType = [ "dense", "compressed", "dense" ]
15}>
16
17#Tensor3 = #sparse_tensor.encoding<{
18  dimLevelType = [ "dense", "dense", "compressed" ],
19  dimOrdering = affine_map<(i,j,k) -> (i,k,j)>
20}>
21
22module {
23  //
24  // Utilities for output and releasing memory.
25  //
26  func.func @dump(%arg0: tensor<2x3x4xf64>) {
27    %c0 = arith.constant 0 : index
28    %d0 = arith.constant -1.0 : f64
29    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %d0: tensor<2x3x4xf64>, vector<2x3x4xf64>
30    vector.print %0 : vector<2x3x4xf64>
31    return
32  }
33  func.func @dumpAndRelease_234(%arg0: tensor<2x3x4xf64>) {
34    call @dump(%arg0) : (tensor<2x3x4xf64>) -> ()
35    %1 = bufferization.to_memref %arg0 : memref<2x3x4xf64>
36    memref.dealloc %1 : memref<2x3x4xf64>
37    return
38  }
39
40  //
41  // Main driver.
42  //
43  func.func @entry() {
44    //
45    // Initialize a 3-dim dense tensor.
46    //
47    %src = arith.constant dense<[
48       [  [  1.0,  2.0,  3.0,  4.0 ],
49          [  5.0,  6.0,  7.0,  8.0 ],
50          [  9.0, 10.0, 11.0, 12.0 ] ],
51       [  [ 13.0, 14.0, 15.0, 16.0 ],
52          [ 17.0, 18.0, 19.0, 20.0 ],
53          [ 21.0, 22.0, 23.0, 24.0 ] ]
54    ]> : tensor<2x3x4xf64>
55
56    //
57    // Convert dense tensor directly to various sparse tensors.
58    //
59    %s1 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor1>
60    %s2 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor2>
61    %s3 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor3>
62
63    //
64    // Convert sparse tensor directly to another sparse format.
65    //
66    %t13 = sparse_tensor.convert %s1 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64, #Tensor3>
67    %t21 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf64, #Tensor1>
68    %t23 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf64, #Tensor3>
69    %t31 = sparse_tensor.convert %s3 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64, #Tensor1>
70
71    //
72    // Convert sparse tensor back to dense.
73    //
74    %d13 = sparse_tensor.convert %t13 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64>
75    %d21 = sparse_tensor.convert %t21 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64>
76    %d23 = sparse_tensor.convert %t23 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64>
77    %d31 = sparse_tensor.convert %t31 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64>
78
79    //
80    // Check round-trip equality.  And release dense tensors.
81    //
82    // CHECK-COUNT-5: ( ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ), ( ( 13, 14, 15, 16 ), ( 17, 18, 19, 20 ), ( 21, 22, 23, 24 ) ) )
83    call @dump(%src) : (tensor<2x3x4xf64>) -> ()
84    call @dumpAndRelease_234(%d13) : (tensor<2x3x4xf64>) -> ()
85    call @dumpAndRelease_234(%d21) : (tensor<2x3x4xf64>) -> ()
86    call @dumpAndRelease_234(%d23) : (tensor<2x3x4xf64>) -> ()
87    call @dumpAndRelease_234(%d31) : (tensor<2x3x4xf64>) -> ()
88
89    //
90    // Release sparse tensors.
91    //
92    sparse_tensor.release %t13 : tensor<2x3x4xf64, #Tensor3>
93    sparse_tensor.release %t21 : tensor<2x3x4xf64, #Tensor1>
94    sparse_tensor.release %t23 : tensor<2x3x4xf64, #Tensor3>
95    sparse_tensor.release %t31 : tensor<2x3x4xf64, #Tensor1>
96    sparse_tensor.release %s1 : tensor<2x3x4xf64, #Tensor1>
97    sparse_tensor.release %s2 : tensor<2x3x4xf64, #Tensor2>
98    sparse_tensor.release %s3 : tensor<2x3x4xf64, #Tensor3>
99
100    return
101  }
102}
103