diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 5b2aad846ad8531fbeb02504f42e1688b22ceebc..46eb2eb73a4ac1fc3ede4c365a3100d07721c40b 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -9,6 +9,7 @@ from semantic_router.encoders import ( CohereEncoder, OpenAIEncoder, ) +from semantic_router.llms import BaseLLM from semantic_router.linear import similarity_matrix, top_scores from semantic_router.route import Route from semantic_router.schema import Encoder, EncoderType, RouteChoice @@ -142,12 +143,16 @@ class RouteLayer: score_threshold: float = 0.82 def __init__( - self, encoder: BaseEncoder | None = None, routes: list[Route] | None = None + self, + encoder: BaseEncoder | None = None, + llm: BaseLLM | None = None, + routes: list[Route] | None = None, ): logger.info("Initializing RouteLayer") self.index = None self.categories = None self.encoder = encoder if encoder is not None else CohereEncoder() + self.llm = llm self.routes: list[Route] = routes if routes is not None else [] # decide on default threshold based on encoder if isinstance(encoder, OpenAIEncoder): @@ -168,6 +173,13 @@ class RouteLayer: if passed: # get chosen route object route = [route for route in self.routes if route.name == top_class][0] + if route.function_schema and not isinstance(route.llm, BaseLLM): + if not self.llm: + raise ValueError( + "LLM is required for dynamic routes. Please ensure the 'llm' is set." + ) + else: + route.llm = self.llm return route(text) else: # if no route passes threshold, return empty route choice diff --git a/semantic_router/llms/__init__.py b/semantic_router/llms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..446f7c42f51c88175dff53bed5d32c4406c8cd34 --- /dev/null +++ b/semantic_router/llms/__init__.py @@ -0,0 +1,7 @@ +from semantic_router.llms.base import BaseLLM +from semantic_router.llms.openai import OpenAI +from semantic_router.llms.openrouter import OpenRouter +from semantic_router.llms.cohere import Cohere + + +__all__ = ["BaseLLM", "OpenAI", "OpenRouter", "Cohere"] diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1a038e34314e9306c8313db9d527ff2ece141c --- /dev/null +++ b/semantic_router/llms/base.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + + +class BaseLLM(BaseModel): + name: str + + class Config: + arbitrary_types_allowed = True + + def __call__(self, prompt) -> str | None: + raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py new file mode 100644 index 0000000000000000000000000000000000000000..80512d5c521fb3b7aaf81c60efcb7f50d61078b3 --- /dev/null +++ b/semantic_router/llms/cohere.py @@ -0,0 +1,43 @@ +import os +import cohere +from semantic_router.llms import BaseLLM +from semantic_router.schema import Message + + +class Cohere(BaseLLM): + client: cohere.Client | None = None + + def __init__( + self, + name: str | None = None, + cohere_api_key: str | None = None, + ): + if name is None: + name = os.getenv("COHERE_CHAT_MODEL_NAME", "command") + super().__init__(name=name) + cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") + if cohere_api_key is None: + raise ValueError("Cohere API key cannot be 'None'.") + try: + self.client = cohere.Client(cohere_api_key) + except Exception as e: + raise ValueError(f"Cohere API client failed to initialize. Error: {e}") + + def __call__(self, messages: list[Message]) -> str: + if self.client is None: + raise ValueError("Cohere client is not initialized.") + try: + completion = self.client.chat( + model=self.name, + chat_history=[m.to_cohere() for m in messages[:-1]], + message=messages[-1].content, + ) + + output = completion.text + + if not output: + raise Exception("No output generated") + return output + + except Exception as e: + raise ValueError(f"Cohere API call failed. Error: {e}") diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..18b6e7063ef16e00c8c29681c4e9818f1cb26bbb --- /dev/null +++ b/semantic_router/llms/openai.py @@ -0,0 +1,51 @@ +import os +import openai +from semantic_router.utils.logger import logger +from semantic_router.llms import BaseLLM +from semantic_router.schema import Message + + +class OpenAI(BaseLLM): + client: openai.OpenAI | None + temperature: float | None + max_tokens: int | None + + def __init__( + self, + name: str | None = None, + openai_api_key: str | None = None, + temperature: float = 0.01, + max_tokens: int = 200, + ): + if name is None: + name = os.getenv("OPENAI_CHAT_MODEL_NAME", "gpt-3.5-turbo") + super().__init__(name=name) + api_key = openai_api_key or os.getenv("OPENAI_API_KEY") + if api_key is None: + raise ValueError("OpenAI API key cannot be 'None'.") + try: + self.client = openai.OpenAI(api_key=api_key) + except Exception as e: + raise ValueError(f"OpenAI API client failed to initialize. Error: {e}") + self.temperature = temperature + self.max_tokens = max_tokens + + def __call__(self, messages: list[Message]) -> str: + if self.client is None: + raise ValueError("OpenAI client is not initialized.") + try: + completion = self.client.chat.completions.create( + model=self.name, + messages=[m.to_openai() for m in messages], + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + output = completion.choices[0].message.content + + if not output: + raise Exception("No output generated") + return output + except Exception as e: + logger.error(f"LLM error: {e}") + raise Exception(f"LLM error: {e}") diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7a9b49130ebcfbbb7f30cf4c52ea170e823b83 --- /dev/null +++ b/semantic_router/llms/openrouter.py @@ -0,0 +1,56 @@ +import os +import openai +from semantic_router.utils.logger import logger +from semantic_router.llms import BaseLLM +from semantic_router.schema import Message + + +class OpenRouter(BaseLLM): + client: openai.OpenAI | None + base_url: str | None + temperature: float | None + max_tokens: int | None + + def __init__( + self, + name: str | None = None, + openrouter_api_key: str | None = None, + base_url: str = "https://openrouter.ai/api/v1", + temperature: float = 0.01, + max_tokens: int = 200, + ): + if name is None: + name = os.getenv( + "OPENROUTER_CHAT_MODEL_NAME", "mistralai/mistral-7b-instruct" + ) + super().__init__(name=name) + self.base_url = base_url + api_key = openrouter_api_key or os.getenv("OPENROUTER_API_KEY") + if api_key is None: + raise ValueError("OpenRouter API key cannot be 'None'.") + try: + self.client = openai.OpenAI(api_key=api_key, base_url=self.base_url) + except Exception as e: + raise ValueError(f"OpenRouter API client failed to initialize. Error: {e}") + self.temperature = temperature + self.max_tokens = max_tokens + + def __call__(self, messages: list[Message]) -> str: + if self.client is None: + raise ValueError("OpenRouter client is not initialized.") + try: + completion = self.client.chat.completions.create( + model=self.name, + messages=[m.to_openai() for m in messages], + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + output = completion.choices[0].message.content + + if not output: + raise Exception("No output generated") + return output + except Exception as e: + logger.error(f"LLM error: {e}") + raise Exception(f"LLM error: {e}") diff --git a/semantic_router/route.py b/semantic_router/route.py index 12afa7fe0a6824882a6a9da4d2d3845f6b68a62c..7a8803d7b454343560d188eea4cb719712b99ca6 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -4,11 +4,13 @@ from typing import Any, Callable, Union from pydantic import BaseModel +from semantic_router.llms import BaseLLM from semantic_router.schema import RouteChoice from semantic_router.utils import function_call -from semantic_router.utils.llm import llm from semantic_router.utils.logger import logger +from semantic_router.schema import Message + def is_valid(route_config: str) -> bool: try: @@ -43,12 +45,17 @@ class Route(BaseModel): utterances: list[str] description: str | None = None function_schema: dict[str, Any] | None = None + llm: BaseLLM | None = None def __call__(self, query: str) -> RouteChoice: if self.function_schema: + if not self.llm: + raise ValueError( + "LLM is required for dynamic routes. Please ensure the 'llm' is set." + ) # if a function schema is provided we generate the inputs extracted_inputs = function_call.extract_function_inputs( - query=query, function_schema=self.function_schema + query=query, llm=self.llm, function_schema=self.function_schema ) func_call = extracted_inputs else: @@ -60,16 +67,16 @@ class Route(BaseModel): return self.dict() @classmethod - def from_dict(cls, data: dict): + def from_dict(cls, data: dict[str, Any]): return cls(**data) @classmethod - def from_dynamic_route(cls, entity: Union[BaseModel, Callable]): + def from_dynamic_route(cls, llm: BaseLLM, entity: Union[BaseModel, Callable]): """ Generate a dynamic Route object from a function or Pydantic model using LLM """ schema = function_call.get_schema(item=entity) - dynamic_route = cls._generate_dynamic_route(function_schema=schema) + dynamic_route = cls._generate_dynamic_route(llm=llm, function_schema=schema) return dynamic_route @classmethod @@ -85,7 +92,7 @@ class Route(BaseModel): raise ValueError("No <config></config> tags found in the output.") @classmethod - def _generate_dynamic_route(cls, function_schema: dict[str, Any]): + def _generate_dynamic_route(cls, llm: BaseLLM, function_schema: dict[str, Any]): logger.info("Generating dynamic route...") prompt = f""" @@ -113,7 +120,8 @@ class Route(BaseModel): {function_schema} """ - output = llm(prompt) + llm_input = [Message(role="user", content=prompt)] + output = llm(llm_input) if not output: raise Exception("No output generated for dynamic route") @@ -122,5 +130,7 @@ class Route(BaseModel): logger.info(f"Generated route config:\n{route_config}") if is_valid(route_config): - return Route.from_dict(json.loads(route_config)) + route_config_dict = json.loads(route_config) + route_config_dict["llm"] = llm + return Route.from_dict(route_config_dict) raise Exception("No config generated") diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 465cfaacb1ecbbc0c26bb2ee8f67f814da308e1e..62eecc7d5bc91ae9731aba40af7de19074a96ba1 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -49,6 +49,14 @@ class Message(BaseModel): role: str content: str + def to_openai(self): + if self.role.lower() not in ["user", "assistant", "system"]: + raise ValueError("Role must be either 'user', 'assistant' or 'system'") + return {"role": self.role, "content": self.content} + + def to_cohere(self): + return {"role": self.role, "message": self.content} + class Conversation(BaseModel): messages: list[Message] diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index 2ead3ab58dd5c54eaf26fdc5b2f73ea95f4bef9c..d93d027ccce4b1697af3c3a4e0e823c4d68d2101 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -4,7 +4,8 @@ from typing import Any, Callable, Union from pydantic import BaseModel -from semantic_router.utils.llm import llm +from semantic_router.llms import BaseLLM +from semantic_router.schema import Message from semantic_router.utils.logger import logger @@ -40,7 +41,9 @@ def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]: return schema -def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict: +def extract_function_inputs( + query: str, llm: BaseLLM, function_schema: dict[str, Any] +) -> dict: logger.info("Extracting function input...") prompt = f""" @@ -71,8 +74,8 @@ def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict schema: {function_schema} Result: """ - - output = llm(prompt) + llm_input = [Message(role="user", content=prompt)] + output = llm(llm_input) if not output: raise Exception("No output generated for extract function input") @@ -113,15 +116,18 @@ def call_function(function: Callable, inputs: dict[str, str]): # TODO: Add route layer object to the input, solve circular import issue -async def route_and_execute(query: str, functions: list[Callable], route_layer): +async def route_and_execute( + query: str, llm: BaseLLM, functions: list[Callable], route_layer +): function_name = route_layer(query) if not function_name: logger.warning("No function found, calling LLM...") - return llm(query) + llm_input = [Message(role="user", content=query)] + return llm(llm_input) for function in functions: if function.__name__ == function_name: print(f"Calling function: {function.__name__}") schema = get_schema(function) - inputs = extract_function_inputs(query, schema) + inputs = extract_function_inputs(query, llm, schema) call_function(function, inputs)