From 9cef8cbae3d63f5a03dae157e4980ceeea35cdfe Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Mon, 29 Apr 2024 01:54:05 +0400 Subject: [PATCH] Further linting. --- docs/02-dynamic-routes.ipynb | 18 ++--- ...c-routes-via-openai-function-calling.ipynb | 73 ++++++++++++------- semantic_router/llms/openai.py | 15 ++-- semantic_router/route.py | 14 ++-- 4 files changed, 71 insertions(+), 49 deletions(-) diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb index d024ded3..b80cc81f 100644 --- a/docs/02-dynamic-routes.ipynb +++ b/docs/02-dynamic-routes.ipynb @@ -163,7 +163,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-29 01:50:52 INFO semantic_router.utils.logger local\u001b[0m\n" + "\u001b[32m2024-04-29 01:53:55 INFO semantic_router.utils.logger local\u001b[0m\n" ] } ], @@ -282,7 +282,7 @@ { "data": { "text/plain": [ - "'17:50'" + "'17:53'" ] }, "execution_count": 6, @@ -396,7 +396,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-29 01:50:53 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" + "\u001b[32m2024-04-29 01:53:56 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" ] } ], @@ -438,18 +438,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-04-29 01:50:54 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", - "\u001b[32m2024-04-29 01:50:54 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n", - "\u001b[32m2024-04-29 01:50:55 INFO semantic_router.utils.logger LLM output: {\n", + "\u001b[33m2024-04-29 01:53:57 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", + "\u001b[32m2024-04-29 01:53:57 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n", + "\u001b[32m2024-04-29 01:53:58 INFO semantic_router.utils.logger LLM output: {\n", "\t\"timezone\": \"America/New_York\"\n", "}\u001b[0m\n", - "\u001b[32m2024-04-29 01:50:55 INFO semantic_router.utils.logger Function inputs: {'timezone': 'America/New_York'}\u001b[0m\n" + "\u001b[32m2024-04-29 01:53:58 INFO semantic_router.utils.logger Function inputs: {'timezone': 'America/New_York'}\u001b[0m\n" ] }, { "data": { "text/plain": [ - "'17:50'" + "'17:53'" ] }, "execution_count": 12, @@ -470,7 +470,7 @@ { "data": { "text/plain": [ - "OpenAILLM(name='gpt-3.5-turbo', client=<openai.OpenAI object at 0x00000129AFF58190>, temperature=0.01, max_tokens=200)" + "OpenAILLM(name='gpt-3.5-turbo', client=<openai.OpenAI object at 0x0000024DD8AEAF10>, temperature=0.01, max_tokens=200)" ] }, "execution_count": 13, diff --git a/docs/10-dynamic-routes-via-openai-function-calling.ipynb b/docs/10-dynamic-routes-via-openai-function-calling.ipynb index 79a68dc0..17f17637 100644 --- a/docs/10-dynamic-routes-via-openai-function-calling.ipynb +++ b/docs/10-dynamic-routes-via-openai-function-calling.ipynb @@ -40,11 +40,27 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": { "id": "dLElfRhgur0v" }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "[notice] A new release of pip is available: 23.1.2 -> 24.0\n", + "[notice] To update, run: python.exe -m pip install --upgrade pip\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: tzdata in c:\\users\\siraj\\documents\\personal\\work\\aurelio\\virtual environments\\semantic_router_2\\lib\\site-packages (2024.1)\n" + ] + }, { "name": "stderr", "output_type": "stream", @@ -56,7 +72,8 @@ } ], "source": [ - "!pip install -qU semantic-router==0.0.34" + "!pip install -qU semantic-router==0.0.34\n", + "!pip install tzdata" ] }, { @@ -79,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "id": "kc9Ty6Lgur0x" }, @@ -131,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -144,7 +161,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-29 01:48:49 INFO semantic_router.utils.logger local\u001b[0m\n" + "\u001b[32m2024-04-29 01:53:43 INFO semantic_router.utils.logger local\u001b[0m\n" ] } ], @@ -180,7 +197,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -195,7 +212,7 @@ "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -224,7 +241,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "id": "5jaF1Xa5ur0y" }, @@ -250,7 +267,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -263,10 +280,10 @@ { "data": { "text/plain": [ - "'17:48'" + "'17:53'" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -286,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -304,7 +321,7 @@ " 'output': \"<class 'str'>\"}" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -318,7 +335,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -333,7 +350,7 @@ " 'required': ['timezone']}}}" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -356,7 +373,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": { "id": "iesBG9P3ur0z" }, @@ -385,7 +402,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -398,7 +415,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-29 01:48:50 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" + "\u001b[32m2024-04-29 01:53:44 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" ] } ], @@ -417,7 +434,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -431,7 +448,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-04-29 01:48:52 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" + "\u001b[33m2024-04-29 01:53:45 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n" ] }, { @@ -440,7 +457,7 @@ "RouteChoice(name='get_time', function_call={'timezone': 'America/New_York'}, similarity_score=None)" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -452,16 +469,16 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'17:48'" + "'17:53'" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -488,7 +505,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -497,7 +514,7 @@ "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -516,7 +533,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -525,7 +542,7 @@ "RouteChoice(name=None, function_call=None, similarity_score=None)" ] }, - "execution_count": 15, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 0989df3e..36219c5c 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -8,7 +8,7 @@ from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger import json -from openai.types.chat import ChatCompletionMessageToolCall + class OpenAILLM(BaseLLM): client: Optional[openai.OpenAI] @@ -37,7 +37,11 @@ class OpenAILLM(BaseLLM): self.temperature = temperature self.max_tokens = max_tokens - def __call__(self, messages: List[Message], openai_function_schema: Optional[dict[str, Any]] = None) -> str: + def __call__( + self, + messages: List[Message], + openai_function_schema: Optional[dict[str, Any]] = None, + ) -> str: if self.client is None: raise ValueError("OpenAI client is not initialized.") try: @@ -50,7 +54,7 @@ class OpenAILLM(BaseLLM): messages=[m.to_openai() for m in messages], temperature=self.temperature, max_tokens=self.max_tokens, - tools=tools, # type: ignore # MyPy expecting Iterable[ChatCompletionToolParam] | NotGiven, but dict is accepted by OpenAI. + tools=tools, # type: ignore # MyPy expecting Iterable[ChatCompletionToolParam] | NotGiven, but dict is accepted by OpenAI. ) output = completion.choices[0].message.content @@ -64,8 +68,9 @@ class OpenAILLM(BaseLLM): logger.error(f"LLM error: {e}") raise Exception(f"LLM error: {e}") from e - - def extract_function_inputs_openai(self, query: str, openai_function_schema: dict[str, Any]) -> dict: + def extract_function_inputs_openai( + self, query: str, openai_function_schema: dict[str, Any] + ) -> dict: messages = [] system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages.append(Message(role="system", content=system_prompt)) diff --git a/semantic_router/route.py b/semantic_router/route.py index 080afd61..5d90ff6e 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -67,25 +67,25 @@ class Route(BaseModel): "LLM is required for dynamic routes. Please ensure the `llm` " "attribute is set." ) - if query is None or type(query) != str: + if query is None or not isinstance(query, str): raise ValueError( "Query is required for dynamic routes. Please ensure the `query` " "argument is passed." ) if self.function_schema: - extracted_inputs = self.llm.extract_function_inputs( - query=query, - function_schema=self.function_schema + extracted_inputs = self.llm.extract_function_inputs( + query=query, function_schema=self.function_schema ) func_call = extracted_inputs - elif self.openai_function_schema: # Logically must be self.openai_function_schema, but keeps MyPy happy. + elif ( + self.openai_function_schema + ): # Logically must be self.openai_function_schema, but keeps MyPy happy. if not isinstance(self.llm, OpenAILLM): raise TypeError( "LLM must be an instance of OpenAILLM for openai_function_schema." ) extracted_inputs = self.llm.extract_function_inputs_openai( - query=query, - openai_function_schema=self.openai_function_schema + query=query, openai_function_schema=self.openai_function_schema ) func_call = extracted_inputs else: -- GitLab