xref: /oneTBB/examples/graph/som/som.cpp (revision b15aabb3)
1d86ed7fbStbbdev /*
2*b15aabb3Stbbdev     Copyright (c) 2005-2021 Intel Corporation
3d86ed7fbStbbdev 
4d86ed7fbStbbdev     Licensed under the Apache License, Version 2.0 (the "License");
5d86ed7fbStbbdev     you may not use this file except in compliance with the License.
6d86ed7fbStbbdev     You may obtain a copy of the License at
7d86ed7fbStbbdev 
8d86ed7fbStbbdev         http://www.apache.org/licenses/LICENSE-2.0
9d86ed7fbStbbdev 
10d86ed7fbStbbdev     Unless required by applicable law or agreed to in writing, software
11d86ed7fbStbbdev     distributed under the License is distributed on an "AS IS" BASIS,
12d86ed7fbStbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13d86ed7fbStbbdev     See the License for the specific language governing permissions and
14d86ed7fbStbbdev     limitations under the License.
15d86ed7fbStbbdev */
16d86ed7fbStbbdev 
17d86ed7fbStbbdev //
18d86ed7fbStbbdev // Self-organizing map in TBB flow::graph
19d86ed7fbStbbdev //
20d86ed7fbStbbdev // we will do a color map (the simple example.)
21d86ed7fbStbbdev //
22d86ed7fbStbbdev //  serial algorithm
23d86ed7fbStbbdev //
24d86ed7fbStbbdev //       initialize map with vectors (could be random, gradient, or something else)
25d86ed7fbStbbdev //       for some number of iterations
26d86ed7fbStbbdev //           update radius r, weight of change L
27d86ed7fbStbbdev //           for each example V
28d86ed7fbStbbdev //               find the best matching unit
29d86ed7fbStbbdev //               for each part of map within radius of BMU W
30d86ed7fbStbbdev //                   update vector:  W(t+1) = W(t) + w(dist)*L*(V - W(t))
31d86ed7fbStbbdev 
32d86ed7fbStbbdev #include "oneapi/tbb/task_group.h"
33d86ed7fbStbbdev 
34d86ed7fbStbbdev #include "som.hpp"
35d86ed7fbStbbdev 
operator <<(std::ostream & out,const SOM_element & s)36d86ed7fbStbbdev std::ostream &operator<<(std::ostream &out, const SOM_element &s) {
37d86ed7fbStbbdev     out << "(";
38d86ed7fbStbbdev     for (int i = 0; i < (int)s.w.size(); ++i) {
39d86ed7fbStbbdev         out << s.w[i];
40d86ed7fbStbbdev         if (i < (int)s.w.size() - 1) {
41d86ed7fbStbbdev             out << ",";
42d86ed7fbStbbdev         }
43d86ed7fbStbbdev     }
44d86ed7fbStbbdev     out << ")";
45d86ed7fbStbbdev     return out;
46d86ed7fbStbbdev }
47d86ed7fbStbbdev 
remark_SOM_element(const SOM_element & s)48d86ed7fbStbbdev void remark_SOM_element(const SOM_element &s) {
49d86ed7fbStbbdev     printf("(");
50d86ed7fbStbbdev     for (int i = 0; i < (int)s.w.size(); ++i) {
51d86ed7fbStbbdev         printf("%g", s.w[i]);
52d86ed7fbStbbdev         if (i < (int)s.w.size() - 1) {
53d86ed7fbStbbdev             printf(",");
54d86ed7fbStbbdev         }
55d86ed7fbStbbdev     }
56d86ed7fbStbbdev     printf(")");
57d86ed7fbStbbdev }
58d86ed7fbStbbdev 
operator <<(std::ostream & out,const search_result_type & s)59d86ed7fbStbbdev std::ostream &operator<<(std::ostream &out, const search_result_type &s) {
60d86ed7fbStbbdev     out << "<";
61d86ed7fbStbbdev     out << std::get<RADIUS>(s);
62d86ed7fbStbbdev     out << ", " << std::get<XV>(s);
63d86ed7fbStbbdev     out << ", ";
64d86ed7fbStbbdev     out << std::get<YV>(s);
65d86ed7fbStbbdev     out << ">";
66d86ed7fbStbbdev     return out;
67d86ed7fbStbbdev }
68d86ed7fbStbbdev 
remark_search_result_type(const search_result_type & s)69d86ed7fbStbbdev void remark_search_result_type(const search_result_type &s) {
70d86ed7fbStbbdev     printf("<%g,%d,%d>", std::get<RADIUS>(s), std::get<XV>(s), std::get<YV>(s));
71d86ed7fbStbbdev }
72d86ed7fbStbbdev 
randval(double lowlimit,double highlimit)73d86ed7fbStbbdev double randval(double lowlimit, double highlimit) {
74d86ed7fbStbbdev     return double(rand()) / double(RAND_MAX) * (highlimit - lowlimit) + lowlimit;
75d86ed7fbStbbdev }
76d86ed7fbStbbdev 
find_data_ranges(teaching_vector_type & teaching,SOM_element & max_range,SOM_element & min_range)77d86ed7fbStbbdev void find_data_ranges(teaching_vector_type &teaching,
78d86ed7fbStbbdev                       SOM_element &max_range,
79d86ed7fbStbbdev                       SOM_element &min_range) {
80d86ed7fbStbbdev     if (teaching.size() == 0)
81d86ed7fbStbbdev         return;
82d86ed7fbStbbdev     max_range = min_range = teaching[0];
83d86ed7fbStbbdev     for (int i = 1; i < (int)teaching.size(); ++i) {
84d86ed7fbStbbdev         max_range.elementwise_max(teaching[i]);
85d86ed7fbStbbdev         min_range.elementwise_min(teaching[i]);
86d86ed7fbStbbdev     }
87d86ed7fbStbbdev }
88d86ed7fbStbbdev 
add_fraction_of_difference(SOM_element & to,SOM_element const & from,double frac)89d86ed7fbStbbdev void add_fraction_of_difference(SOM_element &to, SOM_element const &from, double frac) {
90d86ed7fbStbbdev     for (int i = 0; i < (int)from.size(); ++i) {
91d86ed7fbStbbdev         to[i] += frac * (from[i] - to[i]);
92d86ed7fbStbbdev     }
93d86ed7fbStbbdev }
94d86ed7fbStbbdev 
distance_squared(SOM_element x,SOM_element y)95d86ed7fbStbbdev double distance_squared(SOM_element x, SOM_element y) {
96d86ed7fbStbbdev     double rval = 0.0;
97d86ed7fbStbbdev     for (int i = 0; i < (int)x.size(); ++i) {
98d86ed7fbStbbdev         double diff = x[i] - y[i];
99d86ed7fbStbbdev         rval += diff * diff;
100d86ed7fbStbbdev     }
101d86ed7fbStbbdev     return rval;
102d86ed7fbStbbdev }
103d86ed7fbStbbdev 
initialize(InitializeType it,SOM_element & max_range,SOM_element & min_range)104d86ed7fbStbbdev void SOMap::initialize(InitializeType it, SOM_element &max_range, SOM_element &min_range) {
105d86ed7fbStbbdev     for (int x = 0; x < xMax; ++x) {
106d86ed7fbStbbdev         for (int y = 0; y < yMax; ++y) {
107d86ed7fbStbbdev             for (int i = 0; i < (int)max_range.size(); ++i) {
108d86ed7fbStbbdev                 if (it == InitializeRandom) {
109d86ed7fbStbbdev                     my_map[x][y][i] = (randval(min_range[i], max_range[i]));
110d86ed7fbStbbdev                 }
111d86ed7fbStbbdev                 else if (it == InitializeGradient) {
112d86ed7fbStbbdev                     my_map[x][y][i] =
113d86ed7fbStbbdev                         ((double)(x + y) / (xMax + yMax) * (max_range[i] - min_range[i]) +
114d86ed7fbStbbdev                          min_range[i]);
115d86ed7fbStbbdev                 }
116d86ed7fbStbbdev             }
117d86ed7fbStbbdev         }
118d86ed7fbStbbdev     }
119d86ed7fbStbbdev }
120d86ed7fbStbbdev 
121d86ed7fbStbbdev // subsquare [low,high)
BMU_range(const SOM_element & s,int & xval,int & yval,subsquare_type & r)122d86ed7fbStbbdev double SOMap::BMU_range(const SOM_element &s, int &xval, int &yval, subsquare_type &r) {
123d86ed7fbStbbdev     double min_distance_squared = DBL_MAX;
124d86ed7fbStbbdev     int min_x = -1;
125d86ed7fbStbbdev     int min_y = -1;
126d86ed7fbStbbdev     for (int x = r.rows().begin(); x != r.rows().end(); ++x) {
127d86ed7fbStbbdev         for (int y = r.cols().begin(); y != r.cols().end(); ++y) {
128d86ed7fbStbbdev             double dist = distance_squared(s, my_map[x][y]);
129d86ed7fbStbbdev             if (dist < min_distance_squared) {
130d86ed7fbStbbdev                 min_distance_squared = dist;
131d86ed7fbStbbdev                 min_x = x;
132d86ed7fbStbbdev                 min_y = y;
133d86ed7fbStbbdev             }
134d86ed7fbStbbdev             if (cancel_test && oneapi::tbb::is_current_task_group_canceling()) {
135d86ed7fbStbbdev                 xval = r.rows().begin();
136d86ed7fbStbbdev                 yval = r.cols().begin();
137d86ed7fbStbbdev                 return DBL_MAX;
138d86ed7fbStbbdev             }
139d86ed7fbStbbdev         }
140d86ed7fbStbbdev     }
141d86ed7fbStbbdev     xval = min_x;
142d86ed7fbStbbdev     yval = min_y;
143d86ed7fbStbbdev     return sqrt(min_distance_squared);
144d86ed7fbStbbdev }
145d86ed7fbStbbdev 
epoch_update_range(SOM_element const & s,int epoch,int min_x,int min_y,double radius,double learning_rate,oneapi::tbb::blocked_range<int> & r)146d86ed7fbStbbdev void SOMap::epoch_update_range(SOM_element const &s,
147d86ed7fbStbbdev                                int epoch,
148d86ed7fbStbbdev                                int min_x,
149d86ed7fbStbbdev                                int min_y,
150d86ed7fbStbbdev                                double radius,
151d86ed7fbStbbdev                                double learning_rate,
152d86ed7fbStbbdev                                oneapi::tbb::blocked_range<int> &r) {
153d86ed7fbStbbdev     int min_xiter = (int)((double)min_x - radius);
154d86ed7fbStbbdev     if (min_xiter < 0)
155d86ed7fbStbbdev         min_xiter = 0;
156d86ed7fbStbbdev     int max_xiter = (int)((double)min_x + radius);
157d86ed7fbStbbdev     if (max_xiter > (int)my_map.size() - 1)
158d86ed7fbStbbdev         max_xiter = (int)my_map.size() - 1;
159d86ed7fbStbbdev     for (int xx = r.begin(); xx <= r.end(); ++xx) {
160d86ed7fbStbbdev         double xrsq = (xx - min_x) * (xx - min_x);
161d86ed7fbStbbdev         double ysq = radius * radius - xrsq; // max extent of y influence
162d86ed7fbStbbdev         double yd;
163d86ed7fbStbbdev         if (ysq > 0) {
164d86ed7fbStbbdev             yd = sqrt(ysq);
165d86ed7fbStbbdev             int lb = (int)(min_y - yd);
166d86ed7fbStbbdev             int ub = (int)(min_y + yd);
167d86ed7fbStbbdev             for (int yy = lb; yy < ub; ++yy) {
168d86ed7fbStbbdev                 if (yy >= 0 && yy < (int)my_map[xx].size()) {
169d86ed7fbStbbdev                     // [xx, yy] is in the range of the update.
170d86ed7fbStbbdev                     double my_rsq = xrsq + (yy - min_y) * (yy - min_y); // distance from BMU squared
171d86ed7fbStbbdev                     double theta = exp(-(radius * radius) / (2.0 * my_rsq));
172d86ed7fbStbbdev                     add_fraction_of_difference(my_map[xx][yy], s, theta * learning_rate);
173d86ed7fbStbbdev                 }
174d86ed7fbStbbdev             }
175d86ed7fbStbbdev         }
176d86ed7fbStbbdev     }
177d86ed7fbStbbdev }
178d86ed7fbStbbdev 
teach(teaching_vector_type & in)179d86ed7fbStbbdev void SOMap::teach(teaching_vector_type &in) {
180d86ed7fbStbbdev     for (int i = 0; i < nPasses; ++i) {
181d86ed7fbStbbdev         int j = (int)(randval(0, (double)in.size())); // this won't be reproducible.
182d86ed7fbStbbdev         if (j == in.size())
183d86ed7fbStbbdev             --j;
184d86ed7fbStbbdev 
185d86ed7fbStbbdev         int min_x = -1;
186d86ed7fbStbbdev         int min_y = -1;
187d86ed7fbStbbdev         subsquare_type br2(0, (int)my_map.size(), 1, 0, (int)my_map[0].size(), 1);
188d86ed7fbStbbdev         (void)BMU_range(in[j], min_x, min_y, br2); // just need min_x, min_y
189d86ed7fbStbbdev         // radius of interest
190d86ed7fbStbbdev         double radius = max_radius * exp(-(double)i * radius_decay_rate);
191d86ed7fbStbbdev         // update circle is min_xiter to max_xiter inclusive.
192d86ed7fbStbbdev         double learning_rate = max_learning_rate * exp(-(double)i * learning_decay_rate);
193d86ed7fbStbbdev         epoch_update(in[j], i, min_x, min_y, radius, learning_rate);
194d86ed7fbStbbdev     }
195d86ed7fbStbbdev }
196d86ed7fbStbbdev 
debug_output()197d86ed7fbStbbdev void SOMap::debug_output() {
198d86ed7fbStbbdev     printf("SOMap:\n");
199d86ed7fbStbbdev     for (int i = 0; i < (int)(this->my_map.size()); ++i) {
200d86ed7fbStbbdev         for (int j = 0; j < (int)(this->my_map[i].size()); ++j) {
201d86ed7fbStbbdev             printf("map[%d, %d] == ", i, j);
202d86ed7fbStbbdev             remark_SOM_element(this->my_map[i][j]);
203d86ed7fbStbbdev             printf("\n");
204d86ed7fbStbbdev         }
205d86ed7fbStbbdev     }
206d86ed7fbStbbdev }
207d86ed7fbStbbdev 
208d86ed7fbStbbdev #define RED   0
209d86ed7fbStbbdev #define GREEN 1
210d86ed7fbStbbdev #define BLUE  2
211d86ed7fbStbbdev 
readInputData()212d86ed7fbStbbdev void readInputData() {
213d86ed7fbStbbdev     my_teaching.push_back(SOM_element());
214d86ed7fbStbbdev     my_teaching.push_back(SOM_element());
215d86ed7fbStbbdev     my_teaching.push_back(SOM_element());
216d86ed7fbStbbdev     my_teaching.push_back(SOM_element());
217d86ed7fbStbbdev     my_teaching.push_back(SOM_element());
218d86ed7fbStbbdev     my_teaching[0][RED] = 1.0;
219d86ed7fbStbbdev     my_teaching[0][GREEN] = 0.0;
220d86ed7fbStbbdev     my_teaching[0][BLUE] = 0.0;
221d86ed7fbStbbdev     my_teaching[1][RED] = 0.0;
222d86ed7fbStbbdev     my_teaching[1][GREEN] = 1.0;
223d86ed7fbStbbdev     my_teaching[1][BLUE] = 0.0;
224d86ed7fbStbbdev     my_teaching[2][RED] = 0.0;
225d86ed7fbStbbdev     my_teaching[2][GREEN] = 0.0;
226d86ed7fbStbbdev     my_teaching[2][BLUE] = 1.0;
227d86ed7fbStbbdev     my_teaching[3][RED] = 0.3;
228d86ed7fbStbbdev     my_teaching[3][GREEN] = 0.3;
229d86ed7fbStbbdev     my_teaching[3][BLUE] = 0.0;
230d86ed7fbStbbdev     my_teaching[4][RED] = 0.5;
231d86ed7fbStbbdev     my_teaching[4][GREEN] = 0.5;
232d86ed7fbStbbdev     my_teaching[4][BLUE] = 0.9;
233d86ed7fbStbbdev }
234