Skip to content
Snippets Groups Projects
Commit a1fa7419 authored by hananel's avatar hananel
Browse files

additional improvements

parent da1ff80e
Branches
Tags
No related merge requests found
import json import json
import os import os
from typing import Optional from typing import Optional, Any
import numpy as np import numpy as np
import yaml import yaml
...@@ -14,6 +14,7 @@ from semantic_router.utils.logger import logger ...@@ -14,6 +14,7 @@ from semantic_router.utils.logger import logger
def is_valid(layer_config: str) -> bool: def is_valid(layer_config: str) -> bool:
"""Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]"""
try: try:
output_json = json.loads(layer_config) output_json = json.loads(layer_config)
required_keys = ["encoder_name", "encoder_type", "routes"] required_keys = ["encoder_name", "encoder_type", "routes"]
...@@ -73,7 +74,7 @@ class LayerConfig: ...@@ -73,7 +74,7 @@ class LayerConfig:
self.routes = routes self.routes = routes
@classmethod @classmethod
def from_file(cls, path: str): def from_file(cls, path: str) -> "LayerConfig":
"""Load the routes from a file in JSON or YAML format""" """Load the routes from a file in JSON or YAML format"""
logger.info(f"Loading route config from {path}") logger.info(f"Loading route config from {path}")
_, ext = os.path.splitext(path) _, ext = os.path.splitext(path)
...@@ -98,7 +99,7 @@ class LayerConfig: ...@@ -98,7 +99,7 @@ class LayerConfig:
else: else:
raise Exception("Invalid config JSON or YAML") raise Exception("Invalid config JSON or YAML")
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return { return {
"encoder_type": self.encoder_type, "encoder_type": self.encoder_type,
"encoder_name": self.encoder_name, "encoder_name": self.encoder_name,
......
...@@ -62,7 +62,7 @@ class Route(BaseModel): ...@@ -62,7 +62,7 @@ class Route(BaseModel):
func_call = None func_call = None
return RouteChoice(name=self.name, function_call=func_call) return RouteChoice(name=self.name, function_call=func_call)
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return self.dict() return self.dict()
@classmethod @classmethod
......
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Literal
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
...@@ -55,12 +55,12 @@ class Message(BaseModel): ...@@ -55,12 +55,12 @@ class Message(BaseModel):
role: str role: str
content: str content: str
def to_openai(self): def to_openai(self) -> dict[str, str]:
if self.role.lower() not in ["user", "assistant", "system"]: if self.role.lower() not in ["user", "assistant", "system"]:
raise ValueError("Role must be either 'user', 'assistant' or 'system'") raise ValueError("Role must be either 'user', 'assistant' or 'system'")
return {"role": self.role, "content": self.content} return {"role": self.role, "content": self.content}
def to_cohere(self): def to_cohere(self) -> dict[str, str]:
return {"role": self.role, "message": self.content} return {"role": self.role, "message": self.content}
...@@ -68,11 +68,13 @@ class Conversation(BaseModel): ...@@ -68,11 +68,13 @@ class Conversation(BaseModel):
messages: list[Message] messages: list[Message]
def split_by_topic( def split_by_topic(
self, self,
encoder: BaseEncoder, encoder: BaseEncoder,
threshold: float = 0.5, threshold: float = 0.5,
split_method: str = "consecutive_similarity_drop", split_method: Literal[
): "consecutive_similarity_drop", "cumulative_similarity_drop"
] = "consecutive_similarity_drop",
) -> dict[str, list[str]]:
docs = [f"{m.role}: {m.content}" for m in self.messages] docs = [f"{m.role}: {m.content}" for m in self.messages]
return semantic_splitter( return semantic_splitter(
encoder=encoder, docs=docs, threshold=threshold, split_method=split_method encoder=encoder, docs=docs, threshold=threshold, split_method=split_method
......
...@@ -40,4 +40,4 @@ def setup_custom_logger(name): ...@@ -40,4 +40,4 @@ def setup_custom_logger(name):
return logger return logger
logger = setup_custom_logger(__name__) logger: logging.Logger = setup_custom_logger(__name__)
import numpy as np import numpy as np
from typing import Literal
from semantic_router.encoders import BaseEncoder from semantic_router.encoders import BaseEncoder
...@@ -7,7 +8,9 @@ def semantic_splitter( ...@@ -7,7 +8,9 @@ def semantic_splitter(
encoder: BaseEncoder, encoder: BaseEncoder,
docs: list[str], docs: list[str],
threshold: float, threshold: float,
split_method: str = "consecutive_similarity_drop", split_method: Literal[
"consecutive_similarity_drop", "cumulative_similarity_drop"
] = "consecutive_similarity_drop",
) -> dict[str, list[str]]: ) -> dict[str, list[str]]:
""" """
Splits a list of documents base on semantic similarity changes. Splits a list of documents base on semantic similarity changes.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment