From a1fa7419c5847235e2ca9f6ad84dfc287a5d4d3d Mon Sep 17 00:00:00 2001
From: hananel <hananel.hadad@accenture.com>
Date: Thu, 11 Jan 2024 19:05:19 +0200
Subject: [PATCH] additional improvements

---
 semantic_router/layer.py           |  7 ++++---
 semantic_router/route.py           |  2 +-
 semantic_router/schema.py          | 18 ++++++++++--------
 semantic_router/utils/logger.py    |  2 +-
 semantic_router/utils/splitters.py |  5 ++++-
 5 files changed, 20 insertions(+), 14 deletions(-)

diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index cf546bfc..7ff7a15b 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -1,6 +1,6 @@
 import json
 import os
-from typing import Optional
+from typing import Optional, Any
 
 import numpy as np
 import yaml
@@ -14,6 +14,7 @@ from semantic_router.utils.logger import logger
 
 
 def is_valid(layer_config: str) -> bool:
+    """Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]"""
     try:
         output_json = json.loads(layer_config)
         required_keys = ["encoder_name", "encoder_type", "routes"]
@@ -73,7 +74,7 @@ class LayerConfig:
         self.routes = routes
 
     @classmethod
-    def from_file(cls, path: str):
+    def from_file(cls, path: str) -> "LayerConfig":
         """Load the routes from a file in JSON or YAML format"""
         logger.info(f"Loading route config from {path}")
         _, ext = os.path.splitext(path)
@@ -98,7 +99,7 @@ class LayerConfig:
             else:
                 raise Exception("Invalid config JSON or YAML")
 
-    def to_dict(self):
+    def to_dict(self) -> dict[str, Any]:
         return {
             "encoder_type": self.encoder_type,
             "encoder_name": self.encoder_name,
diff --git a/semantic_router/route.py b/semantic_router/route.py
index 6cca7eaf..b492ae13 100644
--- a/semantic_router/route.py
+++ b/semantic_router/route.py
@@ -62,7 +62,7 @@ class Route(BaseModel):
             func_call = None
         return RouteChoice(name=self.name, function_call=func_call)
 
-    def to_dict(self):
+    def to_dict(self) -> dict[str, Any]:
         return self.dict()
 
     @classmethod
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index bb1a4c6a..8d479ec9 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -1,5 +1,5 @@
 from enum import Enum
-from typing import Optional
+from typing import Optional, Literal
 
 from pydantic import BaseModel
 from pydantic.dataclasses import dataclass
@@ -55,12 +55,12 @@ class Message(BaseModel):
     role: str
     content: str
 
-    def to_openai(self):
+    def to_openai(self) -> dict[str, str]:
         if self.role.lower() not in ["user", "assistant", "system"]:
             raise ValueError("Role must be either 'user', 'assistant' or 'system'")
         return {"role": self.role, "content": self.content}
 
-    def to_cohere(self):
+    def to_cohere(self) -> dict[str, str]:
         return {"role": self.role, "message": self.content}
 
 
@@ -68,11 +68,13 @@ class Conversation(BaseModel):
     messages: list[Message]
 
     def split_by_topic(
-        self,
-        encoder: BaseEncoder,
-        threshold: float = 0.5,
-        split_method: str = "consecutive_similarity_drop",
-    ):
+            self,
+            encoder: BaseEncoder,
+            threshold: float = 0.5,
+            split_method: Literal[
+                "consecutive_similarity_drop", "cumulative_similarity_drop"
+            ] = "consecutive_similarity_drop",
+    ) -> dict[str, list[str]]:
         docs = [f"{m.role}: {m.content}" for m in self.messages]
         return semantic_splitter(
             encoder=encoder, docs=docs, threshold=threshold, split_method=split_method
diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py
index 00c83693..607f09d5 100644
--- a/semantic_router/utils/logger.py
+++ b/semantic_router/utils/logger.py
@@ -40,4 +40,4 @@ def setup_custom_logger(name):
     return logger
 
 
-logger = setup_custom_logger(__name__)
+logger: logging.Logger = setup_custom_logger(__name__)
diff --git a/semantic_router/utils/splitters.py b/semantic_router/utils/splitters.py
index f469fbcc..20160319 100644
--- a/semantic_router/utils/splitters.py
+++ b/semantic_router/utils/splitters.py
@@ -1,4 +1,5 @@
 import numpy as np
+from typing import Literal
 
 from semantic_router.encoders import BaseEncoder
 
@@ -7,7 +8,9 @@ def semantic_splitter(
     encoder: BaseEncoder,
     docs: list[str],
     threshold: float,
-    split_method: str = "consecutive_similarity_drop",
+    split_method: Literal[
+        "consecutive_similarity_drop", "cumulative_similarity_drop"
+    ] = "consecutive_similarity_drop",
 ) -> dict[str, list[str]]:
     """
     Splits a list of documents base on semantic similarity changes.
-- 
GitLab