Added a random query generation

This commit is contained in:
BuildTools 2021-09-16 00:19:21 -05:00
parent 14c651dfa7
commit b6d03c3a7f

28
main.py
View File

@ -27,8 +27,11 @@ def main():
#Generate random evidence
E = gen_ev(bayes_net)
#Generate a random query that is not an evidence variable in the form of {var : val}
X = gen_query(E, bayes_net)
#Get W from LW
W = likelihood_weighting(1, E, bayes_net, 10)
W = likelihood_weighting(X, E, bayes_net, 10)
#Print if desired
print()
@ -41,6 +44,7 @@ def main():
def gen_json(type, num_nodes):
os.chdir("./gen_bn")
os.system('python gen_bn.py' + ' ' + type + ' ' + str(num_nodes))
os.chdir("..")
#Import the BN from the json
def import_bayes():
@ -53,11 +57,11 @@ def import_bayes():
def gen_ev(bayes_net):
total_nodes = len(bayes_net)
#Arbitrarily, let's only generate total_nodes/2 (rounded up) evidence variables at most, but at least 1
num_ev = random.randint(1, int(math.ceil(total_nodes/2)))
num_ev = random.randint(0, int(math.ceil(total_nodes/2)))
fixed_ev = []
#Go through and generate nodes that will be fixed
for i in range(num_ev):
fixed_var = random.randint(0, total_nodes-1)
fixed_var = random.randint(1, total_nodes-1)
if fixed_var not in fixed_ev:
fixed_ev.append(fixed_var)
#Now generate random values for the ev
@ -69,9 +73,23 @@ def gen_ev(bayes_net):
E[str(i)] = True
else:
E[str(i)] = False
return E
#Given the evidence variables and the bayes net, generate a random variable to query and its value
def gen_query(ev, bayes_net):
possible_vars = list(range(len(bayes_net)))
ev_vars = [*ev]
for e in ev_vars:
if int(e) in possible_vars:
possible_vars.remove(int(e))
query = random.choice(possible_vars)
rand_prob = random.random()
if rand_prob >= 0.5:
val = True
else:
val = False
return {query : val}
#Checks if node has parents
def is_root(node, BN):
@ -99,8 +117,6 @@ def get_parents(node, BN):
#Compute the estimate of P(X|e), where X is the query variable, and e is the observed value for variables E
def likelihood_weighting(X, e, bayes_net, num_samples):
W = {}
T = 0
F = 0
for i in range(num_samples):
w = 1