From a1fa7419c5847235e2ca9f6ad84dfc287a5d4d3d Mon Sep 17 00:00:00 2001 From: hananel <hananel.hadad@accenture.com> Date: Thu, 11 Jan 2024 19:05:19 +0200 Subject: [PATCH] additional improvements --- semantic_router/layer.py | 7 ++++--- semantic_router/route.py | 2 +- semantic_router/schema.py | 18 ++++++++++-------- semantic_router/utils/logger.py | 2 +- semantic_router/utils/splitters.py | 5 ++++- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index cf546bfc..7ff7a15b 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -1,6 +1,6 @@ import json import os -from typing import Optional +from typing import Optional, Any import numpy as np import yaml @@ -14,6 +14,7 @@ from semantic_router.utils.logger import logger def is_valid(layer_config: str) -> bool: + """Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]""" try: output_json = json.loads(layer_config) required_keys = ["encoder_name", "encoder_type", "routes"] @@ -73,7 +74,7 @@ class LayerConfig: self.routes = routes @classmethod - def from_file(cls, path: str): + def from_file(cls, path: str) -> "LayerConfig": """Load the routes from a file in JSON or YAML format""" logger.info(f"Loading route config from {path}") _, ext = os.path.splitext(path) @@ -98,7 +99,7 @@ class LayerConfig: else: raise Exception("Invalid config JSON or YAML") - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "encoder_type": self.encoder_type, "encoder_name": self.encoder_name, diff --git a/semantic_router/route.py b/semantic_router/route.py index 6cca7eaf..b492ae13 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -62,7 +62,7 @@ class Route(BaseModel): func_call = None return RouteChoice(name=self.name, function_call=func_call) - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return self.dict() @classmethod diff --git a/semantic_router/schema.py b/semantic_router/schema.py index bb1a4c6a..8d479ec9 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Optional, Literal from pydantic import BaseModel from pydantic.dataclasses import dataclass @@ -55,12 +55,12 @@ class Message(BaseModel): role: str content: str - def to_openai(self): + def to_openai(self) -> dict[str, str]: if self.role.lower() not in ["user", "assistant", "system"]: raise ValueError("Role must be either 'user', 'assistant' or 'system'") return {"role": self.role, "content": self.content} - def to_cohere(self): + def to_cohere(self) -> dict[str, str]: return {"role": self.role, "message": self.content} @@ -68,11 +68,13 @@ class Conversation(BaseModel): messages: list[Message] def split_by_topic( - self, - encoder: BaseEncoder, - threshold: float = 0.5, - split_method: str = "consecutive_similarity_drop", - ): + self, + encoder: BaseEncoder, + threshold: float = 0.5, + split_method: Literal[ + "consecutive_similarity_drop", "cumulative_similarity_drop" + ] = "consecutive_similarity_drop", + ) -> dict[str, list[str]]: docs = [f"{m.role}: {m.content}" for m in self.messages] return semantic_splitter( encoder=encoder, docs=docs, threshold=threshold, split_method=split_method diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py index 00c83693..607f09d5 100644 --- a/semantic_router/utils/logger.py +++ b/semantic_router/utils/logger.py @@ -40,4 +40,4 @@ def setup_custom_logger(name): return logger -logger = setup_custom_logger(__name__) +logger: logging.Logger = setup_custom_logger(__name__) diff --git a/semantic_router/utils/splitters.py b/semantic_router/utils/splitters.py index f469fbcc..20160319 100644 --- a/semantic_router/utils/splitters.py +++ b/semantic_router/utils/splitters.py @@ -1,4 +1,5 @@ import numpy as np +from typing import Literal from semantic_router.encoders import BaseEncoder @@ -7,7 +8,9 @@ def semantic_splitter( encoder: BaseEncoder, docs: list[str], threshold: float, - split_method: str = "consecutive_similarity_drop", + split_method: Literal[ + "consecutive_similarity_drop", "cumulative_similarity_drop" + ] = "consecutive_similarity_drop", ) -> dict[str, list[str]]: """ Splits a list of documents base on semantic similarity changes. -- GitLab