From 3530c854ed049a80e597bd7126b2a39c522651ae Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Tue, 26 Dec 2023 21:30:24 +0100
Subject: [PATCH] restructure for how we use config object and fix circular
 imports

---
 semantic_router/__init__.py            |   6 +-
 semantic_router/encoders/__init__.py   |   8 +-
 semantic_router/hybrid_layer.py        |   2 +-
 semantic_router/layer.py               | 165 +++++++++++++++++++++++--
 semantic_router/route.py               | 113 +++--------------
 semantic_router/schema.py              |  21 +---
 semantic_router/utils/function_call.py |   4 +-
 semantic_router/utils/llm.py           |  31 ++++-
 tests/unit/test_schema.py              |  18 ---
 9 files changed, 216 insertions(+), 152 deletions(-)

diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py
index 2659bfe3..fd3198cf 100644
--- a/semantic_router/__init__.py
+++ b/semantic_router/__init__.py
@@ -1,5 +1,5 @@
-from .hybrid_layer import HybridRouteLayer
-from .layer import RouteLayer
-from .route import Route, RouteConfig
+from semantic_router.hybrid_layer import HybridRouteLayer
+from semantic_router.layer import RouteLayer
+from semantic_router.route import Route, RouteConfig
 
 __all__ = ["RouteLayer", "HybridRouteLayer", "Route", "RouteConfig"]
diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py
index 30ad624a..ac27ebb4 100644
--- a/semantic_router/encoders/__init__.py
+++ b/semantic_router/encoders/__init__.py
@@ -1,6 +1,6 @@
-from .base import BaseEncoder
-from .bm25 import BM25Encoder
-from .cohere import CohereEncoder
-from .openai import OpenAIEncoder
+from semantic_router.encoders.base import BaseEncoder
+from semantic_router.encoders.bm25 import BM25Encoder
+from semantic_router.encoders.cohere import CohereEncoder
+from semantic_router.encoders.openai import OpenAIEncoder
 
 __all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"]
diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py
index 475a12f0..22f6573c 100644
--- a/semantic_router/hybrid_layer.py
+++ b/semantic_router/hybrid_layer.py
@@ -9,7 +9,7 @@ from semantic_router.encoders import (
 )
 from semantic_router.utils.logger import logger
 
-from .route import Route
+from semantic_router.route import Route
 
 
 class HybridRouteLayer:
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 2fa3b863..00dde5c2 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -1,4 +1,5 @@
 import json
+import os
 
 import numpy as np
 import yaml
@@ -11,7 +12,121 @@ from semantic_router.encoders import (
 from semantic_router.linear import similarity_matrix, top_scores
 from semantic_router.utils.logger import logger
 
-from .route import Route
+from semantic_router.route import Route
+from semantic_router.schema import Encoder, EncoderType, RouteChoice
+
+
+def is_valid(route_config: str) -> bool:
+    try:
+        output_json = json.loads(route_config)
+        required_keys = ["name", "utterances"]
+
+        if isinstance(output_json, list):
+            for item in output_json:
+                missing_keys = [key for key in required_keys if key not in item]
+                if missing_keys:
+                    logger.warning(
+                        f"Missing keys in route config: {', '.join(missing_keys)}"
+                    )
+                    return False
+            return True
+        else:
+            missing_keys = [key for key in required_keys if key not in output_json]
+            if missing_keys:
+                logger.warning(
+                    f"Missing keys in route config: {', '.join(missing_keys)}"
+                )
+                return False
+            else:
+                return True
+    except json.JSONDecodeError as e:
+        logger.error(e)
+        return False
+
+
+class LayerConfig:
+    """
+    Generates a LayerConfig object that can be used for initializing a
+    RouteLayer.
+    """
+
+    routes: list[Route] = []
+
+    def __init__(
+        self,
+        routes: list[Route] = [],
+        encoder_type: EncoderType = "openai",
+        encoder_name: str | None = None,
+    ):
+        self.encoder_type = encoder_type
+        if encoder_name is None:
+            # if encoder_name is not provided, use the default encoder for type
+            if encoder_type == EncoderType.OPENAI:
+                encoder_name = "text-embedding-ada-002"
+            elif encoder_type == EncoderType.COHERE:
+                encoder_name = "embed-english-v3.0"
+            elif encoder_type == EncoderType.HUGGINGFACE:
+                raise NotImplementedError
+            logger.info(f"Using default {encoder_type} encoder: {encoder_name}")
+        self.encoder_name = encoder_name
+        self.routes = routes
+
+    @classmethod
+    def from_file(cls, path: str):
+        """Load the routes from a file in JSON or YAML format"""
+        logger.info(f"Loading route config from {path}")
+        _, ext = os.path.splitext(path)
+        with open(path, "r") as f:
+            if ext == ".json":
+                routes = json.load(f)
+            elif ext in [".yaml", ".yml"]:
+                routes = yaml.safe_load(f)
+            else:
+                raise ValueError(
+                    "Unsupported file type. Only .json and .yaml are supported"
+                )
+
+            route_config_str = json.dumps(routes)
+            if is_valid(route_config_str):
+                routes = [Route.from_dict(route) for route in routes]
+                return cls(routes=routes)
+            else:
+                raise Exception("Invalid config JSON or YAML")
+
+    def to_dict(self):
+        return [route.to_dict() for route in self.routes]
+
+    def to_file(self, path: str):
+        """Save the routes to a file in JSON or YAML format"""
+        logger.info(f"Saving route config to {path}")
+        _, ext = os.path.splitext(path)
+        with open(path, "w") as f:
+            if ext == ".json":
+                json.dump(self.to_dict(), f)
+            elif ext in [".yaml", ".yml"]:
+                yaml.safe_dump(self.to_dict(), f)
+            else:
+                raise ValueError(
+                    "Unsupported file type. Only .json and .yaml are supported"
+                )
+
+    def add(self, route: Route):
+        self.routes.append(route)
+        logger.info(f"Added route `{route.name}`")
+
+    def get(self, name: str) -> Route | None:
+        for route in self.routes:
+            if route.name == name:
+                return route
+        logger.error(f"Route `{name}` not found")
+        return None
+
+    def remove(self, name: str):
+        if name not in [route.name for route in self.routes]:
+            logger.error(f"Route `{name}` not found")
+        else:
+            self.routes = [route for route in self.routes if route.name != name]
+            logger.info(f"Removed route `{name}`")
 
 
 class RouteLayer:
@@ -34,28 +149,52 @@ class RouteLayer:
             # initialize index now
             self._add_routes(routes=routes)
 
-    def __call__(self, text: str) -> str | None:
+    def __call__(self, text: str) -> RouteChoice:
         results = self._query(text)
         top_class, top_class_scores = self._semantic_classify(results)
         passed = self._pass_threshold(top_class_scores, self.score_threshold)
         if passed:
-            return top_class
+            # get chosen route object
+            route = [route for route in self.routes if route.name == top_class][0]
+            return route(text)
         else:
-            return None
+            # if no route passes threshold, return empty route choice
+            return RouteChoice()
 
     @classmethod
     def from_json(cls, file_path: str):
-        with open(file_path, "r") as f:
-            routes_data = json.load(f)
-        routes = [Route.from_dict(route_data) for route_data in routes_data]
-        return cls(routes=routes)
+        config = LayerConfig.from_file(file_path)
+        encoder = Encoder(
+            encoder_type=config.encoder_type,
+            encoder_name=config.encoder_name
+        )
+        return cls(
+            encoder=encoder,
+            routes=config.routes
+        )
 
     @classmethod
     def from_yaml(cls, file_path: str):
-        with open(file_path, "r") as f:
-            routes_data = yaml.load(f, Loader=yaml.FullLoader)
-        routes = [Route.from_dict(route_data) for route_data in routes_data]
-        return cls(routes=routes)
+        config = LayerConfig.from_file(file_path)
+        encoder = Encoder(
+            encoder_type=config.encoder_type,
+            encoder_name=config.encoder_name
+        )
+        return cls(
+            encoder=encoder,
+            routes=config.routes
+        )
+    
+    @classmethod
+    def from_config(cls, config: LayerConfig):
+        encoder = Encoder(
+            encoder_type=config.encoder_type,
+            encoder_name=config.encoder_name
+        )
+        return cls(
+            encoder=encoder,
+            routes=config.routes
+        )
 
     def add(self, route: Route):
         # create embeddings
@@ -73,6 +212,8 @@ class RouteLayer:
         else:
             embed_arr = np.array(embeds)
             self.index = np.concatenate([self.index, embed_arr])
+        # add route to routes list
+        self.routes.append(route)
 
     def _add_routes(self, routes: list[Route]):
         # create embeddings for all routes
diff --git a/semantic_router/route.py b/semantic_router/route.py
index 99a7945b..be520da9 100644
--- a/semantic_router/route.py
+++ b/semantic_router/route.py
@@ -1,48 +1,35 @@
 import json
-import os
 import re
 from typing import Any, Callable, Union
 
-import yaml
 from pydantic import BaseModel
 
 from semantic_router.utils import function_call
 from semantic_router.utils.llm import llm
 from semantic_router.utils.logger import logger
-
-
-def is_valid(route_config: str) -> bool:
-    try:
-        output_json = json.loads(route_config)
-        required_keys = ["name", "utterances"]
-
-        if isinstance(output_json, list):
-            for item in output_json:
-                missing_keys = [key for key in required_keys if key not in item]
-                if missing_keys:
-                    logger.warning(
-                        f"Missing keys in route config: {', '.join(missing_keys)}"
-                    )
-                    return False
-            return True
-        else:
-            missing_keys = [key for key in required_keys if key not in output_json]
-            if missing_keys:
-                logger.warning(
-                    f"Missing keys in route config: {', '.join(missing_keys)}"
-                )
-                return False
-            else:
-                return True
-    except json.JSONDecodeError as e:
-        logger.error(e)
-        return False
+from semantic_router.schema import RouteChoice
 
 
 class Route(BaseModel):
     name: str
     utterances: list[str]
     description: str | None = None
+    function_schema: dict[str, Any] | None = None
+
+    def __call__(self, query: str) -> RouteChoice:
+        if self.function_schema:
+            # if a function schema is provided we generate the inputs
+            extracted_inputs = function_call.extract_function_inputs(
+                query=query, function_schema=self.function_schema
+            )
+            function_call = extracted_inputs
+        else:
+            # otherwise we just pass None for the call
+            function_call = None
+        return RouteChoice(
+            name=self.name,
+            function_call=function_call
+        )
 
     def to_dict(self):
         return self.dict()
@@ -114,69 +101,3 @@ class Route(BaseModel):
         raise Exception("No config generated")
 
 
-class RouteConfig:
-    """
-    Generates a RouteConfig object from a list of Route objects
-    """
-
-    routes: list[Route] = []
-
-    def __init__(self, routes: list[Route] = []):
-        self.routes = routes
-
-    @classmethod
-    def from_file(cls, path: str):
-        """Load the routes from a file in JSON or YAML format"""
-        logger.info(f"Loading route config from {path}")
-        _, ext = os.path.splitext(path)
-        with open(path, "r") as f:
-            if ext == ".json":
-                routes = json.load(f)
-            elif ext in [".yaml", ".yml"]:
-                routes = yaml.safe_load(f)
-            else:
-                raise ValueError(
-                    "Unsupported file type. Only .json and .yaml are supported"
-                )
-
-            route_config_str = json.dumps(routes)
-            if is_valid(route_config_str):
-                routes = [Route.from_dict(route) for route in routes]
-                return cls(routes=routes)
-            else:
-                raise Exception("Invalid config JSON or YAML")
-
-    def to_dict(self):
-        return [route.to_dict() for route in self.routes]
-
-    def to_file(self, path: str):
-        """Save the routes to a file in JSON or YAML format"""
-        logger.info(f"Saving route config to {path}")
-        _, ext = os.path.splitext(path)
-        with open(path, "w") as f:
-            if ext == ".json":
-                json.dump(self.to_dict(), f)
-            elif ext in [".yaml", ".yml"]:
-                yaml.safe_dump(self.to_dict(), f)
-            else:
-                raise ValueError(
-                    "Unsupported file type. Only .json and .yaml are supported"
-                )
-
-    def add(self, route: Route):
-        self.routes.append(route)
-        logger.info(f"Added route `{route.name}`")
-
-    def get(self, name: str) -> Route | None:
-        for route in self.routes:
-            if route.name == name:
-                return route
-        logger.error(f"Route `{name}` not found")
-        return None
-
-    def remove(self, name: str):
-        if name not in [route.name for route in self.routes]:
-            logger.error(f"Route `{name}` not found")
-        else:
-            self.routes = [route for route in self.routes if route.name != name]
-            logger.info(f"Removed route `{name}`")
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 4646a637..a3d786db 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -1,8 +1,8 @@
 from enum import Enum
 
 from pydantic.dataclasses import dataclass
+from pydantic import BaseModel
 
-from semantic_router import Route
 from semantic_router.encoders import (
     BaseEncoder,
     CohereEncoder,
@@ -16,6 +16,11 @@ class EncoderType(Enum):
     COHERE = "cohere"
 
 
+class RouteChoice(BaseModel):
+    name: str | None = None
+    function_call: dict | None = None
+
+
 @dataclass
 class Encoder:
     type: EncoderType
@@ -34,17 +39,3 @@ class Encoder:
 
     def __call__(self, texts: list[str]) -> list[list[float]]:
         return self.model(texts)
-
-
-@dataclass
-class SemanticSpace:
-    id: str
-    routes: list[Route]
-    encoder: str = ""
-
-    def __init__(self, routes: list[Route] = []):
-        self.id = ""
-        self.routes = routes
-
-    def add(self, route: Route):
-        self.routes.append(route)
diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py
index c1b4fcee..9504dfb8 100644
--- a/semantic_router/utils/function_call.py
+++ b/semantic_router/utils/function_call.py
@@ -40,7 +40,7 @@ def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]:
     return schema
 
 
-async def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict:
+def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict:
     logger.info("Extracting function input...")
 
     prompt = f"""
@@ -72,7 +72,7 @@ async def extract_function_inputs(query: str, function_schema: dict[str, Any]) -
     Result:
     """
 
-    output = await llm(prompt)
+    output = llm(prompt)
     if not output:
         raise Exception("No output generated for extract function input")
 
diff --git a/semantic_router/utils/llm.py b/semantic_router/utils/llm.py
index e912ee1f..0d22b9a6 100644
--- a/semantic_router/utils/llm.py
+++ b/semantic_router/utils/llm.py
@@ -5,7 +5,36 @@ import openai
 from semantic_router.utils.logger import logger
 
 
-async def llm(prompt: str) -> str | None:
+def llm(prompt: str) -> str | None:
+    try:
+        client = openai.OpenAI(
+            base_url="https://openrouter.ai/api/v1",
+            api_key=os.getenv("OPENROUTER_API_KEY"),
+        )
+
+        completion = client.chat.completions.create(
+            model="mistralai/mistral-7b-instruct",
+            messages=[
+                {
+                    "role": "user",
+                    "content": prompt,
+                },
+            ],
+            temperature=0.01,
+            max_tokens=200,
+        )
+
+        output = completion.choices[0].message.content
+
+        if not output:
+            raise Exception("No output generated")
+        return output
+    except Exception as e:
+        logger.error(f"LLM error: {e}")
+        raise Exception(f"LLM error: {e}")
+
+
+async def allm(prompt: str) -> str | None:
     try:
         client = openai.AsyncOpenAI(
             base_url="https://openrouter.ai/api/v1",
diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py
index 27c73c9f..46799ee8 100644
--- a/tests/unit/test_schema.py
+++ b/tests/unit/test_schema.py
@@ -6,7 +6,6 @@ from semantic_router.schema import (
     Encoder,
     EncoderType,
     OpenAIEncoder,
-    SemanticSpace,
 )
 
 
@@ -40,20 +39,3 @@ class TestEncoderDataclass:
         encoder = Encoder(type="openai", name="test-engine")
         result = encoder(["test"])
         assert result == [0.1, 0.2, 0.3]
-
-
-class TestSemanticSpaceDataclass:
-    def test_semanticspace_initialization(self):
-        semantic_space = SemanticSpace()
-        assert semantic_space.id == ""
-        assert semantic_space.routes == []
-
-    def test_semanticspace_add_route(self):
-        route = Route(name="test", utterances=["hello", "hi"], description="greeting")
-        semantic_space = SemanticSpace()
-        semantic_space.add(route)
-
-        assert len(semantic_space.routes) == 1
-        assert semantic_space.routes[0].name == "test"
-        assert semantic_space.routes[0].utterances == ["hello", "hi"]
-        assert semantic_space.routes[0].description == "greeting"
-- 
GitLab