1 /* 2 Copyright (c) 2020 Intel Corporation 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 #define DOCTEST_CONFIG_SUPER_FAST_ASSERTS 18 #include "common/test.h" 19 20 #include "oneapi/tbb/parallel_scan.h" 21 22 #include <vector> 23 24 //! \file conformance_parallel_scan.cpp 25 //! \brief Test for [algorithms.parallel_scan] specification 26 27 constexpr std::size_t size = 1000; 28 29 template<typename T, typename Op> 30 class Body { 31 const T identity; 32 T sum; 33 std::vector<T>& y; 34 const std::vector<T>& z; 35 public: 36 Body( const std::vector<T>& z_, std::vector<T>& y_, T id ) : identity(id), sum(id), y(y_), z(z_) {} 37 T get_sum() const { return sum; } 38 39 template<typename Tag> 40 void operator()( const oneapi::tbb::blocked_range<std::size_t>& r, Tag ) { 41 T temp = sum; 42 for(std::size_t i=r.begin(); i<r.end(); ++i ) { 43 temp = Op()(temp, z[i]); 44 if( Tag::is_final_scan() ) 45 y[i] = temp; 46 } 47 sum = temp; 48 } 49 Body( Body& b, oneapi::tbb::split ): identity(b.identity), sum(b.identity), y(b.y), z(b.z) {} 50 void reverse_join( Body& a ) { sum = Op()(a.sum, sum); } 51 void assign( Body& b ) { sum = b.sum; } 52 }; 53 54 class default_partitioner_tag{}; 55 56 template<typename Partitioner> 57 struct parallel_scan_wrapper{ 58 template<typename... Args> 59 void operator()(Args&&... args) { 60 oneapi::tbb::parallel_scan(std::forward<Args>(args)..., Partitioner()); 61 } 62 }; 63 64 template<> 65 struct parallel_scan_wrapper<default_partitioner_tag>{ 66 template<typename... Args> 67 void operator()(Args&&... args) { 68 oneapi::tbb::parallel_scan(std::forward<Args>(args)...); 69 } 70 }; 71 72 // Test scan tag 73 //! \brief \ref interface 74 TEST_CASE("scan tags testing") { 75 CHECK(oneapi::tbb::pre_scan_tag::is_final_scan()==false); 76 CHECK(oneapi::tbb::final_scan_tag::is_final_scan()==true); 77 CHECK((bool)oneapi::tbb::pre_scan_tag()==false); 78 CHECK((bool)oneapi::tbb::final_scan_tag()==true); 79 } 80 81 //! Test parallel prefix sum calculation for body-based interface 82 //! \brief \ref requirement \ref interface 83 TEST_CASE_TEMPLATE("Test parallel scan with body", Partitioner, default_partitioner_tag, oneapi::tbb::simple_partitioner, oneapi::tbb::auto_partitioner) { 84 std::vector<int> input(size); 85 std::vector<int> output(size); 86 std::vector<int> control(size); 87 88 for(size_t i = 0; i < size; ++i) { 89 input[i] = int(i / 2); 90 if(i) 91 control[i] = control[i-1] + input[i]; 92 else 93 control[i] = input[i]; 94 } 95 Body<int, std::plus<int>> body(input, output, 0); 96 parallel_scan_wrapper<Partitioner>()(oneapi::tbb::blocked_range<std::size_t>(0U, size, 1U), body); 97 CHECK((control == output)); 98 } 99 100 101 //! Test parallel prefix sum calculation for scan-based interface 102 //! \brief \ref requirement \ref interface 103 TEST_CASE_TEMPLATE("Test parallel scan with body", Partitioner, default_partitioner_tag, oneapi::tbb::simple_partitioner, oneapi::tbb::auto_partitioner) { 104 std::vector<int> input(size); 105 std::vector<int> output(size); 106 std::vector<int> control(size); 107 108 for (size_t i = 0; i<size; ++i) { 109 input[i] = int(i); 110 if (i) 111 control[i] = control[i-1]+input[i]; 112 else 113 control[i] = input[i]; 114 } 115 parallel_scan_wrapper<Partitioner>()(oneapi::tbb::blocked_range<std::size_t>(0U, size, 1U), 0U, 116 [&](const oneapi::tbb::blocked_range<std::size_t>& r, int sum, bool is_final) -> int 117 { 118 int temp = sum; 119 for (std::size_t i = r.begin(); i<r.end(); ++i) { 120 temp = temp + input[i]; 121 if (is_final) 122 output[i] = temp; 123 } 124 return temp; 125 }, 126 [](int a, int b) -> int 127 { 128 return a + b; 129 }); 130 131 CHECK((control==output)); 132 } 133