From afe69747e670ce87de03b04443cb1d3711469afb Mon Sep 17 00:00:00 2001 From: Lukas Vierling <90094980+lukasVierling@users.noreply.github.com> Date: Sat, 16 Mar 2024 01:02:53 +0800 Subject: [PATCH] Update evaluate.py (#20) Changed the final_node_class to ReturnAll because there is no TakeBest --- experiments/crosswords/evaluate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experiments/crosswords/evaluate.py b/experiments/crosswords/evaluate.py index 49ae8d4..8c9a4e2 100644 --- a/experiments/crosswords/evaluate.py +++ b/experiments/crosswords/evaluate.py @@ -34,7 +34,7 @@ if __name__ == "__main__": num_batches = int(len(test_data) / batch_size) evaluator = CrosswordsEvaluator(test_data, batch_size=batch_size, metric="words", window_size=num_batches) swarm = Swarm(["CrosswordsReflection", "CrosswordsToT"], "crosswords", "gpt-4-1106-preview", #"gpt-3.5-turbo-1106", - final_node_class="TakeBest", final_node_kwargs={}, edge_optimize=True, + final_node_class="ReturnAll", final_node_kwargs={}, edge_optimize=True, init_connection_probability=init_connection_probability, connect_output_nodes_to_final_node=True) swarm.connection_dist.load_state_dict(torch.load(f"result/crosswords_Jan15/{experiment_id}_edge_logits_{int(epochs * len(test_data) / batch_size) - 1}.pkl")) @@ -60,4 +60,4 @@ if __name__ == "__main__": utilities += batched_evaluator(evaluator, batch_size, graph, loop) print(f"avg. utility = {np.mean(utilities):.3f}") with open(f"result/crosswords/{experiment_id}_final_utilities_{i}.pkl", "wb") as file: - pickle.dump(utilities, file) \ No newline at end of file + pickle.dump(utilities, file) -- GitLab