From d03a149d55718eadbbecb286628b0f066cba488d Mon Sep 17 00:00:00 2001
From: Luca Mannini <dev@lucamannini.com>
Date: Mon, 18 Dec 2023 10:45:58 +0100
Subject: [PATCH] Another fix on embeddings

---
 docs/examples/function_calling.ipynb |  4 +++-
 poetry.lock                          |  6 +++---
 semantic_router/encoders/openai.py   | 11 ++++++++---
 tests/unit/encoders/test_openai.py   | 29 ++++++++++++++++++++++++----
 4 files changed, 39 insertions(+), 11 deletions(-)

diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb
index a418d05f..d669aab6 100644
--- a/docs/examples/function_calling.ipynb
+++ b/docs/examples/function_calling.ipynb
@@ -407,8 +407,10 @@
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
     "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n",
     "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n",
diff --git a/poetry.lock b/poetry.lock
index f5d58647..3479a366 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1283,13 +1283,13 @@ files = [
 
 [[package]]
 name = "openai"
-version = "1.3.9"
+version = "1.5.0"
 description = "The official Python library for the openai API"
 optional = false
 python-versions = ">=3.7.1"
 files = [
-    {file = "openai-1.3.9-py3-none-any.whl", hash = "sha256:d30faeffe5995a2cf6b790c0260a5b59647e81c3a1f3b62f51b5e0a0e52681c9"},
-    {file = "openai-1.3.9.tar.gz", hash = "sha256:6f638d96bc89b4394be1d7b37d312f70a055df1a471c92d4c4b2ae3a70c98cb3"},
+    {file = "openai-1.5.0-py3-none-any.whl", hash = "sha256:42d8c84b0714c990e18afe81d37f8a64423e8196bf7157b8ea665b8d8f393253"},
+    {file = "openai-1.5.0.tar.gz", hash = "sha256:4cd91e97988ccd6c44f815107def9495cbc718aeb8b28be33a87b6fa2c432508"},
 ]
 
 [package.dependencies]
diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py
index d02787a3..11251451 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/encoders/openai.py
@@ -3,6 +3,7 @@ from time import sleep
 
 import openai
 from openai import OpenAIError
+from openai.types import CreateEmbeddingResponse
 
 from semantic_router.encoders import BaseEncoder
 from semantic_router.utils.logger import logger
@@ -36,7 +37,7 @@ class OpenAIEncoder(BaseEncoder):
             try:
                 logger.info(f"Encoding {len(docs)} documents...")
                 embeds = self.client.embeddings.create(input=docs, model=self.name)
-                if "data" in embeds:
+                if embeds.data:
                     break
             except OpenAIError as e:
                 sleep(2**j)
@@ -46,8 +47,12 @@ class OpenAIEncoder(BaseEncoder):
                 logger.error(f"OpenAI API call failed. Error: {error_message}")
                 raise ValueError(f"OpenAI API call failed. Error: {e}")
 
-        if not embeds or not isinstance(embeds, dict) or "data" not in embeds:
+        if (
+            not embeds
+            or not isinstance(embeds, CreateEmbeddingResponse)
+            or not embeds.data
+        ):
             raise ValueError(f"No embeddings returned. Error: {error_message}")
 
-        embeddings = [r["embedding"] for r in embeds["data"]]
+        embeddings = [r.embedding for r in embeds.data]
         return embeddings
diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py
index 501a9b04..9c2d6058 100644
--- a/tests/unit/encoders/test_openai.py
+++ b/tests/unit/encoders/test_openai.py
@@ -1,5 +1,6 @@
 import pytest
 from openai import OpenAIError
+from openai.types import CreateEmbeddingResponse, Embedding
 
 from semantic_router.encoders import OpenAIEncoder
 
@@ -41,10 +42,20 @@ class TestOpenAIEncoder:
 
     def test_openai_encoder_call_success(self, openai_encoder, mocker):
         mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch("time.sleep", return_value=None)  # To speed up the test
+
+        mock_embedding = Embedding(index=0, object="embedding", embedding=[0.1, 0.2])
+        # Mock the CreateEmbeddingResponse object
+        mock_response = CreateEmbeddingResponse(
+            model="text-embedding-ada-002",
+            object="list",
+            usage={"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 20},
+            data=[mock_embedding],
+        )
+
+        responses = [OpenAIError("OpenAI error"), mock_response]
         mocker.patch.object(
-            openai_encoder.client.embeddings,
-            "create",
-            return_value={"data": [{"embedding": [0.1, 0.2]}]},
+            openai_encoder.client.embeddings, "create", side_effect=responses
         )
         embeddings = openai_encoder(["test document"])
         assert embeddings == [[0.1, 0.2]]
@@ -77,7 +88,17 @@ class TestOpenAIEncoder:
     def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):
         mocker.patch("os.getenv", return_value="fake-api-key")
         mocker.patch("time.sleep", return_value=None)  # To speed up the test
-        responses = [OpenAIError("Test error"), {"data": [{"embedding": [0.1, 0.2]}]}]
+
+        mock_embedding = Embedding(index=0, object="embedding", embedding=[0.1, 0.2])
+        # Mock the CreateEmbeddingResponse object
+        mock_response = CreateEmbeddingResponse(
+            model="text-embedding-ada-002",
+            object="list",
+            usage={"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 20},
+            data=[mock_embedding],
+        )
+
+        responses = [OpenAIError("OpenAI error"), mock_response]
         mocker.patch.object(
             openai_encoder.client.embeddings, "create", side_effect=responses
         )
-- 
GitLab