//! main.cpp contains the main fuction that runs the program including flag //! handling and calls to functions that access the database and generate the //! attack graph. //! #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "ag_gen/ag_gen.h" #include "util/db_functions.h" #include "util/build_sql.h" #include "util/db.h" #include "util/hash.h" #include "util/list.h" #include "util/mem.h" #include "mpi/serialize.h" #ifdef REDIS #include "util/redis_manager.h" #endif // REDIS namespace mpi = boost::mpi; namespace mt = mpi::threading; template class ag_visitor : public boost::default_dfs_visitor { std::vector> &to_delete; public: explicit ag_visitor(std::vector> &_to_delete) : to_delete(_to_delete) {} template void back_edge(GraphEdge e, Graph g) { typename boost::property_map::type Edge_Index = boost::get(boost::edge_index, g); int index = Edge_Index[e]; // edges[index].set_deleted(); to_delete.push_back(std::make_pair(e, index)); } }; typedef boost::property> EdgeProperties; typedef boost::property VertexNameProperty; typedef boost::adjacency_list Graph; typedef boost::graph_traits::vertex_descriptor Vertex; typedef boost::graph_traits::edge_descriptor GraphEdge; Graph graph_init() { GraphInfo info = fetch_graph_info(); auto factbase_ids = info.first; auto edges = info.second; Graph g; boost::property_map::type Factbase_ID = boost::get(boost::vertex_name, g); boost::property_map::type Exploit_ID = boost::get(boost::edge_name, g); boost::property_map::type Edge_Index = boost::get(boost::edge_index, g); std::unordered_map vertex_map; for (int fid : factbase_ids) { Vertex v = boost::add_vertex(g); Factbase_ID[v] = fid; vertex_map[fid] = v; } for (auto ei : edges) { int eid = ei[0]; int from_id = ei[1]; int to_id = ei[2]; int exid = ei[3]; Vertex from_v = vertex_map[from_id]; Vertex to_v = vertex_map[to_id]; GraphEdge e = boost::add_edge(from_v, to_v, g).first; Exploit_ID[e] = std::to_string(exid); Edge_Index[e] = eid; } return g; } void remove_cycles(Graph &g) { std::vector> to_delete; // ag_visitor vis(edges, to_delete); ag_visitor vis(to_delete); boost::depth_first_search(g, boost::visitor(vis)); std::vector delete_edge_ids; delete_edge_ids.resize(to_delete.size()); for (int i = 0; i < to_delete.size(); ++i) { boost::remove_edge(to_delete[i].first, g); delete_edge_ids[i] = to_delete[i].second; } delete_edges(delete_edge_ids); } void graph_ag(Graph &g, std::string &filename) { std::ofstream gout; std::cout << filename << std::endl; gout.open(filename); boost::write_graphviz(gout, g, boost::default_writer(), boost::make_label_writer(boost::get(boost::edge_name, g))); } /* Try and color code the severe violations void color_code(Graph &g) { if } */ extern "C" { extern FILE *nmin; extern int nmparse(networkmodel *nm); } std::string parse_nm(std::string &filename) { FILE *file = fopen(filename.c_str(), "r"); if(!file) { fprintf(stderr, "Cannot open file.\n"); } networkmodel nm; nm.assets = list_new(); //yydebug = 1; nmin = file; do { nmparse(&nm); } while(!feof(nmin)); // FILE *out = stdout; std::string output; //print_xp_list(xplist); ///////////////////////// // ASSETS ///////////////////////// hashtable *asset_ids = new_hashtable(101); // Preload buffer with SQL prelude size_t bufsize = INITIALBUFSIZE; auto buf = static_cast(getcmem(bufsize)); strcat(buf, "INSERT INTO asset VALUES\n"); // Iterate over each exploit in the list // Generate an "exploit_instance" which contains // the generated exploit id and the sql for // for the exploit. for(size_t i=0; isize; i++) { auto asset = static_cast(list_get_idx(nm.assets, i)); add_hashtable(asset_ids, asset, i); asset_instance *ai = make_asset(asset); while(bufsize < strlen(buf) + strlen(ai->sql)) { buf = static_cast(realloc(buf, (bufsize *= 2))); } strcat(buf, ai->sql); } // Replace the last comma with a semicolon char *last = strrchr(buf, ','); *last = ';'; // fprintf(out, "%s\n", buf); output += std::string(buf); ///////////////////////// // FACTS ///////////////////////// // Preload buffer with SQL prelude bufsize = INITIALBUFSIZE; buf = static_cast(getcmem(bufsize)); strcat(buf, "INSERT INTO quality VALUES\n"); size_t buf2size = INITIALBUFSIZE; auto buf2 = static_cast(getcmem(buf2size)); strcat(buf2, "INSERT INTO topology VALUES\n"); // Iterate over each exploit. We then iterate // over each f in the exploit and generate // the sql for it. for(size_t i=0; isize; i++) { auto fct = static_cast(list_get_idx(nm.facts, i)); char *sqlqual,*sqltopo; auto assetFrom = static_cast(get_hashtable(asset_ids, fct->from)); switch(fct->type) { case QUALITY_T: sqlqual = make_quality(assetFrom, fct->st); while(bufsize < (strlen(buf) + strlen(sqlqual))) { buf = static_cast(realloc(buf, (bufsize*=2))); } strcat(buf, sqlqual); break; case TOPOLOGY_T: auto assetTo = static_cast(get_hashtable(asset_ids, fct->to)); sqltopo = make_topology(assetFrom, assetTo, fct->dir, fct->st); while(buf2size < (strlen(buf2) + strlen(sqltopo))) { buf2 = static_cast(realloc(buf2, (buf2size*=2))); } strcat(buf2, sqltopo); break; } } last = strrchr(buf, ','); *last = ';'; char *last2 = strrchr(buf2, ','); *last2 = ';'; output += std::string(buf); output += std::string(buf2); return output; } extern "C" { extern FILE *xpin; extern int xpparse(list *xpplist); } std::string parse_xp(std::string &filename) { FILE *file = fopen(filename.c_str(), "r"); if(!file) { fprintf(stderr, "Cannot open file.\n"); } struct list *xplist = list_new(); //yydebug = 1; xpin = file; do { xpparse(xplist); } while(!feof(xpin)); // FILE *out = stdout; std::string output; //print_xp_list(xplist); ///////////////////////// // EXPLOITS ///////////////////////// hashtable *exploit_ids = new_hashtable(101); // Preload buffer with SQL prelude size_t bufsize = INITIALBUFSIZE; auto buf = static_cast(getcmem(bufsize)); strcat(buf, "INSERT INTO exploit VALUES\n"); // Iterate over each exploit in the list // Generate an "exploit_instance" which contains // the generated exploit id and the sql for // for the exploit. for(size_t i=0; isize; i++) { auto xp = static_cast(list_get_idx(xplist, i)); exploit_instance *ei = make_exploit(xp); add_hashtable(exploit_ids, xp->name, ei->id); printf("%s - %d\n", xp->name, get_hashtable(exploit_ids, xp->name)); while(bufsize < strlen(buf) + strlen(ei->sql)) { //std::cout << "Resizing" << std::endl; bufsize*=20; //buf = static_cast(realloc(buf, bufsize)); auto new_buf = static_cast(realloc(buf, bufsize)); buf = new_buf; } strcat(buf, ei->sql); } // Replace the last comma with a semicolon char *last = strrchr(buf, ','); *last = ';'; // fprintf(out, "%s\n", buf); output += std::string(buf); ///////////////////////// // PRECONDITIONS ///////////////////////// // Preload buffer with SQL prelude bufsize = INITIALBUFSIZE; buf = static_cast(getcmem(bufsize)); strcat(buf, "INSERT INTO exploit_precondition VALUES\n"); // Iterate over each exploit. We then iterate // over each f in the exploit and generate // the sql for it. for(size_t i=0; isize; i++) { auto xp = static_cast(list_get_idx(xplist, i)); for(size_t j=0; jpreconditions->size; j++) { auto fct = static_cast(list_get_idx(xp->preconditions, j)); // printf("%s: %d\n", fct->from, get_hashtable(exploit_ids, fct->from)); char *sqladd = make_precondition(exploit_ids, xp, fct); while(bufsize < strlen(buf) + strlen(sqladd)) { buf = static_cast(realloc(buf, (bufsize*=2))); } strcat(buf, sqladd); } } last = strrchr(buf, ','); *last = ';'; // fprintf(out, "%s\n", buf); output += std::string(buf); ///////////////////////// // POSTCONDITIONS ///////////////////////// // Preload buffer with SQL prelude bufsize = INITIALBUFSIZE; buf = (char *)getcmem(bufsize); strcat(buf, "INSERT INTO exploit_postcondition VALUES\n"); // Iterate over each exploit. We then iterate // over each f in the exploit and generate // the sql for it. for(size_t i=0; isize; i++) { auto xp = static_cast(list_get_idx(xplist, i)); for(size_t j=0; jpostconditions->size; j++) { auto pc = static_cast(list_get_idx(xp->postconditions, j)); char *sqladd = make_postcondition(exploit_ids, xp, pc); while(bufsize < strlen(buf) + strlen(sqladd)) { buf = static_cast(realloc(buf, (bufsize*=2))); } strcat(buf, sqladd); } } last = strrchr(buf, ','); *last = ';'; // fprintf(out, "%s\n", buf); output += std::string(buf); return output; } /** * @brief Parse a string based on delim. */ void tokenize(std::string const &str, const char delim, std::vector &out) { //Construct stream from given string std::stringstream ss(str); std::string s; while (std::getline(ss, s, delim)) { out.push_back(s); } } /** * @brief Prints command line usage information. */ void print_usage() { std::cout << "Usage: ag_gen [OPTION...]" << std::endl << std::endl; std::cout << "Flags:" << std::endl; std::cout << "\t-c\tConfig section in config.ini" << std::endl; std::cout << "\t-b\tEnables batch processing. The argument is the size of batches." << std::endl; std::cout << "\t-g\tGenerate visual graph using graphviz, dot file for saving" << std::endl; std::cout << "\t-d\tPerform a depth first search to remove cycles" << std::endl; std::cout << "\t-n\tNetwork model file used for generation" << std::endl; std::cout << "\t-x\tExploit pattern file used for generation" << std::endl; std::cout << "\t-r\tUse redis for generation" << std::endl; std::cout << "\t-h\tThis help menu." << std::endl; std::cout << "\t-p\tUse PostgreSQL" << std::endl; std::cout << "\t-e\tUse MPI Tasking" << std::endl; std::cout << "\t-s\tUse MPI Subgraphing" << std::endl; std::cout << "\t-l\tDepth Limit param for MPI Subgraphing" << std::endl; std::cout << "\t-a\tDecimal value for specifiying maximum amount of system memory to use" << std::endl; std::cout << "\t-m\tNumber of MPI Nodes to use" << std::endl; std::cout << "\t-z\tDatabase name for overriding config.ini file" << std::endl; std::cout << "\t-f\tBCL Hash Size" << std::endl; } inline bool file_exists(const std::string &name) { struct stat buffer {}; return (stat(name.c_str(), &buffer) == 0); } const std::string read_file(const std::string &fn) { std::ifstream f(fn); std::stringstream buffer; buffer << f.rdbuf(); return buffer.str(); } // the main function executes the command according to the given flag and throws // and error if an unknown flag is provided. It then uses the database given in // the "config.txt" file to generate an attack graph. int main(int argc, char *argv[]) { //------------------------------ //Program block 1: initialization and database connection //------------------------------ //int thread_count=strtol(argv[5],NULL,10); //int init_qsize=strtol(argv[6],NULL,10); struct timeval ts1,tf1,ts2,tf2,ts3,tf3; gettimeofday(&ts1,NULL); if (argc < 2) { print_usage(); return 0; } printf("Start init\n"); std::string opt_nm; std::string opt_xp; std::string opt_config; std::string opt_graph; std::string opt_batch; std::string db_name = "ag_gen"; int thread_count; int init_qsize; int mpi_nodes; int depth_limit; int bclhash_size; bool should_graph = false; bool no_cycles = false; bool batch_process = false; bool use_redis = false; bool use_postgres = false; bool mpi_subgraphing = false; bool mpi_tasking = false; double alpha = 0.5; int opt; while ((opt = getopt(argc, argv, "rb:g:dhc:l:n:x:t:q:pa:f:sem:z:")) != -1) { switch (opt) { case 'g': should_graph = true; opt_graph = optarg; break; case 'h': print_usage(); return 0; case 'n': opt_nm = optarg; //read in the path of the .nm file from the command line arguments break; case 'x': opt_xp = optarg; //read in the path of the .xp file from the command line arguments break; case 'c': opt_config = optarg; break; case 'l': depth_limit = atoi(optarg); break; case 'd': no_cycles = true; break; case 'r': use_redis = true; break; case 'b': batch_process = true; opt_batch = optarg; break; case 't': thread_count =atoi(optarg); break; case 'q': init_qsize = atoi(optarg); break; case 'p': use_postgres = true; break; case 'a': //Save a 10% buffer for PSQL ops alpha = atof(optarg) - 0.1; break; case 'f': bclhash_size = atoi(optarg); break; case 's': mpi_subgraphing = true; break; case 'e': mpi_tasking = true; break; case 'm': mpi_nodes = atoi(optarg); break; case 'z': db_name = optarg; break; case '?': if (optopt == 'c') fprintf(stderr, "Option -%c requires an argument.\n", optopt); exit(EXIT_FAILURE); case ':': fprintf(stderr, "wtf\n"); exit(EXIT_FAILURE); default: fprintf(stderr, "Unknown option -%c.\n", optopt); print_usage(); exit(EXIT_FAILURE); } } if (mpi_subgraphing && mpi_tasking){ std::cout << " You have specified MPI Tasking and MPI Subgraphing. Please choose only 1 option." << std::endl; exit(-1); } std::cout << "Arguments parsed." << std::endl; mt::level mt_level = mt::multiple; boost::mpi::environment env(argc, argv, mt_level); mt ::level provided = env.thread_level(); // std::cout << "Ensure that the MPI package has the MPI_THREAD_MULTIPLE build-time option enabled,"\ << "or change the environment creation to be use MPI threading level of single." << std::endl; // exit(EXIT_FAILURE); mpi::communicator world; if (world.rank() == 0){ std::cout << "Asked the MPI environment to be created with threading level: "\ << mt_level << std::endl; std::cout << "MPI Environment was created with threading level: " << provided \ << std::endl; } char hammer_host[256]; gethostname(hammer_host, 256); BCL::init(); std::cout << "\nHello from process " << world.rank() << " of " << world.size() << " running on " << hammer_host << " with a BCL Rank of " << BCL::rank() << std::endl; std::cout << "Finished init for process " << world.rank() << std::endl; std::string config_section = (opt_config.empty()) ? "default" : opt_config; boost::property_tree::ptree pt; boost::property_tree::ini_parser::read_ini("config.ini", pt); if (use_postgres) { std::string dbName = pt.get("database.name"); std::string host = pt.get("database.host"); std::string port = pt.get("database.port"); std::string username = pt.get("database.username"); std::string password = pt.get("database.password"); //std::cout<The time to load .nm and .xp into the database took %lf ms.<------\n",tdiff3); printf("\n"); } } int batch_size = 0; if (batch_process) batch_size = std::stoi(opt_batch); //Sync all Nodes to ensure everyone has connected to db and models are imported. world.barrier(); //------------------------------------------ //program block 3: //------------------------------------------ AGGenInstance _instance; //the following five assignments to _instance's members are all from db_function.cpp //the following five assignments are only applicable when use_postgres is set to true if (use_postgres) { _instance.facts = fetch_facts(); //The above function call returned an Keyvalue object and assigned the object to facts. The object mainly contains hash table and string vector based on all initial property and value. //the following 4 lines can be used to check the content of the facts. It is based on the initial property and value from table quality, postcondition and topology. //for(std::string abc: _instance.facts.get_str_vector()){ //std::cout<<"Fact: "< total run time is %lf ms. <-----------\n",tdiff1); return(0); } /* struct timeval ts4,tf4; gettimeofday(&ts4,NULL); #pragma omp parallel num_threads(2) { int thread_num=1; #pragma omp for schedule(dynamic,1) for(long a1=0;a1<6;a1++) { int tn=omp_get_thread_num(); printf("Thread num:%d and my a1 is %d\n",tn,a1); for(long d1=0;d1<10000;d1++) { double b1; if(a1%3==0) b1=a1*1.1; else if(a1%3==1) b1=a1*1.3; else b1=a1*1.5; } int b1=120000; if(tn==0) {while(b1--);a1=a1+6;} } } gettimeofday(&tf4,NULL); double tdiff4=(tf4.tv_sec-ts4.tv_sec)*1000.0+(tf4.tv_usec-ts4.tv_usec)/1000.0; printf("%lf\n",tdiff4); */ }