From 6ed79ba2540eb42b6521e3854aa4be63fe28ff47 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Mon, 29 Apr 2024 14:36:35 +0400
Subject: [PATCH] Linting (MyPy fixes) and testing.

---
 docs/02-dynamic-routes.ipynb                  | 26 ++++++++--------
 ...c-routes-via-openai-function-calling.ipynb | 10 +++----
 semantic_router/llms/openai.py                | 30 +++++++++++--------
 3 files changed, 35 insertions(+), 31 deletions(-)

diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb
index 947b0200..41b682fa 100644
--- a/docs/02-dynamic-routes.ipynb
+++ b/docs/02-dynamic-routes.ipynb
@@ -88,7 +88,7 @@
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "Requirement already satisfied: tzdata in c:\\users\\siraj\\documents\\personal\\work\\aurelio\\virtual environments\\semantic_router\\lib\\site-packages (2024.1)\n"
+            "Requirement already satisfied: tzdata in c:\\users\\siraj\\documents\\personal\\work\\aurelio\\virtual environments\\semantic_router_2\\lib\\site-packages (2024.1)\n"
           ]
         },
         {
@@ -135,7 +135,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+            "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_2\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
             "  from .autonotebook import tqdm as notebook_tqdm\n"
           ]
         }
@@ -191,7 +191,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[32m2024-04-29 01:53:55 INFO semantic_router.utils.logger local\u001b[0m\n"
+            "\u001b[32m2024-04-29 14:34:53 INFO semantic_router.utils.logger local\u001b[0m\n"
           ]
         }
       ],
@@ -310,7 +310,7 @@
         {
           "data": {
             "text/plain": [
-              "'17:53'"
+              "'06:34'"
             ]
           },
           "execution_count": 6,
@@ -424,7 +424,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[32m2024-04-29 01:53:56 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n"
+            "\u001b[32m2024-04-29 14:34:54 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n"
           ]
         }
       ],
@@ -466,18 +466,18 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[33m2024-04-29 01:53:57 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n",
-            "\u001b[32m2024-04-29 01:53:57 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n",
-            "\u001b[32m2024-04-29 01:53:58 INFO semantic_router.utils.logger LLM output: {\n",
+            "\u001b[33m2024-04-29 14:34:55 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n",
+            "\u001b[32m2024-04-29 14:34:55 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n",
+            "\u001b[32m2024-04-29 14:34:56 INFO semantic_router.utils.logger LLM output: {\n",
             "\t\"timezone\": \"America/New_York\"\n",
             "}\u001b[0m\n",
-            "\u001b[32m2024-04-29 01:53:58 INFO semantic_router.utils.logger Function inputs: {'timezone': 'America/New_York'}\u001b[0m\n"
+            "\u001b[32m2024-04-29 14:34:56 INFO semantic_router.utils.logger Function inputs: {'timezone': 'America/New_York'}\u001b[0m\n"
           ]
         },
         {
           "data": {
             "text/plain": [
-              "'17:53'"
+              "RouteChoice(name='get_time', function_call={'timezone': 'America/New_York'}, similarity_score=None)"
             ]
           },
           "execution_count": 12,
@@ -487,7 +487,7 @@
       ],
       "source": [
         "out = rl(\"what is the time in new york city?\")\n",
-        "get_time(**out.function_call)"
+        "out\n"
       ]
     },
     {
@@ -498,7 +498,7 @@
         {
           "data": {
             "text/plain": [
-              "OpenAILLM(name='gpt-3.5-turbo', client=<openai.OpenAI object at 0x0000024DD8AEAF10>, temperature=0.01, max_tokens=200)"
+              "'06:34'"
             ]
           },
           "execution_count": 13,
@@ -507,7 +507,7 @@
         }
       ],
       "source": [
-        "time_route.llm"
+        "get_time(**out.function_call)"
       ]
     },
     {
diff --git a/docs/10-dynamic-routes-via-openai-function-calling.ipynb b/docs/10-dynamic-routes-via-openai-function-calling.ipynb
index 17f17637..3a33fc6e 100644
--- a/docs/10-dynamic-routes-via-openai-function-calling.ipynb
+++ b/docs/10-dynamic-routes-via-openai-function-calling.ipynb
@@ -161,7 +161,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[32m2024-04-29 01:53:43 INFO semantic_router.utils.logger local\u001b[0m\n"
+            "\u001b[32m2024-04-29 14:35:09 INFO semantic_router.utils.logger local\u001b[0m\n"
           ]
         }
       ],
@@ -280,7 +280,7 @@
         {
           "data": {
             "text/plain": [
-              "'17:53'"
+              "'06:35'"
             ]
           },
           "execution_count": 6,
@@ -415,7 +415,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[32m2024-04-29 01:53:44 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n"
+            "\u001b[32m2024-04-29 14:35:10 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n"
           ]
         }
       ],
@@ -448,7 +448,7 @@
           "name": "stderr",
           "output_type": "stream",
           "text": [
-            "\u001b[33m2024-04-29 01:53:45 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n"
+            "\u001b[33m2024-04-29 14:35:11 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n"
           ]
         },
         {
@@ -475,7 +475,7 @@
         {
           "data": {
             "text/plain": [
-              "'17:53'"
+              "'06:35'"
             ]
           },
           "execution_count": 12,
diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py
index 36219c5c..4e88030b 100644
--- a/semantic_router/llms/openai.py
+++ b/semantic_router/llms/openai.py
@@ -49,6 +49,7 @@ class OpenAILLM(BaseLLM):
                 tools = [openai_function_schema]
             else:
                 tools = None
+
             completion = self.client.chat.completions.create(
                 model=self.name,
                 messages=[m.to_openai() for m in messages],
@@ -57,12 +58,21 @@ class OpenAILLM(BaseLLM):
                 tools=tools,  # type: ignore # MyPy expecting Iterable[ChatCompletionToolParam] | NotGiven, but dict is accepted by OpenAI.
             )
 
-            output = completion.choices[0].message.content
-
             if openai_function_schema:
-                return completion.choices[0].message.tool_calls
-            if not output:
-                raise Exception("No output generated")
+                tool_calls = completion.choices[0].message.tool_calls
+                if tool_calls is None:
+                    raise ValueError("Invalid output, expected a tool call.")
+                if len(tool_calls) != 1:
+                    raise ValueError("Invalid output, expected a single tool to be specified.")
+                arguments = tool_calls[0].function.arguments
+                if arguments is None:
+                    raise ValueError("Invalid output, expected arguments to be specified.")
+                output = str(arguments) # str to keep MyPy happy.
+            else:
+                content = completion.choices[0].message.content
+                if content is None:
+                    raise ValueError("Invalid output, expected content.")
+                output = str(content) # str to keep MyPy happy.
             return output
         except Exception as e:
             logger.error(f"LLM error: {e}")
@@ -75,12 +85,6 @@ class OpenAILLM(BaseLLM):
         system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request."
         messages.append(Message(role="system", content=system_prompt))
         messages.append(Message(role="user", content=query))
-        output = self(messages=messages, openai_function_schema=openai_function_schema)
-        if not output:
-            raise Exception("No output generated for extract function input")
-        if len(output) != 1:
-            raise ValueError("Invalid output, expected a single tool to be called")
-        tool_call = output[0]
-        arguments_json = tool_call.function.arguments
-        function_inputs = json.loads(arguments_json)
+        function_inputs_str = self(messages=messages, openai_function_schema=openai_function_schema)
+        function_inputs = json.loads(function_inputs_str)
         return function_inputs
-- 
GitLab