Skip to content
Snippets Groups Projects
Commit ff2ed5f9 authored by Simonas's avatar Simonas
Browse files

chore: lint

parent 8e2626af
No related branches found
No related tags found
No related merge requests found
format:
poetry run black --target-version py39 .
poetry run black --target-version py39 -l 88 .
poetry run ruff --select I --fix .
PYTHON_FILES=.
......@@ -7,7 +7,7 @@ lint: PYTHON_FILES=.
lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$')
lint lint_diff:
poetry run black --target-version py39 $(PYTHON_FILES) --check
poetry run black --target-version py39 -l 88 $(PYTHON_FILES) --check
poetry run ruff .
poetry run mypy $(PYTHON_FILES)
......
from typing import List, Optional
from typing import List
from colorama import Fore, Style
from pydantic.v1 import BaseModel, Extra
......@@ -10,7 +10,6 @@ from semantic_router.schema import DocumentSplit
class BaseSplitter(BaseModel):
name: str
encoder: BaseEncoder
score_threshold: Optional[float]
class Config:
extra = Extra.allow
......
......@@ -19,8 +19,9 @@ class ConsecutiveSimSplitter(BaseSplitter):
name: str = "consecutive_similarity_splitter",
score_threshold: float = 0.45,
):
super().__init__(name=name, score_threshold=score_threshold, encoder=encoder)
super().__init__(name=name, encoder=encoder)
encoder.score_threshold = score_threshold
self.score_threshold = score_threshold
def __call__(self, docs: List[Any]):
# Check if there's only a single document
......
......@@ -8,9 +8,9 @@ from semantic_router.splitters.base import BaseSplitter
class CumulativeSimSplitter(BaseSplitter):
"""
Called "cumulative sim" because we check the similarities of the embeddings of cumulative concatenated documents with the next document.
Called "cumulative sim" because we check the similarities of the
embeddings of cumulative concatenated documents with the next document.
"""
def __init__(
......@@ -19,15 +19,17 @@ class CumulativeSimSplitter(BaseSplitter):
name: str = "cumulative_similarity_splitter",
score_threshold: float = 0.45,
):
super().__init__(name=name, score_threshold=score_threshold, encoder=encoder)
super().__init__(name=name, encoder=encoder)
encoder.score_threshold = score_threshold
self.score_threshold = score_threshold
def __call__(self, docs: List[str]):
total_docs = len(docs)
# Check if there's only a single document
if total_docs == 1:
raise ValueError(
"There is only one document provided; at least two are required to determine topics based on similarity."
"There is only one document provided; at least two are required "
"to determine topics based on similarity."
)
splits = []
curr_split_start_idx = 0
......@@ -35,10 +37,12 @@ class CumulativeSimSplitter(BaseSplitter):
for idx in range(0, total_docs):
if idx + 1 < total_docs: # Ensure there is a next document to compare with.
if idx == 0:
# On the first iteration, compare the first document directly to the second.
# On the first iteration, compare the
# first document directly to the second.
curr_split_docs = docs[idx]
else:
# For subsequent iterations, compare cumulative documents up to the current one with the next.
# For subsequent iterations, compare cumulative
# documents up to the current one with the next.
curr_split_docs = "\n".join(docs[curr_split_start_idx : idx + 1])
next_doc = docs[idx + 1]
......
......@@ -9,7 +9,6 @@ from semantic_router.splitters.utils import split_to_sentences, tiktoken_length
from semantic_router.utils.logger import logger
class RollingWindowSplitter(BaseSplitter):
def __init__(
self,
......@@ -20,7 +19,7 @@ class RollingWindowSplitter(BaseSplitter):
max_split_tokens=300,
split_tokens_tolerance=10,
plot_splits=False,
name = "rolling_window_splitter",
name="rolling_window_splitter",
):
super().__init__(name=name, encoder=encoder)
self.calculated_threshold: float
......
from typing import List
import regex
import tiktoken
from colorama import Fore, Style
from semantic_router.schema import DocumentSplit
def split_to_sentences(text: str) -> list[str]:
......@@ -66,4 +61,3 @@ def tiktoken_length(text: str) -> int:
tokenizer = tiktoken.get_encoding("cl100k_base")
tokens = tokenizer.encode(text, disallowed_special=())
return len(tokens)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment