Skip to content
Snippets Groups Projects
Unverified Commit f497e329 authored by James Briggs's avatar James Briggs
Browse files

feat: add token limit to openai encoder

parent 17fd5195
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,7 @@ jobs:
- name: Download nltk data
run: |
python -m nltk.downloader punkt stopwords wordnet
- name: Pytest
- name: Pytest All
env:
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
......
......@@ -13,3 +13,10 @@ lint lint_diff:
test:
poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml
test_functional:
poetry run pytest -vv -n 20 tests/functional
test_unit:
poetry run pytest -vv -n 20 tests/unit
test_integration:
poetry run pytest -vv -n 20 tests/integration
\ No newline at end of file
This diff is collapsed.
import os
from time import sleep
from typing import List, Optional, Union
from typing import Any, List, Optional, Union
import openai
from openai import OpenAIError
from openai._types import NotGiven
from openai.types import CreateEmbeddingResponse
import tiktoken
from semantic_router.encoders import BaseEncoder
from semantic_router.schema import EncoderInfo
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger
model_configs = {
"text-embedding-ada-002": EncoderInfo(
name="text-embedding-ada-002",
type="openai",
token_limit=4000
),
"text-embed-3-small": EncoderInfo(
name="text-embed-3-small",
type="openai",
token_limit=8192
),
"text-embed-3-large": EncoderInfo(
name="text-embed-3-large",
type="openai",
token_limit=8192
)
}
class OpenAIEncoder(BaseEncoder):
client: Optional[openai.Client]
dimensions: Union[int, NotGiven] = NotGiven()
token_limit: Optional[int] = None
token_encoder: Optional[Any] = None
type: str = "openai"
def __init__(
......@@ -44,13 +67,32 @@ class OpenAIEncoder(BaseEncoder):
) from e
# set dimensions to support openai embed 3 dimensions param
self.dimensions = dimensions
# if model name is known, set token limit
if name in model_configs:
self.token_limit = model_configs[name].token_limit
# get token encoder
self.token_encoder = tiktoken.encoding_for_model(name)
def __call__(self, docs: List[str]) -> List[List[float]]:
def __call__(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
"""Encode a list of text documents into embeddings using OpenAI API.
:param docs: List of text documents to encode.
:param truncate: Whether to truncate the documents to token limit. If
False and a document exceeds the token limit, an error will be
raised.
:return: List of embeddings for each document."""
if self.client is None:
raise ValueError("OpenAI client is not initialized.")
embeds = None
error_message = ""
if truncate:
# check if any document exceeds token limit and truncate if so
for i in range(len(docs)):
logger.info(f"Document {i+1} length: {len(docs[i])}")
docs[i] = self._truncate(docs[i])
logger.info(f"Document {i+1} trunc length: {len(docs[i])}")
# Exponential backoff
for j in range(1, 7):
try:
......@@ -74,7 +116,20 @@ class OpenAIEncoder(BaseEncoder):
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned. Error: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
def _truncate(self, text: str) -> str:
tokens = self.token_encoder.encode(text)
if len(tokens) > self.token_limit:
logger.warning(
f"Document exceeds token limit: {len(tokens)} > {self.token_limit}"
"\nTruncating document..."
)
text = self.token_encoder.decode(tokens[:self.token_limit-1])
logger.info(f"Trunc length: {len(self.token_encoder.encode(text))}")
return text
return text
......@@ -2,16 +2,6 @@ from enum import Enum
from typing import List, Optional
from pydantic.v1 import BaseModel
from pydantic.v1.dataclasses import dataclass
from semantic_router.encoders import (
BaseEncoder,
CohereEncoder,
FastEmbedEncoder,
GoogleEncoder,
MistralEncoder,
OpenAIEncoder,
)
class EncoderType(Enum):
......@@ -23,40 +13,17 @@ class EncoderType(Enum):
GOOGLE = "google"
class EncoderInfo(BaseModel):
name: str
type: EncoderType
token_limit: int
class RouteChoice(BaseModel):
name: Optional[str] = None
function_call: Optional[dict] = None
similarity_score: Optional[float] = None
@dataclass
class Encoder:
type: EncoderType
name: Optional[str]
model: BaseEncoder
def __init__(self, type: str, name: Optional[str]):
self.type = EncoderType(type)
self.name = name
if self.type == EncoderType.HUGGINGFACE:
raise NotImplementedError
elif self.type == EncoderType.FASTEMBED:
self.model = FastEmbedEncoder(name=name)
elif self.type == EncoderType.OPENAI:
self.model = OpenAIEncoder(name=name)
elif self.type == EncoderType.COHERE:
self.model = CohereEncoder(name=name)
elif self.type == EncoderType.MISTRAL:
self.model = MistralEncoder(name=name)
elif self.type == EncoderType.GOOGLE:
self.model = GoogleEncoder(name=name)
else:
raise ValueError
def __call__(self, texts: List[str]) -> List[List[float]]:
return self.model(texts)
class Message(BaseModel):
role: str
content: str
......
......@@ -5,7 +5,6 @@ from openai.types.create_embedding_response import Usage
from semantic_router.encoders import OpenAIEncoder
@pytest.fixture
def openai_encoder(mocker):
mocker.patch("openai.Client")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment