From 37fe08cb4259eba102b5b1970c46bb9444224c16 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Tue, 6 Feb 2024 23:01:33 +0400
Subject: [PATCH] Bug Fixes

Making pydantic leave me alone and giving the correct index name.
---
 semantic_router/indices/local_index.py | 18 +++++++-----------
 semantic_router/layer.py               |  3 ++-
 2 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/semantic_router/indices/local_index.py b/semantic_router/indices/local_index.py
index 784c9967..eef81b91 100644
--- a/semantic_router/indices/local_index.py
+++ b/semantic_router/indices/local_index.py
@@ -1,21 +1,17 @@
 import numpy as np
 from typing import List, Any
-from .base import BaseIndex
 from semantic_router.linear import similarity_matrix, top_scores
-from typing import Tuple
+from pydantic import BaseModel
+import numpy as np
+from typing import List, Any, Tuple, Optional
 
-class LocalIndex(BaseIndex):
-    """
-    Local index implementation using numpy arrays.
-    """
+class LocalIndex(BaseModel):
+    index: Optional[np.ndarray] = None
 
-    def __init__(self):
-        self.index = None
+    class Config: # Stop pydantic from complaining about  Optional[np.ndarray] type hints.
+        arbitrary_types_allowed = True
 
     def add(self, embeds: List[Any]):
-        """
-        Add items to the index.
-        """
         embeds = np.array(embeds)
         if self.index is None:
             self.index = embeds
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 487d8b24..ff1168f2 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -13,6 +13,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM
 from semantic_router.route import Route
 from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index
 from semantic_router.utils.logger import logger
+from semantic_router.indices.local_index import LocalIndex
 
 IndexType = Union[LocalIndex, None]
 
@@ -166,7 +167,7 @@ class RouteLayer:
         index_name: Optional[str] = None,
     ):
         logger.info("local")
-        self.index = Index.get_by_name(index_name="index")
+        self.index = Index.get_by_name(index_name="local")
         self.categories = None
         if encoder is None:
             logger.warning(
-- 
GitLab