1// RUN: mlir-opt %s --sparse-compiler | \ 2// RUN: mlir-cpu-runner \ 3// RUN: -e entry -entry-point-result=void \ 4// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ 5// RUN: FileCheck %s 6 7!Filename = !llvm.ptr<i8> 8 9#SparseMatrix = #sparse_tensor.encoding<{ 10 dimLevelType = [ "compressed", "compressed" ] 11}> 12 13#trait_sum_reduce = { 14 indexing_maps = [ 15 affine_map<(i,j) -> (i,j)>, // A 16 affine_map<(i,j) -> ()> // x (out) 17 ], 18 iterator_types = ["reduction", "reduction"], 19 doc = "x += A(i,j)" 20} 21 22module { 23 // 24 // A kernel that sum-reduces a matrix to a single scalar. 25 // 26 func.func @kernel_sum_reduce(%arga: tensor<?x?xbf16, #SparseMatrix>, 27 %argx: tensor<bf16>) -> tensor<bf16> { 28 %0 = linalg.generic #trait_sum_reduce 29 ins(%arga: tensor<?x?xbf16, #SparseMatrix>) 30 outs(%argx: tensor<bf16>) { 31 ^bb(%a: bf16, %x: bf16): 32 %0 = arith.addf %x, %a : bf16 33 linalg.yield %0 : bf16 34 } -> tensor<bf16> 35 return %0 : tensor<bf16> 36 } 37 38 func.func private @getTensorFilename(index) -> (!Filename) 39 40 // 41 // Main driver that reads matrix from file and calls the sparse kernel. 42 // 43 func.func @entry() { 44 // Setup input sparse matrix from compressed constant. 45 %d = arith.constant dense <[ 46 [ 1.1, 1.2, 0.0, 1.4 ], 47 [ 0.0, 0.0, 0.0, 0.0 ], 48 [ 3.1, 0.0, 3.3, 3.4 ] 49 ]> : tensor<3x4xbf16> 50 %a = sparse_tensor.convert %d : tensor<3x4xbf16> to tensor<?x?xbf16, #SparseMatrix> 51 52 %d0 = arith.constant 0.0 : bf16 53 // Setup memory for a single reduction scalar, 54 // initialized to zero. 55 %x = tensor.from_elements %d0 : tensor<bf16> 56 57 // Call the kernel. 58 %0 = call @kernel_sum_reduce(%a, %x) 59 : (tensor<?x?xbf16, #SparseMatrix>, tensor<bf16>) -> tensor<bf16> 60 61 // Print the result for verification. 62 // 63 // CHECK: 13.5 64 // 65 %v = tensor.extract %0[] : tensor<bf16> 66 %vf = arith.extf %v: bf16 to f32 67 vector.print %vf : f32 68 69 // Release the resources. 70 bufferization.dealloc_tensor %a : tensor<?x?xbf16, #SparseMatrix> 71 72 return 73 } 74} 75