From 13f22ae3a2af3e0b7f2f90bf95f9fdc2d8cb0553 Mon Sep 17 00:00:00 2001
From: Mingchen Zhuge <64179323+mczhuge@users.noreply.github.com>
Date: Wed, 27 Mar 2024 17:27:24 +0300
Subject: [PATCH] Cost Calculation (#22)

* ADD GAIA DOWNLOAD

* finish the cost cal

* fix the bug

* fix a bug

* fix a bug

* fix the typos

* fix typos

* fix typos

* add reset
---
 datasets/gaia/README.md |  5 +++++
 experiments/run_gaia.py | 17 +++++++++++++----
 swarm/llm/gpt_chat.py   |  4 ++++
 swarm/llm/price.py      | 16 +++++++++++++++-
 swarm/utils/globals.py  | 12 ++++++++++++
 swarm/utils/log.py      | 31 ++++++++++++++++++++++++-------
 6 files changed, 73 insertions(+), 12 deletions(-)
 create mode 100644 datasets/gaia/README.md

diff --git a/datasets/gaia/README.md b/datasets/gaia/README.md
new file mode 100644
index 0000000..7efbc79
--- /dev/null
+++ b/datasets/gaia/README.md
@@ -0,0 +1,5 @@
+### GAIA Download
+
+* Official Download: https://huggingface.co/datasets/gaia-benchmark/GAIA
+
+* We have also re-constructed the datasets, found in Google Could: https://drive.google.com/file/d/1Mzcy3Z5S23FSWcQ2qjs8Ma8NbYl3wTkN/view?usp=sharing
diff --git a/experiments/run_gaia.py b/experiments/run_gaia.py
index a22ad41..d05d44b 100644
--- a/experiments/run_gaia.py
+++ b/experiments/run_gaia.py
@@ -19,9 +19,9 @@ from swarm.environment.agents.gaia.web_io import WebIO
 from swarm.environment.agents.gaia.tool_tot import ToolTOT
 from swarm.environment.operations import DirectAnswer
 from swarm.memory.memory import GlobalMemory
-from swarm.utils.globals import Time
+from swarm.utils.globals import Time, Cost, CompletionTokens, PromptTokens
 from swarm.utils.const import GPTSWARM_ROOT
-from swarm.utils.log import logger
+from swarm.utils.log import initialize_log_file, logger, swarmlog
 from swarm.environment.domain.gaia import question_scorer
 from swarm.environment.operations.final_decision import MergingStrategy
 
@@ -48,6 +48,12 @@ async def main():
     result_path = GPTSWARM_ROOT / "result"
     os.makedirs(result_path, exist_ok=True)
 
+    current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
+    Time.instance().value = current_time
+
+
+    log_file_path = initialize_log_file("GAIA", Time.instance().value)
+
     if args.config:
         config_args = YAMLReader.parse(args.config, return_str=False)
         for key, value in config_args.items():
@@ -97,6 +103,8 @@ async def main():
         ground_truth = item["Final answer"]
         inputs = {"task": task, "files": files, "GT": ground_truth}
 
+        swarmlog("🐝GPTSWARM SYS", f"Finish {i} samples...", Cost.instance().value, PromptTokens.instance().value, CompletionTokens.instance().value, log_file_path)
+
         # Swarm
         # answer = await swarm.composite_graph.run(inputs)
         # answer = answer[-1].split("FINAL ANSWER: ")[-1]
@@ -129,8 +137,8 @@ async def main():
         print("-----")
         """
 
-        current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
-        Time.instance().value = current_time
+        # current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
+        # Time.instance().value = current_time
         
         result_dir = Path(f"{GPTSWARM_ROOT}/result/eval")
         result_file = result_file or (result_dir / f"{'_'.join(experiment_name.split())}_{args.llm}_{current_time}.json")
@@ -156,6 +164,7 @@ async def main():
             "Total executed": total_executed + 1,
             "Accuracy": (total_solved + is_solved) / (total_executed + 1),
             "Time": exe_time,
+            "Total Cost": Cost.instance().value,
         }
         data.append(updated_item)
 
diff --git a/swarm/llm/gpt_chat.py b/swarm/llm/gpt_chat.py
index 7d78897..fb48b93 100644
--- a/swarm/llm/gpt_chat.py
+++ b/swarm/llm/gpt_chat.py
@@ -33,6 +33,7 @@ def gpt_chat(
     max_tokens: int = 8192,
     temperature: float = 0.0,
     num_comps=1,
+    return_cost=False,
 ) -> Union[List[str], str]:
     if messages[0].content == '$skip$':
         return ''
@@ -60,6 +61,7 @@ def gpt_chat(
         return response.choices[0].message.content
 
     cost_count(response, model)
+
     return [choice.message.content for choice in response.choices]
 
 
@@ -70,6 +72,7 @@ async def gpt_achat(
     max_tokens: int = 8192,
     temperature: float = 0.0,
     num_comps=1,
+    return_cost=False,
 ) -> Union[List[str], str]:
     if messages[0].content == '$skip$':
         return '' 
@@ -101,6 +104,7 @@ async def gpt_achat(
         return response.choices[0].message.content
     
     cost_count(response, model)
+
     return [choice.message.content for choice in response.choices]
 
 
diff --git a/swarm/llm/price.py b/swarm/llm/price.py
index 0ddaacd..e05ddb8 100644
--- a/swarm/llm/price.py
+++ b/swarm/llm/price.py
@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 
 from swarm.utils.log import swarmlog
-from swarm.utils.globals import Cost
+from swarm.utils.globals import Cost, PromptTokens, CompletionTokens
 
 # GPT-4:  https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
 # GPT3.5: https://platform.openai.com/docs/models/gpt-3-5
@@ -45,6 +45,8 @@ def cost_count(response, model_name):
         completion_len = response.usage.completion_tokens
 
     Cost.instance().value += price
+    PromptTokens.instance().value += prompt_len
+    CompletionTokens.instance().value += completion_len
 
     # print(f"Prompt Tokens: {prompt_len}, Completion Tokens: {completion_len}")
     return price, prompt_len, completion_len
@@ -52,6 +54,12 @@ def cost_count(response, model_name):
 OPENAI_MODEL_INFO ={
     "gpt-4": {
         "current_recommended": "gpt-4-1106-preview",
+        "gpt-4-0125-preview": {
+            "context window": 128000, 
+            "training": "Jan 2024", 
+            "input": 0.01, 
+            "output": 0.03
+        },      
         "gpt-4-1106-preview": {
             "context window": 128000, 
             "training": "Apr 2023", 
@@ -97,6 +105,12 @@ OPENAI_MODEL_INFO ={
     },
     "gpt-3.5": {
         "current_recommended": "gpt-3.5-turbo-1106",
+        "gpt-3.5-turbo-0125": {
+            "context window": 16385, 
+            "training": "Jan 2024", 
+            "input": 0.0010, 
+            "output": 0.0020
+        },
         "gpt-3.5-turbo-1106": {
             "context window": 16385, 
             "training": "Sep 2021", 
diff --git a/swarm/utils/globals.py b/swarm/utils/globals.py
index 8276f70..d6b766e 100644
--- a/swarm/utils/globals.py
+++ b/swarm/utils/globals.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import sys
 from typing import Optional
 
 class Singleton:
@@ -11,11 +12,22 @@ class Singleton:
         if cls._instance is None:
             cls._instance = cls()
         return cls._instance
+    
+    def reset(self):
+        self.value = 0.0
 
 class Cost(Singleton):
     def __init__(self):
         self.value = 0.0
 
+class PromptTokens(Singleton):
+    def __init__(self):
+        self.value = 0.0
+
+class CompletionTokens(Singleton):
+    def __init__(self):
+        self.value = 0.0
+
 class Time(Singleton):
     def __init__(self):
         self.value = ""
diff --git a/swarm/utils/log.py b/swarm/utils/log.py
index d28c762..10f18bb 100644
--- a/swarm/utils/log.py
+++ b/swarm/utils/log.py
@@ -5,6 +5,7 @@ import os
 import sys
 from pathlib import Path
 from loguru import logger
+# from globals import CompletionTokens, PromptTokens, Cost
 from swarm.utils.const import GPTSWARM_ROOT
 
 def configure_logging(print_level: str = "INFO", logfile_level: str = "DEBUG") -> None:
@@ -17,9 +18,9 @@ def configure_logging(print_level: str = "INFO", logfile_level: str = "DEBUG") -
     """
     logger.remove()
     logger.add(sys.stderr, level=print_level)
-    logger.add(GPTSWARM_ROOT / 'logs/log.txt', level=logfile_level)
+    logger.add(GPTSWARM_ROOT / 'logs/log.txt', level=logfile_level, rotation="10 MB")
 
-def initialize_log_file(mode: str, time_stamp: str) -> Path:
+def initialize_log_file(experiment_name: str, time_stamp: str) -> Path:
     """
     Initialize the log file with a start message and return its path.
 
@@ -31,7 +32,7 @@ def initialize_log_file(mode: str, time_stamp: str) -> Path:
         Path: The path to the initialized log file.
     """
     try:
-        log_file_path = GPTSWARM_ROOT / f'result/{mode}/logs/log_{time_stamp}.txt'
+        log_file_path = GPTSWARM_ROOT / f'result/{experiment_name}/logs/log_{time_stamp}.txt'
         os.makedirs(log_file_path.parent, exist_ok=True)
         with open(log_file_path, 'w') as file:
             file.write("============ Start ============\n")
@@ -40,9 +41,9 @@ def initialize_log_file(mode: str, time_stamp: str) -> Path:
         raise
     return log_file_path
 
-def swarmlog(sender: str, text: str, cost: float, result_file: Path = None, solution: list = []) -> None:
+def swarmlog(sender: str, text: str, cost: float,  prompt_tokens: int, complete_tokens: int, log_file_path: str) -> None:
     """
-    Custom log function for swarm operations.
+    Custom log function for swarm operations. Includes dynamic global variables.
 
     Args:
         sender (str): The name of the sender.
@@ -51,12 +52,28 @@ def swarmlog(sender: str, text: str, cost: float, result_file: Path = None, solu
         result_file (Path, optional): Path to the result file. Default is None.
         solution (list, optional): Solution data to be logged. Default is an empty list.
     """
-    formatted_message = f"{sender} | 💵Total Cost: {cost:.5f}\n{text}"
+    # Directly reference global variables for dynamic values
+    formatted_message = (
+        f"{sender} | 💵Total Cost: ${cost:.5f} | "
+        f"Prompt Tokens: {prompt_tokens} | "
+        f"Completion Tokens: {complete_tokens} | \n {text}"
+    )
     logger.info(formatted_message)
 
-# It's generally a good practice to have a main function to control the flow of your script
+    try:
+        os.makedirs(log_file_path.parent, exist_ok=True)
+        with open(log_file_path, 'a') as file:
+            file.write(f"{formatted_message}\n")
+    except OSError as error:
+        logger.error(f"Error initializing log file: {error}")
+        raise
+
+
 def main():
     configure_logging()
+    # Example usage of swarmlog with dynamic values
+    swarmlog("SenderName", "This is a test message.", 0.123)
 
 if __name__ == "__main__":
     main()
+
-- 
GitLab