Skip to content
Snippets Groups Projects
Unverified Commit f3b2535d authored by James Briggs's avatar James Briggs
Browse files

modularization and cleanup

parent e26d120a
No related branches found
No related tags found
No related merge requests found
......@@ -25,9 +25,9 @@ class BaseIndex(BaseModel):
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def delete(self, indices_to_remove: List[int]):
def delete(self, route_name: str):
"""
Remove items from the index by their indices.
Deletes route by route name.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
......@@ -39,7 +39,7 @@ class BaseIndex(BaseModel):
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def query(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
"""
Search the index for the query_vector and return top_k results.
This method should be implemented by subclasses.
......
......@@ -15,17 +15,35 @@ class LocalIndex(BaseIndex):
def add(self, embeddings: List[List[float]], routes: List[str], utterances: List[str]):
embeds = np.array(embeddings) # type: ignore
routes_arr = np.array(routes)
utterances_arr = np.array(utterances)
if self.index is None:
self.index = embeds # type: ignore
self.routes = routes_arr
self.utterances = utterances_arr
else:
self.index = np.concatenate([self.index, embeds])
self.routes = np.concatenate([self.routes, routes_arr])
self.utterances = np.concatenate([self.utterances, utterances_arr])
def delete(self, indices_to_remove: List[int]):
def _get_indices_for_route(self, route_name: str):
"""Gets an array of indices for a specific route.
"""
Remove all items of a specific category from the index.
idx = [
i for i, route in enumerate(self.routes)
if route == route_name
]
return idx
def delete(self, route_name: str):
"""
Delete all records of a specific route from the index.
"""
if self.index is not None:
self.index = np.delete(self.index, indices_to_remove, axis=0)
delete_idx = self._get_indices_for_route(route_name=route_name)
self.index = np.delete(self.index, delete_idx, axis=0)
self.routes = np.delete(self.routes, delete_idx, axis=0)
self.utterances = np.delete(self.utterances, delete_idx, axis=0)
def describe(self):
return {
......@@ -34,14 +52,18 @@ class LocalIndex(BaseIndex):
"vectors": self.index.shape[0] if self.index is not None else 0
}
def query(self, query_vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query and return top_k results.
"""
if self.index is None:
raise ValueError("Index is not populated.")
sim = similarity_matrix(query_vector, self.index)
return top_scores(sim, top_k)
sim = similarity_matrix(vector, self.index)
# extract the index values of top scoring vectors
scores, idx = top_scores(sim, top_k)
# get routes from index values
route_names = self.routes[idx].copy()
return scores, route_names
def delete_index(self):
"""
......
......@@ -16,7 +16,7 @@ class PineconeRecord(BaseModel):
def to_dict(self):
return {
"id": self.id,
"id": f"{self.route}#{self.id}",
"values": self.values,
"metadata": {
"sr_route": self.route,
......@@ -105,7 +105,7 @@ class PineconeIndex(BaseIndex):
vectors_to_upsert.append(record.to_dict())
self.index.upsert(vectors=vectors_to_upsert)
def delete(self, ids_to_remove: List[str]):
def delete(self, route_name: str):
self.index.delete(ids=ids_to_remove)
def delete_all(self):
......@@ -119,8 +119,8 @@ class PineconeIndex(BaseIndex):
"vectors": stats["total_vector_count"]
}
def query(self, query_vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
query_vector_list = query_vector.tolist()
def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
query_vector_list = vector.tolist()
results = self.index.query(
vector=[query_vector_list],
top_k=top_k,
......
......@@ -152,7 +152,6 @@ class LayerConfig:
class RouteLayer:
categories: Optional[np.ndarray] = None
score_threshold: float
encoder: BaseEncoder
index: BaseIndex
......@@ -166,7 +165,6 @@ class RouteLayer:
):
logger.info("local")
self.index: BaseIndex = index
self.categories = None
if encoder is None:
logger.warning(
"No encoder provided. Using default OpenAIEncoder. Ensure "
......@@ -208,7 +206,7 @@ class RouteLayer:
vector_arr = self._encode(text=text)
else:
vector_arr = np.array(vector)
# get relevant utterances
# get relevant results (scores and routes)
results = self._retrieve(xq=vector_arr)
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
......@@ -285,24 +283,23 @@ class RouteLayer:
def list_route_names(self) -> List[str]:
return [route.name for route in self.routes]
def remove(self, name: str):
if name not in [route.name for route in self.routes]:
err_msg = f"Route `{name}` not found"
def update(self, route_name: str, utterances: List[str]):
raise NotImplementedError("This method has not yet been implemented.")
def delete(self, route_name: str):
"""Deletes a route given a specific route name.
:param route_name: the name of the route to be deleted
:type str:
"""
if route_name not in [route.name for route in self.routes]:
err_msg = f"Route `{route_name}` not found"
logger.error(err_msg)
raise ValueError(err_msg)
else:
self.routes = [route for route in self.routes if route.name != name]
logger.info(f"Removed route `{name}`")
# Also remove from index and categories
if self.categories is not None and self.index.is_index_populated():
indices_to_remove = [
i
for i, route_name in enumerate(self.categories)
if route_name == name
]
self.index.remove(indices_to_remove)
self.categories = np.delete(self.categories, indices_to_remove, axis=0)
self.routes = [route for route in self.routes if route.name != route_name]
self.index.delete(route_name=route_name)
def _add_routes(self, routes: List[Route]):
# create embeddings for all routes
......@@ -326,13 +323,8 @@ class RouteLayer:
def _retrieve(self, xq: Any, top_k: int = 5) -> List[dict]:
"""Given a query vector, retrieve the top_k most similar records."""
# calculate similarity matrix
if self.index.type == "local":
scores, idx = self.index.query(xq, top_k)
# get the utterance categories (route names)
routes = self.categories[idx] if self.categories is not None else []
elif self.index.type == "pinecone":
scores, routes = self.index.query(xq, top_k)
# get scores and routes
scores, routes = self.index.query(vector=xq, top_k=top_k)
return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]
def _semantic_classify(self, query_results: List[dict]) -> Tuple[str, List[float]]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment