Skip to content
Snippets Groups Projects
Commit a1fa7419 authored by hananel's avatar hananel
Browse files

additional improvements

parent da1ff80e
No related branches found
No related tags found
No related merge requests found
import json
import os
from typing import Optional
from typing import Optional, Any
import numpy as np
import yaml
......@@ -14,6 +14,7 @@ from semantic_router.utils.logger import logger
def is_valid(layer_config: str) -> bool:
"""Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]"""
try:
output_json = json.loads(layer_config)
required_keys = ["encoder_name", "encoder_type", "routes"]
......@@ -73,7 +74,7 @@ class LayerConfig:
self.routes = routes
@classmethod
def from_file(cls, path: str):
def from_file(cls, path: str) -> "LayerConfig":
"""Load the routes from a file in JSON or YAML format"""
logger.info(f"Loading route config from {path}")
_, ext = os.path.splitext(path)
......@@ -98,7 +99,7 @@ class LayerConfig:
else:
raise Exception("Invalid config JSON or YAML")
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"encoder_type": self.encoder_type,
"encoder_name": self.encoder_name,
......
......@@ -62,7 +62,7 @@ class Route(BaseModel):
func_call = None
return RouteChoice(name=self.name, function_call=func_call)
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return self.dict()
@classmethod
......
from enum import Enum
from typing import Optional
from typing import Optional, Literal
from pydantic import BaseModel
from pydantic.dataclasses import dataclass
......@@ -55,12 +55,12 @@ class Message(BaseModel):
role: str
content: str
def to_openai(self):
def to_openai(self) -> dict[str, str]:
if self.role.lower() not in ["user", "assistant", "system"]:
raise ValueError("Role must be either 'user', 'assistant' or 'system'")
return {"role": self.role, "content": self.content}
def to_cohere(self):
def to_cohere(self) -> dict[str, str]:
return {"role": self.role, "message": self.content}
......@@ -68,11 +68,13 @@ class Conversation(BaseModel):
messages: list[Message]
def split_by_topic(
self,
encoder: BaseEncoder,
threshold: float = 0.5,
split_method: str = "consecutive_similarity_drop",
):
self,
encoder: BaseEncoder,
threshold: float = 0.5,
split_method: Literal[
"consecutive_similarity_drop", "cumulative_similarity_drop"
] = "consecutive_similarity_drop",
) -> dict[str, list[str]]:
docs = [f"{m.role}: {m.content}" for m in self.messages]
return semantic_splitter(
encoder=encoder, docs=docs, threshold=threshold, split_method=split_method
......
......@@ -40,4 +40,4 @@ def setup_custom_logger(name):
return logger
logger = setup_custom_logger(__name__)
logger: logging.Logger = setup_custom_logger(__name__)
import numpy as np
from typing import Literal
from semantic_router.encoders import BaseEncoder
......@@ -7,7 +8,9 @@ def semantic_splitter(
encoder: BaseEncoder,
docs: list[str],
threshold: float,
split_method: str = "consecutive_similarity_drop",
split_method: Literal[
"consecutive_similarity_drop", "cumulative_similarity_drop"
] = "consecutive_similarity_drop",
) -> dict[str, list[str]]:
"""
Splits a list of documents base on semantic similarity changes.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment