From c261e62c77dcf48dd6b016084e8608df09b512bc Mon Sep 17 00:00:00 2001
From: Arash Mosharraf <armoshar@microsoft.com>
Date: Tue, 16 Jan 2024 15:48:23 -0600
Subject: [PATCH] AzureOpenAILLM works

---
 .gitignore                   |  2 +-
 docs/test.ipynb              | 51 +++++++++++++++++-------------------
 semantic_router/llms/base.py |  6 ++---
 semantic_router/llms/zure.py |  3 ++-
 4 files changed, 30 insertions(+), 32 deletions(-)

diff --git a/.gitignore b/.gitignore
index cc461499..c45ff835 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,5 +25,5 @@ output
 node_modules
 package-lock.json
 package.json
-
+test.ipynb
 ```
diff --git a/docs/test.ipynb b/docs/test.ipynb
index d6638f0b..2ab38fdd 100644
--- a/docs/test.ipynb
+++ b/docs/test.ipynb
@@ -68,11 +68,11 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
-    "def get_direction(start)->str:\n",
+    "def get_direction(start : str)->str:\n",
     "    \"\"\"just produce a direction from the starting point to the library\n",
     "    :param start: the starting address\n",
     "    :type start: str\n",
@@ -101,7 +101,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -114,30 +114,30 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "{'name': 'do_irrelevant',\n",
-       " 'description': 'Handle the irrelevant questions \\n\\nreturn: the text',\n",
-       " 'signature': '() -> str',\n",
+       "{'name': 'get_direction',\n",
+       " 'description': 'just produce a direction from the starting point to the library\\n:param start: the starting address\\n:type start: str\\n\\n\\n:return: the direction',\n",
+       " 'signature': '(start) -> str',\n",
        " 'output': \"<class 'str'>\"}"
       ]
      },
-     "execution_count": 11,
+     "execution_count": 9,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "irrelevant_schema"
+    "direction_schema"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -192,14 +192,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "\u001b[32m2024-01-16 15:05:08 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n"
+      "\u001b[32m2024-01-16 15:24:32 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n"
      ]
     }
    ],
@@ -214,26 +214,23 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "\u001b[32m2024-01-16 15:05:09 INFO semantic_router.utils.logger LLM  `name='gpt-35-turbo' client=<openai.lib.azure.AzureOpenAI object at 0x000001B4C401AE30> temperature=0.01 max_tokens=200` is chosen\u001b[0m\n",
-      "\u001b[32m2024-01-16 15:05:09 INFO semantic_router.utils.logger this is the llm passed to route object name='gpt-35-turbo' client=<openai.lib.azure.AzureOpenAI object at 0x000001B4C401AE30> temperature=0.01 max_tokens=200\u001b[0m\n",
-      "\u001b[32m2024-01-16 15:05:09 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n",
-      "\u001b[31m2024-01-16 15:05:09 ERROR semantic_router.utils.logger Input validation error: list index out of range\u001b[0m\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "{\n",
+      "\u001b[32m2024-01-16 15:24:33 INFO semantic_router.utils.logger LLM  `name='gpt-35-turbo' client=<openai.lib.azure.AzureOpenAI object at 0x000002324D651630> temperature=0.01 max_tokens=200` is chosen\u001b[0m\n",
+      "\u001b[32m2024-01-16 15:24:33 INFO semantic_router.utils.logger this is the llm passed to route object name='gpt-35-turbo' client=<openai.lib.azure.AzureOpenAI object at 0x000002324D651630> temperature=0.01 max_tokens=200\u001b[0m\n",
+      "\u001b[32m2024-01-16 15:24:33 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n",
+      "\u001b[32m2024-01-16 15:24:33 INFO semantic_router.utils.logger LLM output: {\n",
       "            \"start\": \"my home\"\n",
-      "        }\n"
+      "        }\u001b[0m\n",
+      "\u001b[32m2024-01-16 15:24:33 INFO semantic_router.utils.logger Function inputs: {'start': 'my home'}\u001b[0m\n",
+      "\u001b[32m2024-01-16 15:24:33 INFO semantic_router.utils.logger param info ['start) -> st']\u001b[0m\n",
+      "\u001b[32m2024-01-16 15:24:33 INFO semantic_router.utils.logger param names ['start) -> st']\u001b[0m\n",
+      "\u001b[31m2024-01-16 15:24:33 ERROR semantic_router.utils.logger Input validation error: list index out of range\u001b[0m\n"
      ]
     },
     {
@@ -243,10 +240,10 @@
      "traceback": [
       "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
       "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
-      "Cell \u001b[1;32mIn[14], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mrl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mhow do I get to the nearest gas station from my home?\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
+      "Cell \u001b[1;32mIn[8], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mrl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mhow do I get to the nearest gas station from my home?\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
       "File \u001b[1;32mc:\\Users\\armoshar\\OneDrive - Microsoft\\Projects\\OpenAI\\semantic-router\\semantic_router\\layer.py:203\u001b[0m, in \u001b[0;36mRouteLayer.__call__\u001b[1;34m(self, text)\u001b[0m\n\u001b[0;32m    201\u001b[0m             route\u001b[38;5;241m.\u001b[39mllm \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mllm\n\u001b[0;32m    202\u001b[0m     logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLLM  `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mroute\u001b[38;5;241m.\u001b[39mllm\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` is chosen\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m--> 203\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mroute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    204\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    205\u001b[0m     \u001b[38;5;66;03m# if no route passes threshold, return empty route choice\u001b[39;00m\n\u001b[0;32m    206\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m RouteChoice()\n",
       "File \u001b[1;32mc:\\Users\\armoshar\\OneDrive - Microsoft\\Projects\\OpenAI\\semantic-router\\semantic_router\\route.py:57\u001b[0m, in \u001b[0;36mRoute.__call__\u001b[1;34m(self, query)\u001b[0m\n\u001b[0;32m     52\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m     53\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLLM is required for dynamic routes. Please ensure the `llm` \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m     54\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattribute is set.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m     55\u001b[0m     )\n\u001b[0;32m     56\u001b[0m \u001b[38;5;66;03m# if a function schema is provided we generate the inputs\u001b[39;00m\n\u001b[1;32m---> 57\u001b[0m extracted_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mllm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mextract_function_inputs\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m     58\u001b[0m \u001b[43m    \u001b[49m\u001b[43mquery\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunction_schema\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunction_schema\u001b[49m\n\u001b[0;32m     59\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     60\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mextracted inputs \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mextracted_inputs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     61\u001b[0m func_call \u001b[38;5;241m=\u001b[39m extracted_inputs\n",
-      "File \u001b[1;32mc:\\Users\\armoshar\\OneDrive - Microsoft\\Projects\\OpenAI\\semantic-router\\semantic_router\\llms\\base.py:87\u001b[0m, in \u001b[0;36mBaseLLM.extract_function_inputs\u001b[1;34m(self, query, function_schema)\u001b[0m\n\u001b[0;32m     85\u001b[0m function_inputs \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(output)\n\u001b[0;32m     86\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_is_valid_inputs(function_inputs, function_schema):\n\u001b[1;32m---> 87\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid inputs\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m function_inputs\n",
+      "File \u001b[1;32mc:\\Users\\armoshar\\OneDrive - Microsoft\\Projects\\OpenAI\\semantic-router\\semantic_router\\llms\\base.py:90\u001b[0m, in \u001b[0;36mBaseLLM.extract_function_inputs\u001b[1;34m(self, query, function_schema)\u001b[0m\n\u001b[0;32m     88\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFunction inputs: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunction_inputs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     89\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_is_valid_inputs(function_inputs, function_schema):\n\u001b[1;32m---> 90\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid inputs\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     91\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m function_inputs\n",
       "\u001b[1;31mValueError\u001b[0m: Invalid inputs"
      ]
     }
diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
index e0b3007b..da5c6054 100644
--- a/semantic_router/llms/base.py
+++ b/semantic_router/llms/base.py
@@ -31,7 +31,6 @@ class BaseLLM(BaseModel):
             param_types = [
                 info.split(":")[1].strip().split("=")[0].strip() for info in param_info
             ]
-
             for name, type_str in zip(param_names, param_types):
                 if name not in inputs:
                     logger.error(f"Input {name} missing from query")
@@ -76,13 +75,14 @@ class BaseLLM(BaseModel):
         """
         llm_input = [Message(role="user", content=prompt)]
         output = self(llm_input)
-        print(output)
+        
         if not output:
             raise Exception("No output generated for extract function input")
 
         output = output.replace("'", '"').strip().rstrip(",")
-
+        logger.info(f"LLM output: {output}")
         function_inputs = json.loads(output)
+        logger.info(f"Function inputs: {function_inputs}")
         if not self._is_valid_inputs(function_inputs, function_schema):
             raise ValueError("Invalid inputs")
         return function_inputs
diff --git a/semantic_router/llms/zure.py b/semantic_router/llms/zure.py
index 2810a167..01b13adb 100644
--- a/semantic_router/llms/zure.py
+++ b/semantic_router/llms/zure.py
@@ -25,9 +25,10 @@ class AzureOpenAILLM(BaseLLM):
         if name is None:
             name = os.getenv("OPENAI_CHAT_MODEL_NAME", "gpt-35-turbo")
         super().__init__(name=name)
-        api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
+        api_key = openai_api_key or os.getenv("AZURE_OPENAI_API_KEY")
         if api_key is None:
             raise ValueError("OpenAI API key cannot be 'None'.")
+        azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
         if azure_endpoint is None:
             raise ValueError("Azure endpoint API key cannot be 'None'.")
         try:
-- 
GitLab