Skip to content
Snippets Groups Projects
Unverified Commit 4c49baf5 authored by James Briggs's avatar James Briggs
Browse files

lint

parent ccdef6d3
No related branches found
No related tags found
No related merge requests found
...@@ -55,7 +55,7 @@ class LayerConfig: ...@@ -55,7 +55,7 @@ class LayerConfig:
def __init__( def __init__(
self, self,
routes: list[Route] = [], routes: list[Route] = [],
encoder_type: EncoderType = "openai", encoder_type: str = "openai",
encoder_name: str | None = None, encoder_name: str | None = None,
): ):
self.encoder_type = encoder_type self.encoder_type = encoder_type
...@@ -184,18 +184,18 @@ class RouteLayer: ...@@ -184,18 +184,18 @@ class RouteLayer:
@classmethod @classmethod
def from_json(cls, file_path: str): def from_json(cls, file_path: str):
config = LayerConfig.from_file(file_path) config = LayerConfig.from_file(file_path)
encoder = Encoder(type=config.encoder_type, name=config.encoder_name) encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes) return cls(encoder=encoder, routes=config.routes)
@classmethod @classmethod
def from_yaml(cls, file_path: str): def from_yaml(cls, file_path: str):
config = LayerConfig.from_file(file_path) config = LayerConfig.from_file(file_path)
encoder = Encoder(type=config.encoder_type, name=config.encoder_name) encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes) return cls(encoder=encoder, routes=config.routes)
@classmethod @classmethod
def from_config(cls, config: LayerConfig): def from_config(cls, config: LayerConfig):
encoder = Encoder(type=config.encoder_type, name=config.encoder_name) encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes) return cls(encoder=encoder, routes=config.routes)
def add(self, route: Route): def add(self, route: Route):
......
...@@ -113,7 +113,7 @@ class Route(BaseModel): ...@@ -113,7 +113,7 @@ class Route(BaseModel):
{function_schema} {function_schema}
""" """
output = await llm(prompt) output = llm(prompt)
if not output: if not output:
raise Exception("No output generated for dynamic route") raise Exception("No output generated for dynamic route")
......
...@@ -23,12 +23,12 @@ class RouteChoice(BaseModel): ...@@ -23,12 +23,12 @@ class RouteChoice(BaseModel):
@dataclass @dataclass
class Encoder: class Encoder:
type: EncoderType type: str
name: str name: str | None
model: BaseEncoder model: BaseEncoder
def __init__(self, type: str, name: str): def __init__(self, type: str, name: str | None):
self.type = EncoderType(type) self.type = type
self.name = name self.name = name
if self.type == EncoderType.HUGGINGFACE: if self.type == EncoderType.HUGGINGFACE:
raise NotImplementedError raise NotImplementedError
...@@ -36,6 +36,8 @@ class Encoder: ...@@ -36,6 +36,8 @@ class Encoder:
self.model = OpenAIEncoder(name) self.model = OpenAIEncoder(name)
elif self.type == EncoderType.COHERE: elif self.type == EncoderType.COHERE:
self.model = CohereEncoder(name) self.model = CohereEncoder(name)
else:
raise NotImplementedError
def __call__(self, texts: list[str]) -> list[list[float]]: def __call__(self, texts: list[str]) -> list[list[float]]:
return self.model(texts) return self.model(texts)
...@@ -117,11 +117,11 @@ async def route_and_execute(query: str, functions: list[Callable], route_layer): ...@@ -117,11 +117,11 @@ async def route_and_execute(query: str, functions: list[Callable], route_layer):
function_name = route_layer(query) function_name = route_layer(query)
if not function_name: if not function_name:
logger.warning("No function found, calling LLM...") logger.warning("No function found, calling LLM...")
return await llm(query) return llm(query)
for function in functions: for function in functions:
if function.__name__ == function_name: if function.__name__ == function_name:
print(f"Calling function: {function.__name__}") print(f"Calling function: {function.__name__}")
schema = get_schema(function) schema = get_schema(function)
inputs = await extract_function_inputs(query, schema) inputs = extract_function_inputs(query, schema)
call_function(function, inputs) call_function(function, inputs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment