Skip to content
Snippets Groups Projects
Unverified Commit d5e46072 authored by eustlb's avatar eustlb Committed by GitHub
Browse files

Merge pull request #81 from wuhongsheng/open_api

feat:Add rest call support similar to oepn-api style
parents 8c7272b7 e127cc7a
No related branches found
No related tags found
No related merge requests found
from openai import OpenAI
from LLM.chat import Chat
from baseHandler import BaseHandler
from rich.console import Console
import logging
import time
logger = logging.getLogger(__name__)
console = Console()
from nltk import sent_tokenize
class OpenApiModelHandler(BaseHandler):
"""
Handles the language model part.
"""
def setup(
self,
model_name="deepseek-chat",
device="cuda",
gen_kwargs={},
base_url =None,
api_key=None,
stream=False,
user_role="user",
chat_size=1,
init_chat_role="system",
init_chat_prompt="You are a helpful AI assistant.",
):
self.model_name = model_name
self.stream = stream
self.chat = Chat(chat_size)
if init_chat_role:
if not init_chat_prompt:
raise ValueError(
"An initial promt needs to be specified when setting init_chat_role."
)
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
self.user_role = user_role
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
start = time.time()
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
],
stream=self.stream
)
end = time.time()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {(end - start):.3f} s"
)
def process(self, prompt):
logger.debug("call api language model...")
self.chat.append({"role": self.user_role, "content": prompt})
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": self.user_role, "content": prompt},
],
stream=self.stream
)
if self.stream:
generated_text, printable_text = "", ""
for chunk in response:
new_text = chunk.choices[0].delta.content or ""
generated_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield sentences[0], language_code
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text, language_code
else:
generated_text = response.choices[0].message.content
self.chat.append({"role": "assistant", "content": generated_text})
yield generated_text, language_code
from dataclasses import dataclass, field
@dataclass
class OpenApiLanguageModelHandlerArguments:
open_api_model_name: str = field(
# default="HuggingFaceTB/SmolLM-360M-Instruct",
default="deepseek-chat",
metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
},
)
open_api_user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
},
)
open_api_init_chat_role: str = field(
default="system",
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
open_api_init_chat_prompt: str = field(
# default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)
open_api_chat_size: int = field(
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
)
open_api_api_key: str = field(
default=None,
metadata={
"help": "Is a unique code used to authenticate and authorize access to an API.Default is None"
},
)
open_api_base_url: str = field(
default=None,
metadata={
"help": "Is the root URL for all endpoints of an API, serving as the starting point for constructing API request.Default is Non"
},
)
open_api_stream: bool = field(
default=False,
metadata={
"help": "The stream parameter typically indicates whether data should be transmitted in a continuous flow rather"
" than in a single, complete response, often used for handling large or real-time data.Default is False"
},
)
\ No newline at end of file
...@@ -7,3 +7,4 @@ ChatTTS>=0.1.1 ...@@ -7,3 +7,4 @@ ChatTTS>=0.1.1
funasr>=1.1.6 funasr>=1.1.6
modelscope>=1.17.1 modelscope>=1.17.1
deepfilternet>=0.5.6 deepfilternet>=0.5.6
openai>=1.40.1
\ No newline at end of file
...@@ -9,4 +9,4 @@ ChatTTS>=0.1.1 ...@@ -9,4 +9,4 @@ ChatTTS>=0.1.1
funasr>=1.1.6 funasr>=1.1.6
modelscope>=1.17.1 modelscope>=1.17.1
deepfilternet>=0.5.6 deepfilternet>=0.5.6
openai>=1.40.1
\ No newline at end of file
...@@ -21,6 +21,7 @@ from arguments_classes.socket_sender_arguments import SocketSenderArguments ...@@ -21,6 +21,7 @@ from arguments_classes.socket_sender_arguments import SocketSenderArguments
from arguments_classes.vad_arguments import VADHandlerArguments from arguments_classes.vad_arguments import VADHandlerArguments
from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments
import torch import torch
import nltk import nltk
from rich.console import Console from rich.console import Console
...@@ -76,6 +77,7 @@ def parse_arguments(): ...@@ -76,6 +77,7 @@ def parse_arguments():
WhisperSTTHandlerArguments, WhisperSTTHandlerArguments,
ParaformerSTTHandlerArguments, ParaformerSTTHandlerArguments,
LanguageModelHandlerArguments, LanguageModelHandlerArguments,
OpenApiLanguageModelHandlerArguments,
MLXLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments, ParlerTTSHandlerArguments,
MeloTTSHandlerArguments, MeloTTSHandlerArguments,
...@@ -160,6 +162,7 @@ def prepare_all_args( ...@@ -160,6 +162,7 @@ def prepare_all_args(
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
...@@ -170,16 +173,19 @@ def prepare_all_args( ...@@ -170,16 +173,19 @@ def prepare_all_args(
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
chat_tts_handler_kwargs, chat_tts_handler_kwargs,
) )
rename_args(whisper_stt_handler_kwargs, "stt") rename_args(whisper_stt_handler_kwargs, "stt")
rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") rename_args(paraformer_stt_handler_kwargs, "paraformer_stt")
rename_args(language_model_handler_kwargs, "lm") rename_args(language_model_handler_kwargs, "lm")
rename_args(mlx_language_model_handler_kwargs, "mlx_lm") rename_args(mlx_language_model_handler_kwargs, "mlx_lm")
rename_args(open_api_language_model_handler_kwargs, "open_api")
rename_args(parler_tts_handler_kwargs, "tts") rename_args(parler_tts_handler_kwargs, "tts")
rename_args(melo_tts_handler_kwargs, "melo") rename_args(melo_tts_handler_kwargs, "melo")
rename_args(chat_tts_handler_kwargs, "chat_tts") rename_args(chat_tts_handler_kwargs, "chat_tts")
...@@ -205,6 +211,7 @@ def build_pipeline( ...@@ -205,6 +211,7 @@ def build_pipeline(
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
...@@ -218,7 +225,6 @@ def build_pipeline( ...@@ -218,7 +225,6 @@ def build_pipeline(
spoken_prompt_queue = queues_and_events["spoken_prompt_queue"] spoken_prompt_queue = queues_and_events["spoken_prompt_queue"]
text_prompt_queue = queues_and_events["text_prompt_queue"] text_prompt_queue = queues_and_events["text_prompt_queue"]
lm_response_queue = queues_and_events["lm_response_queue"] lm_response_queue = queues_and_events["lm_response_queue"]
if module_kwargs.mode == "local": if module_kwargs.mode == "local":
from connections.local_audio_streamer import LocalAudioStreamer from connections.local_audio_streamer import LocalAudioStreamer
...@@ -257,7 +263,7 @@ def build_pipeline( ...@@ -257,7 +263,7 @@ def build_pipeline(
) )
stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs)
lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs) lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs)
tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs)
return ThreadManager([*comms_handlers, vad, stt, lm, tts]) return ThreadManager([*comms_handlers, vad, stt, lm, tts])
...@@ -292,7 +298,15 @@ def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_ ...@@ -292,7 +298,15 @@ def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_
raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.") raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.")
def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs): def get_llm_handler(
module_kwargs,
stop_event,
text_prompt_queue,
lm_response_queue,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs
):
if module_kwargs.llm == "transformers": if module_kwargs.llm == "transformers":
from LLM.language_model import LanguageModelHandler from LLM.language_model import LanguageModelHandler
return LanguageModelHandler( return LanguageModelHandler(
...@@ -301,6 +315,15 @@ def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_qu ...@@ -301,6 +315,15 @@ def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_qu
queue_out=lm_response_queue, queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs), setup_kwargs=vars(language_model_handler_kwargs),
) )
elif module_kwargs.llm == "open_api":
from LLM.openai_api_language_model import OpenApiModelHandler
return OpenApiModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(open_api_language_model_handler_kwargs),
)
elif module_kwargs.llm == "mlx-lm": elif module_kwargs.llm == "mlx-lm":
from LLM.mlx_language_model import MLXLanguageModelHandler from LLM.mlx_language_model import MLXLanguageModelHandler
return MLXLanguageModelHandler( return MLXLanguageModelHandler(
...@@ -309,6 +332,7 @@ def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_qu ...@@ -309,6 +332,7 @@ def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_qu
queue_out=lm_response_queue, queue_out=lm_response_queue,
setup_kwargs=vars(mlx_language_model_handler_kwargs), setup_kwargs=vars(mlx_language_model_handler_kwargs),
) )
else: else:
raise ValueError("The LLM should be either transformers or mlx-lm") raise ValueError("The LLM should be either transformers or mlx-lm")
...@@ -364,6 +388,7 @@ def main(): ...@@ -364,6 +388,7 @@ def main():
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
...@@ -377,6 +402,7 @@ def main(): ...@@ -377,6 +402,7 @@ def main():
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
...@@ -393,6 +419,7 @@ def main(): ...@@ -393,6 +419,7 @@ def main():
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment