Skip to content
Snippets Groups Projects
Commit 1eaf06a3 authored by wuhongsheng's avatar wuhongsheng
Browse files

feat:Add rest call support similar to oepn-api style

parent 8afd078a
No related branches found
No related tags found
No related merge requests found
from openai import OpenAI
from LLM.chat import Chat
from baseHandler import BaseHandler
from rich.console import Console
import logging
import time
logger = logging.getLogger(__name__)
console = Console()
class OpenApiModelHandler(BaseHandler):
"""
Handles the language model part.
"""
def setup(
self,
model_name="deepseek-chat",
device="cuda",
gen_kwargs={},
base_url =None,
api_key=None,
stream=False,
user_role="user",
chat_size=1,
init_chat_role="system",
init_chat_prompt="You are a helpful AI assistant.",
):
self.model_name = model_name
self.stream = stream
self.chat = Chat(chat_size)
if init_chat_role:
if not init_chat_prompt:
raise ValueError(
"An initial promt needs to be specified when setting init_chat_role."
)
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
self.user_role = user_role
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
start = time.time()
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
],
stream=self.stream
)
end = time.time()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {(end - start):.3f} s"
)
def process(self, prompt):
logger.debug("call api language model...")
self.chat.append({"role": self.user_role, "content": prompt})
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": self.user_role, "content": prompt},
],
stream=self.stream
)
generated_text = response.choices[0].message.content
self.chat.append({"role": "assistant", "content": generated_text})
yield generated_text
from dataclasses import dataclass, field
@dataclass
class OpenApiLanguageModelHandlerArguments:
open_api_model_name: str = field(
# default="HuggingFaceTB/SmolLM-360M-Instruct",
default="deepseek-chat",
metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
},
)
open_api_user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
},
)
open_api_init_chat_role: str = field(
default="system",
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
open_api_init_chat_prompt: str = field(
# default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
default="你是一位乐于助人且友好的 AI 助手。您彬彬有礼、尊重他人.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)
open_api_chat_size: int = field(
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
)
open_api_api_key: str = field(
default=None,
metadata={
"help": "Is a unique code used to authenticate and authorize access to an API.Default is None"
},
)
open_api_base_url: str = field(
default=None,
metadata={
"help": "Is the root URL for all endpoints of an API, serving as the starting point for constructing API request.Default is Non"
},
)
open_api_stream: bool = field(
default=False,
metadata={
"help": "The stream parameter typically indicates whether data should be transmitted in a continuous flow rather"
" than in a single, complete response, often used for handling large or real-time data.Default is False"
},
)
\ No newline at end of file
......@@ -7,3 +7,4 @@ ChatTTS>=0.1.1
funasr>=1.1.6
modelscope>=1.17.1
deepfilternet>=0.5.6
openai>=1.40.1
\ No newline at end of file
......@@ -9,4 +9,4 @@ ChatTTS>=0.1.1
funasr>=1.1.6
modelscope>=1.17.1
deepfilternet>=0.5.6
openai>=1.40.1
\ No newline at end of file
......@@ -21,6 +21,7 @@ from arguments_classes.socket_sender_arguments import SocketSenderArguments
from arguments_classes.vad_arguments import VADHandlerArguments
from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments
import torch
import nltk
from rich.console import Console
......@@ -77,6 +78,7 @@ def main():
WhisperSTTHandlerArguments,
ParaformerSTTHandlerArguments,
LanguageModelHandlerArguments,
OpenApiLanguageModelHandlerArguments,
MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments,
MeloTTSHandlerArguments,
......@@ -95,6 +97,7 @@ def main():
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
......@@ -110,6 +113,7 @@ def main():
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
......@@ -187,6 +191,7 @@ def main():
prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt")
prepare_args(language_model_handler_kwargs, "lm")
prepare_args(open_api_language_model_handler_kwargs, "open_api")
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo")
......@@ -278,15 +283,26 @@ def main():
queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs),
)
elif module_kwargs.llm == "open_api":
from LLM.openai_api_language_model import OpenApiModelHandler
lm = OpenApiModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(open_api_language_model_handler_kwargs),
)
elif module_kwargs.llm == "mlx-lm":
from LLM.mlx_language_model import MLXLanguageModelHandler
lm = MLXLanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(mlx_language_model_handler_kwargs),
)
else:
raise ValueError("The LLM should be either transformers or mlx-lm")
if module_kwargs.tts == "parler":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment