1 // RUN: %libomptarget-compilexx-run-and-check-generic
2 
3 #include <cassert>
4 #include <iostream>
5 #include <memory>
6 #include <vector>
7 
8 class BlockMatrix {
9 private:
10   const int rowsPerBlock;
11   const int colsPerBlock;
12   const long nRows;
13   const long nCols;
14   const int nBlocksPerRow;
15   const int nBlocksPerCol;
16   std::vector<std::vector<std::unique_ptr<float[]>>> Blocks;
17 
18 public:
19   BlockMatrix(const int _rowsPerBlock, const int _colsPerBlock,
20               const long _nRows, const long _nCols)
21       : rowsPerBlock(_rowsPerBlock), colsPerBlock(_colsPerBlock), nRows(_nRows),
22         nCols(_nCols), nBlocksPerRow(_nRows / _rowsPerBlock),
23         nBlocksPerCol(_nCols / _colsPerBlock), Blocks(nBlocksPerCol) {
24     for (int i = 0; i < nBlocksPerCol; i++) {
25       for (int j = 0; j < nBlocksPerRow; j++) {
26         Blocks[i].emplace_back(new float[_rowsPerBlock * _colsPerBlock]);
27       }
28     }
29   };
30 
31   // Initialize the BlockMatrix from 2D arrays
32   void Initialize(const std::vector<float> &matrix) {
33     for (int i = 0; i < nBlocksPerCol; i++)
34       for (int j = 0; j < nBlocksPerRow; j++) {
35         float *CurrBlock = GetBlock(i, j);
36         for (int ii = 0; ii < colsPerBlock; ++ii)
37           for (int jj = 0; jj < rowsPerBlock; ++jj) {
38             int curri = i * colsPerBlock + ii;
39             int currj = j * rowsPerBlock + jj;
40             CurrBlock[ii + jj * colsPerBlock] = matrix[curri + currj * nCols];
41           }
42       }
43   }
44 
45   long Compare(const std::vector<float> &matrix) const {
46     long fail = 0;
47     for (int i = 0; i < nBlocksPerCol; i++)
48       for (int j = 0; j < nBlocksPerRow; j++) {
49         float *CurrBlock = GetBlock(i, j);
50         for (int ii = 0; ii < colsPerBlock; ++ii)
51           for (int jj = 0; jj < rowsPerBlock; ++jj) {
52             int curri = i * colsPerBlock + ii;
53             int currj = j * rowsPerBlock + jj;
54             float m_value = matrix[curri + currj * nCols];
55             float bm_value = CurrBlock[ii + jj * colsPerBlock];
56             if (bm_value != m_value) {
57               fail++;
58             }
59           }
60       }
61     return fail;
62   }
63 
64   float *GetBlock(int i, int j) const {
65     assert(i < nBlocksPerCol && j < nBlocksPerRow && "Accessing outside block");
66     return Blocks[i][j].get();
67   }
68 };
69 
70 constexpr const int BS = 256;
71 constexpr const int N = 1024;
72 
73 int BlockMatMul_TargetNowait(BlockMatrix &A, BlockMatrix &B, BlockMatrix &C) {
74 #pragma omp parallel
75 #pragma omp master
76   for (int i = 0; i < N / BS; ++i)
77     for (int j = 0; j < N / BS; ++j) {
78       float *BlockC = C.GetBlock(i, j);
79       for (int k = 0; k < N / BS; ++k) {
80         float *BlockA = A.GetBlock(i, k);
81         float *BlockB = B.GetBlock(k, j);
82 // clang-format off
83 #pragma omp target depend(in: BlockA[0], BlockB[0]) depend(inout: BlockC[0])   \
84             map(to: BlockA[:BS * BS], BlockB[:BS * BS])                        \
85             map(tofrom: BlockC[:BS * BS]) nowait
86 // clang-format on
87 #pragma omp parallel for
88         for (int ii = 0; ii < BS; ii++)
89           for (int jj = 0; jj < BS; jj++) {
90             for (int kk = 0; kk < BS; ++kk)
91               BlockC[ii + jj * BS] +=
92                   BlockA[ii + kk * BS] * BlockB[kk + jj * BS];
93           }
94       }
95     }
96   return 0;
97 }
98 
99 void Matmul(const std::vector<float> &a, const std::vector<float> &b,
100             std::vector<float> &c) {
101   for (int i = 0; i < N; ++i) {
102     for (int j = 0; j < N; ++j) {
103       float sum = 0.0;
104       for (int k = 0; k < N; ++k) {
105         sum = sum + a[i * N + k] * b[k * N + j];
106       }
107       c[i * N + j] = sum;
108     }
109   }
110 }
111 
112 int main(int argc, char *argv[]) {
113   std::vector<float> a(N * N);
114   std::vector<float> b(N * N);
115   std::vector<float> c(N * N, 0.0);
116 
117   for (int i = 0; i < N; ++i) {
118     for (int j = 0; j < N; ++j) {
119       a[i * N + j] = b[i * N + j] = i + j % 100;
120     }
121   }
122 
123   auto BlockedA = BlockMatrix(BS, BS, N, N);
124   BlockedA.Initialize(a);
125   BlockedA.Compare(a);
126   auto BlockedB = BlockMatrix(BS, BS, N, N);
127   BlockedB.Initialize(b);
128   BlockedB.Compare(b);
129 
130   Matmul(a, b, c);
131 
132   auto BlockedC = BlockMatrix(BS, BS, N, N);
133   BlockMatMul_TargetNowait(BlockedA, BlockedB, BlockedC);
134 
135   if (BlockedC.Compare(c) > 0) {
136     return 1;
137   }
138 
139   std::cout << "PASS\n";
140 
141   return 0;
142 }
143 
144 // CHECK: PASS
145