From 16fc325a370b90074a18ef54da85133a6c076391 Mon Sep 17 00:00:00 2001
From: Simonas <20096648+simjak@users.noreply.github.com>
Date: Mon, 18 Dec 2023 12:02:20 +0200
Subject: [PATCH] test fix

---
 coverage.xml                         |  5 +-
 docs/examples/function_calling.ipynb | 83 ++++++++++++++++------------
 semantic_router/encoders/openai.py   |  5 +-
 tests/unit/encoders/test_openai.py   | 17 +++++-
 4 files changed, 68 insertions(+), 42 deletions(-)

diff --git a/coverage.xml b/coverage.xml
index 8971e377..3b0a9de9 100644
--- a/coverage.xml
+++ b/coverage.xml
@@ -1,5 +1,5 @@
 <?xml version="1.0" ?>
-<coverage version="7.3.3" timestamp="1702890603439" lines-valid="344" lines-covered="344" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
+<coverage version="7.3.3" timestamp="1702893702032" lines-valid="345" lines-covered="345" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
 	<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.3 -->
 	<!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd -->
 	<sources>
@@ -382,8 +382,9 @@
 						<line number="47" hits="1"/>
 						<line number="49" hits="1"/>
 						<line number="50" hits="1"/>
-						<line number="52" hits="1"/>
+						<line number="51" hits="1"/>
 						<line number="53" hits="1"/>
+						<line number="54" hits="1"/>
 					</lines>
 				</class>
 			</classes>
diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb
index 2b75af65..a6546eec 100644
--- a/docs/examples/function_calling.ipynb
+++ b/docs/examples/function_calling.ipynb
@@ -9,9 +9,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    }
+   ],
    "source": [
     "# OpenAI\n",
     "import openai\n",
@@ -39,7 +48,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -91,7 +100,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -113,7 +122,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -130,7 +139,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -212,7 +221,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -228,7 +237,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -308,19 +317,19 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
     "from semantic_router.schema import Route\n",
-    "from semantic_router.encoders import CohereEncoder\n",
+    "from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n",
     "from semantic_router.layer import RouteLayer\n",
     "from semantic_router.utils.logger import logger\n",
     "\n",
     "\n",
     "def create_router(routes: list[dict]) -> RouteLayer:\n",
     "    logger.info(\"Creating route layer...\")\n",
-    "    encoder = CohereEncoder()\n",
+    "    encoder = OpenAIEncoder()\n",
     "\n",
     "    route_list: list[Route] = []\n",
     "    for route in routes:\n",
@@ -342,7 +351,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -383,16 +392,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "\u001b[32m2023-12-18 11:00:14 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:14 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:17 INFO semantic_router.utils.logger AI message: \n",
+      "\u001b[32m2023-12-18 11:46:34 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:34 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger AI message: \n",
       "    Example output:\n",
       "    {\n",
       "        \"name\": \"get_time\",\n",
@@ -404,10 +413,10 @@
       "            \"Can you tell me the time in Berlin?\"\n",
       "        ]\n",
       "    }\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:17 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:17 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:17 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:21 INFO semantic_router.utils.logger AI message: \n",
+      "\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger AI message: \n",
       "    Example output:\n",
       "    {\n",
       "        \"name\": \"get_news\",\n",
@@ -419,8 +428,9 @@
       "            \"What's the latest news from Germany?\"\n",
       "        ]\n",
       "    }\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:21 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:21 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n"
+      "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Encoding 10 documents...\u001b[0m\n"
      ]
     },
     {
@@ -460,20 +470,22 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "\u001b[32m2023-12-18 11:00:22 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:22 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:23 INFO semantic_router.utils.logger AI message: \n",
+      "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Encoding 1 documents...\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger AI message: \n",
       "    {\n",
       "        \"location\": \"Stockholm\"\n",
       "    }\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:23 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n"
+      "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Encoding 1 documents...\u001b[0m\n"
      ]
     },
     {
@@ -488,14 +500,15 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "\u001b[32m2023-12-18 11:00:23 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:23 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:25 INFO semantic_router.utils.logger AI message: \n",
+      "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:45 INFO semantic_router.utils.logger AI message: \n",
       "    {\n",
       "        \"category\": \"tech\",\n",
       "        \"country\": \"Lithuania\"\n",
       "    }\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:25 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n"
+      "\u001b[32m2023-12-18 11:46:45 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:45 INFO semantic_router.utils.logger Encoding 1 documents...\u001b[0m\n"
      ]
     },
     {
@@ -510,9 +523,9 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "\u001b[33m2023-12-18 11:00:25 WARNING semantic_router.utils.logger No function found\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:25 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
-      "\u001b[32m2023-12-18 11:00:26 INFO semantic_router.utils.logger AI message:  How can I help you today?\u001b[0m\n"
+      "\u001b[33m2023-12-18 11:46:46 WARNING semantic_router.utils.logger No function found\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:46 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
+      "\u001b[32m2023-12-18 11:46:46 INFO semantic_router.utils.logger AI message:  How can I help you today?\u001b[0m\n"
      ]
     },
     {
@@ -521,7 +534,7 @@
        "' How can I help you today?'"
       ]
      },
-     "execution_count": 20,
+     "execution_count": 11,
      "metadata": {},
      "output_type": "execute_result"
     }
diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py
index 56f148d7..65acbf7d 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/encoders/openai.py
@@ -46,8 +46,9 @@ class OpenAIEncoder(BaseEncoder):
                 logger.error(f"OpenAI API call failed. Error: {error_message}")
                 raise ValueError(f"OpenAI API call failed. Error: {e}")
 
-        if not embeds or not isinstance(embeds, dict) or "data" not in embeds:
+        if embeds is None or embeds.data is None:
+            logger.error(f"No embeddings returned. Error: {error_message}")
             raise ValueError(f"No embeddings returned. Error: {error_message}")
 
-        embeddings = [r["embedding"] for r in embeds["data"]]
+        embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
         return embeddings
diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py
index 501a9b04..98ef2c5e 100644
--- a/tests/unit/encoders/test_openai.py
+++ b/tests/unit/encoders/test_openai.py
@@ -1,5 +1,6 @@
 import pytest
 from openai import OpenAIError
+from openai.types.embedding import Embedding
 
 from semantic_router.encoders import OpenAIEncoder
 
@@ -40,11 +41,16 @@ class TestOpenAIEncoder:
         )
 
     def test_openai_encoder_call_success(self, openai_encoder, mocker):
+        mock_embeddings = mocker.Mock()
+        mock_embeddings.data = [
+            Embedding(embedding=[0.1, 0.2], index=0, object="embedding")
+        ]
+
         mocker.patch("os.getenv", return_value="fake-api-key")
         mocker.patch.object(
             openai_encoder.client.embeddings,
             "create",
-            return_value={"data": [{"embedding": [0.1, 0.2]}]},
+            return_value=mock_embeddings,
         )
         embeddings = openai_encoder(["test document"])
         assert embeddings == [[0.1, 0.2]]
@@ -59,7 +65,7 @@ class TestOpenAIEncoder:
         )
         with pytest.raises(ValueError) as e:
             openai_encoder(["test document"])
-        assert "No embeddings returned. Error: Test error" in str(e.value)
+        assert "No embeddings returned. Error" in str(e.value)
 
     def test_openai_encoder_call_failure_non_openai_error(self, openai_encoder, mocker):
         mocker.patch("os.getenv", return_value="fake-api-key")
@@ -75,9 +81,14 @@ class TestOpenAIEncoder:
         assert "OpenAI API call failed. Error: Non-OpenAIError" in str(e.value)
 
     def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):
+        mock_embeddings = mocker.Mock()
+        mock_embeddings.data = [
+            Embedding(embedding=[0.1, 0.2], index=0, object="embedding")
+        ]
+
         mocker.patch("os.getenv", return_value="fake-api-key")
         mocker.patch("time.sleep", return_value=None)  # To speed up the test
-        responses = [OpenAIError("Test error"), {"data": [{"embedding": [0.1, 0.2]}]}]
+        responses = [OpenAIError("Test error"), mock_embeddings]
         mocker.patch.object(
             openai_encoder.client.embeddings, "create", side_effect=responses
         )
-- 
GitLab