From afe69747e670ce87de03b04443cb1d3711469afb Mon Sep 17 00:00:00 2001
From: Lukas Vierling <90094980+lukasVierling@users.noreply.github.com>
Date: Sat, 16 Mar 2024 01:02:53 +0800
Subject: [PATCH] Update evaluate.py (#20)

Changed the final_node_class to ReturnAll because there is no TakeBest
---
 experiments/crosswords/evaluate.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/experiments/crosswords/evaluate.py b/experiments/crosswords/evaluate.py
index 49ae8d4..8c9a4e2 100644
--- a/experiments/crosswords/evaluate.py
+++ b/experiments/crosswords/evaluate.py
@@ -34,7 +34,7 @@ if __name__ == "__main__":
     num_batches = int(len(test_data) / batch_size)
     evaluator = CrosswordsEvaluator(test_data, batch_size=batch_size, metric="words", window_size=num_batches)
     swarm = Swarm(["CrosswordsReflection", "CrosswordsToT"], "crosswords", "gpt-4-1106-preview", #"gpt-3.5-turbo-1106",
-                final_node_class="TakeBest", final_node_kwargs={}, edge_optimize=True,
+                final_node_class="ReturnAll", final_node_kwargs={}, edge_optimize=True,
                 init_connection_probability=init_connection_probability, connect_output_nodes_to_final_node=True)
     swarm.connection_dist.load_state_dict(torch.load(f"result/crosswords_Jan15/{experiment_id}_edge_logits_{int(epochs * len(test_data) / batch_size) - 1}.pkl"))
 
@@ -60,4 +60,4 @@ if __name__ == "__main__":
             utilities += batched_evaluator(evaluator, batch_size, graph, loop)
         print(f"avg. utility = {np.mean(utilities):.3f}")
         with open(f"result/crosswords/{experiment_id}_final_utilities_{i}.pkl", "wb") as file:
-            pickle.dump(utilities, file)
\ No newline at end of file
+            pickle.dump(utilities, file)
-- 
GitLab