From 1eaf06a3672cd5a8c12e7011dff2719d84f6776e Mon Sep 17 00:00:00 2001 From: wuhongsheng <664116298@qq.com> Date: Wed, 4 Sep 2024 21:33:21 +0800 Subject: [PATCH] feat:Add rest call support similar to oepn-api style --- LLM/openai_api_language_model.py | 71 +++++++++++++++++++ .../open_api_language_model_arguments.py | 57 +++++++++++++++ requirements.txt | 1 + requirements_mac.txt | 2 +- s2s_pipeline.py | 18 ++++- 5 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 LLM/openai_api_language_model.py create mode 100644 arguments_classes/open_api_language_model_arguments.py diff --git a/LLM/openai_api_language_model.py b/LLM/openai_api_language_model.py new file mode 100644 index 0000000..393dd39 --- /dev/null +++ b/LLM/openai_api_language_model.py @@ -0,0 +1,71 @@ +from openai import OpenAI +from LLM.chat import Chat +from baseHandler import BaseHandler +from rich.console import Console +import logging +import time +logger = logging.getLogger(__name__) + +console = Console() + + +class OpenApiModelHandler(BaseHandler): + """ + Handles the language model part. + """ + def setup( + self, + model_name="deepseek-chat", + device="cuda", + gen_kwargs={}, + base_url =None, + api_key=None, + stream=False, + user_role="user", + chat_size=1, + init_chat_role="system", + init_chat_prompt="You are a helpful AI assistant.", + ): + self.model_name = model_name + self.stream = stream + self.chat = Chat(chat_size) + if init_chat_role: + if not init_chat_prompt: + raise ValueError( + "An initial promt needs to be specified when setting init_chat_role." + ) + self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) + self.user_role = user_role + self.client = OpenAI(api_key=api_key, base_url=base_url) + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + start = time.time() + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + ], + stream=self.stream + ) + end = time.time() + logger.info( + f"{self.__class__.__name__}: warmed up! time: {(end - start):.3f} s" + ) + + + def process(self, prompt): + logger.debug("call api language model...") + self.chat.append({"role": self.user_role, "content": prompt}) + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": self.user_role, "content": prompt}, + ], + stream=self.stream + ) + generated_text = response.choices[0].message.content + self.chat.append({"role": "assistant", "content": generated_text}) + yield generated_text diff --git a/arguments_classes/open_api_language_model_arguments.py b/arguments_classes/open_api_language_model_arguments.py new file mode 100644 index 0000000..f497a64 --- /dev/null +++ b/arguments_classes/open_api_language_model_arguments.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass, field + + +@dataclass +class OpenApiLanguageModelHandlerArguments: + open_api_model_name: str = field( + # default="HuggingFaceTB/SmolLM-360M-Instruct", + default="deepseek-chat", + metadata={ + "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." + }, + ) + open_api_user_role: str = field( + default="user", + metadata={ + "help": "Role assigned to the user in the chat context. Default is 'user'." + }, + ) + open_api_init_chat_role: str = field( + default="system", + metadata={ + "help": "Initial role for setting up the chat context. Default is 'system'." + }, + ) + open_api_init_chat_prompt: str = field( + # default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", + default="ä½ æ˜¯ä¸€ä½ä¹äºŽåŠ©äººä¸”å‹å¥½çš„ AI 助手。您彬彬有礼ã€å°Šé‡ä»–人.", + metadata={ + "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" + }, + ) + + open_api_chat_size: int = field( + default=2, + metadata={ + "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." + }, + ) + open_api_api_key: str = field( + default=None, + metadata={ + "help": "Is a unique code used to authenticate and authorize access to an API.Default is None" + }, + ) + open_api_base_url: str = field( + default=None, + metadata={ + "help": "Is the root URL for all endpoints of an API, serving as the starting point for constructing API request.Default is Non" + }, + ) + open_api_stream: bool = field( + default=False, + metadata={ + "help": "The stream parameter typically indicates whether data should be transmitted in a continuous flow rather" + " than in a single, complete response, often used for handling large or real-time data.Default is False" + }, + ) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index fba30cd..fd6542f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ ChatTTS>=0.1.1 funasr>=1.1.6 modelscope>=1.17.1 deepfilternet>=0.5.6 +openai>=1.40.1 \ No newline at end of file diff --git a/requirements_mac.txt b/requirements_mac.txt index 4a1c5cb..a146c3b 100644 --- a/requirements_mac.txt +++ b/requirements_mac.txt @@ -9,4 +9,4 @@ ChatTTS>=0.1.1 funasr>=1.1.6 modelscope>=1.17.1 deepfilternet>=0.5.6 - +openai>=1.40.1 \ No newline at end of file diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 8da8298..6090b43 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -21,6 +21,7 @@ from arguments_classes.socket_sender_arguments import SocketSenderArguments from arguments_classes.vad_arguments import VADHandlerArguments from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments +from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments import torch import nltk from rich.console import Console @@ -77,6 +78,7 @@ def main(): WhisperSTTHandlerArguments, ParaformerSTTHandlerArguments, LanguageModelHandlerArguments, + OpenApiLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments, ParlerTTSHandlerArguments, MeloTTSHandlerArguments, @@ -95,6 +97,7 @@ def main(): whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs, + open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, @@ -110,6 +113,7 @@ def main(): whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs, language_model_handler_kwargs, + open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, @@ -187,6 +191,7 @@ def main(): prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt") prepare_args(language_model_handler_kwargs, "lm") + prepare_args(open_api_language_model_handler_kwargs, "open_api") prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(melo_tts_handler_kwargs, "melo") @@ -278,15 +283,26 @@ def main(): queue_out=lm_response_queue, setup_kwargs=vars(language_model_handler_kwargs), ) + + elif module_kwargs.llm == "open_api": + from LLM.openai_api_language_model import OpenApiModelHandler + + lm = OpenApiModelHandler( + stop_event, + queue_in=text_prompt_queue, + queue_out=lm_response_queue, + setup_kwargs=vars(open_api_language_model_handler_kwargs), + ) + elif module_kwargs.llm == "mlx-lm": from LLM.mlx_language_model import MLXLanguageModelHandler - lm = MLXLanguageModelHandler( stop_event, queue_in=text_prompt_queue, queue_out=lm_response_queue, setup_kwargs=vars(mlx_language_model_handler_kwargs), ) + else: raise ValueError("The LLM should be either transformers or mlx-lm") if module_kwargs.tts == "parler": -- GitLab