Skip to content
Snippets Groups Projects
Unverified Commit 4c49baf5 authored by James Briggs's avatar James Briggs
Browse files

lint

parent ccdef6d3
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,7 @@ class LayerConfig:
def __init__(
self,
routes: list[Route] = [],
encoder_type: EncoderType = "openai",
encoder_type: str = "openai",
encoder_name: str | None = None,
):
self.encoder_type = encoder_type
......@@ -184,18 +184,18 @@ class RouteLayer:
@classmethod
def from_json(cls, file_path: str):
config = LayerConfig.from_file(file_path)
encoder = Encoder(type=config.encoder_type, name=config.encoder_name)
encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes)
@classmethod
def from_yaml(cls, file_path: str):
config = LayerConfig.from_file(file_path)
encoder = Encoder(type=config.encoder_type, name=config.encoder_name)
encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes)
@classmethod
def from_config(cls, config: LayerConfig):
encoder = Encoder(type=config.encoder_type, name=config.encoder_name)
encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
return cls(encoder=encoder, routes=config.routes)
def add(self, route: Route):
......
......@@ -113,7 +113,7 @@ class Route(BaseModel):
{function_schema}
"""
output = await llm(prompt)
output = llm(prompt)
if not output:
raise Exception("No output generated for dynamic route")
......
......@@ -23,12 +23,12 @@ class RouteChoice(BaseModel):
@dataclass
class Encoder:
type: EncoderType
name: str
type: str
name: str | None
model: BaseEncoder
def __init__(self, type: str, name: str):
self.type = EncoderType(type)
def __init__(self, type: str, name: str | None):
self.type = type
self.name = name
if self.type == EncoderType.HUGGINGFACE:
raise NotImplementedError
......@@ -36,6 +36,8 @@ class Encoder:
self.model = OpenAIEncoder(name)
elif self.type == EncoderType.COHERE:
self.model = CohereEncoder(name)
else:
raise NotImplementedError
def __call__(self, texts: list[str]) -> list[list[float]]:
return self.model(texts)
......@@ -117,11 +117,11 @@ async def route_and_execute(query: str, functions: list[Callable], route_layer):
function_name = route_layer(query)
if not function_name:
logger.warning("No function found, calling LLM...")
return await llm(query)
return llm(query)
for function in functions:
if function.__name__ == function_name:
print(f"Calling function: {function.__name__}")
schema = get_schema(function)
inputs = await extract_function_inputs(query, schema)
inputs = extract_function_inputs(query, schema)
call_function(function, inputs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment