Skip to content
Snippets Groups Projects
Unverified Commit 59600647 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Linting.

parent 8d2de77e
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,7 @@ from semantic_router.schema import Encoder, EncoderType, RouteChoice
from semantic_router.utils.logger import logger
import importlib
def is_valid(layer_config: str) -> bool:
"""Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]"""
try:
......@@ -85,7 +86,9 @@ class LayerConfig:
elif ext in [".yaml", ".yml"]:
layer = yaml.safe_load(f)
else:
raise ValueError("Unsupported file type. Only .json and .yaml are supported")
raise ValueError(
"Unsupported file type. Only .json and .yaml are supported"
)
if not is_valid(json.dumps(layer)):
raise Exception("Invalid config JSON or YAML")
......@@ -95,22 +98,28 @@ class LayerConfig:
routes = []
for route_data in layer["routes"]:
# Handle the 'llm' field specially if it exists
if 'llm' in route_data:
llm_data = route_data.pop('llm') # Remove 'llm' from route_data and handle it separately
if "llm" in route_data:
llm_data = route_data.pop(
"llm"
) # Remove 'llm' from route_data and handle it separately
# Use the module path directly from llm_data without modification
llm_module_path = llm_data['module']
llm_module_path = llm_data["module"]
# Dynamically import the module and then the class from that module
llm_module = importlib.import_module(llm_module_path)
llm_class = getattr(llm_module, llm_data['class'])
llm_class = getattr(llm_module, llm_data["class"])
# Instantiate the LLM class with the provided model name
llm = llm_class(name=llm_data['model'])
route_data['llm'] = llm # Reassign the instantiated llm object back to route_data
llm = llm_class(name=llm_data["model"])
route_data[
"llm"
] = llm # Reassign the instantiated llm object back to route_data
# Dynamically create the Route object using the remaining route_data
route = Route(**route_data)
routes.append(route)
return cls(encoder_type=encoder_type, encoder_name=encoder_name, routes=routes)
return cls(
encoder_type=encoder_type, encoder_name=encoder_name, routes=routes
)
def to_dict(self) -> Dict[str, Any]:
return {
......
......@@ -70,14 +70,14 @@ class Route(BaseModel):
# def to_dict(self) -> Dict[str, Any]:
# return self.dict()
def to_dict(self) -> Dict[str, Any]:
data = self.dict()
if self.llm is not None:
data['llm'] = {
'module': self.llm.__module__,
'class': self.llm.__class__.__name__,
'model': self.llm.name
data["llm"] = {
"module": self.llm.__module__,
"class": self.llm.__class__.__name__,
"model": self.llm.name,
}
return data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment