Skip to content
Snippets Groups Projects
Commit bcaa22ec authored by James Briggs's avatar James Briggs
Browse files

feat: further docstrings and cleanup

parent f16f620f
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,20 @@ from semantic_router.utils.logger import logger
class Parameter(BaseModel):
"""Parameter for a function.
:param name: The name of the parameter.
:type name: str
:param description: The description of the parameter.
:type description: Optional[str]
:param type: The type of the parameter.
:type type: str
:param default: The default value of the parameter.
:type default: Any
:param required: Whether the parameter is required.
:type required: bool
"""
class Config:
arbitrary_types_allowed = True
......@@ -21,6 +35,11 @@ class Parameter(BaseModel):
required: bool = Field(description="Whether the parameter is required")
def to_ollama(self):
"""Convert the parameter to a dictionary for an Ollama-compatible function schema.
:return: The parameter in dictionary format.
:rtype: Dict[str, Any]
"""
return {
self.name: {
"description": self.description,
......@@ -41,6 +60,11 @@ class FunctionSchema:
parameters: List[Parameter] = Field(description="The parameters of the function")
def __init__(self, function: Union[Callable, BaseModel]):
"""Initialize the FunctionSchema.
:param function: The function to consume.
:type function: Union[Callable, BaseModel]
"""
self.function = function
if callable(function):
self._process_function(function)
......@@ -50,6 +74,11 @@ class FunctionSchema:
raise TypeError("Function must be a Callable or BaseModel")
def _process_function(self, function: Callable):
"""Process the function to get the name, description, signature, and output.
:param function: The function to process.
:type function: Callable
"""
self.name = function.__name__
self.description = str(inspect.getdoc(function))
self.signature = str(inspect.signature(function))
......@@ -67,6 +96,11 @@ class FunctionSchema:
self.parameters = parameters
def to_ollama(self):
"""Convert the FunctionSchema to an Ollama-compatible function schema dictionary.
:return: The function schema in dictionary format.
:rtype: Dict[str, Any]
"""
schema_dict = {
"type": "function",
"function": {
......@@ -94,6 +128,13 @@ class FunctionSchema:
return schema_dict
def _ollama_type_mapping(self, param_type: str) -> str:
"""Map the parameter type to an Ollama-compatible type.
:param param_type: The type of the parameter.
:type param_type: str
:return: The Ollama-compatible type.
:rtype: str
"""
if param_type == "int":
return "number"
elif param_type == "float":
......@@ -107,6 +148,13 @@ class FunctionSchema:
def get_schema_list(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, Any]]:
"""Get a list of function schemas from a list of functions or Pydantic BaseModels.
:param items: The functions or BaseModels to get the schemas for.
:type items: List[Union[BaseModel, Callable]]
:return: A list of function schemas.
:rtype: List[Dict[str, Any]]
"""
schemas = []
for item in items:
schema = get_schema(item)
......@@ -115,6 +163,13 @@ def get_schema_list(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, A
def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]:
"""Get a function schema from a function or Pydantic BaseModel.
:param item: The function or BaseModel to get the schema for.
:type item: Union[BaseModel, Callable]
:return: The function schema.
:rtype: Dict[str, Any]
"""
if isinstance(item, BaseModel):
signature_parts = []
for field_name, field_model in item.__annotations__.items():
......@@ -147,6 +202,13 @@ def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]:
def convert_python_type_to_json_type(param_type: str) -> str:
"""Convert a Python type to a JSON type.
:param param_type: The type of the parameter.
:type param_type: str
:return: The JSON type.
:rtype: str
"""
if param_type == "int":
return "number"
if param_type == "float":
......@@ -167,6 +229,19 @@ def convert_python_type_to_json_type(param_type: str) -> str:
async def route_and_execute(
query: str, llm: BaseLLM, functions: List[Callable], layer
) -> Any:
"""Route and execute a function.
:param query: The query to route and execute.
:type query: str
:param llm: The LLM to use.
:type llm: BaseLLM
:param functions: The functions to execute.
:type functions: List[Callable]
:param layer: The layer to use.
:type layer: Layer
:return: The result of the function.
:rtype: Any
"""
route_choice: RouteChoice = layer(query)
for function in functions:
......
import os
from typing import Optional
import openai
from semantic_router.utils.logger import logger
def llm(prompt: str) -> Optional[str]:
try:
client = openai.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY"),
)
completion = client.chat.completions.create(
model="mistralai/mistral-7b-instruct",
messages=[
{
"role": "user",
"content": prompt,
},
],
temperature=0.01,
max_tokens=200,
)
output = completion.choices[0].message.content
if not output:
raise Exception("No output generated")
return output
except Exception as e:
logger.error(f"LLM error: {e}")
raise Exception(f"LLM error: {e}") from e
# TODO integrate async LLM function
# async def allm(prompt: str) -> Optional[str]:
# try:
# client = openai.AsyncOpenAI(
# base_url="https://openrouter.ai/api/v1",
# api_key=os.getenv("OPENROUTER_API_KEY"),
# )
# completion = await client.chat.completions.create(
# model="mistralai/mistral-7b-instruct",
# messages=[
# {
# "role": "user",
# "content": prompt,
# },
# ],
# temperature=0.01,
# max_tokens=200,
# )
# output = completion.choices[0].message.content
# if not output:
# raise Exception("No output generated")
# return output
# except Exception as e:
# logger.error(f"LLM error: {e}")
# raise Exception(f"LLM error: {e}") from e
......@@ -4,6 +4,9 @@ import colorlog
class CustomFormatter(colorlog.ColoredFormatter):
"""Custom formatter for the logger.
"""
def __init__(self):
super().__init__(
"%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
......@@ -21,6 +24,8 @@ class CustomFormatter(colorlog.ColoredFormatter):
def add_coloured_handler(logger):
"""Add a coloured handler to the logger.
"""
formatter = CustomFormatter()
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
......@@ -29,6 +34,8 @@ def add_coloured_handler(logger):
def setup_custom_logger(name):
"""Setup a custom logger.
"""
logger = logging.getLogger(name)
if not logger.hasHandlers():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment