Sync. Fire Rewrite

This commit is contained in:
Noah L. Schrick 2022-01-18 21:48:57 -06:00
parent 7cf0889d31
commit a6a737ef9d

View File

@ -337,7 +337,7 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
gettimeofday(&t1,NULL); gettimeofday(&t1,NULL);
int num_tasks = 6; int num_tasks = 6;
#pragma omp parallel for num_threads(numThrd) default(none) shared(esize,counter,exploit_list,od_map,frt_size,total_t,t1,t2,std::cout, mem_threshold, num_tasks) schedule(dynamic,1) #pragma omp parallel for num_threads(numThrd) default(none) shared(esize,counter,exploit_list,od_map,frt_size,total_t,t1,t2,std::cout, mem_threshold, num_tasks, ex_groups) schedule(dynamic,1)
//auto ag_start = std::chrono::system_clock::now(); //auto ag_start = std::chrono::system_clock::now();
for(int k=0;k<frt_size;k++){ for(int k=0;k<frt_size;k++){
@ -412,99 +412,59 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
//task_two(); //task_two();
auto appl_expl_size = appl_exploits.size(); std::map<std::string, int> group_fired; //Map to hold fired status per group
std::map<std::string, std::vector<std::tuple<Exploit, AssetGroup>>> sync_vectors; //Map to hold all group exploits
//skip flag is used to ensure that the egroup loop is not repeatedly run more than necessary for (auto map_group : ex_groups)
int skip_flag=0; {
group_fired.insert(std::pair<std::string, int> (map_group, 0));
}
//for (size_t j = 0; j < appl_expl_size; j++) { //(OLD) for loop for new states starts //Build up the map of synchronous fire exploits
for(auto itr=appl_exploits.begin(); itr!=appl_exploits.end(); itr++){
//auto e = appl_exploits.at(itr);
auto e = *itr;
auto egroup = std::get<0>(e).get_group();
//vector for holding the appl_exploits indices at which groups exist if (egroup != "null"){
std::vector<int> idr_idx; sync_vectors[egroup].push_back(e);
}
}
//vector for holding indices that have already fired
std::vector<int> fired_idx;
//iterator for the applicable exploits vector
auto itr=appl_exploits.begin();
int break_flag=0;
int testing_flag=0;
//loop through the vector //loop through the vector
for(auto itr=appl_exploits.begin(); itr!=appl_exploits.end(); itr++){ for(auto itr=appl_exploits.begin(); itr!=appl_exploits.end(); itr++){
//keep track of index for later use
auto index=std::distance(appl_exploits.begin(), itr);
//reset break flag
break_flag=0;
//To avoid double-fire, check if an index has already been run. auto e = *itr;
//If it has, then there is no need to run through this loop again.
for(auto itr_f=fired_idx.begin(); itr_f!=fired_idx.end(); itr_f++){
auto index_f=std::distance(fired_idx.begin(),itr_f);
if(index==index_f)
break_flag=1;
}
if (break_flag==1)
break;
//empty the appl_exploits index vector at the start of each loop so that
//it doesn't contain stale data from a previous loop
idr_idx.clear();
NetworkState new_state{current_state};
//auto e = appl_exploits.at(j);
/* Synchronous fire function
First: double/sanity checks to see if there are other exploits that need to be fired
This also prevents the firing from occurring when it shouldn't via a regular passthrough
(as in, when this gets checked from NOT the goto.)
After popping, it checks if the vector is empty. If it is, then we no longer need to
re-fill the vector since we've gone through all possibilities
*/
SYNCH_FIRE:;
if(!idr_idx.empty()){
//std::cout<<"IDR Size " << idr_idx.size()<<std::endl;
index=idr_idx.back();
idr_idx.pop_back();
if(idr_idx.empty())
skip_flag=1;
fired_idx.push_back(index);
}
auto e = appl_exploits.at(index);
auto exploit = std::get<0>(e); auto exploit = std::get<0>(e);
auto assetGroup = std::get<1>(e);
//std::cout<<exploit.get_name()<<std::endl; //std::cout<<exploit.get_name()<<std::endl;
//For synchronous firing: get indices of all exploits in the same group and
//push them onto the index vector for later use
auto egroup=exploit.get_group(); auto egroup=exploit.get_group();
if (egroup!="null" && idr_idx.empty() && skip_flag==0){
for(int i=0; i!=appl_exploits.size(); i++){ if ((egroup != "null" && group_fired[egroup] == 0) || egroup == "null"){
if((std::get<0>(appl_exploits.at(i))).get_group()==egroup && i!=index){ NetworkState new_state{current_state};
idr_idx.emplace_back(i); std::vector<std::tuple<Exploit, AssetGroup>> sync_exploits;
if (egroup == "null")
sync_exploits.push_back(e);
else {
sync_exploits = sync_vectors[egroup];
//TODO: Does not work if only some assets belong to a group. This only works if
//all assets are in the group
if(sync_exploits.size() < instance.assets.size()){
break;
} }
} }
//TODO: If there are other assets in group, for(auto sync_itr=sync_exploits.begin(); sync_itr!=sync_exploits.end(); sync_itr++){
//but you check idr_idx after filling and it's still empty e = *sync_itr;
//you know that the other asset isn't ready to be fired yet, so wait. exploit = std::get<0>(e);
//THIS BREAKS CODE IF ONLY 1 ASSET IN GROUP EXPLOIT. NEED TO FIGURE OUT HOW TO SEE HOW MANY ASSETS ARE IN GROUP egroup=exploit.get_group();
//std::cout<<std::get<1>(e).size()<<std::endl; assetGroup = std::get<1>(e);
//if(std::get<1>(e).size()>1){ group_fired[egroup] = 1;
if(idr_idx.empty()){
testing_flag=1;
}
// }
}
if(testing_flag==1)
break;
skip_flag=0;
auto assetGroup = std::get<1>(e);
//assetGroup.print_group();
//std::cout<<std::endl;
auto postconditions = createPostConditions(e, instance.facts); auto postconditions = createPostConditions(e, instance.facts);
auto qualities = std::get<0>(postconditions); auto qualities = std::get<0>(postconditions);
auto topologies = std::get<1>(postconditions); auto topologies = std::get<1>(postconditions);
@ -540,6 +500,7 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
break; break;
} }
} }
for(auto &topo : topologies) { for(auto &topo : topologies) {
auto action = std::get<0>(topo); auto action = std::get<0>(topo);
auto fact = std::get<1>(topo); auto fact = std::get<1>(topo);
@ -555,33 +516,25 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
break; break;
} }
} }
//appl_exploits.erase(appl_exploits.begin()+index); }//Sync. Fire for
if(!idr_idx.empty())
goto SYNCH_FIRE;
auto hash_num = new_state.get_hash(instance.facts); auto hash_num = new_state.get_hash(instance.facts);
if (hash_num == current_hash) if (hash_num == current_hash)
continue; continue;
//gettimeofday(&t1,NULL);
#pragma omp critical #pragma omp critical
if (hash_map.find(hash_num) == hash_map.end()) {//although local frontier is updated, the global hash is also updated to avoid testing on explored states. //although local frontier is updated, the global hash is also updated to avoid testing on explored states.
if (hash_map.find(hash_num) == hash_map.end()) {
new_state.set_id(); new_state.set_id();
auto facts_tuple = new_state.get_factbase().get_facts_tuple(); auto facts_tuple = new_state.get_factbase().get_facts_tuple();
FactbaseItems new_items = FactbaseItems new_items = std::make_tuple(facts_tuple, new_state.get_id());
std::make_tuple(facts_tuple, new_state.get_id());
instance.factbase_items.push_back(new_items); instance.factbase_items.push_back(new_items);
instance.factbases.push_back(new_state.get_factbase()); instance.factbases.push_back(new_state.get_factbase());
hash_map.insert(std::make_pair(new_state.get_hash(instance.facts), new_state.get_id())); hash_map.insert(std::make_pair(new_state.get_hash(instance.facts), new_state.get_id()));
//localFrontier.emplace_front(new_state);
//See memory usage. If it exceeds the threshold, store new states in the DB //See memory usage. If it exceeds the threshold, store new states in the DB
double i_alpha = 0.0; double i_alpha = 0.0;
//double i_usage = (sizeof(instance.factbases) + (sizeof(instance.factbases[0]) * instance.factbases.size()) +\
sizeof(instance.factbase_items) + (sizeof(instance.factbase_items[0]) * instance.factbase_items.size()) +\
sizeof(instance.edges) + (sizeof(instance.edges[0]) * instance.edges.size()));
//Get the most recent Factbase's size * total number of factbases, rough approximation of *2 to account for factbase_items //Get the most recent Factbase's size * total number of factbases, rough approximation of *2 to account for factbase_items
double i_usage = instance.factbases.back().get_size() * instance.factbases.size() * 2 + sizeof(instance.edges[0]) * instance.edges.size(); double i_usage = instance.factbases.back().get_size() * instance.factbases.size() * 2 + sizeof(instance.edges[0]) * instance.edges.size();
@ -590,29 +543,22 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
f_alpha = (static_cast<double>(localFrontier.size()) * (localFrontier.back().get_size()))/tot_sys_mem; f_alpha = (static_cast<double>(localFrontier.size()) * (localFrontier.back().get_size()))/tot_sys_mem;
else else
f_alpha = 0.0; f_alpha = 0.0;
//std::cout << "Frontier Alpha: " << f_alpha << std::endl;
//std::cout << "Instance Alpha: " << i_alpha << std::endl;
//std::cout << "Mem Threshold: " << mem_threshold << std::endl;
if (f_alpha >= (mem_threshold/2)) { if (f_alpha >= (mem_threshold/2)) {
std::cout << "Frontier Alpha prior to database storing: " << f_alpha << std::endl; std::cout << "Frontier Alpha prior to database storing: " << f_alpha << std::endl;
//std::cout << "Factbase Usage Before: " << sizeof(instance.factbases) + (sizeof(instance.factbases[0]) * instance.factbases.size()) << std::endl;
//std::cout << "Factbase Item Usage Before: " << sizeof(instance.factbase_items) + (sizeof(instance.factbase_items[0]) * instance.factbase_items.size()) << std::endl;
//std::cout << "Edge Usage Before: " << sizeof(instance.edges) + (sizeof(instance.edges[0]) * instance.edges.size()) << std::endl;
save_unexplored_to_db(new_state); save_unexplored_to_db(new_state);
if (!localFrontier.empty()) if (!localFrontier.empty())
f_alpha = (static_cast<double>(localFrontier.size()) * (localFrontier.back().get_size()))/tot_sys_mem; f_alpha = (static_cast<double>(localFrontier.size()) * (localFrontier.back().get_size()))/tot_sys_mem;
else else
f_alpha = 0; f_alpha = 0;
std::cout << "Frontier Alpha after database storing: " << f_alpha << std::endl; std::cout << "Frontier Alpha after database storing: " << f_alpha << std::endl;
//std::cout << "Storing in database." << std::endl;
} }
//Store new state in database to ensure proper ordering of the FIFO queue //Store new state in database to ensure proper ordering of the FIFO queue
else if (!unex_empty()){ else if (!unex_empty()){
save_unexplored_to_db(new_state); save_unexplored_to_db(new_state);
} }
//Otherwise, we can just store in memory //Otherwise, we can just store in memory
else { else {
localFrontier.emplace_front(new_state); localFrontier.emplace_front(new_state);
@ -635,19 +581,20 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
} }
//NetworkState new_state = fetch_unexplored(instance.facts);
Edge ed(current_state.get_id(), new_state.get_id(), exploit, assetGroup); Edge ed(current_state.get_id(), new_state.get_id(), exploit, assetGroup);
ed.set_id(); ed.set_id();
instance.edges.push_back(ed); instance.edges.push_back(ed);
counter++; } //END if (hash_map.find(hash_num) == hash_map.end())
}
else { else {
int id = hash_map[hash_num]; int id = hash_map[hash_num];
Edge ed(current_state.get_id(), id, exploit, assetGroup); Edge ed(current_state.get_id(), id, exploit, assetGroup);
ed.set_id(); ed.set_id();
instance.edges.push_back(ed); instance.edges.push_back(ed);
} }
} //sync fire if
else
break;
} //for loop for new states ends } //for loop for new states ends
} //while ends } //while ends
auto ag_end= std::chrono::system_clock::now(); auto ag_end= std::chrono::system_clock::now();