Skip to content
Snippets Groups Projects
Unverified Commit f506995e authored by digriffiths's avatar digriffiths Committed by GitHub
Browse files

Merge branch 'main' into add_tfidf

parents 4394759e 9a2a4b8c
Branches
Tags
No related merge requests found
<?xml version="1.0" ?>
<coverage version="7.3.2" timestamp="1702633916069" lines-valid="344" lines-covered="344" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.2 -->
<coverage version="7.3.3" timestamp="1702894511196" lines-valid="345" lines-covered="345" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.3 -->
<!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd -->
<sources>
<source>/Users/jakit/customers/aurelio/semantic-router/semantic_router</source>
......@@ -349,12 +349,12 @@
<line number="2" hits="1"/>
<line number="4" hits="1"/>
<line number="5" hits="1"/>
<line number="7" hits="1"/>
<line number="6" hits="1"/>
<line number="8" hits="1"/>
<line number="11" hits="1"/>
<line number="9" hits="1"/>
<line number="12" hits="1"/>
<line number="14" hits="1"/>
<line number="19" hits="1"/>
<line number="13" hits="1"/>
<line number="15" hits="1"/>
<line number="20" hits="1"/>
<line number="21" hits="1"/>
<line number="22" hits="1"/>
......@@ -362,12 +362,12 @@
<line number="24" hits="1"/>
<line number="25" hits="1"/>
<line number="26" hits="1"/>
<line number="28" hits="1"/>
<line number="27" hits="1"/>
<line number="29" hits="1"/>
<line number="30" hits="1"/>
<line number="31" hits="1"/>
<line number="32" hits="1"/>
<line number="35" hits="1"/>
<line number="33" hits="1"/>
<line number="36" hits="1"/>
<line number="37" hits="1"/>
<line number="38" hits="1"/>
......@@ -380,10 +380,11 @@
<line number="45" hits="1"/>
<line number="46" hits="1"/>
<line number="47" hits="1"/>
<line number="49" hits="1"/>
<line number="48" hits="1"/>
<line number="50" hits="1"/>
<line number="52" hits="1"/>
<line number="53" hits="1"/>
<line number="55" hits="1"/>
<line number="57" hits="1"/>
<line number="58" hits="1"/>
</lines>
</class>
</classes>
......
File added
File added
This diff is collapsed.
File added
This diff is collapsed.
[tool.poetry]
name = "semantic-router"
version = "0.0.10"
version = "0.0.11"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <james@aurelio.ai>",
......@@ -12,7 +12,7 @@ authors = [
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
python = "^3.9"
pydantic = "^1.8.2"
openai = "^1.3.9"
cohere = "^4.32"
......
......@@ -3,6 +3,7 @@ from time import sleep
import openai
from openai import OpenAIError
from openai.types import CreateEmbeddingResponse
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.logger import logger
......@@ -36,7 +37,7 @@ class OpenAIEncoder(BaseEncoder):
try:
logger.info(f"Encoding {len(docs)} documents...")
embeds = self.client.embeddings.create(input=docs, model=self.name)
if "data" in embeds:
if embeds.data:
break
except OpenAIError as e:
sleep(2**j)
......@@ -46,8 +47,12 @@ class OpenAIEncoder(BaseEncoder):
logger.error(f"OpenAI API call failed. Error: {error_message}")
raise ValueError(f"OpenAI API call failed. Error: {e}")
if not embeds or not isinstance(embeds, dict) or "data" not in embeds:
if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
raise ValueError(f"No embeddings returned. Error: {error_message}")
embeddings = [r["embedding"] for r in embeds["data"]]
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
import pytest
from openai import OpenAIError
from openai.types import CreateEmbeddingResponse, Embedding
from openai.types.create_embedding_response import Usage
from semantic_router.encoders import OpenAIEncoder
......@@ -40,11 +42,26 @@ class TestOpenAIEncoder:
)
def test_openai_encoder_call_success(self, openai_encoder, mocker):
mock_embeddings = mocker.Mock()
mock_embeddings.data = [
Embedding(embedding=[0.1, 0.2], index=0, object="embedding")
]
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mock_embedding = Embedding(index=0, object="embedding", embedding=[0.1, 0.2])
# Mock the CreateEmbeddingResponse object
mock_response = CreateEmbeddingResponse(
model="text-embedding-ada-002",
object="list",
usage=Usage(prompt_tokens=0, total_tokens=20),
data=[mock_embedding],
)
responses = [OpenAIError("OpenAI error"), mock_response]
mocker.patch.object(
openai_encoder.client.embeddings,
"create",
return_value={"data": [{"embedding": [0.1, 0.2]}]},
openai_encoder.client.embeddings, "create", side_effect=responses
)
embeddings = openai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
......@@ -59,7 +76,7 @@ class TestOpenAIEncoder:
)
with pytest.raises(ValueError) as e:
openai_encoder(["test document"])
assert "No embeddings returned. Error: Test error" in str(e.value)
assert "No embeddings returned. Error" in str(e.value)
def test_openai_encoder_call_failure_non_openai_error(self, openai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
......@@ -75,9 +92,24 @@ class TestOpenAIEncoder:
assert "OpenAI API call failed. Error: Non-OpenAIError" in str(e.value)
def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):
mock_embeddings = mocker.Mock()
mock_embeddings.data = [
Embedding(embedding=[0.1, 0.2], index=0, object="embedding")
]
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
responses = [OpenAIError("Test error"), {"data": [{"embedding": [0.1, 0.2]}]}]
mock_embedding = Embedding(index=0, object="embedding", embedding=[0.1, 0.2])
# Mock the CreateEmbeddingResponse object
mock_response = CreateEmbeddingResponse(
model="text-embedding-ada-002",
object="list",
usage=Usage(prompt_tokens=0, total_tokens=20),
data=[mock_embedding],
)
responses = [OpenAIError("OpenAI error"), mock_response]
mocker.patch.object(
openai_encoder.client.embeddings, "create", side_effect=responses
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment