diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 77053ef08100a5d7dd5f06048ce9360b0987f091..5271b897298af0d99efb5e6f23f87c94da4e7b12 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -1,6 +1,7 @@ -from pydantic.v1 import BaseModel -from typing import Any, List, Tuple, Optional, Union +from typing import Any, List, Optional, Tuple, Union + import numpy as np +from pydantic.v1 import BaseModel class BaseIndex(BaseModel): @@ -36,7 +37,8 @@ class BaseIndex(BaseModel): def describe(self) -> dict: """ - Returns a dictionary with index details such as type, dimensions, and total vector count. + Returns a dictionary with index details such as type, dimensions, and total + vector count. This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 058ee2bc08e562a56327c61850471b14409dfa6a..0c2a7140d2dd50099fa9389db255133452f11dfd 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -1,7 +1,9 @@ +from typing import List, Optional, Tuple + import numpy as np -from typing import List, Tuple, Optional -from semantic_router.linear import similarity_matrix, top_scores + from semantic_router.index.base import BaseIndex +from semantic_router.linear import similarity_matrix, top_scores class LocalIndex(BaseIndex): @@ -14,7 +16,8 @@ class LocalIndex(BaseIndex): super().__init__(index=index, routes=routes, utterances=utterances) self.type = "local" - class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints. + class Config: + # Stop pydantic from complaining about Optional[np.ndarray]type hints. arbitrary_types_allowed = True def add( @@ -78,7 +81,8 @@ class LocalIndex(BaseIndex): self.utterances = np.delete(self.utterances, delete_idx, axis=0) else: raise ValueError( - "Attempted to delete route records but either index, routes or utterances is None." + "Attempted to delete route records but either index, routes or " + "utterances is None." ) def delete_index(self): diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index a6492226a46ebd2550f9c467fa64aa44af36d0a3..0fa6137df351243d9894d59dee3a3c911e503929 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -1,12 +1,14 @@ -from pydantic.v1 import BaseModel, Field -import requests -import time import hashlib import os -from typing import Any, Dict, List, Tuple, Optional, Union +import time +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import requests +from pydantic.v1 import BaseModel, Field + from semantic_router.index.base import BaseIndex from semantic_router.utils.logger import logger -import numpy as np def clean_route_name(route_name: str) -> str: @@ -21,9 +23,9 @@ class PineconeRecord(BaseModel): def __init__(self, **data): super().__init__(**data) - # generate ID based on route name and utterances to prevent duplicates clean_route = clean_route_name(self.route) - utterance_id = hashlib.md5(self.utterance.encode()).hexdigest() + # Use SHA-256 for a more secure hash + utterance_id = hashlib.sha256(self.utterance.encode()).hexdigest() self.id = f"{clean_route}#{utterance_id}" def to_dict(self): @@ -48,14 +50,10 @@ class PineconeIndex(BaseIndex): def __init__(self, **data): super().__init__(**data) - self._initialize_client() - - self.type = "pinecone" - self.client = self._initialize_client() - if not self.index_name.startswith(self.index_prefix): - self.index_name = f"{self.index_prefix}{self.index_name}" - # Create or connect to an existing Pinecone index - self.index = self._init_index() + clean_route = clean_route_name(self.route) + # Use SHA-256 for a more secure hash + utterance_id = hashlib.sha256(self.utterance.encode()).hexdigest() + self.id = f"{clean_route}#{utterance_id}" def _initialize_client(self, api_key: Optional[str] = None): try: diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 1e8a813066960f2428dc04df4b4780112c21847d..9573aa0a7fa1e758a803cffbfc1b53beabe00fe8 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -1,3 +1,4 @@ +import importlib import json import os import random @@ -6,19 +7,19 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import yaml from tqdm.auto import tqdm -import importlib from semantic_router.encoders import BaseEncoder, OpenAIEncoder +from semantic_router.index.base import BaseIndex +from semantic_router.index.local import LocalIndex from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route from semantic_router.schema import Encoder, EncoderType, RouteChoice from semantic_router.utils.logger import logger -from semantic_router.index.base import BaseIndex -from semantic_router.index.local import LocalIndex def is_valid(layer_config: str) -> bool: - """Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]""" + """Make sure the given string is json format and contains the 3 keys: + ["encoder_name", "encoder_type", "routes"]""" try: output_json = json.loads(layer_config) required_keys = ["encoder_name", "encoder_type", "routes"] @@ -209,7 +210,8 @@ class RouteLayer: matching_routes = [route for route in self.routes if route.name == top_class] if not matching_routes: logger.error( - f"No route found with name {top_class}. Check to see if any Routes have been defined." + f"No route found with name {top_class}. Check to see if any Routes " + "have been defined." ) return None return matching_routes[0] diff --git a/semantic_router/splitters/base.py b/semantic_router/splitters/base.py index c4df5785da827cd27847dc39afa58cf0a8c8aa41..c64d384c0726e7595333c61007f4d9d5d7bc82d5 100644 --- a/semantic_router/splitters/base.py +++ b/semantic_router/splitters/base.py @@ -1,10 +1,11 @@ +from itertools import cycle from typing import List from pydantic.v1 import BaseModel +from termcolor import colored from semantic_router.encoders import BaseEncoder from semantic_router.schema import DocumentSplit -from termcolor import colored class BaseSplitter(BaseModel): @@ -12,14 +13,21 @@ class BaseSplitter(BaseModel): encoder: BaseEncoder score_threshold: float - def __call__(self, docs: List[str]) -> List[List[float]]: + def __call__(self, docs: List[str]) -> List[DocumentSplit]: raise NotImplementedError("Subclasses must implement this method") - def print_colored_splits(self, splits: List[DocumentSplit]): - colors = ["red", "green", "blue", "magenta", "cyan"] + def print_splits(self, splits: list[DocumentSplit]): + colors = cycle(["red", "green", "blue", "magenta", "cyan"]) for i, split in enumerate(splits): - color = colors[i % len(colors)] - for doc in split.docs: - print(colored(doc, color)) # type: ignore - print("Triggered score:", split.triggered_score) - print("\n") + triggered_text = ( + "Triggered " + format(split.triggered_score, ".2f") + if split.triggered_score + else "Not Triggered" + ) + header = f"Split {i+1} - ({triggered_text})" + if split.triggered_score: + print(colored(header, "red")) + else: + print(colored(header, "blue")) + print(colored(split.docs, next(colors))) # type: ignore + print("\n" + "-" * 50 + "\n") diff --git a/semantic_router/splitters/dynamic_cumulative.py b/semantic_router/splitters/dynamic_cumulative.py deleted file mode 100644 index f1b35bdfed06e0c59d66f64d4d84049d07187ac1..0000000000000000000000000000000000000000 --- a/semantic_router/splitters/dynamic_cumulative.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import List - -import numpy as np - -from semantic_router.encoders import BaseEncoder -from semantic_router.schema import DocumentSplit -from semantic_router.splitters.base import BaseSplitter -from semantic_router.utils.logger import logger - - -class DynamicCumulativeSplitter(BaseSplitter): - """ - Splits documents dynamically based on the cumulative similarity of document - embeddings, adjusting thresholds and window sizes based on recent similarities. - """ - - def __init__( - self, - encoder: BaseEncoder, - name: str = "dynamic_cumulative_similarity_splitter", - score_threshold: float = 0.9, - ): - super().__init__(name=name, encoder=encoder, score_threshold=score_threshold) - # Log the initialization details - logger.info( - f"Initialized {self.name} with score threshold: {self.score_threshold}" - ) - - def encode_documents(self, docs: List[str]) -> np.ndarray: - # Encode the documents using the provided encoder and return as a numpy array - encoded_docs = self.encoder(docs) - logger.info(f"Encoded {len(docs)} documents") - return np.array(encoded_docs) - - def adjust_threshold(self, similarities): - # Adjust the similarity threshold based on recent similarities - if len(similarities) <= 5: - # If not enough data, return the default score threshold - return self.score_threshold - - # Calculate mean and standard deviation of the last 5 similarities - recent_similarities = similarities[-5:] - mean_similarity, std_dev_similarity = np.mean(recent_similarities), np.std( - recent_similarities - ) - - # Calculate the change in mean and standard deviation if enough data is - # available - delta_mean = delta_std_dev = 0 - if len(similarities) > 10: - previous_similarities = similarities[-10:-5] - delta_mean = mean_similarity - np.mean(previous_similarities) - delta_std_dev = std_dev_similarity - np.std(previous_similarities) - - # Adjust the threshold based on the calculated metrics - adjustment_factor = std_dev_similarity + abs(delta_mean) + abs(delta_std_dev) - adjusted_threshold = mean_similarity - adjustment_factor - dynamic_lower_bound = max(0.2, 0.2 + delta_mean - delta_std_dev) - min_split_threshold = 0.3 - - # Ensure the new threshold is within a sensible range - new_threshold = max( - np.clip(adjusted_threshold, dynamic_lower_bound, self.score_threshold), - min_split_threshold, - ) - logger.debug( - f"Adjusted threshold to {new_threshold}, with dynamic lower " - f"bound {dynamic_lower_bound}" - ) - return new_threshold - - def calculate_dynamic_context_similarity(self, encoded_docs): - # Calculate the dynamic context similarity to determine split indices - split_indices, similarities = [0], [] - dynamic_window_size = 5 # Initial window size - norms = np.linalg.norm( - encoded_docs, axis=1 - ) # Pre-calculate norms for efficiency - - for idx in range(1, len(encoded_docs)): - # Adjust window size based on the standard deviation of recent similarities - if len(similarities) > 10: - std_dev_recent = np.std(similarities[-10:]) - dynamic_window_size = 5 if std_dev_recent < 0.05 else 10 - - # Calculate the similarity for the current document - window_start = max(0, idx - dynamic_window_size) - cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0) - cumulative_norm = np.linalg.norm(cumulative_context) - curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / ( - cumulative_norm * norms[idx] + 1e-10 - ) - - similarities.append(curr_sim_score) - # If the similarity is below the dynamically adjusted threshold, - # mark a new split - if curr_sim_score < self.adjust_threshold(similarities): - split_indices.append(idx) - - return split_indices, similarities - - def __call__(self, docs: List[str]): - # Main method to split the documents - logger.info(f"Splitting {len(docs)} documents") - encoded_docs = self.encode_documents(docs) - split_indices, similarities = self.calculate_dynamic_context_similarity( - encoded_docs - ) - splits = [] - - # Create DocumentSplit objects for each identified split - last_idx = 0 - for idx in split_indices: - if idx == 0: - continue - splits.append( - DocumentSplit( - docs=docs[last_idx:idx], - is_triggered=(idx - last_idx > 1), - triggered_score=( - similarities[idx - 1] if idx - 1 < len(similarities) else None - ), - ) - ) - last_idx = idx - splits.append( - DocumentSplit( - docs=docs[last_idx:], - is_triggered=(len(docs) - last_idx > 1), - triggered_score=similarities[-1] if similarities else None, - ) - ) - logger.info(f"Completed splitting documents into {len(splits)} splits") - - return splits diff --git a/semantic_router/splitters/rolling_window.py b/semantic_router/splitters/rolling_window.py new file mode 100644 index 0000000000000000000000000000000000000000..57bf4733347b74b56db9511430fa1bccb3c559c9 --- /dev/null +++ b/semantic_router/splitters/rolling_window.py @@ -0,0 +1,235 @@ +from typing import List + +import numpy as np +from matplotlib import pyplot as plt +from nltk.tokenize import word_tokenize + +from semantic_router.schema import DocumentSplit +from semantic_router.splitters.base import BaseSplitter +from semantic_router.utils.logger import logger + + +class RollingWindowSplitter(BaseSplitter): + """ + A splitter that divides documents into segments based on semantic similarity + using a rolling window approach. + It adjusts the similarity threshold dynamically. + Splitting is based: + - On the similarity threshold + - On the maximum token limit for a split + + Attributes: + encoder (Callable): A function to encode documents into semantic vectors. + score_threshold (float): Initial threshold for similarity scores to decide + splits. + window_size (int): Size of the rolling window to calculate document context. + plot_splits (bool): Whether to plot the similarity scores and splits for + visualization. + min_split_tokens (int): Minimum number of tokens for a valid document split. + max_split_tokens (int): Maximum number of tokens a split can contain. + split_tokens_tolerance (int): Tolerance in token count to still consider a split + valid. + threshold_step_size (float): Step size to adjust the similarity threshold during + optimization. + """ + + def __init__( + self, + encoder, + score_threshold=0.3, + window_size=5, + plot_splits=False, + min_split_tokens=100, + max_split_tokens=300, + split_tokens_tolerance=10, + threshold_step_size=0.01, + ): + self.encoder = encoder + self.score_threshold = score_threshold + self.window_size = window_size + self.plot_splits = plot_splits + self.min_split_tokens = min_split_tokens + self.max_split_tokens = max_split_tokens + self.split_tokens_tolerance = split_tokens_tolerance + self.threshold_step_size = threshold_step_size + + def encode_documents(self, docs: list[str]) -> np.ndarray: + return np.array(self.encoder(docs)) + + def find_optimal_threshold(self, docs: list[str], encoded_docs: np.ndarray): + logger.info(f"Number of documents for finding optimal threshold: {len(docs)}") + token_counts = [len(word_tokenize(doc)) for doc in docs] + low, high = 0, 1 + while low <= high: + self.score_threshold = (low + high) / 2 + similarity_scores = self.calculate_similarity_scores(encoded_docs) + split_indices = self.find_split_indices(similarity_scores) + average_tokens = np.mean( + [ + sum(token_counts[start:end]) + for start, end in zip( + [0] + split_indices, split_indices + [len(token_counts)] + ) + ] + ) + if ( + self.min_split_tokens - self.split_tokens_tolerance + <= average_tokens + <= self.max_split_tokens + self.split_tokens_tolerance + ): + break + elif average_tokens < self.min_split_tokens: + high = self.score_threshold - self.threshold_step_size + else: + low = self.score_threshold + self.threshold_step_size + + def calculate_similarity_scores(self, encoded_docs: np.ndarray) -> list[float]: + raw_similarities = [] + for idx in range(1, len(encoded_docs)): + window_start = max(0, idx - self.window_size) + cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0) + curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / ( + np.linalg.norm(cumulative_context) * np.linalg.norm(encoded_docs[idx]) + + 1e-10 + ) + raw_similarities.append(curr_sim_score) + return raw_similarities + + def find_split_indices(self, similarities: list[float]) -> list[int]: + return [ + idx + 1 + for idx, sim in enumerate(similarities) + if sim < self.score_threshold + ] + + def split_documents( + self, docs: list[str], split_indices: list[int], similarities: list[float] + ) -> list[DocumentSplit]: + """ + This method iterates through each document, appending it to the current split + until it either reaches a split point (determined by split_indices) or exceeds + the maximum token limit for a split (self.max_split_tokens). + When a document causes the current token count to exceed this limit, + or when a split point is reached and the minimum token requirement is met, + the current split is finalized and added to the list of splits. + """ + token_counts = [len(word_tokenize(doc)) for doc in docs] + splits: List[DocumentSplit] = [] + current_split: List[str] = [] + current_tokens_count = 0 + + for doc_idx, doc in enumerate(docs): + doc_token_count = token_counts[doc_idx] + # Check if current document causes token count to exceed max limit + if ( + current_tokens_count + doc_token_count > self.max_split_tokens + and current_tokens_count >= self.min_split_tokens + ): + splits.append( + DocumentSplit(docs=current_split.copy(), is_triggered=True) + ) + logger.info( + f"Split finalized with {current_tokens_count} tokens due to " + f"exceeding token limit of {self.max_split_tokens}." + ) + current_split, current_tokens_count = [], 0 + + current_split.append(doc) + current_tokens_count += doc_token_count + + # Check if current index is a split point based on similarity + if doc_idx + 1 in split_indices or doc_idx == len(docs) - 1: + if current_tokens_count >= self.min_split_tokens: + if doc_idx < len(similarities): + triggered_score = similarities[doc_idx] + splits.append( + DocumentSplit( + docs=current_split.copy(), + is_triggered=True, + triggered_score=triggered_score, + ) + ) + logger.info( + f"Split finalized with {current_tokens_count} tokens due to" + f" similarity score {triggered_score:.2f}." + ) + else: + # This case handles the end of the document list + # where there's no similarity score + splits.append( + DocumentSplit(docs=current_split.copy(), is_triggered=False) + ) + logger.info( + f"Split finalized with {current_tokens_count} tokens " + "at the end of the document list." + ) + current_split, current_tokens_count = [], 0 + + # Ensure any remaining documents are included in the final token count + if current_split: + splits.append(DocumentSplit(docs=current_split.copy(), is_triggered=False)) + logger.info( + f"Final split added with {current_tokens_count} tokens " + "due to remaining documents." + ) + + # Validation + original_token_count = sum(token_counts) + split_token_count = sum( + [len(word_tokenize(doc)) for split in splits for doc in split.docs] + ) + logger.debug( + f"Original Token Count: {original_token_count}, " + f"Split Token Count: {split_token_count}" + ) + + if original_token_count != split_token_count: + logger.error( + f"Token count mismatch: {original_token_count} != {split_token_count}" + ) + for i, split in enumerate(splits): + split_token_count = sum([len(word_tokenize(doc)) for doc in split.docs]) + logger.error(f"Split {i} Token Count: {split_token_count}") + raise ValueError( + f"Token count mismatch: {original_token_count} != {split_token_count}" + ) + + return splits + + # TODO: fix to plot split based on token count and final split + def plot_similarity_scores( + self, similarities: list[float], split_indices: list[int] + ): + if not self.plot_splits: + return + plt.figure(figsize=(12, 6)) + plt.plot(similarities, label="Similarity Scores", marker="o") + for split_index in split_indices: + plt.axvline( + x=split_index - 1, + color="r", + linestyle="--", + label="Split" if split_index == split_indices[0] else "", + ) + plt.axhline( + y=self.score_threshold, + color="g", + linestyle="-.", + label="Threshold Similarity Score", + ) + plt.xlabel("Document Segment Index") + plt.ylabel("Similarity Score") + plt.title(f"Threshold: {self.score_threshold}", loc="right", fontsize=10) + plt.suptitle("Document Similarity Scores", fontsize=14) + plt.legend() + plt.show() + + def __call__(self, docs: list[str]) -> list[DocumentSplit]: + encoded_docs = self.encode_documents(docs) + self.find_optimal_threshold(docs, encoded_docs) + similarities = self.calculate_similarity_scores(encoded_docs) + split_indices = self.find_split_indices(similarities=similarities) + splits = self.split_documents(docs, split_indices, similarities) + + self.plot_similarity_scores(similarities, split_indices) + return splits diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 88a4679a8dc71c016817472c9a504ff1ab56dac8..113a5d4a69bca2f02e14c086c30eef1750526a5e 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -6,8 +6,8 @@ import pytest from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.layer import LayerConfig, RouteLayer -from semantic_router.route import Route from semantic_router.llms.base import BaseLLM +from semantic_router.route import Route def mock_encoder_call(utterances): @@ -399,10 +399,11 @@ class TestRouteLayer: # Load the LayerConfig from the temporary file layer_config = LayerConfig.from_file(str(config_path)) - # Using BaseLLM because trying to create a useable Mock LLM is a nightmare. + # Using BaseLLM because trying to create a usable Mock LLM is a nightmare. assert isinstance( layer_config.routes[0].llm, BaseLLM - ), "LLM should be instantiated and associated with the route based on the config" + ), "LLM should be instantiated and associated with the route based on the " + "config" assert ( layer_config.routes[0].llm.name == "fake-model-v1" ), "LLM instance should have the 'name' attribute set correctly"