Skip to content
Snippets Groups Projects
Commit 51000d5b authored by Simonas's avatar Simonas
Browse files

embed all decision at once + logging

parent 4a81de14
No related branches found
No related tags found
No related merge requests found
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]] [[package]]
name = "aiohttp" name = "aiohttp"
...@@ -439,6 +439,23 @@ files = [ ...@@ -439,6 +439,23 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
] ]
[[package]]
name = "colorlog"
version = "6.8.0"
description = "Add colours to the output of Python's logging module."
optional = false
python-versions = ">=3.6"
files = [
{file = "colorlog-6.8.0-py3-none-any.whl", hash = "sha256:4ed23b05a1154294ac99f511fabe8c1d6d4364ec1f7fc989c7fb515ccc29d375"},
{file = "colorlog-6.8.0.tar.gz", hash = "sha256:fbb6fdf9d5685f2517f388fb29bb27d54e8654dd31f58bc2a3b217e967a95ca6"},
]
[package.dependencies]
colorama = {version = "*", markers = "sys_platform == \"win32\""}
[package.extras]
development = ["black", "flake8", "mypy", "pytest", "types-colorama"]
[[package]] [[package]]
name = "comm" name = "comm"
version = "0.2.0" version = "0.2.0"
...@@ -1987,4 +2004,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p ...@@ -1987,4 +2004,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "b751e9eced707d903729ec6f473ec547e00bd7ef98e7536da003e5d2f4a80783" content-hash = "64e772051ca3411e09defc8ab06235a7c3e39f9bf60e58fb06b25317c5a34053"
...@@ -18,6 +18,7 @@ openai = "^0.28.1" ...@@ -18,6 +18,7 @@ openai = "^0.28.1"
cohere = "^4.32" cohere = "^4.32"
numpy = "^1.25.2" numpy = "^1.25.2"
pinecone-text = "^0.7.0" pinecone-text = "^0.7.0"
colorlog = "^6.8.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
......
...@@ -15,10 +15,16 @@ class CohereEncoder(BaseEncoder): ...@@ -15,10 +15,16 @@ class CohereEncoder(BaseEncoder):
cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
if cohere_api_key is None: if cohere_api_key is None:
raise ValueError("Cohere API key cannot be 'None'.") raise ValueError("Cohere API key cannot be 'None'.")
self.client = cohere.Client(cohere_api_key) try:
self.client = cohere.Client(cohere_api_key)
except Exception as e:
raise ValueError(f"Cohere API client failed to initialize. Error: {e}")
def __call__(self, docs: list[str]) -> list[list[float]]: def __call__(self, docs: list[str]) -> list[list[float]]:
if self.client is None: if self.client is None:
raise ValueError("Cohere client is not initialized.") raise ValueError("Cohere client is not initialized.")
embeds = self.client.embed(docs, input_type="search_query", model=self.name) try:
return embeds.embeddings embeds = self.client.embed(docs, input_type="search_query", model=self.name)
return embeds.embeddings
except Exception as e:
raise ValueError(f"Cohere API call failed. Error: {e}")
...@@ -2,9 +2,10 @@ import os ...@@ -2,9 +2,10 @@ import os
from time import sleep from time import sleep
import openai import openai
from openai.error import RateLimitError from openai.error import RateLimitError, ServiceUnavailableError
from semantic_router.encoders import BaseEncoder from semantic_router.encoders import BaseEncoder
from semantic_router.utils.logger import logger
class OpenAIEncoder(BaseEncoder): class OpenAIEncoder(BaseEncoder):
...@@ -19,17 +20,20 @@ class OpenAIEncoder(BaseEncoder): ...@@ -19,17 +20,20 @@ class OpenAIEncoder(BaseEncoder):
vector embeddings. vector embeddings.
""" """
res = None res = None
# exponential backoff in case of RateLimitError error_message = ""
# exponential backoff
for j in range(5): for j in range(5):
try: try:
logger.info(f"Encoding {len(docs)} docs...")
res = openai.Embedding.create(input=docs, engine=self.name) res = openai.Embedding.create(input=docs, engine=self.name)
if isinstance(res, dict) and "data" in res: if isinstance(res, dict) and "data" in res:
break break
except RateLimitError: except (RateLimitError, ServiceUnavailableError) as e:
sleep(2**j) sleep(2**j)
error_message = str(e)
if not res or not isinstance(res, dict) or "data" not in res: if not res or not isinstance(res, dict) or "data" not in res:
raise ValueError("Failed to create embeddings.") raise ValueError(f"OpenAI API call failed. Error: {error_message}")
# get embeddings
embeds = [r["embedding"] for r in res["data"]] embeds = [r["embedding"] for r in res["data"]]
return embeds return embeds
...@@ -4,9 +4,9 @@ from tqdm.auto import tqdm ...@@ -4,9 +4,9 @@ from tqdm.auto import tqdm
from semantic_router.encoders import ( from semantic_router.encoders import (
BaseEncoder, BaseEncoder,
BM25Encoder,
CohereEncoder, CohereEncoder,
OpenAIEncoder, OpenAIEncoder,
BM25Encoder,
) )
from semantic_router.linear import similarity_matrix, top_scores from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.schema import Decision from semantic_router.schema import Decision
...@@ -29,8 +29,7 @@ class DecisionLayer: ...@@ -29,8 +29,7 @@ class DecisionLayer:
# if decisions list has been passed, we initialize index now # if decisions list has been passed, we initialize index now
if decisions: if decisions:
# initialize index now # initialize index now
for decision in tqdm(decisions): self._add_decisions(decisions=decisions)
self._add_decision(decision=decision)
def __call__(self, text: str) -> str | None: def __call__(self, text: str) -> str | None:
results = self._query(text) results = self._query(text)
...@@ -61,6 +60,32 @@ class DecisionLayer: ...@@ -61,6 +60,32 @@ class DecisionLayer:
embed_arr = np.array(embeds) embed_arr = np.array(embeds)
self.index = np.concatenate([self.index, embed_arr]) self.index = np.concatenate([self.index, embed_arr])
def _add_decisions(self, decisions: list[Decision]):
# create embeddings for all decisions
all_utterances = [
utterance for decision in decisions for utterance in decision.utterances
]
embedded_utterance = self.encoder(all_utterances)
# create decision array
decision_names = [
decision.name for decision in decisions for _ in decision.utterances
]
decision_array = np.array(decision_names)
self.categories = (
np.concatenate([self.categories, decision_array])
if self.categories is not None
else decision_array
)
# create utterance array (the index)
embed_utterance_arr = np.array(embedded_utterance)
self.index = (
np.concatenate([self.index, embed_utterance_arr])
if self.index is not None
else embed_utterance_arr
)
def _query(self, text: str, top_k: int = 5): def _query(self, text: str, top_k: int = 5):
"""Given some text, encodes and searches the index vector space to """Given some text, encodes and searches the index vector space to
retrieve the top_k most similar records. retrieve the top_k most similar records.
...@@ -172,6 +197,9 @@ class HybridDecisionLayer: ...@@ -172,6 +197,9 @@ class HybridDecisionLayer:
else: else:
self.sparse_index = np.concatenate([self.sparse_index, sparse_embeds]) self.sparse_index = np.concatenate([self.sparse_index, sparse_embeds])
def _add_decisions(self, decisions: list[Decision]):
raise NotImplementedError
def _query(self, text: str, top_k: int = 5): def _query(self, text: str, top_k: int = 5):
"""Given some text, encodes and searches the index vector space to """Given some text, encodes and searches the index vector space to
retrieve the top_k most similar records. retrieve the top_k most similar records.
......
import logging
import colorlog
class CustomFormatter(colorlog.ColoredFormatter):
def __init__(self):
super().__init__(
"%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bold_red",
},
reset=True,
style="%",
)
def add_coloured_handler(logger):
formatter = CustomFormatter()
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logging.basicConfig(
datefmt="%Y-%m-%d %H:%M:%S",
format="%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
force=True,
)
logger.addHandler(console_handler)
return logger
def setup_custom_logger(name):
logger = logging.getLogger(name)
logger.handlers = []
add_coloured_handler(logger)
logger.setLevel(logging.INFO)
logger.propagate = False
return logger
logger = setup_custom_logger(__name__)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment