diff --git a/experiments/crosswords/evaluate.py b/experiments/crosswords/evaluate.py index 49ae8d430f58a6d63edbbafdf2fd8ee2182e8901..8c9a4e236968f2d06cdc17e4e63e8f742f4cea52 100644 --- a/experiments/crosswords/evaluate.py +++ b/experiments/crosswords/evaluate.py @@ -34,7 +34,7 @@ if __name__ == "__main__": num_batches = int(len(test_data) / batch_size) evaluator = CrosswordsEvaluator(test_data, batch_size=batch_size, metric="words", window_size=num_batches) swarm = Swarm(["CrosswordsReflection", "CrosswordsToT"], "crosswords", "gpt-4-1106-preview", #"gpt-3.5-turbo-1106", - final_node_class="TakeBest", final_node_kwargs={}, edge_optimize=True, + final_node_class="ReturnAll", final_node_kwargs={}, edge_optimize=True, init_connection_probability=init_connection_probability, connect_output_nodes_to_final_node=True) swarm.connection_dist.load_state_dict(torch.load(f"result/crosswords_Jan15/{experiment_id}_edge_logits_{int(epochs * len(test_data) / batch_size) - 1}.pkl")) @@ -60,4 +60,4 @@ if __name__ == "__main__": utilities += batched_evaluator(evaluator, batch_size, graph, loop) print(f"avg. utility = {np.mean(utilities):.3f}") with open(f"result/crosswords/{experiment_id}_final_utilities_{i}.pkl", "wb") as file: - pickle.dump(utilities, file) \ No newline at end of file + pickle.dump(utilities, file)