From 3ab633d09e763248d1cf299d6f7099b68e835c5e Mon Sep 17 00:00:00 2001
From: Luca Mannini <dev@lucamannini.com>
Date: Fri, 15 Dec 2023 18:47:24 +0100
Subject: [PATCH] Fix for embeddings

---
 .gitignore                           |   1 +
 docs/examples/function_calling.ipynb | 140 +++++++++++++--------------
 pyproject.toml                       |   2 +-
 semantic_router/encoders/openai.py   |   4 +-
 tests/unit/encoders/test_openai.py   |   1 +
 5 files changed, 72 insertions(+), 76 deletions(-)

diff --git a/.gitignore b/.gitignore
index 807674fa..219091bd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,3 +17,4 @@ mac.env
 .coverage
 .coverage.*
 .pytest_cache
+test.py
\ No newline at end of file
diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb
index 650b5b94..c12eed73 100644
--- a/docs/examples/function_calling.ipynb
+++ b/docs/examples/function_calling.ipynb
@@ -9,15 +9,25 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 213,
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%reload_ext dotenv\n",
+    "%dotenv"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
     "# OpenAI\n",
+    "import os\n",
     "import openai\n",
     "from semantic_router.utils.logger import logger\n",
     "\n",
-    "\n",
     "# Docs # https://platform.openai.com/docs/guides/function-calling\n",
     "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n",
     "    try:\n",
@@ -39,7 +49,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 214,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -48,7 +58,7 @@
     "import requests\n",
     "\n",
     "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n",
-    "HF_API_TOKEN = os.environ[\"HF_API_TOKEN\"]\n",
+    "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n",
     "\n",
     "\n",
     "def llm_mistral(prompt: str) -> str:\n",
@@ -180,7 +190,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 217,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -242,6 +252,23 @@
     "Set up the routing layer"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from semantic_router.schema import Route\n",
+    "from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n",
+    "from semantic_router.layer import RouteLayer\n",
+    "from semantic_router.utils.logger import logger\n",
+    "\n",
+    "\n",
+    "def create_router(routes: list[dict]) -> RouteLayer:\n",
+    "    logger.info(\"Creating route layer...\")\n",
+    "    encoder = OpenAIEncoder"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -256,7 +283,7 @@
     "\n",
     "def create_router(routes: list[dict]) -> RouteLayer:\n",
     "    logger.info(\"Creating route layer...\")\n",
-    "    encoder = CohereEncoder()\n",
+    "    encoder = OpenAIEncoder()\n",
     "\n",
     "    route_list: list[Route] = []\n",
     "    for route in routes:\n",
@@ -278,7 +305,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 219,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -322,6 +349,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "\n",
     "def get_time(location: str) -> str:\n",
     "    \"\"\"Useful to get the time in a specific location\"\"\"\n",
     "    print(f\"Calling `get_time` function with location: {location}\")\n",
@@ -349,72 +377,38 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 220,
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_time(location: str) -> str:\n",
+    "    \"\"\"Useful to get the time in a specific location\"\"\"\n",
+    "    print(f\"Calling `get_time` function with location: {location}\")\n",
+    "    return \"get_time\"\n",
+    "\n",
+    "\n",
+    "def get_news(category: str, country: str) -> str:\n",
+    "    \"\"\"Useful to get the news in a specific country\"\"\"\n",
+    "    print(\n",
+    "        f\"Calling `get_news` function with category: {category} and country: {country}\"\n",
+    "    )\n",
+    "    return \"get_news\"\n",
+    "\n",
+    "\n",
+    "# Registering functions to the router\n",
+    "route_get_time = generate_route(get_time)\n",
+    "route_get_news = generate_route(get_news)\n",
+    "\n",
+    "routes = [route_get_time, route_get_news]\n",
+    "router = create_router(routes)\n",
+    "\n",
+    "# Tools\n",
+    "tools = [get_time, get_news]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
-      "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
-      "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger AI message: \n",
-      "    {\n",
-      "        'location': 'Stockholm'\n",
-      "    }\u001b[0m\n",
-      "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "parameters: {'location': 'Stockholm'}\n",
-      "Calling `get_time` function with location: Stockholm\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
-      "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
-      "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger AI message: \n",
-      "    {\n",
-      "        'category': 'tech',\n",
-      "        'country': 'Lithuania'\n",
-      "    }\u001b[0m\n",
-      "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "parameters: {'category': 'tech', 'country': 'Lithuania'}\n",
-      "Calling `get_news` function with category: tech and country: Lithuania\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "\u001b[33m2023-12-15 11:41:57 WARNING semantic_router.utils.logger No function found\u001b[0m\n",
-      "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
-      "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger AI message:  How can I help you today?\u001b[0m\n"
-     ]
-    },
-    {
-     "data": {
-      "text/plain": [
-       "' How can I help you today?'"
-      ]
-     },
-     "execution_count": 220,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
    "source": [
     "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n",
     "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n",
@@ -438,7 +432,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.11.3"
+   "version": "3.11.5"
   }
  },
  "nbformat": 4,
diff --git a/pyproject.toml b/pyproject.toml
index b041516c..030a8f72 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "semantic-router"
-version = "0.0.9"
+version = "0.0.10"
 description = "Super fast semantic router for AI decision making"
 authors = [
     "James Briggs <james@aurelio.ai>",
diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py
index 1375c57a..920fac20 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/encoders/openai.py
@@ -3,7 +3,7 @@ from time import sleep
 
 import openai
 from openai import OpenAIError
-
+from openai.types import CreateEmbeddingResponse
 from semantic_router.encoders import BaseEncoder
 from semantic_router.utils.logger import logger
 
@@ -36,7 +36,7 @@ class OpenAIEncoder(BaseEncoder):
             try:
                 logger.info(f"Encoding {len(docs)} documents...")
                 embeds = self.client.embeddings.create(input=docs, model=self.name)
-                if isinstance(embeds, dict) and "data" in embeds:
+                if "data" in embeds:
                     break
             except OpenAIError as e:
                 sleep(2**j)
diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py
index f67bce1e..a1f458d1 100644
--- a/tests/unit/encoders/test_openai.py
+++ b/tests/unit/encoders/test_openai.py
@@ -71,6 +71,7 @@ class TestOpenAIEncoder:
         )
         with pytest.raises(ValueError) as e:
             openai_encoder(["test document"])
+        
         assert "OpenAI API call failed. Error: Non-OpenAIError" in str(e.value)
 
     def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):
-- 
GitLab