From 1eaf06a3672cd5a8c12e7011dff2719d84f6776e Mon Sep 17 00:00:00 2001
From: wuhongsheng <664116298@qq.com>
Date: Wed, 4 Sep 2024 21:33:21 +0800
Subject: [PATCH] feat:Add rest call support similar to oepn-api style

---
 LLM/openai_api_language_model.py              | 71 +++++++++++++++++++
 .../open_api_language_model_arguments.py      | 57 +++++++++++++++
 requirements.txt                              |  1 +
 requirements_mac.txt                          |  2 +-
 s2s_pipeline.py                               | 18 ++++-
 5 files changed, 147 insertions(+), 2 deletions(-)
 create mode 100644 LLM/openai_api_language_model.py
 create mode 100644 arguments_classes/open_api_language_model_arguments.py

diff --git a/LLM/openai_api_language_model.py b/LLM/openai_api_language_model.py
new file mode 100644
index 0000000..393dd39
--- /dev/null
+++ b/LLM/openai_api_language_model.py
@@ -0,0 +1,71 @@
+from openai import OpenAI
+from LLM.chat import Chat
+from baseHandler import BaseHandler
+from rich.console import Console
+import logging
+import time
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class OpenApiModelHandler(BaseHandler):
+    """
+    Handles the language model part.
+    """
+    def setup(
+        self,
+        model_name="deepseek-chat",
+        device="cuda",
+        gen_kwargs={},
+        base_url =None,
+        api_key=None,
+        stream=False,
+        user_role="user",
+        chat_size=1,
+        init_chat_role="system",
+        init_chat_prompt="You are a helpful AI assistant.",
+    ):
+        self.model_name = model_name
+        self.stream = stream
+        self.chat = Chat(chat_size)
+        if init_chat_role:
+            if not init_chat_prompt:
+                raise ValueError(
+                    "An initial promt needs to be specified when setting init_chat_role."
+                )
+            self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
+        self.user_role = user_role
+        self.client = OpenAI(api_key=api_key, base_url=base_url)
+        self.warmup()
+
+    def warmup(self):
+        logger.info(f"Warming up {self.__class__.__name__}")
+        start = time.time()
+        response = self.client.chat.completions.create(
+            model=self.model_name,
+            messages=[
+                {"role": "system", "content": "You are a helpful assistant"},
+                {"role": "user", "content": "Hello"},
+            ],
+            stream=self.stream
+        )
+        end = time.time()
+        logger.info(
+            f"{self.__class__.__name__}:  warmed up! time: {(end - start):.3f} s"
+        )
+
+
+    def process(self, prompt):
+        logger.debug("call api language model...")
+        self.chat.append({"role": self.user_role, "content": prompt})
+        response = self.client.chat.completions.create(
+            model=self.model_name,
+            messages=[
+                {"role": self.user_role, "content": prompt},
+            ],
+            stream=self.stream
+        )
+        generated_text = response.choices[0].message.content
+        self.chat.append({"role": "assistant", "content": generated_text})
+        yield generated_text
diff --git a/arguments_classes/open_api_language_model_arguments.py b/arguments_classes/open_api_language_model_arguments.py
new file mode 100644
index 0000000..f497a64
--- /dev/null
+++ b/arguments_classes/open_api_language_model_arguments.py
@@ -0,0 +1,57 @@
+from dataclasses import dataclass, field
+
+
+@dataclass
+class OpenApiLanguageModelHandlerArguments:
+    open_api_model_name: str = field(
+        # default="HuggingFaceTB/SmolLM-360M-Instruct",
+        default="deepseek-chat",
+        metadata={
+            "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
+        },
+    )
+    open_api_user_role: str = field(
+        default="user",
+        metadata={
+            "help": "Role assigned to the user in the chat context. Default is 'user'."
+        },
+    )
+    open_api_init_chat_role: str = field(
+        default="system",
+        metadata={
+            "help": "Initial role for setting up the chat context. Default is 'system'."
+        },
+    )
+    open_api_init_chat_prompt: str = field(
+        # default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
+        default="你是一位乐于助人且友好的 AI 助手。您彬彬有礼、尊重他人.",
+        metadata={
+            "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
+        },
+    )
+
+    open_api_chat_size: int = field(
+        default=2,
+        metadata={
+            "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
+        },
+    )
+    open_api_api_key: str = field(
+        default=None,
+        metadata={
+            "help": "Is a unique code used to authenticate and authorize access to an API.Default is None"
+        },
+    )
+    open_api_base_url: str = field(
+        default=None,
+        metadata={
+            "help": "Is the root URL for all endpoints of an API, serving as the starting point for constructing API request.Default is Non"
+        },
+    )
+    open_api_stream: bool = field(
+        default=False,
+        metadata={
+            "help": "The stream parameter typically indicates whether data should be transmitted in a continuous flow rather"
+                    " than in a single, complete response, often used for handling large or real-time data.Default is False"
+        },
+    )
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index fba30cd..fd6542f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,3 +7,4 @@ ChatTTS>=0.1.1
 funasr>=1.1.6
 modelscope>=1.17.1
 deepfilternet>=0.5.6
+openai>=1.40.1
\ No newline at end of file
diff --git a/requirements_mac.txt b/requirements_mac.txt
index 4a1c5cb..a146c3b 100644
--- a/requirements_mac.txt
+++ b/requirements_mac.txt
@@ -9,4 +9,4 @@ ChatTTS>=0.1.1
 funasr>=1.1.6
 modelscope>=1.17.1
 deepfilternet>=0.5.6
-
+openai>=1.40.1
\ No newline at end of file
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 8da8298..6090b43 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -21,6 +21,7 @@ from arguments_classes.socket_sender_arguments import SocketSenderArguments
 from arguments_classes.vad_arguments import VADHandlerArguments
 from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
 from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
+from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments
 import torch
 import nltk
 from rich.console import Console
@@ -77,6 +78,7 @@ def main():
             WhisperSTTHandlerArguments,
             ParaformerSTTHandlerArguments,
             LanguageModelHandlerArguments,
+            OpenApiLanguageModelHandlerArguments,
             MLXLanguageModelHandlerArguments,
             ParlerTTSHandlerArguments,
             MeloTTSHandlerArguments,
@@ -95,6 +97,7 @@ def main():
             whisper_stt_handler_kwargs,
             paraformer_stt_handler_kwargs,
             language_model_handler_kwargs,
+            open_api_language_model_handler_kwargs,
             mlx_language_model_handler_kwargs,
             parler_tts_handler_kwargs,
             melo_tts_handler_kwargs,
@@ -110,6 +113,7 @@ def main():
             whisper_stt_handler_kwargs,
             paraformer_stt_handler_kwargs,
             language_model_handler_kwargs,
+            open_api_language_model_handler_kwargs,
             mlx_language_model_handler_kwargs,
             parler_tts_handler_kwargs,
             melo_tts_handler_kwargs,
@@ -187,6 +191,7 @@ def main():
     prepare_args(whisper_stt_handler_kwargs, "stt")
     prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt")
     prepare_args(language_model_handler_kwargs, "lm")
+    prepare_args(open_api_language_model_handler_kwargs, "open_api")
     prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
     prepare_args(parler_tts_handler_kwargs, "tts")
     prepare_args(melo_tts_handler_kwargs, "melo")
@@ -278,15 +283,26 @@ def main():
             queue_out=lm_response_queue,
             setup_kwargs=vars(language_model_handler_kwargs),
         )
+
+    elif module_kwargs.llm == "open_api":
+        from LLM.openai_api_language_model import OpenApiModelHandler
+
+        lm = OpenApiModelHandler(
+            stop_event,
+            queue_in=text_prompt_queue,
+            queue_out=lm_response_queue,
+            setup_kwargs=vars(open_api_language_model_handler_kwargs),
+        )
+
     elif module_kwargs.llm == "mlx-lm":
         from LLM.mlx_language_model import MLXLanguageModelHandler
-
         lm = MLXLanguageModelHandler(
             stop_event,
             queue_in=text_prompt_queue,
             queue_out=lm_response_queue,
             setup_kwargs=vars(mlx_language_model_handler_kwargs),
         )
+
     else:
         raise ValueError("The LLM should be either transformers or mlx-lm")
     if module_kwargs.tts == "parler":
-- 
GitLab