Main Complete + CSV Writing
This commit is contained in:
parent
da54e609f8
commit
00547cbd41
69
main.py
69
main.py
@ -9,18 +9,23 @@ import math
|
||||
import os
|
||||
import numpy as np
|
||||
from shutil import copyfile
|
||||
|
||||
import gen_bn.gen_bn
|
||||
|
||||
NETWORK_SIZE = 20
|
||||
NETWORK_TYPE = "polytree"
|
||||
NUM_SAMPLES = 200
|
||||
ALPHA = 0.75
|
||||
import argparse
|
||||
from csv import writer
|
||||
import re
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-s", "--size", dest = "NETWORK_SIZE", default = 5, type = int, help = "Number of nodes in the network")
|
||||
parser.add_argument("-t", "--type", dest = "NETWORK_TYPE", default = "dag", help = "Type of network. dag or polytree")
|
||||
parser.add_argument("-n", "--num", dest = "NUM_SAMPLES", default = 100, type = int, help = "Number of samples to take")
|
||||
parser.add_argument("-a", "--alpha", dest = "ALPHA", default = 0.95, type = float, help = "Metropolis-Hastings split probabilities. Must be between 0 and 1.")
|
||||
|
||||
args = parser.parse_args()
|
||||
#Generate a new BN. Specify type and number of nodes in network
|
||||
print("Generating a Bayesian Network of type", NETWORK_TYPE, "and with", NETWORK_SIZE, "nodes.")
|
||||
gen_json(NETWORK_TYPE, NETWORK_SIZE)
|
||||
print("Generating a Bayesian Network of type", args.NETWORK_TYPE, "and with", args.NETWORK_SIZE, "nodes.")
|
||||
gen_json(args.NETWORK_TYPE, args.NETWORK_SIZE)
|
||||
|
||||
#Get our BN
|
||||
bayes_net = import_bayes()
|
||||
@ -35,22 +40,50 @@ def main():
|
||||
print()
|
||||
|
||||
#Get probability of query from LW given evidence
|
||||
LW_prob = likelihood_weighting(X, E, bayes_net, NUM_SAMPLES)
|
||||
print("Probability of", X, "given", E, "with the LW algorithm and", NUM_SAMPLES, "samples is:")
|
||||
LW_prob = likelihood_weighting(X, E, bayes_net, args.NUM_SAMPLES)
|
||||
print("Probability of", X, "given", E, "with the LW algorithm and", args.NUM_SAMPLES, "samples is:")
|
||||
print(LW_prob)
|
||||
|
||||
GS_prob, tmp = gibbs_sampling(X, E, bayes_net, NUM_SAMPLES)
|
||||
GS_prob, tmp = gibbs_sampling(X, E, bayes_net, args.NUM_SAMPLES)
|
||||
|
||||
print("Probability of", X, "given", E, "with the GS algorithm and", NUM_SAMPLES, "samples is:")
|
||||
print("Probability of", X, "given", E, "with the GS algorithm and", args.NUM_SAMPLES, "samples is:")
|
||||
print(GS_prob)
|
||||
|
||||
#Get probability of query from MH given evidence
|
||||
MH_prob = metropolis_hastings(X, E, bayes_net, NUM_SAMPLES, ALPHA)
|
||||
print("Probability of", X, "given", E, "with the MH algorithm and", NUM_SAMPLES, "samples, and a", ALPHA*100, "/", 100-(ALPHA*100), "split is:")
|
||||
MH_prob = metropolis_hastings(X, E, bayes_net, args.NUM_SAMPLES, args.ALPHA)
|
||||
print("Probability of", X, "given", E, "with the MH algorithm and", args.NUM_SAMPLES, "samples, and a", args.ALPHA*100, "/", 100-(args.ALPHA*100), "split is:")
|
||||
print(MH_prob)
|
||||
|
||||
query_var = (list(X.items())[0][0])
|
||||
run_exact(query_var)
|
||||
query_val = (list(X.values())[0])
|
||||
exact_total = run_exact(query_var)
|
||||
|
||||
print("Probability of", X, "with the Variable Elimination algorithm is:")
|
||||
|
||||
#Do extremely sloppy string parsing that I'm too lazy to fix
|
||||
if(query_val):
|
||||
match = re.search(r'\bP_True\b', exact_total)
|
||||
offset = 9
|
||||
else:
|
||||
match = re.search(r'\bP_False\b', exact_total)
|
||||
offset = 10
|
||||
start = match.span()[0]
|
||||
if(query_val):
|
||||
Exact_prob = exact_total[(start+offset):].splitlines()[0]
|
||||
print(Exact_prob)
|
||||
else:
|
||||
result = exact_total[(start+offset):]
|
||||
res_split = result.split()
|
||||
Exact_prob = res_split[0]
|
||||
print(Exact_prob)
|
||||
|
||||
to_write = [args.NETWORK_TYPE, args.NETWORK_SIZE, args.ALPHA, args.NUM_SAMPLES, LW_prob, GS_prob, MH_prob, Exact_prob]
|
||||
append_csv(to_write)
|
||||
|
||||
def append_csv(list_of_ele):
|
||||
with open('results.csv', 'a+', newline='') as file:
|
||||
csv_writer = writer(file)
|
||||
csv_writer.writerow(list_of_ele)
|
||||
|
||||
def run_exact(query):
|
||||
#Get dir of exact_inference
|
||||
@ -63,7 +96,9 @@ def run_exact(query):
|
||||
copyfile("bn.json", dirname+'/bn.json')
|
||||
os.chdir("../exact_inference")
|
||||
#Run the exact_inference on the query variable
|
||||
os.system('python exact_inference.py -f bn.json' + ' ' + '-q' + ' ' + str(query))
|
||||
output = os.popen('python exact_inference.py -f bn.json' + ' ' + '-q' + ' ' + str(query)).read()
|
||||
os.chdir("..")
|
||||
return output
|
||||
|
||||
#Generate a new BN.
|
||||
#Input: Type ("dag", or "polytree")
|
||||
|
||||
3
results.csv
Normal file
3
results.csv
Normal file
@ -0,0 +1,3 @@
|
||||
dag,10,0.95,100,0.0,0.0,0.0,0.9681714923847967
|
||||
dag,10,0.95,100,0.13113707604213187,0.91,0.95,0.8278012710087637
|
||||
dag,10,0.95,100,0.3552298068888378,0.88,0.0,0.1798664152088963
|
||||
|
Loading…
x
Reference in New Issue
Block a user