MPI Subgraphing

This commit is contained in:
Noah L. Schrick 2022-02-06 22:55:58 -06:00
parent c3976f1d46
commit ea713ea9cb
15 changed files with 14392 additions and 116 deletions

File diff suppressed because it is too large Load Diff

12046
build/ag.svg

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 6.9 KiB

After

Width:  |  Height:  |  Size: 785 KiB

Binary file not shown.

View File

@ -57,9 +57,9 @@ if [ "$TYPE" == "$strval1" ]; then
if [ "$(dnsdomainname)" = "hammer.esg.utulsa.edu" ]; then
mpiexec --mca btl_openib_allow_ib 1 --mca btl openib,self,vader --mca opal_warn_on_missing_libcuda 0 --bind-to numa --map-by numa -np "$NODES" --timeout 129600 ./ag_gen -n ../Oct_2021/nm_files/"$CARS"_car_timeline_maintenance.nm -x ../Oct_2021/Sync/6_Exploits/"$NUM_SERV"_Serv/sync_timeline_maintenance.xp -t "$NUM_THREADS" -q 1 -p -a 0.6 -z "$DBNAME"
mpiexec --mca btl_openib_allow_ib 1 --mca btl openib,self,vader --mca opal_warn_on_missing_libcuda 0 --bind-to numa --map-by numa -np "$NODES" --timeout 129600 ./ag_gen -n ../Oct_2021/nm_files/"$CARS"_car_timeline_maintenance.nm -x ../Oct_2021/Sync/6_Exploits/"$NUM_SERV"_Serv/sync_timeline_maintenance.xp -t "$NUM_THREADS" -q 1 -p -a 0.6 -z "$DBNAME" -s -l 20
else
mpiexec --mca btl_openib_allow_ib 1 --mca opal_warn_on_missing_libcuda 0 --bind-to numa --map-by numa -np "$NODES" --timeout 129600 ./ag_gen -n ../Oct_2021/nm_files/"$CARS"_car_timeline_maintenance.nm -x ../Oct_2021/Sync/4_Exploits/"$NUM_SERV"_Serv/sync_timeline_maintenance.xp -t "$NUM_THREADS" -q 1 -p -a 0.6 -z "$DBNAME"
mpiexec --mca btl_openib_allow_ib 1 --mca opal_warn_on_missing_libcuda 0 --bind-to numa --map-by numa -np "$NODES" --timeout 129600 ./ag_gen -n ../Oct_2021/nm_files/"$CARS"_car_timeline_maintenance.nm -x ../Oct_2021/Sync/4_Exploits/"$NUM_SERV"_Serv/sync_timeline_maintenance.xp -t "$NUM_THREADS" -q 1 -p -s -l 20 -a 0.6 -g DOTFILE.dot -z "$DBNAME"
fi
# 4 Exploit
#mpiexec -np "$NODES" --bind-to numa --map-by numa ./ag_gen -n ../Oct_2021/nm_files/"$CARS"_car_timeline_maintenance.nm -x ../Oct_2021/Sync/4_Exploits/"$NUM_SERV"_Serv/sync_timeline_maintenance.xp -t 1 -q 1 -p -a 0.6 -z "$DBNAME"
@ -77,6 +77,9 @@ else
fi
#Graphviz Strict graphing to avoid duplicate nodes and edges
#echo -n 'strict ' | cat - DOTFILE.dot > temp && mv temp DOTFILE.dot
#dot -Tsvg new.dot > ag.svg
#dot -Tsvg DOTFILE.dot > ag.svg

View File

@ -11,6 +11,8 @@
#include <sys/time.h>
#include <string.h>
#include <map>
#include <random>
#include <unordered_set>
#include "ag_gen.h"
@ -29,6 +31,7 @@
#include <boost/serialization/assume_abstract.hpp>
#include <boost/serialization/string.hpp>
#include <boost/serialization/vector.hpp>
#include <boost/serialization/deque.hpp>
#include <boost/mpi.hpp>
#include <boost/mpi/environment.hpp>
@ -595,6 +598,9 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
struct timeval t41,t42;
gettimeofday(&t41,NULL);
save_ag_to_db(instance, true);
std::vector<Factbase>().swap(instance.factbases);
std::vector<FactbaseItems>().swap(instance.factbase_items);
std::vector<Edge>().swap(instance.edges);
gettimeofday(&t42,NULL);
total_task4+=(t42.tv_sec-t41.tv_sec)*1000.0+(t42.tv_usec-t41.tv_usec)/1000.0;
}
@ -1168,3 +1174,517 @@ AGGenInstance &AGGen::single_generate(bool batch_process, int batch_num, int num
return instance;
}
AGGenInstance &AGGen::sg_generate(bool batch_process, int batch_num, int numThrd,\
int initQSize, double mem_threshold, mpi::communicator &world, int state_limit){
//Init all Nodes with these variables
std::vector<Exploit> exploit_list = instance.exploits;
//Create a vector that contains all the groups of exploits to be fired synchonously
std::vector<std::string> ex_groups;
for (const auto &ex : exploit_list) {
//If the group isn't already in the vector
if(!(std::find(ex_groups.begin(), ex_groups.end(), ex.get_group()) !=ex_groups.end())) {
//Don't include the "no" group
if(ex.get_group()!="null")
ex_groups.emplace_back(ex.get_group());
}
}
//Print out the groups if desired
if (world.rank() == 0){
std::cout <<"\nThere are "<<ex_groups.size()<<" groups: ";
for(int i=0; i<ex_groups.size(); i++){
std::cout<<ex_groups[i] << ". ";
}
std::cout<<"\n";
}
auto counter = 0;
auto start = std::chrono::system_clock::now();
unsigned long esize = exploit_list.size();
if (world.rank() ==0){
printf("esize: %lu\n", esize);
}
bool save_queued = false;
if (world.rank() == 0){
std::cout << "Generating Attack Graph Using Subgraphing" << std::endl;
}
std::unordered_map<size_t, PermSet<size_t>> od_map;
size_t assets_size = instance.assets.size();
for (const auto &ex : exploit_list) {
size_t num_params = ex.get_num_params();
if (od_map.find(num_params) == od_map.end()) {
Odometer<size_t> od(num_params, assets_size);
od_map[num_params] = od.get_all();
}
}
int frt_size=frontier.size();
if (world.rank() ==0){
printf("The actual QSize to start using multiple threads is %d\n",frt_size);
}
double total_t=0.0;
//unit:ms
double total_task0, total_task1, total_task2, total_task3, total_task4 = 0.0;
struct timeval t1,t2;
gettimeofday(&t1,NULL);
std::deque<NetworkState> localFrontier;
if(world.rank() == 0){
double f_alpha = 0.0;
auto tot_sys_mem = getTotalSystemMemory();
localFrontier.emplace_front(frontier[0]);
}
int finished_signal = 0;
int send_msg = 0;
int state_count = 0;
std::random_device rd; // obtain a random number from hardware
std::mt19937 gen(rd()); // seed the generator
//Make separate comm to not sync with db node
mpi::communicator work_comm = world.split(world.rank() != 1);
//std::unordered_set<NetworkState> localFrontier_seen;
//Send new Network State to all worker nodes, if we have enough unex states to do so
if(world.rank() == 0){
//2 offset for root node and db node
for (int w = 0; w < std::min((int)world.size()-2, (int)localFrontier.size()); w++){
mpi::request state_req = world.isend(w+2, 1, localFrontier.front());
localFrontier.pop_front();
state_req.wait();
}
}
//Main Work Loop
while((!localFrontier.empty() && !finished_signal) || world.rank() > 0){
//Refill localFrontier if needed
if(localFrontier.empty() && world.rank() == 0) {
task_zero(instance, localFrontier, mem_threshold);
}
//Don't sync with db node
work_comm.barrier();
if (world.rank() > 1){
//Check for finished signal
if(world.iprobe(0, 99)){
break;
}
if (world.iprobe(0, 1)){
NetworkState current_state;
world.recv(0, 1, current_state);
state_count = 0;
send_msg = 1;
while(!localFrontier.empty()){
if (state_count < state_limit){
//Do work
//Force set NS id to hash
auto current_state = localFrontier.front();
auto current_hash = current_state.get_hash(instance.facts);
localFrontier.pop_front();
std::vector<std::tuple<Exploit, AssetGroup>> appl_exploits;
for (size_t i = 0; i < esize; i++) {//for loop for applicable exploits starts
auto e = exploit_list.at(i);
size_t num_params = e.get_num_params();
auto preconds_q = e.precond_list_q();
auto preconds_t = e.precond_list_t();
auto perms = od_map[num_params];
std::vector<AssetGroup> asset_groups;
for (auto perm : perms) {
std::vector<Quality> asset_group_quals;
std::vector<Topology> asset_group_topos;
asset_group_quals.reserve(preconds_q.size());
asset_group_topos.reserve(preconds_t.size());
//std::vector<int>::size_type sz;
//sz=asset_group_quals.capacity();
for (auto &precond : preconds_q) {
//Old quality encode caused this to crash
asset_group_quals.emplace_back(
perm[precond.get_param_num()], precond.name, precond.op,
precond.value, instance.facts);
}
for (auto &precond : preconds_t) {
auto dir = precond.get_dir();
auto prop = precond.get_property();
auto op = precond.get_operation();
auto val = precond.get_value();
asset_group_topos.emplace_back(
perm[precond.get_from_param()],
perm[precond.get_to_param()], dir, prop, op, val, instance.facts);
}
asset_groups.emplace_back(asset_group_quals, asset_group_topos,
perm);
}
auto assetgroup_size = asset_groups.size();
for (size_t j = 0; j < assetgroup_size; j++) {
auto asset_group = asset_groups.at(j);
for (auto &quality : asset_group.get_hypo_quals()) {
if (!current_state.get_factbase().find_quality(quality)) {
goto LOOPCONTINUE1;
}
}
for (auto &topology : asset_group.get_hypo_topos()) {
if (!current_state.get_factbase().find_topology(topology)) {
goto LOOPCONTINUE1;
}
}
{
auto new_appl_exploit = std::make_tuple(e, asset_group);
appl_exploits.push_back(new_appl_exploit);
}
LOOPCONTINUE1:;
}
} //for loop for creating applicable exploits ends
std::map<std::string, int> group_fired; //Map to hold fired status per group
std::map<std::string, std::vector<std::tuple<Exploit, AssetGroup>>> sync_vectors; //Map to hold all group exploits
for (auto map_group : ex_groups)
{
group_fired.insert(std::pair<std::string, int> (map_group, 0));
}
//Build up the map of synchronous fire exploits
for(auto itr=appl_exploits.begin(); itr!=appl_exploits.end(); itr++){
//auto e = appl_exploits.at(itr);
auto e = *itr;
auto egroup = std::get<0>(e).get_group();
if (egroup != "null"){
sync_vectors[egroup].push_back(e);
}
}
//loop through the vector
for(auto itr=appl_exploits.begin(); itr!=appl_exploits.end(); itr++){
auto e = *itr;
auto exploit = std::get<0>(e);
auto assetGroup = std::get<1>(e);
//std::cout<<exploit.get_name()<<std::endl;
auto egroup=exploit.get_group();
if ((egroup != "null" && group_fired[egroup] == 0) || egroup == "null"){
NetworkState new_state{current_state};
std::vector<std::tuple<Exploit, AssetGroup>> sync_exploits;
if (egroup == "null")
sync_exploits.push_back(e);
else {
sync_exploits = sync_vectors[egroup];
//TODO: Does not work if only some assets belong to a group. This only works if
//all assets are in the group
if(sync_exploits.size() < instance.assets.size()){
break;
}
}
for(auto sync_itr=sync_exploits.begin(); sync_itr!=sync_exploits.end(); sync_itr++){
e = *sync_itr;
exploit = std::get<0>(e);
egroup=exploit.get_group();
assetGroup = std::get<1>(e);
group_fired[egroup] = 1;
auto postconditions = createPostConditions(e, instance.facts);
auto qualities = std::get<0>(postconditions);
auto topologies = std::get<1>(postconditions);
for(auto &qual : qualities) {
auto action = std::get<0>(qual);
auto fact = std::get<1>(qual);
switch(action) {
case ADD_T:
new_state.add_quality(fact);
break;
case UPDATE_T:
new_state.update_quality(fact);
//TODO: if fact!= "="" call new_state function, passing fact and instance.facts. Update the quality, and insert it into the hash_table instead of this convoluted mess
if(fact.get_op()=="+="){
//std::cout<<" AFTER UPDATE "<<new_state.compound_assign(fact)<<std::endl;
std::unordered_map<std::string,int>::const_iterator got = instance.facts.hash_table.find(new_state.compound_assign(fact));
//If the value is not already in the hash_table, insert it.
//Since the compound operators include a value that is not in the original Keyvalue object, the unordered map does not include it
//As a result, you have to manually add it.
if(got==instance.facts.hash_table.end()){
instance.facts.hash_table[new_state.compound_assign(fact)]=instance.facts.size();
instance.facts.length++;
instance.facts.str_vector.push_back(new_state.compound_assign(fact));
//Update ALL nodes (include ttwo_comm nodes) with new data
for (int w = 0; w < world.size(); w++)
{
if(w != world.rank() && w != 1)
{
mpi::request ns_req = world.isend(w, 5, new_state);
mpi::request fact_req = world.isend(w, 6, fact);
ns_req.wait();
fact_req.wait();
}
}
}
}
break;
case DELETE_T:
new_state.delete_quality(fact);
break;
}
}
for(auto &topo : topologies) {
auto action = std::get<0>(topo);
auto fact = std::get<1>(topo);
switch(action) {
case ADD_T:
new_state.add_topology(fact);
break;
case UPDATE_T:
new_state.update_topology(fact);
break;
case DELETE_T:
new_state.delete_topology(fact);
break;
}
}
}//Sync. Fire for
auto hash_num = new_state.get_hash(instance.facts);
if (hash_num == current_hash)
continue;
#pragma omp critical
//although local frontier is updated, the global hash is also updated to avoid testing on explored states.
if (hash_map.find(hash_num) == hash_map.end()) {
new_state.force_set_id(hash_num);
instance.factbases.push_back(new_state.get_factbase());
hash_map.insert(std::make_pair(new_state.get_hash(instance.facts), new_state.get_id()));
localFrontier.emplace_front(new_state);
Edge ed(current_state.get_hash(instance.facts), new_state.get_id(), exploit, assetGroup);
ed.set_id();
instance.edges.push_back(ed);
} //END if (hash_map.find(hash_num) == hash_map.end())
else {
auto id = hash_map[hash_num];
Edge ed(current_state.get_id(), id, exploit, assetGroup);
ed.set_id();
instance.edges.push_back(ed);
}
} //sync fire if
else
break;
} //for loop for new states ends
}
else
break;
}
}
//Let root node we finished, but only if we have done work since our last message
if(send_msg == 1){
world.isend(0, 2, 1);
send_msg = 0;
if(localFrontier.size() > 0){
world.isend(0, 3, localFrontier);
}
//Send new states and edges, then clear worker instance
world.isend(0, 10, instance.factbases);
world.isend(0, 11, instance.edges);
std::vector<Factbase>().swap(instance.factbases);
std::vector<FactbaseItems>().swap(instance.factbase_items);
std::vector<Edge>().swap(instance.edges);
}
//Check for new fact and new state that caused an update in the hash table and facts
while(world.iprobe(mpi::any_source, 5) || world.iprobe(mpi::any_source, 6)){
NetworkState update_state;
Quality update_fact;
world.recv(mpi::any_source, 5, update_state);
world.recv(mpi::any_source, 6, update_fact);
//Update
instance.facts.hash_table[update_state.compound_assign(update_fact)]=instance.facts.size();
instance.facts.length++;
instance.facts.str_vector.push_back(update_state.compound_assign(update_fact));
}
} //end worker nodes
else if (world.rank() == 1){
//Check for finished signal, assuming we have no more storage requests
if(world.iprobe(0, 99) && !world.iprobe(0,7) && !world.iprobe(0,8)){
break;
}
//Check for instance storage requests
if(world.iprobe(0, 7) || world.iprobe(0, 8)){
std::vector<Factbase> factbases_dump;
std::vector<Edge> edges_dump;
world.recv(0, 7, factbases_dump);
world.recv(0, 8, edges_dump);
instance.factbases = factbases_dump;
instance.edges = edges_dump;
save_ag_to_db(instance, true);
//Clear vectors and free memory
std::vector<Factbase>().swap(instance.factbases);
std::vector<FactbaseItems>().swap(instance.factbase_items);
std::vector<Edge>().swap(instance.edges);
}
//Check for frontier storage requests
while(world.iprobe(0, 50)){
NetworkState save_state;
world.recv(0, 50, save_state);
save_unexplored_to_db(save_state);
}
} //end world rank =1
//World Rank = 0
else{
std::map<int, int> deque_marker;
int finish_count = 0;
//Receive states and edges when nodes finish
while(finish_count != world.size() -2){
while(world.iprobe(mpi::any_source, 10) || world.iprobe(mpi::any_source, 11)) {
std::vector<Factbase> node_factbases;
std::vector<Edge> node_edges;
world.recv(mpi::any_source, 10, node_factbases);
world.recv(mpi::any_source, 11, node_edges);
state_merge(node_factbases, node_edges, hash_map, instance, mem_threshold, world);
}
//Nodes finish
for(int w = 2; w < world.size(); w++){
int dummy_finish = 0;
if(world.iprobe(w, 2)){
world.recv(w, 2, dummy_finish);
finish_count++;
}
}
//Check for new fact and new state that caused an update in the hash table and facts
while(world.iprobe(mpi::any_source, 5) || world.iprobe(mpi::any_source, 6)){
NetworkState update_state;
Quality update_fact;
world.recv(mpi::any_source, 5, update_state);
world.recv(mpi::any_source, 6, update_fact);
//Update
instance.facts.hash_table[update_state.compound_assign(update_fact)]=instance.facts.size();
instance.facts.length++;
instance.facts.str_vector.push_back(update_state.compound_assign(update_fact));
}
/*
//Rather than just busy-wait for the remainder of the time, remove duplicates from localFrontier
for(auto itr = localFrontier.begin(); itr != localFrontier.end();){
auto ret = localFrontier_seen.insert(*itr);
if(!ret.second){
itr = localFrontier.erase(itr);
}
else
itr++;
}
*/
}
if(localFrontier.empty() && finish_count == world.size() -2){
finished_signal = 1;
}
//Receive node frontiers and merge them into root frontier
for(int w = 2; w < world.size(); w++){
if(world.iprobe(w, 3)){
std::deque<NetworkState> nodeFrontier;
world.recv(w, 3, nodeFrontier);
localFrontier.insert(localFrontier.begin(), \
std::make_move_iterator(nodeFrontier.begin()),\
std::make_move_iterator(nodeFrontier.end()));
deque_marker[w] = (int)(localFrontier.size()-1);
}
//If a node doesn't have a specific state to pull, randomly assign it one
if (!deque_marker.count(w)){
//Randomly assign a state index to pop
std::uniform_int_distribution<> distr(0, localFrontier.size()-1); // define the range
deque_marker[w] = distr(gen);
}
}
//Send new Network State to all worker nodes, if we have enough unex states to do so
//2 offset for root node and db node
for (int w = 0; w < world.size()-2; w++){
int proceed = 0;
while(!proceed){
if(localFrontier.size() > 0){
if(deque_marker[w+2] > localFrontier.size())
deque_marker[w+2]--;
//auto deque_access = localFrontier.begin() + deque_marker[w+2];
auto deque_access = deque_marker[w+2];
NetworkState send_state = localFrontier.at(deque_access);
//Don't explore on states we already have explored
if(hash_map.find(send_state.get_id()) != hash_map.end()){
localFrontier.erase(localFrontier.begin()+deque_access);
deque_marker[w+2]--;
}
if(proceed){
mpi::request state_req = world.isend(w+2, 1, send_state);
localFrontier.erase(localFrontier.begin()+deque_access);
state_req.wait();
}
}
}
}
} //end world rank 0
} // end main work loop
//Tell all nodes to finish
if(world.rank() == 0){
for(int w = 1; w < world.size(); w++){
world.send(w, 99, 1);
}
}
world.barrier();
if(world.rank() == 0){
gettimeofday(&t2,NULL);
total_t+=(t2.tv_sec-t1.tv_sec)*1000.0+(t2.tv_usec-t1.tv_usec)/1000.0;
printf("AG TOOK %lf ms.\n", total_t);
auto end = std::chrono::system_clock::now();
std::chrono::duration<double> elapsed_seconds = end - start;
instance.elapsed_seconds = elapsed_seconds;
}
return instance;
}

View File

@ -88,6 +88,9 @@ class AGGen {
AGGen(AGGenInstance &_instance, RedisManager &_rman);
#endif
AGGenInstance &sg_generate(bool batch_process, int batch_num, int numThrd,\
int initQSize, double mem_threshold, boost::mpi::communicator &world, int state_limit);
AGGenInstance &generate(bool batch_process, int batch_num, int numThrd,\
int initQSize, double mem_threshold, boost::mpi::communicator &world);

View File

@ -17,13 +17,17 @@
* @param ex Exploit associated with the Edge
* @param ag AssetGroup associated with the Edge
*/
Edge::Edge(int iFrom, int iTo, Exploit &ex, AssetGroup &ag)
//Edge::Edge(int iFrom, int iTo, Exploit &ex, AssetGroup &ag)
// : from_node(iFrom), to_node(iTo), exploit(ex), assetGroup(ag), deleted(false) {}
Edge::Edge(size_t iFrom, size_t iTo, Exploit &ex, AssetGroup &ag)
: from_node(iFrom), to_node(iTo), exploit(ex), assetGroup(ag), deleted(false) {}
Edge::Edge()
{
}
/**
* @return The Edge ID
*/
@ -33,18 +37,28 @@ void Edge::set_deleted() { deleted = true; }
bool Edge::is_deleted() { return deleted; }
int Edge::get_from_id()
//int Edge::get_from_id()
//{
//return from_node;
//}
size_t Edge::get_from_id()
{
return from_node;
}
int Edge::get_to_id()
//int Edge::get_to_id()
//{
// return to_node;
//}
size_t Edge::get_to_id()
{
return to_node;
}
int Edge::get_exploit_id()

View File

@ -21,8 +21,9 @@
/** Edge class
* @brief Edge of the graph.
* @brief Edge of the graph based on integer id.
*/
/*
class Edge {
static int edge_current_id;
int id;
@ -55,5 +56,40 @@ class Edge {
void set_deleted();
bool is_deleted();
};
*/
// Edge class based on hash
class Edge {
static int edge_current_id;
int id;
size_t from_node;
size_t to_node;
Exploit exploit;
AssetGroup assetGroup;
bool deleted;
friend std::ostream & operator << (std::ostream &os, const Edge &ed);
friend class boost::serialization::access;
template<class Archive>
void serialize(Archive &ar, const unsigned int version){
ar & edge_current_id & id & from_node & to_node & exploit & assetGroup & deleted;
}
public:
Edge(size_t, size_t, Exploit &, AssetGroup &);
Edge();
std::string get_query();
std::string get_asset_query();
int get_id();
int set_id();
size_t get_from_id();
size_t get_to_id();
int get_exploit_id();
void set_deleted();
bool is_deleted();
};
#endif // AG_GEN_EDGE_H

View File

@ -42,10 +42,14 @@ void Factbase::force_set_id(int i) {
id = i;
}
void Factbase::force_set_id(size_t i) {
id = i;
}
/**
* @return The current Factbase ID.
*/
int Factbase::get_id() const { return id; }
size_t Factbase::get_id() const { return id; }
std::tuple<std::vector<Quality>, std::vector<Topology>> Factbase::get_facts_tuple() const {
return std::make_tuple(qualities, topologies);

View File

@ -32,7 +32,7 @@ class Factbase {
friend std::ostream & operator << (std::ostream &os, const Factbase &fb);
friend class boost::serialization::access;
int id;
size_t id;
int qsize;
int tsize;
@ -73,7 +73,9 @@ class Factbase {
void print() const;
void set_id();
void force_set_id(int i);
int get_id() const;
void force_set_id(size_t i);
size_t get_id() const;
size_t hash(Keyvalue &factlist) const;
int get_size();
};

View File

@ -42,7 +42,7 @@ void NetworkState::force_set_id(int i) { factbase.force_set_id(i); }
/**
* @return The ID of the NetworkState
*/
int NetworkState::get_id() { return factbase.get_id(); }
size_t NetworkState::get_id() { return factbase.get_id(); }
int NetworkState::get_size() { return factbase.get_size(); }
@ -179,6 +179,7 @@ void NetworkState::delete_topology(Topology &t) {
}
}
// int NetworkState::compare(std::string &hash, RedisManager* rman) const {
// if (!rman->check_collision(hash)) {
// if (!rman->check_facts(hash, factbase.qualities, factbase.topologies))

View File

@ -57,7 +57,7 @@ class NetworkState {
void set_id();
void force_set_id(int i);
int get_id();
size_t get_id();
int get_size();
void add_qualities(std::vector<Quality> q);
@ -72,6 +72,8 @@ class NetworkState {
void delete_quality(Quality &q);
void delete_topology(Topology &t);
//bool operator==(NetworkState& foo) {return get_id() == foo.get_id();}
};
BOOST_SERIALIZATION_ASSUME_ABSTRACT(NetworkState)

View File

@ -444,17 +444,19 @@ int main(int argc, char *argv[]) {
int thread_count;
int init_qsize;
int mpi_nodes;
int depth_limit;
bool should_graph = false;
bool no_cycles = false;
bool batch_process = false;
bool use_redis = false;
bool use_postgres = false;
bool mpi_subgraphing = false;
double alpha = 0.5;
int opt;
while ((opt = getopt(argc, argv, "rb:g:dhc:n:x:t:q:pa:m:z:")) != -1) {
while ((opt = getopt(argc, argv, "rb:g:dhc:l:n:x:t:q:pa:sm:z:")) != -1) {
switch (opt) {
case 'g':
should_graph = true;
@ -472,6 +474,9 @@ int main(int argc, char *argv[]) {
case 'c':
opt_config = optarg;
break;
case 'l':
depth_limit = atoi(optarg);
break;
case 'd':
no_cycles = true;
break;
@ -495,6 +500,9 @@ int main(int argc, char *argv[]) {
//Save a 10% buffer for PSQL ops
alpha = atof(optarg) - 0.1;
break;
case 's':
mpi_subgraphing = true;
break;
case 'm':
mpi_nodes = atoi(optarg);
break;
@ -678,7 +686,9 @@ int main(int argc, char *argv[]) {
std::cout << "Generating Attack Graph: " << std::flush;
AGGen gen(_instance);//use AGGen class to instantiate an obj with the name gen! _instance obj as the parameter! constructor defined in ag_gen.cpp
if (world.size() > 1)
if (mpi_subgraphing && world.size() > 3)
postinstance = gen.sg_generate(batch_process, batch_size, thread_count, init_qsize, alpha, world, depth_limit); //The method call to generate the attack graph, defined in ag_gen.cpp.
else if (world.size() > 1)
postinstance = gen.generate(batch_process, batch_size, thread_count, init_qsize, alpha, world); //The method call to generate the attack graph, defined in ag_gen.cpp.
else
postinstance = gen.single_generate(batch_process, batch_size, thread_count, init_qsize, alpha, world); //The method call to generate the attack graph, defined in ag_gen.cpp.

View File

@ -423,8 +423,8 @@ void task_three(AGGenInstance &instance, NetworkState &new_state, std::deque<Net
if (hash_map.find(hash_num) == hash_map.end()) {
new_state.set_id();
auto facts_tuple = new_state.get_factbase().get_facts_tuple();
FactbaseItems new_items = std::make_tuple(facts_tuple, new_state.get_id());
instance.factbase_items.push_back(new_items);
//FactbaseItems new_items = std::make_tuple(facts_tuple, new_state.get_id());
//instance.factbase_items.push_back(new_items);
instance.factbases.push_back(new_state.get_factbase());
hash_map.insert(std::make_pair(new_state.get_hash(instance.facts), new_state.get_id()));
@ -506,3 +506,52 @@ int send_check(boost::mpi::communicator &world, int curr_node){
return send_to;
}
void state_merge(std::vector<Factbase> node_factbases, std::vector<Edge> node_edges,\
std::unordered_map<size_t, int> &hash_map, AGGenInstance &instance, double mem_threshold, mpi::communicator &world){
auto tot_sys_mem = getTotalSystemMemory();
for(auto fb : node_factbases){
//std::cout << "Started Task 3." << std::endl;
auto hash_num = fb.get_id();
//although local frontier is updated, the global hash is also updated to avoid testing on explored states.
if (hash_map.find(hash_num) == hash_map.end()) {
instance.factbases.push_back(fb);
hash_map.insert(std::make_pair(fb.get_id(), fb.get_id()));
//See memory usage. If it exceeds the threshold, store new states in the DB
double i_alpha = 0.0;
//Get the most recent Factbase's size * total number of factbases, rough approximation of *2 to account for factbase_items
double i_usage = instance.factbases.back().get_size() * instance.factbases.size() * 2 + sizeof(instance.edges[0]) * instance.edges.size();
if (i_alpha >= mem_threshold/2){
//std::cout << "Instance Alpha prior to database storing: " << i_alpha << std::endl;
mpi::request fb_req = world.isend(1, 7, instance.factbases);
mpi::request ed_req = world.isend(1, 8, instance.edges);
//save_ag_to_db(instance, true);
fb_req.wait();
ed_req.wait();
//Clear vectors and free memory
std::vector<Factbase>().swap(instance.factbases);
std::vector<FactbaseItems>().swap(instance.factbase_items);
std::vector<Edge>().swap(instance.edges);
i_usage = (sizeof(instance.factbases) + (sizeof(instance.factbases[0]) * instance.factbases.size()) +\
sizeof(instance.factbase_items) + (sizeof(instance.factbase_items[0]) * instance.factbase_items.size()) +\
sizeof(instance.edges) + (sizeof(instance.edges[0]) * instance.edges.size()));
i_alpha = i_usage/tot_sys_mem;
//std::cout << "Instance Alpha after database storing: " << i_alpha << std::endl;
}
}
}
//This does add duplicate edges
for (auto ed : node_edges){
instance.edges.push_back(ed);
}
}

View File

@ -23,5 +23,7 @@ void task_four(NetworkState &new_state);
int send_check(boost::mpi::communicator &world, int curr_node);
void state_merge(std::vector<Factbase> node_factbases, std::vector<Edge> node_edges,\
std::unordered_map<size_t, int> &hash_map, AGGenInstance &instance, double mem_threshold, mpi::communicator &world);
#endif //TASKS_H