1// RUN:   mlir-opt %s -pass-pipeline="async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-arith-to-llvm),convert-vector-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts" \
2// RUN: | mlir-cpu-runner                                                      \
3// RUN:     -e main -entry-point-result=void -O0                               \
4// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext  \
5// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext    \
6// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext   \
7// RUN: | FileCheck %s --dump-input=always
8
9func.func @main() {
10
11  // ------------------------------------------------------------------------ //
12  // Blocking async.await outside of the async.execute.
13  // ------------------------------------------------------------------------ //
14  %token, %result = async.execute -> !async.value<f32> {
15    %0 = arith.constant 123.456 : f32
16    async.yield %0 : f32
17  }
18  %1 = async.await %result : !async.value<f32>
19
20  // CHECK: 123.456
21  vector.print %1 : f32
22
23  // ------------------------------------------------------------------------ //
24  // Non-blocking async.await inside the async.execute
25  // ------------------------------------------------------------------------ //
26  %token0, %result0 = async.execute -> !async.value<f32> {
27    %token1, %result2 = async.execute -> !async.value<f32> {
28      %2 = arith.constant 456.789 : f32
29      async.yield %2 : f32
30    }
31    %3 = async.await %result2 : !async.value<f32>
32    async.yield %3 : f32
33  }
34  %4 = async.await %result0 : !async.value<f32>
35
36  // CHECK: 456.789
37  vector.print %4 : f32
38
39  // ------------------------------------------------------------------------ //
40  // Memref allocated inside async.execute region.
41  // ------------------------------------------------------------------------ //
42  %token2, %result2 = async.execute[%token0] -> !async.value<memref<f32>> {
43    %5 = memref.alloc() : memref<f32>
44    %c0 = arith.constant 0.25 : f32
45    memref.store %c0, %5[]: memref<f32>
46    async.yield %5 : memref<f32>
47  }
48  %6 = async.await %result2 : !async.value<memref<f32>>
49  %7 = memref.cast %6 :  memref<f32> to memref<*xf32>
50
51  // CHECK: Unranked Memref
52  // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
53  // CHECK-NEXT: [0.25]
54  call @printMemrefF32(%7): (memref<*xf32>) -> ()
55
56  // ------------------------------------------------------------------------ //
57  // Memref passed as async.execute operand.
58  // ------------------------------------------------------------------------ //
59  %token3 = async.execute(%result2 as %unwrapped : !async.value<memref<f32>>) {
60    %8 = memref.load %unwrapped[]: memref<f32>
61    %9 = arith.addf %8, %8 : f32
62    memref.store %9, %unwrapped[]: memref<f32>
63    async.yield
64  }
65  async.await %token3 : !async.token
66
67  // CHECK: Unranked Memref
68  // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
69  // CHECK-NEXT: [0.5]
70  call @printMemrefF32(%7): (memref<*xf32>) -> ()
71
72  memref.dealloc %6 : memref<f32>
73
74  return
75}
76
77func.func private @printMemrefF32(memref<*xf32>)
78  attributes { llvm.emit_c_interface }
79