diff --git a/wavelet/CMakeLists.txt b/wavelet/CMakeLists.txt new file mode 100644 index 0000000..caf1670 --- /dev/null +++ b/wavelet/CMakeLists.txt @@ -0,0 +1,7 @@ +file(GLOB sources *.cpp) + +foreach(source ${sources}) + get_filename_component(exec ${source} NAME_WE) + add_executable(${exec} ${source}) + target_link_libraries(${exec} ${libs}) +endforeach() diff --git a/wavelet/periodic_iterator.hpp b/wavelet/periodic_iterator.hpp new file mode 100644 index 0000000..289f429 --- /dev/null +++ b/wavelet/periodic_iterator.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include + +template +class periodic_iterator : public boost::iterator_adaptor, Iterator> { + typedef boost::iterator_adaptor, Iterator> super_t; + friend class boost::iterator_core_access; +public: + periodic_iterator() + : super_t() + , end() + , length(1) + {} + + periodic_iterator(Iterator begin, Iterator end) + : super_t(begin) + , end(end) + , length(end - begin) + {} + +private: + void advance(typename super_t::difference_type n){ + this->base_reference() += n; + if(this->base() >= end) this->base_reference() -= length; + } + + void increment() { advance(1); } + void decrement() { advance(-1); } + + Iterator end; + int length; +}; + +template +periodic_iterator periodic(Iterator begin, Iterator end){ + return periodic_iterator(begin, end); +} diff --git a/wavelet/periodic_iterator_test.cpp b/wavelet/periodic_iterator_test.cpp new file mode 100644 index 0000000..22f6c5e --- /dev/null +++ b/wavelet/periodic_iterator_test.cpp @@ -0,0 +1,28 @@ +#include +#include +#include +#include + +#include + +#include "periodic_iterator.hpp" +#include "striding_iterator.hpp" + +template +void print10(Iterator it){ + for(int i = 0; i < 10; ++i){ + std::cout << *it++ << std::endl; + } +} + +int main(){ + using namespace boost::assign; + std::vector v; + v += 0,1,2,3,4; + + print10(periodic(v.begin(), v.end())); + std::cout << "***\n"; + print10(periodic(strided(v.begin(), 1), strided(v.end(), 1)) + 2); + std::cout << "***\n"; + std::copy(v.begin(), v.end(), std::ostream_iterator(std::cout, "\n")); +} diff --git a/wavelet/striding_iterator.hpp b/wavelet/striding_iterator.hpp new file mode 100644 index 0000000..c69887e --- /dev/null +++ b/wavelet/striding_iterator.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include + +template +class striding_iterator : public boost::iterator_adaptor, Iterator> { + typedef boost::iterator_adaptor, Iterator> super_t; + friend class boost::iterator_core_access; +public: + striding_iterator() + : super_t() + , stride(1) + {} + + striding_iterator(Iterator it, int stride) + : super_t(it) + , stride(stride) + {} + + template + striding_iterator(striding_iterator const& r, typename boost::enable_if_convertible::type* = 0) + : super_t(r.base()) + {} + + int stride; + +private: + void advance(typename super_t::difference_type n){ + this->base_reference() += stride * n; + } + + template + typename super_t::difference_type distance_to(striding_iterator const & that) const { + int s = that.base() - this->base(); + if(s >= 0) return (s + stride - 1) / stride; + return (s - stride + 1) / stride; + } + + void increment() { advance(1); } + void decrement() { advance(-1); } +}; + +template +striding_iterator strided(Iterator it, int stride = 1){ + return striding_iterator(it, stride); +} diff --git a/wavelet/striding_iterator_test.cpp b/wavelet/striding_iterator_test.cpp new file mode 100644 index 0000000..b00e883 --- /dev/null +++ b/wavelet/striding_iterator_test.cpp @@ -0,0 +1,43 @@ +#include +#include + +#include + +#include "striding_iterator.hpp" + +template +striding_iterator double_stride(striding_iterator it){ + it.stride *= 2; + return it; +} + +template +void print(Iterator begin, Iterator end){ + while(begin < end){ + std::cout << *begin++ << ", "; + } +} + +template +void print_some_rec(Iterator begin, Iterator end){ + print(begin, end); + std::cout << std::endl; + + if(std::distance(begin, end) >= 2){ + print_some_rec(double_stride(begin), double_stride(end)); + print_some_rec(double_stride(begin+1), double_stride(end)); + } +} + +template +void print_some(Iterator begin, Iterator end){ + print_some_rec(strided(begin, 1), strided(end, 1)); +} + +int main(){ + using namespace boost::assign; + std::vector v; + v += 0,1,2,3,4; + + print_some(v.begin(), v.end()); +} diff --git a/wavelet/wavelet.hpp b/wavelet/wavelet.hpp new file mode 100644 index 0000000..f60487c --- /dev/null +++ b/wavelet/wavelet.hpp @@ -0,0 +1,44 @@ +#pragma once + +static double const evn_coef[] = { + (1.0 + std::sqrt(3.0))/(std::sqrt(32.0)), + (3.0 + std::sqrt(3.0))/(std::sqrt(32.0)), + (3.0 - std::sqrt(3.0))/(std::sqrt(32.0)), + (1.0 - std::sqrt(3.0))/(std::sqrt(32.0)) +}; + +static double const odd_coef[] = { + evn_coef[3], + -evn_coef[2], + evn_coef[1], + -evn_coef[0] +}; + +template +void wavelet_mul(Iterator begin, Iterator end){ + int mul = end - begin; + std::vector out(mul, 0.0); + for(int i = 0; i < mul; i += 2){ + out[i] = std::inner_product(evn_coef, evn_coef+4, periodic(begin, end) + i, 0.0); + out[i+1] = std::inner_product(odd_coef, odd_coef+4, periodic(begin, end) + i, 0.0); + } + for(int i = 0; i < mul; ++i){ + *begin++ = out[i]; + } +} + +template +void wavelet_inv(Iterator begin, Iterator end){ + int mul = end - begin; + std::vector out(mul, 0.0); + Iterator bc = begin; + for(int i = 0; i < mul; i += 2, begin += 2){ + Iterator b2 = begin + 1; + for(int j = 0; j < 4; ++j){ + out[(i+j) % mul] += *begin * evn_coef[j] + *b2 * odd_coef[j]; + } + } + for(int i = 0; i < mul; ++i){ + *bc++ = out[i]; + } +} diff --git a/wavelet/wavelet2.cpp b/wavelet/wavelet2.cpp new file mode 100644 index 0000000..f998746 --- /dev/null +++ b/wavelet/wavelet2.cpp @@ -0,0 +1,116 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "striding_iterator.hpp" +#include "periodic_iterator.hpp" + +#include "wavelet.hpp" + +bool is_pow_of_two(int n){ + return (n & (n - 1)) == 0; +} + +template +void shuffle(Iterator begin, Iterator end){ + typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::difference_type diff_type; + diff_type s = end - begin; + assert(s % 2 == 0); + + std::vector v(s); + std::copy(strided(begin , 2), strided(end , 2), v.begin()); + std::copy(strided(begin+1, 2), strided(end+1, 2), v.begin() + s/2); + std::copy(v.begin(), v.end(), begin); +} + +template +void unshuffle(Iterator begin, Iterator end){ + typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::difference_type diff_type; + diff_type s = end - begin; + assert(s % 2 == 0); + + std::vector v(s); + std::copy(begin, begin + s/2, strided(v.begin(), 2)); + std::copy(begin + s/2, end, strided(v.begin()+1, 2)); + std::copy(v.begin(), v.end(), begin); +} + +template +void wavelet(Iterator begin, Iterator end){ + int s = end - begin; + for(int i = s; i >= 4; i >>= 1){ + // half interval + end = begin + i; + assert(is_pow_of_two(end - begin)); + + // multiply with Wn + wavelet_mul(begin, end); + // then with Sn + shuffle(begin, end); + } +} + +template +void unwavelet(Iterator begin, Iterator end){ + int s = end - begin; + for(int i = 4; i <= s; i <<= 1){ + // double interval + end = begin + i; + assert(is_pow_of_two(end - begin)); + + // unshuffle: Sn^-1 + unshuffle(begin, end); + // then Wn^-1 + wavelet_inv(begin, end); + } +} + +struct filter{ + filter(double threshold) + : threshold(threshold) + {} + + void operator()(double& x){ + if(std::abs(x) <= threshold) x = 0; + } + + double threshold; +}; + +int main(){ + using namespace boost::assign; + std::vector input; + input += 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0; + + // print input + std::copy(input.begin(), input.end(), std::ostream_iterator(std::cout, "\n")); + std::cout << std::endl; + + std::vector thresholds; + thresholds += 0.0, 0.1, 0.2, 0.5; + for(int i = 0; i < thresholds.size(); ++i){ + std::vector v; + v = input; + + // transform to wavelet domain + wavelet(v.begin(), v.end()); + + // apply threshold + std::for_each(v.begin(), v.end(), filter(thresholds[i])); + int zeros = std::count(v.begin(), v.end(), 0.0); + + // transform back to sample domain + unwavelet(v.begin(), v.end()); + + // print compressed + std::cout << "\ncp: " << zeros / double(v.size()) << std::endl; + std::copy(v.begin(), v.end(), std::ostream_iterator(std::cout, "\n")); + std::cout << std::endl; + } +}