Skip to content
Snippets Groups Projects
Unverified Commit afe69747 authored by Lukas Vierling's avatar Lukas Vierling Committed by GitHub
Browse files

Update evaluate.py (#20)

Changed the final_node_class to ReturnAll because there is no TakeBest
parent f2e45c19
No related branches found
No related tags found
No related merge requests found
...@@ -34,7 +34,7 @@ if __name__ == "__main__": ...@@ -34,7 +34,7 @@ if __name__ == "__main__":
num_batches = int(len(test_data) / batch_size) num_batches = int(len(test_data) / batch_size)
evaluator = CrosswordsEvaluator(test_data, batch_size=batch_size, metric="words", window_size=num_batches) evaluator = CrosswordsEvaluator(test_data, batch_size=batch_size, metric="words", window_size=num_batches)
swarm = Swarm(["CrosswordsReflection", "CrosswordsToT"], "crosswords", "gpt-4-1106-preview", #"gpt-3.5-turbo-1106", swarm = Swarm(["CrosswordsReflection", "CrosswordsToT"], "crosswords", "gpt-4-1106-preview", #"gpt-3.5-turbo-1106",
final_node_class="TakeBest", final_node_kwargs={}, edge_optimize=True, final_node_class="ReturnAll", final_node_kwargs={}, edge_optimize=True,
init_connection_probability=init_connection_probability, connect_output_nodes_to_final_node=True) init_connection_probability=init_connection_probability, connect_output_nodes_to_final_node=True)
swarm.connection_dist.load_state_dict(torch.load(f"result/crosswords_Jan15/{experiment_id}_edge_logits_{int(epochs * len(test_data) / batch_size) - 1}.pkl")) swarm.connection_dist.load_state_dict(torch.load(f"result/crosswords_Jan15/{experiment_id}_edge_logits_{int(epochs * len(test_data) / batch_size) - 1}.pkl"))
...@@ -60,4 +60,4 @@ if __name__ == "__main__": ...@@ -60,4 +60,4 @@ if __name__ == "__main__":
utilities += batched_evaluator(evaluator, batch_size, graph, loop) utilities += batched_evaluator(evaluator, batch_size, graph, loop)
print(f"avg. utility = {np.mean(utilities):.3f}") print(f"avg. utility = {np.mean(utilities):.3f}")
with open(f"result/crosswords/{experiment_id}_final_utilities_{i}.pkl", "wb") as file: with open(f"result/crosswords/{experiment_id}_final_utilities_{i}.pkl", "wb") as file:
pickle.dump(utilities, file) pickle.dump(utilities, file)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment