Skip to content
Snippets Groups Projects
Unverified Commit 13f22ae3 authored by Mingchen Zhuge's avatar Mingchen Zhuge Committed by GitHub
Browse files

Cost Calculation (#22)

* ADD GAIA DOWNLOAD

* finish the cost cal

* fix the bug

* fix a bug

* fix a bug

* fix the typos

* fix typos

* fix typos

* add reset
parent bc70cecb
No related branches found
No related tags found
No related merge requests found
### GAIA Download
* Official Download: https://huggingface.co/datasets/gaia-benchmark/GAIA
* We have also re-constructed the datasets, found in Google Could: https://drive.google.com/file/d/1Mzcy3Z5S23FSWcQ2qjs8Ma8NbYl3wTkN/view?usp=sharing
......@@ -19,9 +19,9 @@ from swarm.environment.agents.gaia.web_io import WebIO
from swarm.environment.agents.gaia.tool_tot import ToolTOT
from swarm.environment.operations import DirectAnswer
from swarm.memory.memory import GlobalMemory
from swarm.utils.globals import Time
from swarm.utils.globals import Time, Cost, CompletionTokens, PromptTokens
from swarm.utils.const import GPTSWARM_ROOT
from swarm.utils.log import logger
from swarm.utils.log import initialize_log_file, logger, swarmlog
from swarm.environment.domain.gaia import question_scorer
from swarm.environment.operations.final_decision import MergingStrategy
......@@ -48,6 +48,12 @@ async def main():
result_path = GPTSWARM_ROOT / "result"
os.makedirs(result_path, exist_ok=True)
current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
Time.instance().value = current_time
log_file_path = initialize_log_file("GAIA", Time.instance().value)
if args.config:
config_args = YAMLReader.parse(args.config, return_str=False)
for key, value in config_args.items():
......@@ -97,6 +103,8 @@ async def main():
ground_truth = item["Final answer"]
inputs = {"task": task, "files": files, "GT": ground_truth}
swarmlog("🐝GPTSWARM SYS", f"Finish {i} samples...", Cost.instance().value, PromptTokens.instance().value, CompletionTokens.instance().value, log_file_path)
# Swarm
# answer = await swarm.composite_graph.run(inputs)
# answer = answer[-1].split("FINAL ANSWER: ")[-1]
......@@ -129,8 +137,8 @@ async def main():
print("-----")
"""
current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
Time.instance().value = current_time
# current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
# Time.instance().value = current_time
result_dir = Path(f"{GPTSWARM_ROOT}/result/eval")
result_file = result_file or (result_dir / f"{'_'.join(experiment_name.split())}_{args.llm}_{current_time}.json")
......@@ -156,6 +164,7 @@ async def main():
"Total executed": total_executed + 1,
"Accuracy": (total_solved + is_solved) / (total_executed + 1),
"Time": exe_time,
"Total Cost": Cost.instance().value,
}
data.append(updated_item)
......
......@@ -33,6 +33,7 @@ def gpt_chat(
max_tokens: int = 8192,
temperature: float = 0.0,
num_comps=1,
return_cost=False,
) -> Union[List[str], str]:
if messages[0].content == '$skip$':
return ''
......@@ -60,6 +61,7 @@ def gpt_chat(
return response.choices[0].message.content
cost_count(response, model)
return [choice.message.content for choice in response.choices]
......@@ -70,6 +72,7 @@ async def gpt_achat(
max_tokens: int = 8192,
temperature: float = 0.0,
num_comps=1,
return_cost=False,
) -> Union[List[str], str]:
if messages[0].content == '$skip$':
return ''
......@@ -101,6 +104,7 @@ async def gpt_achat(
return response.choices[0].message.content
cost_count(response, model)
return [choice.message.content for choice in response.choices]
......
......@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
from swarm.utils.log import swarmlog
from swarm.utils.globals import Cost
from swarm.utils.globals import Cost, PromptTokens, CompletionTokens
# GPT-4: https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
# GPT3.5: https://platform.openai.com/docs/models/gpt-3-5
......@@ -45,6 +45,8 @@ def cost_count(response, model_name):
completion_len = response.usage.completion_tokens
Cost.instance().value += price
PromptTokens.instance().value += prompt_len
CompletionTokens.instance().value += completion_len
# print(f"Prompt Tokens: {prompt_len}, Completion Tokens: {completion_len}")
return price, prompt_len, completion_len
......@@ -52,6 +54,12 @@ def cost_count(response, model_name):
OPENAI_MODEL_INFO ={
"gpt-4": {
"current_recommended": "gpt-4-1106-preview",
"gpt-4-0125-preview": {
"context window": 128000,
"training": "Jan 2024",
"input": 0.01,
"output": 0.03
},
"gpt-4-1106-preview": {
"context window": 128000,
"training": "Apr 2023",
......@@ -97,6 +105,12 @@ OPENAI_MODEL_INFO ={
},
"gpt-3.5": {
"current_recommended": "gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0125": {
"context window": 16385,
"training": "Jan 2024",
"input": 0.0010,
"output": 0.0020
},
"gpt-3.5-turbo-1106": {
"context window": 16385,
"training": "Sep 2021",
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
from typing import Optional
class Singleton:
......@@ -11,11 +12,22 @@ class Singleton:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def reset(self):
self.value = 0.0
class Cost(Singleton):
def __init__(self):
self.value = 0.0
class PromptTokens(Singleton):
def __init__(self):
self.value = 0.0
class CompletionTokens(Singleton):
def __init__(self):
self.value = 0.0
class Time(Singleton):
def __init__(self):
self.value = ""
......
......@@ -5,6 +5,7 @@ import os
import sys
from pathlib import Path
from loguru import logger
# from globals import CompletionTokens, PromptTokens, Cost
from swarm.utils.const import GPTSWARM_ROOT
def configure_logging(print_level: str = "INFO", logfile_level: str = "DEBUG") -> None:
......@@ -17,9 +18,9 @@ def configure_logging(print_level: str = "INFO", logfile_level: str = "DEBUG") -
"""
logger.remove()
logger.add(sys.stderr, level=print_level)
logger.add(GPTSWARM_ROOT / 'logs/log.txt', level=logfile_level)
logger.add(GPTSWARM_ROOT / 'logs/log.txt', level=logfile_level, rotation="10 MB")
def initialize_log_file(mode: str, time_stamp: str) -> Path:
def initialize_log_file(experiment_name: str, time_stamp: str) -> Path:
"""
Initialize the log file with a start message and return its path.
......@@ -31,7 +32,7 @@ def initialize_log_file(mode: str, time_stamp: str) -> Path:
Path: The path to the initialized log file.
"""
try:
log_file_path = GPTSWARM_ROOT / f'result/{mode}/logs/log_{time_stamp}.txt'
log_file_path = GPTSWARM_ROOT / f'result/{experiment_name}/logs/log_{time_stamp}.txt'
os.makedirs(log_file_path.parent, exist_ok=True)
with open(log_file_path, 'w') as file:
file.write("============ Start ============\n")
......@@ -40,9 +41,9 @@ def initialize_log_file(mode: str, time_stamp: str) -> Path:
raise
return log_file_path
def swarmlog(sender: str, text: str, cost: float, result_file: Path = None, solution: list = []) -> None:
def swarmlog(sender: str, text: str, cost: float, prompt_tokens: int, complete_tokens: int, log_file_path: str) -> None:
"""
Custom log function for swarm operations.
Custom log function for swarm operations. Includes dynamic global variables.
Args:
sender (str): The name of the sender.
......@@ -51,12 +52,28 @@ def swarmlog(sender: str, text: str, cost: float, result_file: Path = None, solu
result_file (Path, optional): Path to the result file. Default is None.
solution (list, optional): Solution data to be logged. Default is an empty list.
"""
formatted_message = f"{sender} | 💵Total Cost: {cost:.5f}\n{text}"
# Directly reference global variables for dynamic values
formatted_message = (
f"{sender} | 💵Total Cost: ${cost:.5f} | "
f"Prompt Tokens: {prompt_tokens} | "
f"Completion Tokens: {complete_tokens} | \n {text}"
)
logger.info(formatted_message)
# It's generally a good practice to have a main function to control the flow of your script
try:
os.makedirs(log_file_path.parent, exist_ok=True)
with open(log_file_path, 'a') as file:
file.write(f"{formatted_message}\n")
except OSError as error:
logger.error(f"Error initializing log file: {error}")
raise
def main():
configure_logging()
# Example usage of swarmlog with dynamic values
swarmlog("SenderName", "This is a test message.", 0.123)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment