diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb index 06c4f737be0526d1dba1ea670bdcc09dc2b9d1dd..d024ded3cfd35d255513479809fe7586b78822af 100644 --- a/docs/02-dynamic-routes.ipynb +++ b/docs/02-dynamic-routes.ipynb @@ -46,9 +46,36 @@ "metadata": { "id": "dLElfRhgur0v" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "[notice] A new release of pip is available: 23.1.2 -> 24.0\n", + "[notice] To update, run: python.exe -m pip install --upgrade pip\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: tzdata in c:\\users\\siraj\\documents\\personal\\work\\aurelio\\virtual environments\\semantic_router\\lib\\site-packages (2024.1)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "[notice] A new release of pip is available: 23.1.2 -> 24.0\n", + "[notice] To update, run: python.exe -m pip install --upgrade pip\n" + ] + } + ], "source": [ - "!pip install -qU semantic-router==0.0.34" + "!pip install -qU semantic-router==0.0.34\n", + "!pip install tzdata" ] }, { @@ -136,7 +163,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-27 02:19:43 INFO semantic_router.utils.logger local\u001b[0m\n" + "\u001b[32m2024-04-29 01:50:52 INFO semantic_router.utils.logger local\u001b[0m\n" ] } ], @@ -255,7 +282,7 @@ { "data": { "text/plain": [ - "'18:19'" + "'17:50'" ] }, "execution_count": 6, @@ -369,7 +396,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-27 02:19:45 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" + "\u001b[32m2024-04-29 01:50:53 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" ] } ], @@ -411,48 +438,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-04-27 02:19:45 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", - "\u001b[32m2024-04-27 02:19:45 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "##################################################\n", - "tools\n", - "None\n", - "##################################################\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-04-27 02:19:46 INFO semantic_router.utils.logger LLM output: {\n", + "\u001b[33m2024-04-29 01:50:54 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", + "\u001b[32m2024-04-29 01:50:54 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n", + "\u001b[32m2024-04-29 01:50:55 INFO semantic_router.utils.logger LLM output: {\n", "\t\"timezone\": \"America/New_York\"\n", "}\u001b[0m\n", - "\u001b[32m2024-04-27 02:19:46 INFO semantic_router.utils.logger Function inputs: {'timezone': 'America/New_York'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "##################################################\n", - "completion.choices[0].message.tool_calls\n", - "None\n", - "##################################################\n", - "##################################################\n", - "extracted_inputs\n", - "{'timezone': 'America/New_York'}\n", - "##################################################\n" + "\u001b[32m2024-04-29 01:50:55 INFO semantic_router.utils.logger Function inputs: {'timezone': 'America/New_York'}\u001b[0m\n" ] }, { "data": { "text/plain": [ - "'18:19'" + "'17:50'" ] }, "execution_count": 12, @@ -473,7 +470,7 @@ { "data": { "text/plain": [ - "OpenAILLM(name='gpt-3.5-turbo', client=<openai.OpenAI object at 0x00000152CAD11ED0>, temperature=0.01, max_tokens=200)" + "OpenAILLM(name='gpt-3.5-turbo', client=<openai.OpenAI object at 0x00000129AFF58190>, temperature=0.01, max_tokens=200)" ] }, "execution_count": 13, diff --git a/docs/10-dynamic-routes-via-openai-function-calling.ipynb b/docs/10-dynamic-routes-via-openai-function-calling.ipynb index e20ade1b8218bdf94269d24669f9b7f79a991949..79a68dc05db8880b598fb976bff332899a42f4a0 100644 --- a/docs/10-dynamic-routes-via-openai-function-calling.ipynb +++ b/docs/10-dynamic-routes-via-openai-function-calling.ipynb @@ -40,11 +40,21 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "id": "dLElfRhgur0v" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "[notice] A new release of pip is available: 23.1.2 -> 24.0\n", + "[notice] To update, run: python.exe -m pip install --upgrade pip\n" + ] + } + ], "source": [ "!pip install -qU semantic-router==0.0.34" ] @@ -69,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "id": "kc9Ty6Lgur0x" }, @@ -78,7 +88,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_2\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } @@ -121,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -134,7 +144,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-28 22:54:05 INFO semantic_router.utils.logger local\u001b[0m\n" + "\u001b[32m2024-04-29 01:48:49 INFO semantic_router.utils.logger local\u001b[0m\n" ] } ], @@ -170,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -185,7 +195,7 @@ "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -214,7 +224,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "id": "5jaF1Xa5ur0y" }, @@ -240,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -253,10 +263,10 @@ { "data": { "text/plain": [ - "'14:54'" + "'17:48'" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -276,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -294,7 +304,7 @@ " 'output': \"<class 'str'>\"}" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -308,7 +318,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -323,7 +333,7 @@ " 'required': ['timezone']}}}" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -346,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "id": "iesBG9P3ur0z" }, @@ -375,7 +385,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -388,7 +398,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-28 22:54:06 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" + "\u001b[32m2024-04-29 01:48:50 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" ] } ], @@ -407,7 +417,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -421,7 +431,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-04-28 22:54:07 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" + "\u001b[33m2024-04-29 01:48:52 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" ] }, { @@ -430,7 +440,7 @@ "RouteChoice(name='get_time', function_call={'timezone': 'America/New_York'}, similarity_score=None)" ] }, - "execution_count": 11, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -442,16 +452,16 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'14:54'" + "'17:48'" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -478,7 +488,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -487,7 +497,7 @@ "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -506,7 +516,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -515,7 +525,7 @@ "RouteChoice(name=None, function_call=None, similarity_score=None)" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 9072821092c32bb105865624b051890d21c48757..0989df3e91439c766f7530c9a7db99906d3c87de 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -1,5 +1,5 @@ import os -from typing import Any, List, Optional +from typing import List, Optional, Any import openai @@ -8,6 +8,7 @@ from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger import json +from openai.types.chat import ChatCompletionMessageToolCall class OpenAILLM(BaseLLM): client: Optional[openai.OpenAI] @@ -36,12 +37,12 @@ class OpenAILLM(BaseLLM): self.temperature = temperature self.max_tokens = max_tokens - def __call__(self, messages: List[Message], function_schema: dict = None) -> str: + def __call__(self, messages: List[Message], openai_function_schema: Optional[dict[str, Any]] = None) -> str: if self.client is None: raise ValueError("OpenAI client is not initialized.") try: - if function_schema: - tools = [function_schema] + if openai_function_schema: + tools = [openai_function_schema] else: tools = None completion = self.client.chat.completions.create( @@ -49,34 +50,27 @@ class OpenAILLM(BaseLLM): messages=[m.to_openai() for m in messages], temperature=self.temperature, max_tokens=self.max_tokens, - tools=tools, + tools=tools, # type: ignore # MyPy expecting Iterable[ChatCompletionToolParam] | NotGiven, but dict is accepted by OpenAI. ) output = completion.choices[0].message.content - if function_schema: + if openai_function_schema: return completion.choices[0].message.tool_calls - # tool_calls = completion.choices[0].message.tool_calls - # if not tool_calls: - # raise Exception("No tool calls available in the completion response.") - # tool_call = tool_calls[0] - # arguments_json = tool_call.function.arguments - # arguments_dict = json.loads(arguments_json) - # return arguments_dict - if not output: raise Exception("No output generated") return output except Exception as e: logger.error(f"LLM error: {e}") raise Exception(f"LLM error: {e}") from e -# - def extract_function_inputs_openai(self, query: str, function_schema: dict) -> dict: + + + def extract_function_inputs_openai(self, query: str, openai_function_schema: dict[str, Any]) -> dict: messages = [] system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages.append(Message(role="system", content=system_prompt)) messages.append(Message(role="user", content=query)) - output = self(messages=messages, function_schema=function_schema) + output = self(messages=messages, openai_function_schema=openai_function_schema) if not output: raise Exception("No output generated for extract function input") if len(output) != 1: @@ -85,4 +79,3 @@ class OpenAILLM(BaseLLM): arguments_json = tool_call.function.arguments function_inputs = json.loads(arguments_json) return function_inputs - \ No newline at end of file diff --git a/semantic_router/route.py b/semantic_router/route.py index 3a62fab163ca33737cdcc170e42d2dfa5120295e..080afd61598090bf60db1c49010ea165ca56c5e9 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -67,23 +67,27 @@ class Route(BaseModel): "LLM is required for dynamic routes. Please ensure the `llm` " "attribute is set." ) - elif query is None: + if query is None or type(query) != str: raise ValueError( "Query is required for dynamic routes. Please ensure the `query` " "argument is passed." ) - if self.function_schema: - extracted_inputs = self.llm.extract_function_inputs( - query=query, function_schema=self.function_schema - ) - func_call = extracted_inputs - elif self.openai_function_schema: - if not isinstance(self.llm, OpenAILLM): - raise TypeError("LLM must be an instance of OpenAILLM for openai_function_schema.") - extracted_inputs = self.llm.extract_function_inputs_openai( - query=query, function_schema=self.openai_function_schema - ) - func_call = extracted_inputs + if self.function_schema: + extracted_inputs = self.llm.extract_function_inputs( + query=query, + function_schema=self.function_schema + ) + func_call = extracted_inputs + elif self.openai_function_schema: # Logically must be self.openai_function_schema, but keeps MyPy happy. + if not isinstance(self.llm, OpenAILLM): + raise TypeError( + "LLM must be an instance of OpenAILLM for openai_function_schema." + ) + extracted_inputs = self.llm.extract_function_inputs_openai( + query=query, + openai_function_schema=self.openai_function_schema + ) + func_call = extracted_inputs else: # otherwise we just pass None for the call func_call = None diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index c5e3a355691300e00b1003d1ad2cd4a336cd7aa4..562da07b35018dcb8d6621f164578dbfa477afb7 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -8,6 +8,7 @@ from semantic_router.schema import Message, RouteChoice from semantic_router.utils.logger import logger import re + def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]: if isinstance(item, BaseModel): signature_parts = [] @@ -39,6 +40,7 @@ def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]: } return schema + def convert_param_type_to_json_type(param_type: str) -> str: if param_type == "int": return "number" @@ -55,48 +57,50 @@ def convert_param_type_to_json_type(param_type: str) -> str: else: return "object" + def get_schema_openai_func_calling(item: Callable) -> Dict[str, Any]: if not callable(item): raise ValueError("Provided item must be a callable function.") - + docstring = inspect.getdoc(item) signature = inspect.signature(item) - + schema = { "type": "function", "function": { "name": item.__name__, "description": docstring if docstring else "No description available.", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } - } + "parameters": {"type": "object", "properties": {}, "required": []}, + }, } - + for param_name, param in signature.parameters.items(): - param_type = param.annotation.__name__ if param.annotation != inspect.Parameter.empty else "Any" + param_type = ( + param.annotation.__name__ + if param.annotation != inspect.Parameter.empty + else "Any" + ) param_description = "No description available." param_required = param.default is inspect.Parameter.empty - + # Attempt to extract the parameter description from the docstring if docstring: param_doc_regex = re.compile(rf":param {param_name}:(.*?)\n(?=:\w|$)", re.S) match = param_doc_regex.search(docstring) if match: param_description = match.group(1).strip() - + schema["function"]["parameters"]["properties"][param_name] = { "type": convert_param_type_to_json_type(param_type), - "description": param_description + "description": param_description, } - + if param_required: schema["function"]["parameters"]["required"].append(param_name) - + return schema + # TODO: Add route layer object to the input, solve circular import issue async def route_and_execute( query: str, llm: BaseLLM, functions: List[Callable], layer