diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb index 05649eeb3cb4cc6a6d78945cceaaae2f2d97146e..5ee87e715570f7537f26f1fe2258c7d66dc5bd6f 100644 --- a/docs/02-dynamic-routes.ipynb +++ b/docs/02-dynamic-routes.ipynb @@ -86,12 +86,16 @@ "name": "stderr", "output_type": "stream", "text": [ + "WARNING: Ignoring invalid distribution ~ (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "WARNING: Ignoring invalid distribution ~rotobuf (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~ (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "WARNING: Ignoring invalid distribution ~rotobuf (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~ (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "WARNING: Ignoring invalid distribution ~rotobuf (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", + "WARNING: Ignoring invalid distribution ~ (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "WARNING: Ignoring invalid distribution ~illow (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "WARNING: Ignoring invalid distribution ~rotobuf (C:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages)\n", "\n", @@ -102,7 +106,7 @@ ], "source": [ "!pip install tzdata\n", - "!pip install -qU semantic-router" + "# !pip install -qU semantic-router" ] }, { @@ -190,7 +194,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-06 21:44:57 INFO semantic_router.utils.logger local\u001b[0m\n" + "\u001b[32m2024-05-07 00:15:02 INFO semantic_router.utils.logger local\u001b[0m\n" ] } ], @@ -309,7 +313,7 @@ { "data": { "text/plain": [ - "'13:44'" + "'16:15'" ] }, "execution_count": 6, @@ -361,8 +365,8 @@ "source": [ "from semantic_router.llms.openai import get_schemas_openai\n", "\n", - "schema = get_schemas_openai([get_time])\n", - "schema" + "schemas = get_schemas_openai([get_time])\n", + "schemas" ] }, { @@ -389,7 +393,7 @@ " \"what is the time in london?\",\n", " \"I live in Rome, what time is it?\",\n", " ],\n", - " function_schemas=schema,\n", + " function_schemas=schemas,\n", ")" ] }, @@ -426,7 +430,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-06 21:44:58 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" + "\u001b[32m2024-05-07 00:15:03 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" ] } ], @@ -468,8 +472,22 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-05-06 21:44:59 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", - "\u001b[32m2024-05-06 21:45:00 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" + "\u001b[33m2024-05-07 00:15:04 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", + "\u001b[32m2024-05-07 00:15:05 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "##################################################\n", + "tool_calls\n", + "[ChatCompletionMessageToolCall(id='call_WFV3WaT0jSUu5bSxZ84n0jAM', function=Function(arguments='{\"timezone\":\"America/New_York\"}', name='get_time'), type='function')]\n", + "##################################################\n", + "##################################################\n", + "type(tool_calls)\n", + "<class 'list'>\n", + "##################################################\n" ] }, { @@ -514,7 +532,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "13:45\n" + "16:15\n" ] } ], @@ -760,7 +778,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-06 21:45:00 INFO semantic_router.utils.logger local\u001b[0m\n" + "\u001b[32m2024-05-07 00:15:05 INFO semantic_router.utils.logger local\u001b[0m\n" ] } ], @@ -868,8 +886,22 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-05-06 21:45:02 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", - "\u001b[32m2024-05-06 21:45:03 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" + "\u001b[33m2024-05-07 00:15:08 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", + "\u001b[32m2024-05-07 00:15:09 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "##################################################\n", + "tool_calls\n", + "[ChatCompletionMessageToolCall(id='call_QZ52hh4HGESYAP5vCwNNQPZd', function=Function(arguments='{\"timezone\":\"America/New_York\"}', name='get_time'), type='function')]\n", + "##################################################\n", + "##################################################\n", + "type(tool_calls)\n", + "<class 'list'>\n", + "##################################################\n" ] }, { @@ -897,7 +929,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "13:45\n" + "16:15\n" ] } ], @@ -921,7 +953,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-06 21:45:05 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}]\u001b[0m\n" + "\u001b[32m2024-05-07 00:15:11 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}]\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "##################################################\n", + "tool_calls\n", + "[ChatCompletionMessageToolCall(id='call_9PUlZqlwSK3KRu4QHgJhHeQy', function=Function(arguments='{\"timezone1\":\"America/Los_Angeles\",\"timezone2\":\"Europe/Istanbul\"}', name='get_time_difference'), type='function')]\n", + "##################################################\n", + "##################################################\n", + "type(tool_calls)\n", + "<class 'list'>\n", + "##################################################\n" ] }, { @@ -973,7 +1019,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-06 21:45:07 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}]\u001b[0m\n" + "\u001b[32m2024-05-07 00:15:12 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}]\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "##################################################\n", + "tool_calls\n", + "[ChatCompletionMessageToolCall(id='call_4w6FLMpkojbhhJSSBfQnTTDY', function=Function(arguments='{\"time\":\"23:02\",\"from_timezone\":\"Asia/Dubai\",\"to_timezone\":\"Asia/Tokyo\"}', name='convert_time'), type='function')]\n", + "##################################################\n", + "##################################################\n", + "type(tool_calls)\n", + "<class 'list'>\n", + "##################################################\n" ] }, { @@ -1025,7 +1085,21 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-06 21:45:10 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}]\u001b[0m\n" + "\u001b[32m2024-05-07 00:15:15 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}]\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "##################################################\n", + "tool_calls\n", + "[ChatCompletionMessageToolCall(id='call_WrJ7AXz3bsNlVEAFyzz7GVWx', function=Function(arguments='{\"timezone\": \"Europe/Prague\"}', name='get_time'), type='function'), ChatCompletionMessageToolCall(id='call_31PebURXBI83AdgE9HMNQcL2', function=Function(arguments='{\"timezone1\": \"Europe/Berlin\", \"timezone2\": \"Asia/Shanghai\"}', name='get_time_difference'), type='function'), ChatCompletionMessageToolCall(id='call_NuFGDj6PePGbfJuAtCjogtzE', function=Function(arguments='{\"time\": \"05:53\", \"from_timezone\": \"Europe/Lisbon\", \"to_timezone\": \"Asia/Bangkok\"}', name='convert_time'), type='function')]\n", + "##################################################\n", + "##################################################\n", + "type(tool_calls)\n", + "<class 'list'>\n", + "##################################################\n" ] } ], @@ -1067,7 +1141,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "19:45\n", + "22:15\n", "The time difference between Europe/Berlin and Asia/Shanghai is 6.0 hours.\n", "11:53\n" ] diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index bafc8c20bd3ce1c2c41256b65109c45de7d19f2d..c8c1f14d0795066a01fe9172587fe939f4c90841 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -18,14 +18,14 @@ class BaseLLM(BaseModel): def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") - + def _is_valid_inputs( self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]] ) -> bool: - """Determine if the functions chosen by the LLM exist within the function_schemas, + """Determine if the functions chosen by the LLM exist within the function_schemas, and if the input arguments are valid for those functions.""" try: - # Currently only supporting single functions for most LLMs in Dynamic Routes. + # Currently only supporting single functions for most LLMs in Dynamic Routes. if len(inputs) != 1: logger.error("Only one set of function inputs is allowed.") return False @@ -33,7 +33,9 @@ class BaseLLM(BaseModel): logger.error("Only one function schema is allowed.") return False # Validate the inputs against the function schema - if not self._validate_single_function_inputs(inputs[0], function_schemas[0]): + if not self._validate_single_function_inputs( + inputs[0], function_schemas[0] + ): return False return True @@ -41,7 +43,9 @@ class BaseLLM(BaseModel): logger.error(f"Input validation error: {str(e)}") return False - def _validate_single_function_inputs(self, inputs: Dict[str, Any], function_schema: Dict[str, Any]) -> bool: + def _validate_single_function_inputs( + self, inputs: Dict[str, Any], function_schema: Dict[str, Any] + ) -> bool: """Validate the extracted inputs against the function schema""" try: # Extract parameter names and types from the signature string @@ -136,4 +140,4 @@ Provide JSON output now: logger.info(f"Function inputs: {function_inputs}") if not self._is_valid_inputs(function_inputs, function_schemas): raise ValueError("Invalid inputs") - return function_inputs \ No newline at end of file + return function_inputs diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index 469183361eadafc40131fb6735db4081f43ec983..102f7fff7d253533a9b6e195cf5aaa8f7a9d3984 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -80,7 +80,7 @@ class LlamaCppLLM(BaseLLM): def extract_function_inputs( self, query: str, function_schemas: List[Dict[str, Any]] - ) -> Dict: + ) -> List[Dict[str, Any]]: with self._grammar(): return super().extract_function_inputs( query=query, function_schemas=function_schemas diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index be9a10dd47a6684ed8281a10f360c12e9235cac6..8fd2a1b2e715adaf6a4255d6d5da4f1326ebd0d0 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -1,17 +1,22 @@ import os -from typing import List, Optional, Any, Callable, Dict +from typing import List, Optional, Any, Callable, Dict, Union import openai -from openai._types import NotGiven +from openai._types import NotGiven, NOT_GIVEN from semantic_router.llms import BaseLLM from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger import json -from semantic_router.utils.function_call import get_schema, convert_python_type_to_json_type +from semantic_router.utils.function_call import ( + get_schema, + convert_python_type_to_json_type, +) import inspect import re +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall + class OpenAILLM(BaseLLM): client: Optional[openai.OpenAI] @@ -40,19 +45,23 @@ class OpenAILLM(BaseLLM): self.temperature = temperature self.max_tokens = max_tokens - def _extract_tool_calls_info(self, tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _extract_tool_calls_info( + self, tool_calls: List[ChatCompletionMessageToolCall] + ) -> List[Dict[str, Any]]: tool_calls_info = [] for tool_call in tool_calls: if tool_call.function.arguments is None: raise ValueError( "Invalid output, expected arguments to be specified for each tool call." ) - tool_calls_info.append({ - "function_name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments) - }) + tool_calls_info.append( + { + "function_name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) return tool_calls_info - + def __call__( self, messages: List[Message], @@ -61,17 +70,14 @@ class OpenAILLM(BaseLLM): if self.client is None: raise ValueError("OpenAI client is not initialized.") try: - if function_schemas: - tools = function_schemas - else: - tools = NotGiven + tools: Union[List[Dict[str, Any]], NotGiven] = function_schemas if function_schemas is not None else NOT_GIVEN completion = self.client.chat.completions.create( model=self.name, messages=[m.to_openai() for m in messages], temperature=self.temperature, max_tokens=self.max_tokens, - tools=tools, + tools=tools, # type: ignore # We pass a list of dicts which get interpreted as Iterable[ChatCompletionToolParam]. ) if function_schemas: @@ -82,14 +88,26 @@ class OpenAILLM(BaseLLM): raise ValueError( "Invalid output, expected at least one tool to be specified." ) - + # Collecting multiple tool calls information - output = self._extract_tool_calls_info(tool_calls) + # DEBUGGING: Start. + print('#'*50) + print('tool_calls') + print(tool_calls) + print('#'*50) + # DEBUGGING: End. + # DEBUGGING: Start. + print('#'*50) + print('type(tool_calls)') + print(type(tool_calls)) + print('#'*50) + # DEBUGGING: End. + output = str(self._extract_tool_calls_info(tool_calls)) # str in keepign with base type. else: content = completion.choices[0].message.content if content is None: raise ValueError("Invalid output, expected content.") - output = str(content) # str to keep MyPy happy. + output = str(content) # str in keepign with base type. return output @@ -99,21 +117,24 @@ class OpenAILLM(BaseLLM): def extract_function_inputs( self, query: str, function_schemas: List[Dict[str, Any]] - ) -> Dict: + ) -> List[Dict[str, Any]]: messages = [] system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages.append(Message(role="system", content=system_prompt)) messages.append(Message(role="user", content=query)) - function_inputs = self(messages=messages, function_schemas=function_schemas) + output = self(messages=messages, function_schemas=function_schemas) + if not output: + raise Exception("No output generated for extract function input") + function_inputs = json.loads(output) logger.info(f"Function inputs: {function_inputs}") if not self._is_valid_inputs(function_inputs, function_schemas): raise ValueError("Invalid inputs") return function_inputs - + def _is_valid_inputs( self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]] ) -> bool: - """Determine if the functions chosen by the LLM exist within the function_schemas, + """Determine if the functions chosen by the LLM exist within the function_schemas, and if the input arguments are valid for those functions.""" try: for input_dict in inputs: @@ -126,14 +147,27 @@ class OpenAILLM(BaseLLM): arguments = input_dict["arguments"] # Find the matching function schema based on function_name - matching_schema = next((schema['function'] for schema in function_schemas if schema['function']['name'] == function_name), None) + matching_schema = next( + ( + schema["function"] + for schema in function_schemas + if schema["function"]["name"] == function_name + ), + None, + ) if not matching_schema: - logger.error(f"No matching function schema found for function name: {function_name}") + logger.error( + f"No matching function schema found for function name: {function_name}" + ) return False # Validate the inputs against the function schema - if not self._validate_single_function_inputs(arguments, matching_schema): - logger.error(f"Validation failed for function name: {function_name}") + if not self._validate_single_function_inputs( + arguments, matching_schema + ): + logger.error( + f"Validation failed for function name: {function_name}" + ) return False return True @@ -141,12 +175,14 @@ class OpenAILLM(BaseLLM): logger.error(f"Input validation error: {str(e)}") return False - def _validate_single_function_inputs(self, inputs: Dict[str, Any], function_schema: Dict[str, Any]) -> bool: + def _validate_single_function_inputs( + self, inputs: Dict[str, Any], function_schema: Dict[str, Any] + ) -> bool: """Validate the extracted inputs against the function schema""" try: # Access the parameters and their properties from the function schema directly - parameters = function_schema['parameters']['properties'] - required_params = function_schema['parameters'].get('required', []) + parameters = function_schema["parameters"]["properties"] + required_params = function_schema["parameters"].get("required", []) # Check if all required parameters are present in the inputs for param_name in required_params: @@ -157,10 +193,14 @@ class OpenAILLM(BaseLLM): # Check if the types of the inputs match the expected types (if type checking is needed) for param_name, param_info in parameters.items(): if param_name in inputs: - expected_type = param_info['type'] + expected_type = param_info["type"] # This is a simple type check, consider expanding it based on your needs - if expected_type == 'string' and not isinstance(inputs[param_name], str): - logger.error(f"Input type for '{param_name}' is not {expected_type}") + if expected_type == "string" and not isinstance( + inputs[param_name], str + ): + logger.error( + f"Input type for '{param_name}' is not {expected_type}" + ) return False return True @@ -168,6 +208,7 @@ class OpenAILLM(BaseLLM): logger.error(f"Single input validation error: {str(e)}") return False + def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: schemas = [] for item in items: @@ -179,9 +220,9 @@ def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: # Initialize the function schema with basic details function_schema = { - "name": basic_schema['name'], - "description": basic_schema['description'], - "parameters": {"type": "object", "properties": {}, "required": []} + "name": basic_schema["name"], + "description": basic_schema["description"], + "parameters": {"type": "object", "properties": {}, "required": []}, } # Extract parameter details from the signature @@ -191,7 +232,11 @@ def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: doc_params = param_doc_regex.findall(docstring) if docstring else [] for param_name, param in signature.parameters.items(): - param_type = param.annotation.__name__ if param.annotation != inspect.Parameter.empty else "Any" + param_type = ( + param.annotation.__name__ + if param.annotation != inspect.Parameter.empty + else "Any" + ) param_description = "No description available." param_required = param.default is inspect.Parameter.empty @@ -203,16 +248,12 @@ def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: function_schema["parameters"]["properties"][param_name] = { "type": convert_python_type_to_json_type(param_type), - "description": param_description + "description": param_description, } if param_required: function_schema["parameters"]["required"].append(param_name) - schemas.append({ - "type": "function", - "function": function_schema - }) + schemas.append({"type": "function", "function": function_schema}) return schemas - diff --git a/semantic_router/route.py b/semantic_router/route.py index e9fe47bbf536ff8df3d4148b96cbb76b81c0dafe..38c165d7dd94f3f831052866dc611fc5d5240f80 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -92,16 +92,19 @@ class Route(BaseModel): @classmethod def from_dict(cls, data: Dict[str, Any]): return cls(**data) + @classmethod - def from_dynamic_route(cls, llm: BaseLLM, entities: List[Union[BaseModel, Callable]]): + def from_dynamic_route( + cls, llm: BaseLLM, entities: List[Union[BaseModel, Callable]], route_name: str + ): """ Generate a dynamic Route object from a list of functions or Pydantic models using LLM """ - schemas = function_call.get_schemas(items=entities) - dynamic_route = cls._generate_dynamic_route(llm=llm, function_schemas=schemas) + schemas = function_call.get_schema_list(items=entities) + dynamic_route = cls._generate_dynamic_route(llm=llm, function_schemas=schemas, route_name=route_name) dynamic_route.function_schemas = schemas return dynamic_route - + @classmethod def _parse_route_config(cls, config: str) -> str: # Regular expression to match content inside <config></config> @@ -115,10 +118,14 @@ class Route(BaseModel): raise ValueError("No <config></config> tags found in the output.") @classmethod - def _generate_dynamic_route(cls, llm: BaseLLM, function_schemas: List[Dict[str, Any]], route_name: str): + def _generate_dynamic_route( + cls, llm: BaseLLM, function_schemas: List[Dict[str, Any]], route_name: str + ): logger.info("Generating dynamic route...") - formatted_schemas = "\n".join([json.dumps(schema, indent=4) for schema in function_schemas]) + formatted_schemas = "\n".join( + [json.dumps(schema, indent=4) for schema in function_schemas] + ) prompt = f""" You are tasked to generate a single JSON configuration for multiple function schemas. Each function schema should contribute five example utterances. diff --git a/semantic_router/splitters/utils.py b/semantic_router/splitters/utils.py index 6f71f979d1e1c668c5bded72722377e8232471ad..3438a7a811e577ffc72e8357b79f156d6e4d907c 100644 --- a/semantic_router/splitters/utils.py +++ b/semantic_router/splitters/utils.py @@ -2,6 +2,7 @@ import regex import tiktoken from typing import List + def split_to_sentences(text: str) -> List[str]: """ Enhanced regex pattern to split a given text into sentences more accurately. diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index 57e4c46704529db32631dba15c2b3591572d19d9..99c9d3853135002375bfa1a556630c1b4e7802b4 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -6,7 +6,6 @@ from pydantic.v1 import BaseModel from semantic_router.llms import BaseLLM from semantic_router.schema import Message, RouteChoice from semantic_router.utils.logger import logger -import re def get_schema_list(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, Any]]: @@ -16,6 +15,7 @@ def get_schema_list(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, A schemas.append(schema) return schemas + def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]: if isinstance(item, BaseModel): signature_parts = [] @@ -65,7 +65,6 @@ def convert_python_type_to_json_type(param_type: str) -> str: return "object" - # TODO: Add route layer object to the input, solve circular import issue async def route_and_execute( query: str, llm: BaseLLM, functions: List[Callable], layer @@ -75,7 +74,7 @@ async def route_and_execute( for function in functions: if function.__name__ == route_choice.name: if route_choice.function_call: - return function(**route_choice.function_call) + return function(**route_choice.function_call[0]) logger.warning("No function found, calling LLM.") llm_input = [Message(role="user", content=query)]