1 // Split the MLIR string: this will produce %t/input.mlir
2 // RUN: split-file %s %t
3 
4 // Compile the MLIR file to LLVM:
5 // RUN: mlir-opt %t/input.mlir \
6 // RUN:  -lower-affine  -convert-scf-to-cf  -convert-memref-to-llvm \
7 // RUN:  -convert-func-to-llvm -reconcile-unrealized-casts \
8 // RUN: | mlir-translate --mlir-to-llvmir -o %t.ll
9 
10 // Generate an object file for the MLIR code
11 // RUN: llc %t.ll -o %t.o -filetype=obj
12 
13 // Compile the current C file and link it to the MLIR code:
14 // RUN: %host_cc %s %t.o -o %t.exe
15 
16 // Exec
17 // RUN: %t.exe | FileCheck %s
18 
19 /* MLIR_BEGIN
20 //--- input.mlir
21 // Performs: arg0[i, j] = arg0[i, j] + arg1[i, j]
22 func.func private @add_memref(%arg0: memref<?x?xf64>, %arg1: memref<?x?xf64>) -> i64
23    attributes {llvm.emit_c_interface} {
24   %c0 = arith.constant 0 : index
25   %c1 = arith.constant 1 : index
26   %dimI = memref.dim %arg0, %c0 : memref<?x?xf64>
27   %dimJ = memref.dim %arg0, %c1 : memref<?x?xf64>
28   affine.for %i = 0 to %dimI {
29     affine.for %j = 0 to %dimJ {
30       %load0 = memref.load %arg0[%i, %j] : memref<?x?xf64>
31       %load1 = memref.load %arg1[%i, %j] : memref<?x?xf64>
32       %add = arith.addf %load0, %load1 : f64
33       affine.store %add, %arg0[%i, %j] : memref<?x?xf64>
34     }
35   }
36   %c42 = arith.constant 42 : i64
37   return %c42 : i64
38 }
39 
40 //--- end_input.mlir
41 
42 MLIR_END */
43 
44 #include <stdint.h>
45 #include <stdio.h>
46 
47 // Define the API for the MLIR function, see
48 // https://mlir.llvm.org/docs/TargetLLVMIR/#calling-conventions for details.
49 //
50 // The function takes two 2D memref, the signature in MLIR LLVM dialect will be:
51 // llvm.func @add_memref(
52 //   // First Memref (%arg0)
53 //      %allocated_ptr0: !llvm.ptr<f64>, %aligned_ptr0: !llvm.ptr<f64>,
54 //      %offset0: i64, %size0_d0: i64, %size0_d1: i64, %stride0_d0: i64,
55 //      %stride0_d1: i64,
56 //   // Second Memref (%arg1)
57 //      %allocated_ptr1: !llvm.ptr<f64>, %aligned_ptr1: !llvm.ptr<f64>,
58 //      %offset1: i64, %size1_d0: i64, %size1_d1: i64, %stride1_d0: i64,
59 //      %stride1_d1: i64,
60 //
61 long long add_memref(double *allocated_ptr0, double *aligned_ptr0,
62                      intptr_t offset0, intptr_t size0_d0, intptr_t size0_d1,
63                      intptr_t stride0_d0, intptr_t stride0_d1,
64                      // Second Memref (%arg1)
65                      double *allocated_ptr1, double *aligned_ptr1,
66                      intptr_t offset1, intptr_t size1_d0, intptr_t size1_d1,
67                      intptr_t stride1_d0, intptr_t stride1_d1);
68 
69 // The llvm.emit_c_interface will also trigger emission of another wrapper:
70 // llvm.func @_mlir_ciface_add_memref(
71 //   %arg0: !llvm.ptr<struct<(ptr<f64>, ptr<f64>, i64,
72 //                            array<2 x i64>, array<2 x i64>)>>,
73 //   %arg1: !llvm.ptr<struct<(ptr<f64>, ptr<f64>, i64,
74 //                            array<2 x i64>, array<2 x i64>)>>)
75 // -> i64
76 typedef struct {
77   double *allocated;
78   double *aligned;
79   intptr_t offset;
80   intptr_t size[2];
81   intptr_t stride[2];
82 } memref_2d_descriptor;
83 long long _mlir_ciface_add_memref(memref_2d_descriptor *arg0,
84                                   memref_2d_descriptor *arg1);
85 
86 #define N 4
87 #define M 8
88 double arg0[N][M];
89 double arg1[N][M];
90 
dump()91 void dump() {
92   for (int i = 0; i < N; i++) {
93     printf("[");
94     for (int j = 0; j < M; j++)
95       printf("%d,\t", (int)arg0[i][j]);
96     printf("] [");
97     for (int j = 0; j < M; j++)
98       printf("%d,\t", (int)arg1[i][j]);
99     printf("]\n");
100   }
101 }
102 
main()103 int main() {
104   int count = 0;
105   for (int i = 0; i < N; i++) {
106     for (int j = 0; j < M; j++) {
107       arg0[i][j] = count++;
108       arg1[i][j] = count++;
109     }
110   }
111   printf("Before:\n");
112   dump();
113   // clang-format off
114   // CHECK-LABEL: Before:
115   // CHECK: [0,	  2,	4,	6,	8,	10,	12,	14,	] [1,	  3,	5, 7, 9,	11,	13,	15,	]
116   // CHECK: [16,	18,	20,	22, 24, 26,	28,	30,	] [17,	19,	21,	23,	25,	27,	29, 31, ]
117   // CHECK: [32,	34,	36,	38,	40,	42,	44,	46,	] [33,	35, 37, 39,	41,	43,	45,	47,	]
118   // CHECK: [48,	50,	52, 54, 56,	58,	60,	62,	] [49,	51,	53,	55,	57,	59, 61, 63,	]
119   // clang-format on
120 
121   // Call into MLIR.
122   long long result = add_memref((double *)arg0, (double *)arg0, 0, N, M, M, 0,
123                                 //
124                                 (double *)arg1, (double *)arg1, 0, N, M, M, 0);
125 
126   // CHECK-LABEL: Result:
127   // CHECK: 42
128   printf("Result: %d\n", (int)result);
129 
130   printf("After:\n");
131   dump();
132 
133   // clang-format off
134   // CHECK-LABEL: After:
135   // CHECK: [1,	  5,	  9,	  13,	 17,	21,	  25,	  29,	  ] [1, 3,	5,	7,	9,	11,	13,	15,	]
136   // CHECK: [33,	37,  41,	  45,	 49,	53,	  57,	  61,	  ] [17,	19,	21, 23, 25,	27,	29,	31,	]
137   // CHECK: [65,	69,	  73,   77,	 81,	85,	  89,	  93,	  ] [33,	35,	37,	39, 41, 43,	45,	47,	]
138   // CHECK: [97,	101,	105,	109, 113,	117,	121,	125,	] [49,	51,	53,	55,	57,	59, 61, 63,	]
139   // clang-format on
140 
141   // Reset the input and re-apply the same function use the C API wrapper.
142   count = 0;
143   for (int i = 0; i < N; i++) {
144     for (int j = 0; j < M; j++) {
145       arg0[i][j] = count++;
146       arg1[i][j] = count++;
147     }
148   }
149 
150   // Call into MLIR.
151   memref_2d_descriptor arg0_descriptor = {
152       (double *)arg0, (double *)arg0, 0, N, M, M, 0};
153   memref_2d_descriptor arg1_descriptor = {
154       (double *)arg1, (double *)arg1, 0, N, M, M, 0};
155   result = _mlir_ciface_add_memref(&arg0_descriptor, &arg1_descriptor);
156 
157   // CHECK-LABEL: Result2:
158   // CHECK: 42
159   printf("Result2: %d\n", (int)result);
160 
161   printf("After2:\n");
162   dump();
163 
164   // clang-format off
165   // CHECK-LABEL: After2:
166   // CHECK: [1,	  5,	  9,	  13,	 17,	21,	  25,	  29,	  ] [1, 3,	5,	7,	9,	11,	13,	15,	]
167   // CHECK: [33,	37,  41,	  45,	 49,	53,	  57,	  61,	  ] [17,	19,	21, 23, 25,	27,	29,	31,	]
168   // CHECK: [65,	69,	  73,   77,	 81,	85,	  89,	  93,	  ] [33,	35,	37,	39, 41, 43,	45,	47,	]
169   // CHECK: [97,	101,	105,	109, 113,	117,	121,	125,	] [49,	51,	53,	55,	57,	59, 61, 63,	]
170   // clang-format on
171 
172   return 0;
173 }
174