Skip to content
Snippets Groups Projects
Unverified Commit 4f30ca3f authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Fixed Some More Bugs

parent 3cee0236
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,7 @@ class PineconeIndex(BaseIndex):
cloud: str = "aws"
region: str = "us-west-2"
pinecone: Any = Field(default=None, exclude=True)
vector_id_counter: int = 0
vector_id_counter: int = -1
def __init__(self, **data):
super().__init__(**data)
......@@ -33,15 +33,12 @@ class PineconeIndex(BaseIndex):
)
self.index = self.pinecone.Index(self.index_name)
# Store the index name for potential deletion
self.index_name = self.index_name
def add(self, embeds: List[List[float]]):
# Format embeds as a list of dictionaries for Pinecone's upsert method
vectors_to_upsert = []
for i, vector in enumerate(embeds):
# Generate a unique ID for each vector
vector_id = f"vec{i+1}"
for vector in embeds:
self.vector_id_counter += 1 # Increment the counter for each new vector
vector_id = str(self.vector_id_counter) # Convert counter to string ID
# Prepare for upsert
vectors_to_upsert.append({"id": vector_id, "values": vector})
......@@ -52,18 +49,31 @@ class PineconeIndex(BaseIndex):
def remove(self, ids_to_remove: List[str]):
self.index.delete(ids=ids_to_remove)
def remove_all(self):
self.index.delete(delete_all=True)
def is_index_populated(self) -> bool:
stats = self.index.describe_index_stats()
return stats["dimension"] > 0 and stats["index_size"] > 0
return stats["dimension"] > 0 and stats["total_vector_count"] > 0
def query(self, query_vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
results = self.index.query(queries=[query_vector], top_k=top_k)
ids = [result["id"] for result in results["matches"]]
query_vector_list = query_vector.tolist()
results = self.index.query(vector=[query_vector_list], top_k=top_k)
ids = [int(result["id"]) for result in results["matches"]]
scores = [result["score"] for result in results["matches"]]
return np.array(ids), np.array(scores)
# DEBUGGING: Start.
print('#'*50)
print('ids')
print(ids)
print('#'*50)
# DEBUGGING: End.
# DEBUGGING: Start.
print('#'*50)
print('scores')
print(scores)
print('#'*50)
# DEBUGGING: End.
return np.array(scores), np.array(ids)
def delete_index(self):
"""
Deletes the Pinecone index.
"""
pinecone.delete_index(self.index_name)
\ No newline at end of file
......@@ -188,6 +188,9 @@ class RouteLayer:
self._add_routes(routes=self.routes)
def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
# DEBUGGING: Start.
print(f'top_class 2: {top_class}')
# DEBUGGING: End.
matching_routes = [route for route in self.routes if route.name == top_class]
if not matching_routes:
logger.error(
......@@ -210,8 +213,17 @@ class RouteLayer:
vector_arr = np.array(vector)
# get relevant utterances
results = self._retrieve(xq=vector_arr)
# DEBUGGING: Start.
print(f'results: {results}')
# DEBUGGING: End.
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
# DEBUGGING: Start.
print(f'top_class 1: {top_class}')
# DEBUGGING: End.
# DEBUGGING: Start.
print(f'top_class_scores: {top_class_scores}')
# DEBUGGING: End.
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
if route is None:
......@@ -221,6 +233,24 @@ class RouteLayer:
if route.score_threshold is not None
else self.score_threshold
)
# DEBUGGING: Start.
print('#'*50)
print('Chosen route')
print(route)
print('#'*50)
# DEBUGGING: End.
# DEBUGGING: Start.
print('#'*50)
print('top_class_scores')
print(top_class_scores)
print('#'*50)
# DEBUGGING: End.
# DEBUGGING: Start.
print('#'*50)
print('threshold')
print(threshold)
print('#'*50)
# DEBUGGING: End.
passed = self._pass_threshold(top_class_scores, threshold)
if passed:
if route.function_schema and text is None:
......@@ -334,7 +364,17 @@ class RouteLayer:
def _retrieve(self, xq: Any, top_k: int = 5) -> List[dict]:
"""Given a query vector, retrieve the top_k most similar records."""
# DEBUGGING: Start.
print('#'*50)
print('RouteLayer._retrieve - CHECKPOINT 1')
print('#'*50)
# DEBUGGING: End.
if self.index.is_index_populated():
# DEBUGGING: Start.
print('#'*50)
print('RouteLayer._retrieve - CHECKPOINT 2')
print('#'*50)
# DEBUGGING: End.
# calculate similarity matrix
scores, idx = self.index.query(xq, top_k)
# get the utterance categories (route names)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment