1d86ed7fbStbbdev /*
2*b15aabb3Stbbdev Copyright (c) 2005-2021 Intel Corporation
3d86ed7fbStbbdev
4d86ed7fbStbbdev Licensed under the Apache License, Version 2.0 (the "License");
5d86ed7fbStbbdev you may not use this file except in compliance with the License.
6d86ed7fbStbbdev You may obtain a copy of the License at
7d86ed7fbStbbdev
8d86ed7fbStbbdev http://www.apache.org/licenses/LICENSE-2.0
9d86ed7fbStbbdev
10d86ed7fbStbbdev Unless required by applicable law or agreed to in writing, software
11d86ed7fbStbbdev distributed under the License is distributed on an "AS IS" BASIS,
12d86ed7fbStbbdev WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13d86ed7fbStbbdev See the License for the specific language governing permissions and
14d86ed7fbStbbdev limitations under the License.
15d86ed7fbStbbdev */
16d86ed7fbStbbdev
17d86ed7fbStbbdev //
18d86ed7fbStbbdev // Self-organizing map in TBB flow::graph
19d86ed7fbStbbdev //
20d86ed7fbStbbdev // we will do a color map (the simple example.)
21d86ed7fbStbbdev //
22d86ed7fbStbbdev // serial algorithm
23d86ed7fbStbbdev //
24d86ed7fbStbbdev // initialize map with vectors (could be random, gradient, or something else)
25d86ed7fbStbbdev // for some number of iterations
26d86ed7fbStbbdev // update radius r, weight of change L
27d86ed7fbStbbdev // for each example V
28d86ed7fbStbbdev // find the best matching unit
29d86ed7fbStbbdev // for each part of map within radius of BMU W
30d86ed7fbStbbdev // update vector: W(t+1) = W(t) + w(dist)*L*(V - W(t))
31d86ed7fbStbbdev
32d86ed7fbStbbdev #include "oneapi/tbb/task_group.h"
33d86ed7fbStbbdev
34d86ed7fbStbbdev #include "som.hpp"
35d86ed7fbStbbdev
operator <<(std::ostream & out,const SOM_element & s)36d86ed7fbStbbdev std::ostream &operator<<(std::ostream &out, const SOM_element &s) {
37d86ed7fbStbbdev out << "(";
38d86ed7fbStbbdev for (int i = 0; i < (int)s.w.size(); ++i) {
39d86ed7fbStbbdev out << s.w[i];
40d86ed7fbStbbdev if (i < (int)s.w.size() - 1) {
41d86ed7fbStbbdev out << ",";
42d86ed7fbStbbdev }
43d86ed7fbStbbdev }
44d86ed7fbStbbdev out << ")";
45d86ed7fbStbbdev return out;
46d86ed7fbStbbdev }
47d86ed7fbStbbdev
remark_SOM_element(const SOM_element & s)48d86ed7fbStbbdev void remark_SOM_element(const SOM_element &s) {
49d86ed7fbStbbdev printf("(");
50d86ed7fbStbbdev for (int i = 0; i < (int)s.w.size(); ++i) {
51d86ed7fbStbbdev printf("%g", s.w[i]);
52d86ed7fbStbbdev if (i < (int)s.w.size() - 1) {
53d86ed7fbStbbdev printf(",");
54d86ed7fbStbbdev }
55d86ed7fbStbbdev }
56d86ed7fbStbbdev printf(")");
57d86ed7fbStbbdev }
58d86ed7fbStbbdev
operator <<(std::ostream & out,const search_result_type & s)59d86ed7fbStbbdev std::ostream &operator<<(std::ostream &out, const search_result_type &s) {
60d86ed7fbStbbdev out << "<";
61d86ed7fbStbbdev out << std::get<RADIUS>(s);
62d86ed7fbStbbdev out << ", " << std::get<XV>(s);
63d86ed7fbStbbdev out << ", ";
64d86ed7fbStbbdev out << std::get<YV>(s);
65d86ed7fbStbbdev out << ">";
66d86ed7fbStbbdev return out;
67d86ed7fbStbbdev }
68d86ed7fbStbbdev
remark_search_result_type(const search_result_type & s)69d86ed7fbStbbdev void remark_search_result_type(const search_result_type &s) {
70d86ed7fbStbbdev printf("<%g,%d,%d>", std::get<RADIUS>(s), std::get<XV>(s), std::get<YV>(s));
71d86ed7fbStbbdev }
72d86ed7fbStbbdev
randval(double lowlimit,double highlimit)73d86ed7fbStbbdev double randval(double lowlimit, double highlimit) {
74d86ed7fbStbbdev return double(rand()) / double(RAND_MAX) * (highlimit - lowlimit) + lowlimit;
75d86ed7fbStbbdev }
76d86ed7fbStbbdev
find_data_ranges(teaching_vector_type & teaching,SOM_element & max_range,SOM_element & min_range)77d86ed7fbStbbdev void find_data_ranges(teaching_vector_type &teaching,
78d86ed7fbStbbdev SOM_element &max_range,
79d86ed7fbStbbdev SOM_element &min_range) {
80d86ed7fbStbbdev if (teaching.size() == 0)
81d86ed7fbStbbdev return;
82d86ed7fbStbbdev max_range = min_range = teaching[0];
83d86ed7fbStbbdev for (int i = 1; i < (int)teaching.size(); ++i) {
84d86ed7fbStbbdev max_range.elementwise_max(teaching[i]);
85d86ed7fbStbbdev min_range.elementwise_min(teaching[i]);
86d86ed7fbStbbdev }
87d86ed7fbStbbdev }
88d86ed7fbStbbdev
add_fraction_of_difference(SOM_element & to,SOM_element const & from,double frac)89d86ed7fbStbbdev void add_fraction_of_difference(SOM_element &to, SOM_element const &from, double frac) {
90d86ed7fbStbbdev for (int i = 0; i < (int)from.size(); ++i) {
91d86ed7fbStbbdev to[i] += frac * (from[i] - to[i]);
92d86ed7fbStbbdev }
93d86ed7fbStbbdev }
94d86ed7fbStbbdev
distance_squared(SOM_element x,SOM_element y)95d86ed7fbStbbdev double distance_squared(SOM_element x, SOM_element y) {
96d86ed7fbStbbdev double rval = 0.0;
97d86ed7fbStbbdev for (int i = 0; i < (int)x.size(); ++i) {
98d86ed7fbStbbdev double diff = x[i] - y[i];
99d86ed7fbStbbdev rval += diff * diff;
100d86ed7fbStbbdev }
101d86ed7fbStbbdev return rval;
102d86ed7fbStbbdev }
103d86ed7fbStbbdev
initialize(InitializeType it,SOM_element & max_range,SOM_element & min_range)104d86ed7fbStbbdev void SOMap::initialize(InitializeType it, SOM_element &max_range, SOM_element &min_range) {
105d86ed7fbStbbdev for (int x = 0; x < xMax; ++x) {
106d86ed7fbStbbdev for (int y = 0; y < yMax; ++y) {
107d86ed7fbStbbdev for (int i = 0; i < (int)max_range.size(); ++i) {
108d86ed7fbStbbdev if (it == InitializeRandom) {
109d86ed7fbStbbdev my_map[x][y][i] = (randval(min_range[i], max_range[i]));
110d86ed7fbStbbdev }
111d86ed7fbStbbdev else if (it == InitializeGradient) {
112d86ed7fbStbbdev my_map[x][y][i] =
113d86ed7fbStbbdev ((double)(x + y) / (xMax + yMax) * (max_range[i] - min_range[i]) +
114d86ed7fbStbbdev min_range[i]);
115d86ed7fbStbbdev }
116d86ed7fbStbbdev }
117d86ed7fbStbbdev }
118d86ed7fbStbbdev }
119d86ed7fbStbbdev }
120d86ed7fbStbbdev
121d86ed7fbStbbdev // subsquare [low,high)
BMU_range(const SOM_element & s,int & xval,int & yval,subsquare_type & r)122d86ed7fbStbbdev double SOMap::BMU_range(const SOM_element &s, int &xval, int &yval, subsquare_type &r) {
123d86ed7fbStbbdev double min_distance_squared = DBL_MAX;
124d86ed7fbStbbdev int min_x = -1;
125d86ed7fbStbbdev int min_y = -1;
126d86ed7fbStbbdev for (int x = r.rows().begin(); x != r.rows().end(); ++x) {
127d86ed7fbStbbdev for (int y = r.cols().begin(); y != r.cols().end(); ++y) {
128d86ed7fbStbbdev double dist = distance_squared(s, my_map[x][y]);
129d86ed7fbStbbdev if (dist < min_distance_squared) {
130d86ed7fbStbbdev min_distance_squared = dist;
131d86ed7fbStbbdev min_x = x;
132d86ed7fbStbbdev min_y = y;
133d86ed7fbStbbdev }
134d86ed7fbStbbdev if (cancel_test && oneapi::tbb::is_current_task_group_canceling()) {
135d86ed7fbStbbdev xval = r.rows().begin();
136d86ed7fbStbbdev yval = r.cols().begin();
137d86ed7fbStbbdev return DBL_MAX;
138d86ed7fbStbbdev }
139d86ed7fbStbbdev }
140d86ed7fbStbbdev }
141d86ed7fbStbbdev xval = min_x;
142d86ed7fbStbbdev yval = min_y;
143d86ed7fbStbbdev return sqrt(min_distance_squared);
144d86ed7fbStbbdev }
145d86ed7fbStbbdev
epoch_update_range(SOM_element const & s,int epoch,int min_x,int min_y,double radius,double learning_rate,oneapi::tbb::blocked_range<int> & r)146d86ed7fbStbbdev void SOMap::epoch_update_range(SOM_element const &s,
147d86ed7fbStbbdev int epoch,
148d86ed7fbStbbdev int min_x,
149d86ed7fbStbbdev int min_y,
150d86ed7fbStbbdev double radius,
151d86ed7fbStbbdev double learning_rate,
152d86ed7fbStbbdev oneapi::tbb::blocked_range<int> &r) {
153d86ed7fbStbbdev int min_xiter = (int)((double)min_x - radius);
154d86ed7fbStbbdev if (min_xiter < 0)
155d86ed7fbStbbdev min_xiter = 0;
156d86ed7fbStbbdev int max_xiter = (int)((double)min_x + radius);
157d86ed7fbStbbdev if (max_xiter > (int)my_map.size() - 1)
158d86ed7fbStbbdev max_xiter = (int)my_map.size() - 1;
159d86ed7fbStbbdev for (int xx = r.begin(); xx <= r.end(); ++xx) {
160d86ed7fbStbbdev double xrsq = (xx - min_x) * (xx - min_x);
161d86ed7fbStbbdev double ysq = radius * radius - xrsq; // max extent of y influence
162d86ed7fbStbbdev double yd;
163d86ed7fbStbbdev if (ysq > 0) {
164d86ed7fbStbbdev yd = sqrt(ysq);
165d86ed7fbStbbdev int lb = (int)(min_y - yd);
166d86ed7fbStbbdev int ub = (int)(min_y + yd);
167d86ed7fbStbbdev for (int yy = lb; yy < ub; ++yy) {
168d86ed7fbStbbdev if (yy >= 0 && yy < (int)my_map[xx].size()) {
169d86ed7fbStbbdev // [xx, yy] is in the range of the update.
170d86ed7fbStbbdev double my_rsq = xrsq + (yy - min_y) * (yy - min_y); // distance from BMU squared
171d86ed7fbStbbdev double theta = exp(-(radius * radius) / (2.0 * my_rsq));
172d86ed7fbStbbdev add_fraction_of_difference(my_map[xx][yy], s, theta * learning_rate);
173d86ed7fbStbbdev }
174d86ed7fbStbbdev }
175d86ed7fbStbbdev }
176d86ed7fbStbbdev }
177d86ed7fbStbbdev }
178d86ed7fbStbbdev
teach(teaching_vector_type & in)179d86ed7fbStbbdev void SOMap::teach(teaching_vector_type &in) {
180d86ed7fbStbbdev for (int i = 0; i < nPasses; ++i) {
181d86ed7fbStbbdev int j = (int)(randval(0, (double)in.size())); // this won't be reproducible.
182d86ed7fbStbbdev if (j == in.size())
183d86ed7fbStbbdev --j;
184d86ed7fbStbbdev
185d86ed7fbStbbdev int min_x = -1;
186d86ed7fbStbbdev int min_y = -1;
187d86ed7fbStbbdev subsquare_type br2(0, (int)my_map.size(), 1, 0, (int)my_map[0].size(), 1);
188d86ed7fbStbbdev (void)BMU_range(in[j], min_x, min_y, br2); // just need min_x, min_y
189d86ed7fbStbbdev // radius of interest
190d86ed7fbStbbdev double radius = max_radius * exp(-(double)i * radius_decay_rate);
191d86ed7fbStbbdev // update circle is min_xiter to max_xiter inclusive.
192d86ed7fbStbbdev double learning_rate = max_learning_rate * exp(-(double)i * learning_decay_rate);
193d86ed7fbStbbdev epoch_update(in[j], i, min_x, min_y, radius, learning_rate);
194d86ed7fbStbbdev }
195d86ed7fbStbbdev }
196d86ed7fbStbbdev
debug_output()197d86ed7fbStbbdev void SOMap::debug_output() {
198d86ed7fbStbbdev printf("SOMap:\n");
199d86ed7fbStbbdev for (int i = 0; i < (int)(this->my_map.size()); ++i) {
200d86ed7fbStbbdev for (int j = 0; j < (int)(this->my_map[i].size()); ++j) {
201d86ed7fbStbbdev printf("map[%d, %d] == ", i, j);
202d86ed7fbStbbdev remark_SOM_element(this->my_map[i][j]);
203d86ed7fbStbbdev printf("\n");
204d86ed7fbStbbdev }
205d86ed7fbStbbdev }
206d86ed7fbStbbdev }
207d86ed7fbStbbdev
208d86ed7fbStbbdev #define RED 0
209d86ed7fbStbbdev #define GREEN 1
210d86ed7fbStbbdev #define BLUE 2
211d86ed7fbStbbdev
readInputData()212d86ed7fbStbbdev void readInputData() {
213d86ed7fbStbbdev my_teaching.push_back(SOM_element());
214d86ed7fbStbbdev my_teaching.push_back(SOM_element());
215d86ed7fbStbbdev my_teaching.push_back(SOM_element());
216d86ed7fbStbbdev my_teaching.push_back(SOM_element());
217d86ed7fbStbbdev my_teaching.push_back(SOM_element());
218d86ed7fbStbbdev my_teaching[0][RED] = 1.0;
219d86ed7fbStbbdev my_teaching[0][GREEN] = 0.0;
220d86ed7fbStbbdev my_teaching[0][BLUE] = 0.0;
221d86ed7fbStbbdev my_teaching[1][RED] = 0.0;
222d86ed7fbStbbdev my_teaching[1][GREEN] = 1.0;
223d86ed7fbStbbdev my_teaching[1][BLUE] = 0.0;
224d86ed7fbStbbdev my_teaching[2][RED] = 0.0;
225d86ed7fbStbbdev my_teaching[2][GREEN] = 0.0;
226d86ed7fbStbbdev my_teaching[2][BLUE] = 1.0;
227d86ed7fbStbbdev my_teaching[3][RED] = 0.3;
228d86ed7fbStbbdev my_teaching[3][GREEN] = 0.3;
229d86ed7fbStbbdev my_teaching[3][BLUE] = 0.0;
230d86ed7fbStbbdev my_teaching[4][RED] = 0.5;
231d86ed7fbStbbdev my_teaching[4][GREEN] = 0.5;
232d86ed7fbStbbdev my_teaching[4][BLUE] = 0.9;
233d86ed7fbStbbdev }
234