From 28f7497e0a73583fe10a6e9a4d38fa0d4f0bb161 Mon Sep 17 00:00:00 2001 From: Joshua Moerman Date: Mon, 13 Jan 2014 15:40:38 +0100 Subject: [PATCH] Cleans up a bit, adds options for iterations --- wavelet/wavelet_parallel_mockup.cpp | 64 +++++++++++------------------ 1 file changed, 23 insertions(+), 41 deletions(-) diff --git a/wavelet/wavelet_parallel_mockup.cpp b/wavelet/wavelet_parallel_mockup.cpp index f8d027e..7134a66 100644 --- a/wavelet/wavelet_parallel_mockup.cpp +++ b/wavelet/wavelet_parallel_mockup.cpp @@ -7,7 +7,7 @@ #include "wavelet_parallel.hpp" // Number of iterations to improve time measurements -const unsigned int ITERS = 1; +static unsigned int ITERS = 1; // Static :(, will be set in main static unsigned int P; @@ -101,46 +101,20 @@ static void seq_wavelet(){ std::copy(v.begin(), v.end(), seq_result.begin()); } -// Checks whether seq and par agree -// NOTE: modifies the global par_result -static void check_equality(double threshold){ - if(par_result == seq_result){ - std::cout << colors::green("SUCCES:") << " Results are bitwise equal" << std::endl; - } else { - for(unsigned int i = 0; i < N; ++i){ - auto sq = par_result[i] - seq_result[i]; - par_result[i] = sq*sq; - } - auto rmse = std::sqrt(std::accumulate(par_result.begin(), par_result.end(), 0.0) / N); - if(rmse <= threshold){ - std::cout << colors::green("SUCCES:") << " Results are almost the same: rmse = " << rmse << std::endl; - } else { - std::cout << colors::red("FAIL:") << " Results differ: rmse = " << rmse << std::endl; - } - } -} +// square difference, used to calculate root mean squared error +static double sq_diff(double x, double y){ return (x-y)*(x-y); } -// Checks whether inverse gives us the data back -// NOTE: modifies the global seq_result -static void check_inverse(double threshold){ - for(unsigned int i = 0; i < ITERS; ++i){ - wvlt::unwavelet(seq_result.data(), seq_result.size(), 1); +static void compare_results(std::vector const & lh, std::vector const & rh, double threshold){ + if(lh == rh){ + std::cout << colors::green("SUCCES:") << " bitwise qual" << std::endl; + return; } - bool same = true; - for(unsigned int i = 0; i < N; ++i){ - if(data(i) != seq_result[i]) same = false; - auto sq = data(i) - seq_result[i]; - seq_result[i] = sq*sq; - } - auto rmse = std::sqrt(std::accumulate(seq_result.begin(), seq_result.end(), 0.0) / N); - if(same){ - std::cout << colors::green("SUCCES:") << " Inverse is bitwise correct" << std::endl; + + double rmse = std::sqrt(std::inner_product(lh.begin(), lh.end(), rh.begin(), 0.0, std::plus(), &sq_diff) / lh.size()); + if(rmse <= threshold){ + std::cout << colors::green("SUCCES:") << " error within threshold, rmse = " << rmse << std::endl; } else { - if(rmse <= threshold){ - std::cout << colors::green("SUCCES:") << " Inverse is almost correct: rmse = " << rmse << std::endl; - } else { - std::cout << colors::red("FAIL:") << " Inverse seems wrong: rmse = " << rmse << std::endl; - } + std::cout << colors::red("FAIL:") << " error to big, rmse = " << rmse << std::endl; } } @@ -152,8 +126,9 @@ int main(int argc, char** argv){ opts.add_options() ("p", po::value(), "number of processors") ("n", po::value(), "number of elements") + ("iterations", po::value()->default_value(5), "number of iterations") ("help", po::value(), "show this help") - ("check", po::value(&should_check), "enables correctness checks"); + ("check", po::value(), "enables correctness checks"); po::variables_map vm; // Parse and set options @@ -169,6 +144,7 @@ int main(int argc, char** argv){ N = vm["n"].as(); P = vm["p"].as(); + ITERS = vm["iterations"].as(); if(!is_pow_of_two(N)) throw po::error("n is not a power of two"); if(!is_pow_of_two(P)) throw po::error("p is not a power of two"); @@ -190,7 +166,13 @@ int main(int argc, char** argv){ // Checking equality of algorithms if(vm.count("check")){ double threshold = 1.0e-8; - check_equality(threshold); - check_inverse(threshold); + std::cout << "Checking results "; + compare_results(seq_result, par_result, threshold); + + for(int i = 0; i < ITERS; ++i) wvlt::unwavelet(seq_result.data(), seq_result.size(), 1); + for(unsigned int i = 0; i < par_result.size(); ++i) par_result[i] = data(i); + + std::cout << "Checking inverse "; + compare_results(seq_result, par_result, threshold); } }