ag_gen/src/mpi/serialize.cpp
2022-01-13 14:11:56 -06:00

380 lines
10 KiB
C++

#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <unistd.h>
#include "../ag_gen/asset.h"
#include "../ag_gen/assetgroup.h"
#include "../ag_gen/edge.h"
#include "../ag_gen/exploit.h"
#include "../ag_gen/factbase.h"
#include "../ag_gen/network_state.h"
#include "../ag_gen/ag_gen.h"
#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/properties.hpp>
#include <boost/graph/graphviz.hpp>
#include <boost/property_tree/ptree.hpp>
#include <boost/property_tree/ini_parser.hpp>
#include <boost/graph/visitors.hpp>
#include <boost/graph/depth_first_search.hpp>
#include <boost/archive/tmpdir.hpp>
#include <boost/archive/text_iarchive.hpp>
#include <boost/archive/text_oarchive.hpp>
#include <boost/serialization/base_object.hpp>
#include <boost/serialization/utility.hpp>
#include <boost/serialization/list.hpp>
#include <boost/serialization/assume_abstract.hpp>
#include <boost/serialization/string.hpp>
#include <boost/mpi.hpp>
#include <boost/mpi/environment.hpp>
#include <boost/mpi/communicator.hpp>
#include <boost/serialization/is_bitwise_serializable.hpp>
namespace mpi = boost::mpi;
void save_quality(const Quality &q, const char * filename){
std::ofstream ofs(filename);
boost::archive::text_oarchive oa(ofs);
oa << q;
}
void restore_quality(Quality &q, const char * filename){
std::ifstream ifs(filename);
boost::archive::text_iarchive ia(ifs);
ia >> q;
}
int quality_check(Quality &q1, Quality &q2){
if (q1.get_asset_id() == q2.get_asset_id() &&
q1.get_name() == q2.get_name() &&
q1.get_op() == q2.get_op() &&
q1.get_value() == q2.get_value())
{
return 1;
}
else return 0;
}
void save_topology(const Topology &t, const char * filename){
std::ofstream ofs(filename);
boost::archive::text_oarchive oa(ofs);
oa << t;
}
void restore_topology(Topology &t, const char * filename){
std::ifstream ifs(filename);
boost::archive::text_iarchive ia(ifs);
ia >> t;
}
int topology_check(Topology &t1, Topology &t2){
if (t1.get_from_asset_id() == t2.get_from_asset_id() &&
t1.get_to_asset_id() == t2.get_to_asset_id() &&
t1.get_dir() == t2.get_dir() &&
t1.get_property() == t2.get_property() &&
t1.get_op() == t2.get_op() &&
t1.get_value() == t2.get_value())
{
return 1;
}
else return 0;
}
void save_factbase(const Factbase &fb, const char * filename){
std::ofstream ofs(filename);
boost::archive::text_oarchive oa(ofs);
oa << fb;
}
void restore_factbase(Factbase &fb, const char * filename){
std::ifstream ifs(filename);
boost::archive::text_iarchive ia(ifs);
ia >> fb;
}
int factbase_check(Factbase &fb1, Factbase &fb2){
int qual_count = 0;
int qual_corr = 0;
int topo_count = 0;
int topo_corr = 0;
auto fb1_tuple = fb1.get_facts_tuple();
auto fb1_quals = std::get<0>(fb1_tuple);
auto fb1_topos = std::get<1>(fb1_tuple);
auto fb2_tuple = fb2.get_facts_tuple();
auto fb2_quals = std::get<0>(fb2_tuple);
auto fb2_topos = std::get<1>(fb2_tuple);
auto itq1 = fb1_quals.begin();
auto itq2 = fb2_quals.begin();
auto itt1 = fb1_topos.begin();
auto itt2 = fb2_topos.begin();
while(itq1 != fb1_quals.end() || itq2 != fb2_quals.end())
{
qual_corr += quality_check(*itq1, *itq2);
qual_count++;
if(itq1 != fb1_quals.end())
++itq1;
if(itq2 != fb2_quals.end())
++itq2;
}
while(itt1 != fb1_topos.end() || itt2 != fb2_topos.end())
{
topo_corr += topology_check(*itt1, *itt2);
topo_count++;
if(itt1 != fb1_topos.end())
++itt1;
if(itt2 != fb2_topos.end())
++itt2;
}
if(fb1.get_id() == fb2.get_id() &&
qual_count == qual_corr &&
topo_count == topo_corr)
{
return 1;
}
else return 0;
}
void save_network_state(const NetworkState &ns, const char * filename){
std::ofstream ofs(filename);
boost::archive::text_oarchive oa(ofs);
oa << ns;
}
void restore_network_state(NetworkState &ns, const char * filename){
std::ifstream ifs(filename);
boost::archive::text_iarchive ia(ifs);
ia >> ns;
}
int network_state_check(NetworkState &ns1, NetworkState &ns2){
auto fb1 = ns1.get_factbase();
auto fb2 = ns2.get_factbase();
return factbase_check(fb1, fb2);
}
void serialization_unit_testing(AGGenInstance &instance, boost::mpi::communicator &world){
char hammer_host[256];
gethostname(hammer_host, 256);
std::string str_host(hammer_host);
if(world.rank() == 0){
printf("\n");
std::cout << "---------STARTING SERIALIZATION UNIT TESTING---------" << std::endl;
}
world.barrier();
std::string rollcall = "Hello from process " + std::to_string(world.rank())\
+ " of " + std::to_string(world.size()) + " running on "\
+ str_host + ".";
if(world.rank() != 0)
world.send(0, 0, rollcall);
else{
for(int i = 0; i < world.size()-1; i++){
world.recv(mpi::any_source, 0, str_host);
std::cout << str_host << std::endl;
}
std::cout << "" << std::endl;
}
world.barrier();
std::string filename(boost::archive::tmpdir());
filename += "/qualfile.txt";
auto init_quals = instance.initial_qualities;
auto init_topos = instance.initial_topologies;
NetworkState init_state(init_quals, init_topos);//instantiate an obj init_state with initial input
int e_flag = 0;
int qual_count = 0;
int qual_corr = 0;
int top_count = 0;
int top_corr = 0;
int fb_count = 0;
int fb_corr = 0;
int ns_count = 0;
int ns_corr = 0;
if(world.rank() == 0){
std::cout << "Performing Unit Testing on Quality Serialization." << std::endl;
}
for (auto qual : instance.initial_qualities)
{
Quality new_qual;
if(world.rank() == 0)
new_qual = qual;
mpi::request req;
broadcast(world, new_qual, 0);
qual_count++;
qual_corr += quality_check(qual, new_qual);
}
int total_qual_corr;
reduce(world, qual_corr, total_qual_corr, std::plus<int>(), 0);
if (world.rank() == 0){
std::cout << "Quality Unit Testing: " << std::to_string(total_qual_corr) << "/" << std::to_string(world.size() * qual_count) << std::endl;
printf("\n");
std::cout << "Performing Unit Testing on Topology Serialization." << std::endl;
}
for (auto topo : instance.initial_topologies)
{
Topology new_top;
if(world.rank() == 0)
new_top = topo;
mpi::request req;
broadcast(world, new_top, 0);
top_count++;
top_corr += topology_check(topo, new_top);
}
int total_top_corr;
reduce(world, top_corr, total_top_corr, std::plus<int>(), 0);
if (world.rank() == 0){
std::cout << "Topology Unit Testing: " << std::to_string(total_top_corr) << "/" << std::to_string(world.size() * top_count) << std::endl;
printf("\n");
std::cout << "Performing Unit Testing on Factbase Serialization." << std::endl;
}
for (auto fb : instance.factbases)
{
Factbase new_fb;
int id;
if(world.rank() == 0){
new_fb = fb;
id = fb.get_id();
}
mpi::request req;
broadcast(world, new_fb, 0);
broadcast(world, id, 0);
new_fb.force_set_id(id);
fb_count++;
fb_corr += factbase_check(fb, new_fb);
}
int total_fb_corr;
reduce(world, fb_corr, total_fb_corr, std::plus<int>(), 0);
if (world.rank() == 0){
std::cout << "Factbase Unit Testing: " << std::to_string(total_fb_corr) << "/" << std::to_string(world.size() * fb_count) << std::endl;
printf("\n");
std::cout << "Performing Unit Testing on Network State Serialization." << std::endl;
}
NetworkState new_ns;
int id;
if(world.rank() == 0){
new_ns = init_state;
id = init_state.get_id();
}
mpi::request req;
broadcast(world, new_ns, 0);
broadcast(world, id, 0);
new_ns.force_set_id(id);
ns_count++;
ns_corr += network_state_check(init_state, new_ns);
int total_ns_corr;
reduce(world, ns_corr, total_ns_corr, std::plus<int>(), 0);
if (world.rank() == 0){
std::cout << "Network State Unit Testing: " << std::to_string(total_ns_corr) << "/" << std::to_string(world.size() * ns_count) << std::endl;
printf("\n");
}
if(world.rank() == 0){
if(total_qual_corr == world.size() * qual_count)
{
std::cout << "100% Success Rate for Quality Serialization." << std::endl;
}
else{
std::cout << "Errors occurred in the Quality Serialization." << std::endl;
e_flag = 1;
}
if(total_top_corr == world.size() * top_count)
{
std::cout << "100% Success Rate for Topology Serialization." << std::endl;
}
else{
std::cout << "Errors occurred in the Topology Serialization." << std::endl;
e_flag = 1;
}
if(total_fb_corr == world.size() * fb_count)
{
std::cout << "100% Success Rate for Factbase Serialization." << std::endl;
}
else{
std::cout << "Errors occurred in the Factbase Serialization." << std::endl;
e_flag = 1;
}
if(total_ns_corr == world.size() * ns_count)
{
std::cout << "100% Success Rate for Network State Serialization." << std::endl;
}
else{
std::cout << "Errors occurred in the Network State Serialization." << std::endl;
e_flag = 1;
}
std::cout << "" << std::endl;
if(e_flag == 1)
std::cout << "-------------ERRORS OCCURRED DURING SERIALIZATION UNIT TESTING---------------" << std::endl;
std::cout << "---------FINISHED SERIALIZATION UNIT TESTING---------" << std::endl;
printf("\n");
}
//MPI Clean-up
//MPI_Finalize();
return;
}