diff --git a/src/learning_graph.cpp b/src/learning_graph.cpp index 3afa768..3893089 100644 --- a/src/learning_graph.cpp +++ b/src/learning_graph.cpp @@ -17,9 +17,12 @@ static const char USAGE[] = R"(Generate a statistical learning graph from multiple runs Usage: - learning_graph ... + learning_graph [options] ... Options: + --testing_only Only count the figures for testing + --learning_only Only count the figures for learning + --accumulate Accumulates the data -h, --help Show this screen --version Show version )"; @@ -73,12 +76,24 @@ void print_quantiles(C const & container, S && selector, ostream & out) { out << selector(sorted_container.back()); } +auto all(datapoint const & p) { + return p.learning_queries + p.learning_inputs + p.testing_queries + p.testing_inputs; +} +auto testing(datapoint const & p) { + return p.testing_queries + p.testing_inputs; +} +auto learning(datapoint const & p) { + return p.learning_queries + p.learning_inputs; +} + int main(int argc, char * argv[]) { const auto args = docopt::docopt(USAGE, {argv + 1, argv + argc}, true, __DATE__ __TIME__); + const auto field = args.at("--testing_only").asBool() ? &testing : args.at("--learning_only").asBool() ? &learning : &all; + vector> dataset_futures; for (auto const & filename : args.at("").asStringList()) { - dataset_futures.emplace_back(async([filename] { + dataset_futures.emplace_back(async([filename, &args] { fstream file(filename); if (!file) throw runtime_error("Could not open file " + filename); @@ -89,7 +104,8 @@ int main(int argc, char * argv[]) { s.push_back(p); } - accumulate_dataset(s); + if (args.at("--accumulate").asBool()) + accumulate_dataset(s); return s; })); @@ -143,10 +159,8 @@ int main(int argc, char * argv[]) { // if we're spot on, update current if (it.next->states == state) it.current = it.next; - const auto v2 = it.next->learning_queries + it.next->learning_inputs - + it.next->testing_queries + it.next->testing_inputs; - const auto v1 = it.current->learning_queries + it.current->learning_inputs - + it.current->testing_queries + it.current->testing_inputs; + const auto v2 = field(*it.next); + const auto v1 = field(*it.current); const auto ratio = it.next->states == state ? 1.0