Dynamic switching from memory storage to Postgres for AGInstance and Frontier

This commit is contained in:
Noah L. Schrick 2021-11-09 16:32:03 -06:00
parent e3e2f37e6f
commit 9a15ef3c7a
15 changed files with 429 additions and 50 deletions

Binary file not shown.

Binary file not shown.

View File

@ -3,6 +3,11 @@ network model =
# Cars # Cars
car1; car1;
car2; car2;
car3;
car4;
car5;
car6;
facts: facts:
quality:car1,brake_months=6; quality:car1,brake_months=6;
@ -27,6 +32,49 @@ network model =
quality:car2,vacuum_vio=false; quality:car2,vacuum_vio=false;
quality:car2,compliance_vio=false; quality:car2,compliance_vio=false;
quality:car3,brake_months=6;
quality:car3,exhaust_months=12;
quality:car3,ac_odometer=120000;
quality:car3,vacuum_odometer=120000;
quality:car3,engine=gas;
quality:car3,brake_vio=false;
quality:car3,exhaust_vio=false;
quality:car3,ac_vio=false;
quality:car3,vacuum_vio=false;
quality:car3,compliance_vio=false;
quality:car4,brake_months=6;
quality:car4,exhaust_months=12;
quality:car4,ac_odometer=120000;
quality:car4,vacuum_odometer=120000;
quality:car4,engine=diesel;
quality:car4,brake_vio=false;
quality:car4,exhaust_vio=false;
quality:car4,ac_vio=false;
quality:car4,vacuum_vio=false;
quality:car4,compliance_vio=false;
quality:car5,brake_months=6;
quality:car5,exhaust_months=12;
quality:car5,ac_odometer=120000;
quality:car5,vacuum_odometer=120000;
quality:car5,engine=gas;
quality:car5,brake_vio=false;
quality:car5,exhaust_vio=false;
quality:car5,ac_vio=false;
quality:car5,vacuum_vio=false;
quality:car5,compliance_vio=false;
quality:car6,brake_months=6;
quality:car6,exhaust_months=12;
quality:car6,ac_odometer=120000;
quality:car6,vacuum_odometer=120000;
quality:car6,engine=diesel;
quality:car6,brake_vio=false;
quality:car6,exhaust_vio=false;
quality:car6,ac_vio=false;
quality:car6,vacuum_vio=false;
quality:car6,compliance_vio=false;
topology:car1<->car2,road; topology:car1<->car2,road;
tags: tags:
. .

View File

@ -15,6 +15,7 @@
#include "../util/odometer.h" #include "../util/odometer.h"
#include "../util/db_functions.h" #include "../util/db_functions.h"
#include "../util/avail_mem.h"
#ifdef REDIS #ifdef REDIS
@ -130,7 +131,7 @@ createPostConditions(std::tuple<Exploit, AssetGroup> &group, Keyvalue &facts) {
* break and continue checking with the next exploit. * break and continue checking with the next exploit.
* 5. Push the new network state onto the frontier to be expanded later. * 5. Push the new network state onto the frontier to be expanded later.
*/ */
AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd, int initQSize ) { AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd, int initQSize, double mem_threshold) {
std::vector<Exploit> exploit_list = instance.exploits; std::vector<Exploit> exploit_list = instance.exploits;
//Create a vector that contains all the groups of exploits to be fired synchonously //Create a vector that contains all the groups of exploits to be fired synchonously
@ -170,6 +171,7 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
od_map[num_params] = od.get_all(); od_map[num_params] = od.get_all();
} }
} }
/*
//might be where to apply parallelization. //might be where to apply parallelization.
while (frontier.size()<initQSize){//while starts, test multiple thread case THIS WAS THE ONE MING USED while (frontier.size()<initQSize){//while starts, test multiple thread case THIS WAS THE ONE MING USED
//while (frontier.size()!=0){//while starts, test single thread case //while (frontier.size()!=0){//while starts, test single thread case
@ -301,6 +303,7 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
//int numThrd=32; //int numThrd=32;
printf("The number of threads used is %d\n",numThrd); printf("The number of threads used is %d\n",numThrd);
printf("The initial QSize is %d\n",initQSize); printf("The initial QSize is %d\n",initQSize);
*/
int frt_size=frontier.size(); int frt_size=frontier.size();
printf("The actual QSize to start using multiple threads is %d\n",frt_size); printf("The actual QSize to start using multiple threads is %d\n",frt_size);
@ -309,16 +312,43 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
struct timeval t1,t2; struct timeval t1,t2;
gettimeofday(&t1,NULL); gettimeofday(&t1,NULL);
//#pragma omp parallel for num_threads(numThrd) default(none) shared(esize,counter,exploit_list,od_map,frt_size,total_t,t1,t2) schedule(dynamic,1) //#pragma omp parallel for num_threads(numThrd) default(none) shared(esize,counter,exploit_list,od_map,frt_size,total_t,t1,t2) schedule(dynamic,1)
#pragma omp parallel for num_threads(numThrd) default(none) shared(esize,counter,exploit_list,od_map,frt_size,total_t,t1,t2,std::cout) schedule(dynamic,1) #pragma omp parallel for num_threads(numThrd) default(none) shared(esize,counter,exploit_list,od_map,frt_size,total_t,t1,t2,std::cout, mem_threshold) schedule(dynamic,1)
//auto ag_start = std::chrono::system_clock::now(); //auto ag_start = std::chrono::system_clock::now();
for(int k=0;k<frt_size;k++){ for(int k=0;k<frt_size;k++){
//printf("State %d in Frontier\n",k); //printf("State %d in Frontier\n",k);
std::deque<NetworkState> localFrontier; //double alpha = get_alpha();
double f_alpha = 0.0;
auto tot_sys_mem = getTotalSystemMemory();
std::deque<NetworkState> localFrontier;
localFrontier.emplace_front(frontier[k]); localFrontier.emplace_front(frontier[k]);
while (!localFrontier.empty()){//while starts while (!localFrontier.empty() || !unex_empty()){//while starts
//std::cout<<"SIZE: "<<localFrontier.size()<<std::endl; //We need to refill the localFrontier with states from the database if it's empty
if(localFrontier.empty()) {
std::cout << "Frontier empty, retrieving from database" << std::endl;
double total_tt = 0.0;
struct timeval tt1,tt2;
gettimeofday(&tt1,NULL);
int retrv_counter = 0;
//TODO: One (or a few) larger queries to pull in new states, rather than single queries that pull states one-by-one
do {
NetworkState db_new_state = fetch_unexplored(instance.facts);
localFrontier.emplace_front(db_new_state);
//alpha = get_alpha();
f_alpha = (static_cast<double>(localFrontier.size()) * (localFrontier.back().get_size()))/tot_sys_mem;
retrv_counter += 1;
}
//Leave a 30% buffer in alpha
while((f_alpha <= (mem_threshold * 0.7)) && !unex_empty());
std::cout << "Retrieved " << retrv_counter << " factbases from the database." << std::endl;
gettimeofday(&tt2,NULL);
total_tt+=(tt2.tv_sec-tt1.tv_sec)*1000.0+(tt2.tv_usec-tt1.tv_usec)/1000.0;
//printf("Retrieving from db took %lf s.\n", total_tt);
}
//std::cout<<"FRONTIER SIZE: "<<localFrontier.size()<<std::endl;
auto current_state = localFrontier.back(); auto current_state = localFrontier.back();
auto current_hash = current_state.get_hash(instance.facts); auto current_hash = current_state.get_hash(instance.facts);
localFrontier.pop_back(); localFrontier.pop_back();
@ -456,10 +486,10 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
} }
} }
//TODO: Comment/think more. If there are other assets in group, //TODO: If there are other assets in group,
//but you check idr_idx after filling and it's still empty //but you check idr_idx after filling and it's still empty
//you know that the other asset isn't ready to be fired yet, so wait. //you know that the other asset isn't ready to be fired yet, so wait.
//CORRECT: THIS BREAKS CODE IF ONLY 1 ASSET IN GROUP EXPLOIT. NEED TO FIGURE OUT HOW TO SEE HOW MANY ASSETS ARE IN GROUP //THIS BREAKS CODE IF ONLY 1 ASSET IN GROUP EXPLOIT. NEED TO FIGURE OUT HOW TO SEE HOW MANY ASSETS ARE IN GROUP
//std::cout<<std::get<1>(e).size()<<std::endl; //std::cout<<std::get<1>(e).size()<<std::endl;
//if(std::get<1>(e).size()>1){ //if(std::get<1>(e).size()>1){
if(idr_idx.empty()){ if(idr_idx.empty()){
@ -546,9 +576,65 @@ AGGenInstance &AGGen::generate(bool batch_process, int batch_size, int numThrd,
hash_map.insert(std::make_pair(new_state.get_hash(instance.facts), new_state.get_id())); hash_map.insert(std::make_pair(new_state.get_hash(instance.facts), new_state.get_id()));
//localFrontier.emplace_front(new_state); //localFrontier.emplace_front(new_state);
save_unexplored_to_db(new_state); //See memory usage. If it exceeds the threshold, store new states in the DB
NetworkState new_state = fetch_unexplored(instance.facts); double i_alpha = 0.0;
localFrontier.emplace_front(new_state); //double i_usage = (sizeof(instance.factbases) + (sizeof(instance.factbases[0]) * instance.factbases.size()) +\
sizeof(instance.factbase_items) + (sizeof(instance.factbase_items[0]) * instance.factbase_items.size()) +\
sizeof(instance.edges) + (sizeof(instance.edges[0]) * instance.edges.size()));
//Get the most recent Factbase's size * total number of factbases, rough approximation of *2 to account for factbase_items
double i_usage = instance.factbases.back().get_size() * instance.factbases.size() * 2 + sizeof(instance.edges[0]) * instance.edges.size();
i_alpha = i_usage/tot_sys_mem;
if (!localFrontier.empty())
f_alpha = (static_cast<double>(localFrontier.size()) * (localFrontier.back().get_size()))/tot_sys_mem;
else
f_alpha = 0.0;
//std::cout << "Frontier Alpha: " << f_alpha << std::endl;
//std::cout << "Instance Alpha: " << i_alpha << std::endl;
//std::cout << "Mem Threshold: " << mem_threshold << std::endl;
if (f_alpha >= (mem_threshold/2)) {
std::cout << "Frontier Alpha prior to database storing: " << f_alpha << std::endl;
//std::cout << "Factbase Usage Before: " << sizeof(instance.factbases) + (sizeof(instance.factbases[0]) * instance.factbases.size()) << std::endl;
//std::cout << "Factbase Item Usage Before: " << sizeof(instance.factbase_items) + (sizeof(instance.factbase_items[0]) * instance.factbase_items.size()) << std::endl;
//std::cout << "Edge Usage Before: " << sizeof(instance.edges) + (sizeof(instance.edges[0]) * instance.edges.size()) << std::endl;
save_unexplored_to_db(new_state);
if (!localFrontier.empty())
f_alpha = (static_cast<double>(localFrontier.size()) * (localFrontier.back().get_size()))/tot_sys_mem;
else
f_alpha = 0;
std::cout << "Frontier Alpha after database storing: " << f_alpha << std::endl;
//std::cout << "Storing in database." << std::endl;
}
//Store new state in database to ensure proper ordering of the FIFO queue
else if (!unex_empty()){
save_unexplored_to_db(new_state);
}
//Otherwise, we can just store in memory
else {
localFrontier.emplace_front(new_state);
}
if (i_alpha >= mem_threshold/2){
std::cout << "Instance Alpha prior to database storing: " << i_alpha << std::endl;
save_ag_to_db(instance, true);
//Clear vectors and free memory
std::vector<Factbase>().swap(instance.factbases);
std::vector<FactbaseItems>().swap(instance.factbase_items);
std::vector<Edge>().swap(instance.edges);
i_usage = (sizeof(instance.factbases) + (sizeof(instance.factbases[0]) * instance.factbases.size()) +\
sizeof(instance.factbase_items) + (sizeof(instance.factbase_items[0]) * instance.factbase_items.size()) +\
sizeof(instance.edges) + (sizeof(instance.edges[0]) * instance.edges.size()));
i_alpha = i_usage/tot_sys_mem;
std::cout << "Instance Alpha after database storing: " << i_alpha << std::endl;
}
//NetworkState new_state = fetch_unexplored(instance.facts);
Edge ed(current_state.get_id(), new_state.get_id(), exploit, assetGroup); Edge ed(current_state.get_id(), new_state.get_id(), exploit, assetGroup);
ed.set_id(); ed.set_id();

View File

@ -68,7 +68,7 @@ class AGGen {
AGGen(AGGenInstance &_instance, RedisManager &_rman); AGGen(AGGenInstance &_instance, RedisManager &_rman);
#endif #endif
AGGenInstance &generate(bool batch_process, int batch_num, int numThrd, int initQSize); AGGenInstance &generate(bool batch_process, int batch_num, int numThrd, int initQSize, double mem_threshold);
}; };
#endif // AG_GEN_HPP #endif // AG_GEN_HPP

View File

@ -196,3 +196,8 @@ void Factbase::print() const {
topo.print(); topo.print();
} }
} }
int Factbase::get_size() {
return (sizeof(id) + (qualities.size() * sizeof(qualities.back())) +\
(topologies.size() * sizeof(topologies.back())));
}

View File

@ -47,6 +47,7 @@ class Factbase {
void force_set_id(int i); void force_set_id(int i);
int get_id() const; int get_id() const;
size_t hash(Keyvalue &factlist) const; size_t hash(Keyvalue &factlist) const;
int get_size();
}; };
#endif #endif

View File

@ -39,6 +39,8 @@ void NetworkState::force_set_id(int i) { factbase.force_set_id(i); }
*/ */
int NetworkState::get_id() { return factbase.get_id(); } int NetworkState::get_id() { return factbase.get_id(); }
int NetworkState::get_size() { return factbase.get_size(); }
/** /**
* @return The Factbase for the NetworkState * @return The Factbase for the NetworkState
*/ */

View File

@ -35,6 +35,7 @@ class NetworkState {
void set_id(); void set_id();
void force_set_id(int i); void force_set_id(int i);
int get_id(); int get_id();
int get_size();
void add_qualities(std::vector<Quality> q); void add_qualities(std::vector<Quality> q);
void add_topologies(std::vector<Topology> t); void add_topologies(std::vector<Topology> t);

View File

@ -421,6 +421,7 @@ int main(int argc, char *argv[]) {
print_usage(); print_usage();
return 0; return 0;
} }
printf("Start init\n"); printf("Start init\n");
std::string opt_nm; std::string opt_nm;
std::string opt_xp; std::string opt_xp;
@ -436,8 +437,10 @@ int main(int argc, char *argv[]) {
bool use_redis = false; bool use_redis = false;
bool use_postgres = false; bool use_postgres = false;
double alpha = 0.5;
int opt; int opt;
while ((opt = getopt(argc, argv, "rb:g:dhc:n:x:t:q:p")) != -1) { while ((opt = getopt(argc, argv, "rb:g:dhc:n:x:t:q:pa:")) != -1) {
switch (opt) { switch (opt) {
case 'g': case 'g':
should_graph = true; should_graph = true;
@ -474,6 +477,10 @@ int main(int argc, char *argv[]) {
case 'p': case 'p':
use_postgres = true; use_postgres = true;
break; break;
case 'a':
//Save a 10% buffer for PSQL ops
alpha = atof(optarg) - 0.1;
break;
case '?': case '?':
if (optopt == 'c') if (optopt == 'c')
fprintf(stderr, "Option -%c requires an argument.\n", optopt); fprintf(stderr, "Option -%c requires an argument.\n", optopt);
@ -517,16 +524,11 @@ int main(int argc, char *argv[]) {
else else
{ {
std::cout << "No" << std::endl; std::cout << "Error: Not yet implemented. Database must be used.\
//pqxx::connection c; Please use the '-p' argument when running this program." << std::endl;
//pqxx::work w(c); return(1);
} }
//int a2=1;
//a2=a2+1;
//while(a2);
//-------------------------------------------- //--------------------------------------------
//program block 2: read in network model and exploit pattern and store them in local database //program block 2: read in network model and exploit pattern and store them in local database
@ -624,25 +626,19 @@ int main(int argc, char *argv[]) {
std::cout << "Generating Attack Graph: " << std::flush; std::cout << "Generating Attack Graph: " << std::flush;
AGGen gen(_instance);//use AGGen class to instantiate an obj with the name gen! _instance obj as the parameter! constructor defined in ag_gen.cpp AGGen gen(_instance);//use AGGen class to instantiate an obj with the name gen! _instance obj as the parameter! constructor defined in ag_gen.cpp
postinstance = gen.generate(batch_process, batch_size, thread_count, init_qsize); //The method call to generate the attack graph, defined in ag_gen.cpp. postinstance = gen.generate(batch_process, batch_size, thread_count, init_qsize, alpha); //The method call to generate the attack graph, defined in ag_gen.cpp.
std::cout << "Done\n"; std::cout << "Done\n";
std::cout << "# of edges " <<postinstance.edges.size()<<std::endl; //std::cout << "# of edges " <<postinstance.edges.size()<<std::endl;
std::cout << "# of edge_asset_binding" <<postinstance.edges.size()<<std::endl; //std::cout << "# of edge_asset_binding" <<postinstance.edges.size()<<std::endl;
// std::cout << "# of factbase " <<postinstance.factbases.size()<<std::endl; // std::cout << "# of factbase " <<postinstance.factbases.size()<<std::endl;
// std::cout << "# of factbase_item " <<postinstance.factbase_items.size()<<std::endl; // std::cout << "# of factbase_item " <<postinstance.factbase_items.size()<<std::endl;
std::cout << "Total Time: " << postinstance.elapsed_seconds.count() << " seconds\n";
std::cout << "Total States: " << postinstance.factbases.size() << "\n";
std::cout << "Saving Attack Graph to Database: " << std::flush;
save_ag_to_db(postinstance, true); save_ag_to_db(postinstance, true);
std::cout << "Total Edges: " << get_num_edges() << std::endl;
std::cout << "Total Time: " << postinstance.elapsed_seconds.count() << " seconds\n";
//std::cout << "Total States: " << postinstance.factbases.size() << "\n";
std::cout << "Total States: " << get_num_states() << std::endl;
std::cout << "Saving Attack Graph to Database: " << std::flush;
std::cout << "Done\n"; std::cout << "Done\n";
if(should_graph) { if(should_graph) {

72
src/tests/avail_mem.c Normal file
View File

@ -0,0 +1,72 @@
#include <unistd.h>
#include <ios>
#include <iostream>
#include <fstream>
#include <string>
//////////////////////////////////////////////////////////////////////////////
//
// process_mem_usage(double &, double &) - takes two doubles by reference,
// attempts to read the system-dependent data for a process' virtual memory
// size and resident set size, and return the results in KB.
//
// On failure, returns 0.0, 0.0
void process_mem_usage(double& vm_usage, double& resident_set)
{
using std::ios_base;
using std::ifstream;
using std::string;
vm_usage = 0.0;
resident_set = 0.0;
// 'file' stat seems to give the most reliable results
//
ifstream stat_stream("/proc/self/stat",ios_base::in);
// dummy vars for leading entries in stat that we don't care about
//
string pid, comm, state, ppid, pgrp, session, tty_nr;
string tpgid, flags, minflt, cminflt, majflt, cmajflt;
string utime, stime, cutime, cstime, priority, nice;
string O, itrealvalue, starttime;
// the two fields we want
//
unsigned long vsize;
long rss;
stat_stream >> pid >> comm >> state >> ppid >> pgrp >> session >> tty_nr
>> tpgid >> flags >> minflt >> cminflt >> majflt >> cmajflt
>> utime >> stime >> cutime >> cstime >> priority >> nice
>> O >> itrealvalue >> starttime >> vsize >> rss; // don't care about the rest
stat_stream.close();
long page_size_kb = sysconf(_SC_PAGE_SIZE) / 1024; // in case x86-64 is configured to use 2MB pages
vm_usage = vsize / 1024.0;
resident_set = rss * page_size_kb;
}
unsigned long long getTotalSystemMemory()
{
long pages = sysconf(_SC_PHYS_PAGES);
long page_size = sysconf(_SC_PAGE_SIZE);
return pages * page_size;
}
int main()
{
using std::cout;
using std::endl;
double vm, rss;
process_mem_usage(vm, rss);
unsigned long long tot_mem = getTotalSystemMemory();
auto percent = rss/tot_mem;
cout << "Percent: " << percent << endl;
cout << "VM: " << vm << "; RSS: " << rss << endl;
cout << "Total Mem: " << tot_mem;
return 0;
}

75
src/util/avail_mem.cpp Normal file
View File

@ -0,0 +1,75 @@
#include <unistd.h>
#include <ios>
#include <iostream>
#include <fstream>
#include <string>
#include "avail_mem.h"
//////////////////////////////////////////////////////////////////////////////
//
// process_mem_usage(double &, double &) - takes two doubles by reference,
// attempts to read the system-dependent data for a process' virtual memory
// size and resident set size, and return the results in KB.
//
// On failure, returns 0.0, 0.0
void process_mem_usage(double& vm_usage, double& resident_set)
{
using std::ios_base;
using std::ifstream;
using std::string;
vm_usage = 0.0;
resident_set = 0.0;
// 'file' stat seems to give the most reliable results
//
ifstream stat_stream("/proc/self/stat",ios_base::in);
// dummy vars for leading entries in stat that we don't care about
//
string pid, comm, state, ppid, pgrp, session, tty_nr;
string tpgid, flags, minflt, cminflt, majflt, cmajflt;
string utime, stime, cutime, cstime, priority, nice;
string O, itrealvalue, starttime;
// the two fields we want
//
unsigned long vsize;
long rss;
stat_stream >> pid >> comm >> state >> ppid >> pgrp >> session >> tty_nr
>> tpgid >> flags >> minflt >> cminflt >> majflt >> cmajflt
>> utime >> stime >> cutime >> cstime >> priority >> nice
>> O >> itrealvalue >> starttime >> vsize >> rss; // don't care about the rest
stat_stream.close();
long page_size_kb = sysconf(_SC_PAGE_SIZE) / 1024; // in case x86-64 is configured to use 2MB pages
vm_usage = vsize / 1024.0;
resident_set = rss * page_size_kb;
}
unsigned long long getTotalSystemMemory()
{
long pages = sysconf(_SC_PHYS_PAGES);
long page_size = sysconf(_SC_PAGE_SIZE);
return pages * page_size;
}
double get_alpha()
{
using std::cout;
using std::endl;
double vm, rss, alpha;
//vm and rss are KB
process_mem_usage(vm, rss);
//tot_mem is B
unsigned long long tot_mem = getTotalSystemMemory();
auto percent = (rss*1024)/tot_mem;
//cout << "Percent: " << percent << endl;
//cout << "VM: " << vm << "; RSS: " << rss << endl;
alpha = percent;
return alpha;
}

8
src/util/avail_mem.h Normal file
View File

@ -0,0 +1,8 @@
#ifndef _AVAIL_MEM_H
#define _AVAIL_MEM_H
void process_mem_usage(double& vm_usage, double& resident_set);
unsigned long long getTotalSystemMemory();
double get_alpha();
#endif

View File

@ -587,26 +587,39 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
gettimeofday(&t1,NULL); gettimeofday(&t1,NULL);
db.exec("BEGIN;"); db.exec("BEGIN;");
if (!factbases.empty()){ if (!factbases.empty()){
printf("The size of the factbases is %ld\n",factbases.size()); //printf("The size of the factbases is %ld\n",factbases.size());
std::string factbase_sql_query = "INSERT INTO factbase VALUES "; std::string factbase_sql_query = "INSERT INTO factbase VALUES ";
int factbase_counter = 0;
int flag = 0;
for (int i = 0; i < factbases.size(); ++i) { for (int i = 0; i < factbases.size(); ++i) {
/** redirect state info printing to states.txt file **/ /** redirect state info printing to states.txt file **/
factbases[i].print(); factbases[i].print();
if (i == 0) { if (i == 0 || flag == 1) {
factbase_sql_query += "(" + std::to_string(factbases[i].get_id()) + factbase_sql_query += "(" + std::to_string(factbases[i].get_id()) +
",'" + ",'" +
std::to_string(factbases[i].hash(factlist)) + "')"; std::to_string(factbases[i].hash(factlist)) + "')";
flag = 0;
} else { } else {
factbase_sql_query += ",(" + std::to_string(factbases[i].get_id()) + factbase_sql_query += ",(" + std::to_string(factbases[i].get_id()) +
",'" + ",'" +
std::to_string(factbases[i].hash(factlist)) + "')"; std::to_string(factbases[i].hash(factlist)) + "')";
} }
//Break up query due to memory constraints. Suboptimal approach.
factbase_counter += 1;
if (factbase_counter >= 500){
factbase_sql_query += ";";
db.execAsync(factbase_sql_query);
flag = 1;
factbase_sql_query = "INSERT INTO factbase VALUES ";
db.exec("COMMIT;");
db.exec("BEGIN;");
}
} }
factbase_sql_query += ";"; factbase_sql_query += ";";
db.execAsync(factbase_sql_query); db.execAsync(factbase_sql_query);
@ -619,9 +632,12 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
//gettimeofday(&t2,NULL); //gettimeofday(&t2,NULL);
//printf("The preparation of factbases took %lf ms\n",(t2.tv_sec-t1.tv_sec)*1000.0+(t2.tv_usec-t1.tv_usec)/1000.0); //printf("The preparation of factbases took %lf ms\n",(t2.tv_sec-t1.tv_sec)*1000.0+(t2.tv_usec-t1.tv_usec)/1000.0);
//this part takes 1.5s //this part takes 1.5s
//gettimeofday(&t1,NULL); //gettimeofday(&t1,NULL);
//potentially bug in forming the string for db storage, check if the first record has a comma before it //potentially bug in forming the string for db storage, check if the first record has a comma before it
/*FACTBASE ITEMS
if (!factbase_items.empty()) { if (!factbase_items.empty()) {
int fis=factbase_items.size(); int fis=factbase_items.size();
//fis=4; //fis=4;
@ -635,6 +651,7 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
std::string quality_sql_query = ""; std::string quality_sql_query = "";
std::string topology_sql_query = ""; std::string topology_sql_query = "";
int sql_index=0; int sql_index=0;
*///FACTBASE ITEMS
/* /*
for (int j = 0; j<fis/4+((k==3)?(fis%4):0); j++){ for (int j = 0; j<fis/4+((k==3)?(fis%4):0); j++){
@ -646,7 +663,11 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
auto topo = std::get<1>(items); auto topo = std::get<1>(items);
for (auto qi : quals) { for (auto qi : quals) {
*/ */
//UNCOMMENT FOR FACTBASE ITEM STORAGE
/*
int fbi_counter = 0;
int fbi_q_flag = 0;
int fbi_t_flag = 0;
for (int j = 0, sql_index = 0; j < factbase_items.size(); ++j) { for (int j = 0, sql_index = 0; j < factbase_items.size(); ++j) {
auto fbi = factbase_items[j]; auto fbi = factbase_items[j];
int id = std::get<1>(fbi); int id = std::get<1>(fbi);
@ -654,10 +675,12 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
auto quals = std::get<0>(items); auto quals = std::get<0>(items);
auto topo = std::get<1>(items); auto topo = std::get<1>(items);
for (auto qi : quals) { for (auto qi : quals) {
if (sql_index == 0) if (sql_index == 0 || fbi_q_flag == 1){
quality_sql_query += "(" + std::to_string(id) + "," + quality_sql_query += "(" + std::to_string(id) + "," +
std::to_string(qi.get_encoding()) + std::to_string(qi.get_encoding()) +
",'quality')"; ",'quality')";
fbi_q_flag = 0;
}
else else
quality_sql_query += ",(" + std::to_string(id) + "," + quality_sql_query += ",(" + std::to_string(id) + "," +
@ -667,10 +690,12 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
} }
for (auto ti : topo) { for (auto ti : topo) {
if (sql_index == 0) if (sql_index == 0 || fbi_t_flag == 1) {
topology_sql_query += "(" + std::to_string(id) + "," + topology_sql_query += "(" + std::to_string(id) + "," +
std::to_string(ti.get_encoding()) + std::to_string(ti.get_encoding()) +
",'topology')"; ",'topology')";
fbi_t_flag = 0;
}
else else
topology_sql_query += ",(" + std::to_string(id) + "," + topology_sql_query += ",(" + std::to_string(id) + "," +
@ -678,9 +703,23 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
",'topology')"; ",'topology')";
sql_index++; sql_index++;
} }
//Break up query due to memory constraints. Suboptimal approach.
fbi_counter += 1;
if (fbi_counter >= 500){
item_sql_query += quality_sql_query + topology_sql_query + "ON CONFLICT DO NOTHING;";
db.exec("BEGIN;");
db.execAsync(item_sql_query);
db.execAsync("COMMIT;");
fbi_q_flag = 1;
fbi_t_flag = 1;
item_sql_query = "INSERT INTO factbase_item VALUES ";
quality_sql_query = "";
topology_sql_query = "";
}
} }
item_sql_query += quality_sql_query + topology_sql_query + "ON CONFLICT DO NOTHING;"; item_sql_query += quality_sql_query + topology_sql_query + "ON CONFLICT DO NOTHING;";
int thrd_num=omp_get_thread_num(); int thrd_num=omp_get_thread_num();
if(thrd_num==0){ if(thrd_num==0){
db.exec("BEGIN;"); db.exec("BEGIN;");
@ -705,12 +744,14 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
} }
} }
gettimeofday(&t2,NULL); gettimeofday(&t2,NULL);
printf("The saving of factbase and items took %lf ms\n",(t2.tv_sec-t1.tv_sec)*1000.0+(t2.tv_usec-t1.tv_usec)/1000.0); //printf("The saving of factbase and items took %lf ms\n",(t2.tv_sec-t1.tv_sec)*1000.0+(t2.tv_usec-t1.tv_usec)/1000.0);
*///FACTBASE ITEMS
//this part takes 42.0s //this part takes 42.0s
//probably where to apply parallelization, this part takes most of the db saving time (>=80%) //probably where to apply parallelization, this part takes most of the db saving time (>=80%)
gettimeofday(&t1,NULL); gettimeofday(&t1,NULL);
int edge_counter = 0;
int edge_flag = 0;
if (!edges.empty()) { if (!edges.empty()) {
std::vector<std::string> edge_queries; std::vector<std::string> edge_queries;
edge_queries.resize(edges.size()); edge_queries.resize(edges.size());
@ -743,15 +784,27 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
int eid = edges[i].get_id(); int eid = edges[i].get_id();
//int eid = edges[i].set_id(); //int eid = edges[i].set_id();
if (ii == 0) { if (ii == 0 | edge_flag == 1) {
edge_sql_query += "(" + std::to_string(eid) + "," + ei.first; edge_sql_query += "(" + std::to_string(eid) + "," + ei.first;
edge_assets_sql_query += edges[i].get_asset_query(); edge_assets_sql_query += edges[i].get_asset_query();
edge_flag = 0;
} else { } else {
edge_sql_query += ",(" + std::to_string(eid) + "," + ei.first; edge_sql_query += ",(" + std::to_string(eid) + "," + ei.first;
edge_assets_sql_query += "," + edges[i].get_asset_query(); edge_assets_sql_query += "," + edges[i].get_asset_query();
} }
++ii; ++ii;
//Break up query due to memory constraints. Suboptimal approach.
edge_counter += 1;
if (edge_counter >= 500){
edge_sql_query += "ON CONFLICT DO NOTHING;";
edge_assets_sql_query += "ON CONFLICT DO NOTHING;";
db.execAsync(edge_sql_query);
db.execAsync(edge_assets_sql_query);
edge_flag = 1;
edge_sql_query = "INSERT INTO edge VALUES ";
edge_assets_sql_query = "INSERT INTO edge_asset_binding VALUES ";
}
} }
edge_sql_query += " ON CONFLICT DO NOTHING;"; edge_sql_query += " ON CONFLICT DO NOTHING;";
edge_assets_sql_query += " ON CONFLICT DO NOTHING;"; edge_assets_sql_query += " ON CONFLICT DO NOTHING;";
@ -761,8 +814,8 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
gettimeofday(&t6,NULL); gettimeofday(&t6,NULL);
db.execAsync(edge_assets_sql_query); //7.6s db.execAsync(edge_assets_sql_query); //7.6s
gettimeofday(&t7,NULL); gettimeofday(&t7,NULL);
printf("The saving of edges took %lf ms\n",(t6.tv_sec-t5.tv_sec)*1000.0+(t6.tv_usec-t5.tv_usec)/1000.0); // printf("The saving of edges took %lf ms\n",(t6.tv_sec-t5.tv_sec)*1000.0+(t6.tv_usec-t5.tv_usec)/1000.0);
printf("The saving of edges_assets_bingding took %lf ms\n",(t7.tv_sec-t6.tv_sec)*1000.0+(t7.tv_usec-t6.tv_usec)/1000.0); //printf("The saving of edges_assets_bingding took %lf ms\n",(t7.tv_sec-t6.tv_sec)*1000.0+(t7.tv_usec-t6.tv_usec)/1000.0);
//----second method to store into db //----second method to store into db
//---------------------------------- //----------------------------------
@ -837,7 +890,7 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
} }
gettimeofday(&t2,NULL); gettimeofday(&t2,NULL);
printf("The saving of edge and edge_asset_binding took %lf ms\n",(t2.tv_sec-t1.tv_sec)*1000.0+(t2.tv_usec-t1.tv_usec)/1000.0);//42.0s //printf("The saving of edge and edge_asset_binding took %lf ms\n",(t2.tv_sec-t1.tv_sec)*1000.0+(t2.tv_usec-t1.tv_usec)/1000.0);//42.0s
//this part takes around 8.6s //this part takes around 8.6s
gettimeofday(&t1,NULL); gettimeofday(&t1,NULL);
@ -860,7 +913,7 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue){
} }
db.execAsync("COMMIT;"); db.execAsync("COMMIT;");
gettimeofday(&t2,NULL); gettimeofday(&t2,NULL);
printf("The saving of keyvalue took %lf ms\n",(t2.tv_sec-t1.tv_sec)*1000.0+(t2.tv_usec-t1.tv_usec)/1000.0); //printf("The saving of keyvalue took %lf ms\n",(t2.tv_sec-t1.tv_sec)*1000.0+(t2.tv_usec-t1.tv_usec)/1000.0);
} }
void save_unexplored_to_db(NetworkState netstate) { void save_unexplored_to_db(NetworkState netstate) {
@ -997,4 +1050,32 @@ NetworkState fetch_unexplored(Keyvalue &facts){
fetched_state.force_set_id(id); fetched_state.force_set_id(id);
return fetched_state; return fetched_state;
}
bool unex_empty()
{
std::vector<Row> row = db.exec("SELECT count(*) FROM unex_state_q LIMIT 1;");
for (auto &r : row){
auto count = std::stoi(r[0]);
if (count == 0)
return true;
else
return false;
}
}
int get_num_states()
{
std::vector<Row> row = db.exec("SELECT count(*) AS exact_count FROM factbase;");
for (auto &r : row) {
return std::stoi(r[0]);
}
}
int get_num_edges()
{
std::vector<Row> row = db.exec("SELECT count(*) AS exact_count FROM edge;");
for (auto &r : row) {
return std::stoi(r[0]);
}
} }

View File

@ -79,5 +79,9 @@ void save_ag_to_db(AGGenInstance &instance, bool save_keyvalue);
void save_unexplored_to_db(NetworkState netstate); void save_unexplored_to_db(NetworkState netstate);
NetworkState fetch_unexplored(Keyvalue &facts); NetworkState fetch_unexplored(Keyvalue &facts);
bool unex_empty();
int get_num_states();
int get_num_edges();
#endif #endif