Skip to content
Snippets Groups Projects
Commit 8e2626af authored by Simonas's avatar Simonas
Browse files

fix: Fix inheritance with BaseSplitter

parent a063d42b
No related branches found
No related tags found
No related merge requests found
Source diff could not be displayed: it is too large. Options to address this: view the blob.
from itertools import cycle
from typing import List, Optional
from pydantic.v1 import BaseModel
from termcolor import colored
from colorama import Fore, Style
from pydantic.v1 import BaseModel, Extra
from semantic_router.encoders import BaseEncoder
from semantic_router.schema import DocumentSplit
......@@ -11,25 +10,30 @@ from semantic_router.schema import DocumentSplit
class BaseSplitter(BaseModel):
name: str
encoder: BaseEncoder
score_threshold: float
min_split_tokens: Optional[int] = None
max_split_tokens: Optional[int] = None
score_threshold: Optional[float]
class Config:
extra = Extra.allow
def __call__(self, docs: List[str]) -> List[DocumentSplit]:
raise NotImplementedError("Subclasses must implement this method")
def print_splits(self, splits: list[DocumentSplit]):
colors = cycle(["red", "green", "blue", "magenta", "cyan"])
for i, split in enumerate(splits):
triggered_text = (
"Triggered " + format(split.triggered_score, ".2f")
if split.triggered_score
else "Not Triggered"
)
header = f"Split {i+1} - ({triggered_text})"
if split.triggered_score:
print(colored(header, "red"))
def print(self, document_splits: List[DocumentSplit]) -> None:
colors = [Fore.RED, Fore.GREEN, Fore.BLUE, Fore.MAGENTA]
for i, split in enumerate(document_splits):
color = colors[i % len(colors)]
colored_content = f"{color}{split.content}{Style.RESET_ALL}"
if split.is_triggered:
triggered = f"{split.triggered_score:.2f}"
elif i == len(document_splits) - 1:
triggered = "final split"
else:
print(colored(header, "blue"))
print(colored(split.docs, next(colors))) # type: ignore
print("\n" + "-" * 50 + "\n")
triggered = "token limit"
print(
f"Split {i + 1}, "
f"tokens {split.token_count}, "
f"triggered by: {triggered}"
)
print(colored_content)
print("-" * 88)
print("\n")
......@@ -4,21 +4,25 @@ import numpy as np
from semantic_router.encoders.base import BaseEncoder
from semantic_router.schema import DocumentSplit
from semantic_router.splitters.base import BaseSplitter
from semantic_router.splitters.utils import split_to_sentences, tiktoken_length
from semantic_router.utils.logger import logger
class RollingWindowSplitter:
class RollingWindowSplitter(BaseSplitter):
def __init__(
self,
encoder: BaseEncoder,
threshold_adjustment: float = 0.01,
threshold_adjustment=0.01,
window_size=5,
min_split_tokens=100,
max_split_tokens=300,
split_tokens_tolerance=10,
plot_splits=False,
name = "rolling_window_splitter",
):
super().__init__(name=name, encoder=encoder)
self.calculated_threshold: float
self.encoder = encoder
self.threshold_adjustment = threshold_adjustment
......
......@@ -67,23 +67,3 @@ def tiktoken_length(text: str) -> int:
tokens = tokenizer.encode(text, disallowed_special=())
return len(tokens)
def plot_splits(document_splits: List[DocumentSplit]) -> None:
colors = [Fore.RED, Fore.GREEN, Fore.BLUE, Fore.MAGENTA]
for i, split in enumerate(document_splits):
color = colors[i % len(colors)]
colored_content = f"{color}{split.content}{Style.RESET_ALL}"
if split.is_triggered:
triggered = f"{split.triggered_score:.2f}"
elif i == len(document_splits) - 1:
triggered = "final split"
else:
triggered = "token limit"
print(
f"Split {i + 1}, "
f"tokens {split.token_count}, "
f"triggered by: {triggered}"
)
print(colored_content)
print("-" * 88)
print("\n")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment