Skip to content
Snippets Groups Projects
Unverified Commit 4913a447 authored by James Briggs's avatar James Briggs
Browse files

change id logic for no-overwrites and easy route deletion

parent f3b2535d
Branches
Tags
No related merge requests found
from pydantic.v1 import BaseModel, Field from pydantic.v1 import BaseModel, Field
import requests
import time import time
import hashlib
import os import os
from typing import Any, List, Tuple, Optional, Union from typing import Any, List, Tuple, Optional, Union
from semantic_router.index.base import BaseIndex from semantic_router.index.base import BaseIndex
from semantic_router.utils.logger import logger
import numpy as np import numpy as np
import uuid
def clean_route_name(route_name: str) -> str:
return route_name.strip().replace(" ", "-")
class PineconeRecord(BaseModel): class PineconeRecord(BaseModel):
id: str = Field(default_factory=lambda: f"utt_{uuid.uuid4().hex}") id: str = ""
values: List[float] values: List[float]
route: str route: str
utterance: str utterance: str
def __init__(self, **data):
super().__init__(**data)
# generate ID based on route name and utterances to prevent duplicates
clean_route = clean_route_name(self.route)
utterance_id = hashlib.md5(self.utterance.encode()).hexdigest()
self.id = f"{clean_route}#{utterance_id}"
def to_dict(self): def to_dict(self):
return { return {
"id": f"{self.route}#{self.id}", "id": self.id,
"values": self.values, "values": self.values,
"metadata": { "metadata": {
"sr_route": self.route, "sr_route": self.route,
...@@ -32,8 +43,10 @@ class PineconeIndex(BaseIndex): ...@@ -32,8 +43,10 @@ class PineconeIndex(BaseIndex):
metric: str = "cosine" metric: str = "cosine"
cloud: str = "aws" cloud: str = "aws"
region: str = "us-west-2" region: str = "us-west-2"
host: str = ""
client: Any = Field(default=None, exclude=True) client: Any = Field(default=None, exclude=True)
index: Optional[Any] = Field(default=None, exclude=True) index: Optional[Any] = Field(default=None, exclude=True)
ServerlessSpec: Any = Field(default=None, exclude=True)
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
...@@ -91,7 +104,10 @@ class PineconeIndex(BaseIndex): ...@@ -91,7 +104,10 @@ class PineconeIndex(BaseIndex):
else: else:
# if the index doesn't exist and we don't have the dimensions # if the index doesn't exist and we don't have the dimensions
# we return None # we return None
logger.warning("Index could not be initialized.")
index = None index = None
if index is not None:
self.host = self.client.describe_index(self.index_name)["host"]
return index return index
def add(self, embeddings: List[List[float]], routes: List[str], utterances: List[str]): def add(self, embeddings: List[List[float]], routes: List[str], utterances: List[str]):
...@@ -105,8 +121,17 @@ class PineconeIndex(BaseIndex): ...@@ -105,8 +121,17 @@ class PineconeIndex(BaseIndex):
vectors_to_upsert.append(record.to_dict()) vectors_to_upsert.append(record.to_dict())
self.index.upsert(vectors=vectors_to_upsert) self.index.upsert(vectors=vectors_to_upsert)
def _get_route_vecs(self, route_name: str):
clean_route = clean_route_name(route_name)
res = requests.get(
f"https://{self.host}/vectors/list?prefix={clean_route}#",
headers={"Api-Key": os.environ["PINECONE_API_KEY"]}
)
return [vec["id"] for vec in res.json()["vectors"]]
def delete(self, route_name: str): def delete(self, route_name: str):
self.index.delete(ids=ids_to_remove) route_vec_ids = self._get_route_vecs(route_name=route_name)
self.index.delete(ids=route_vec_ids)
def delete_all(self): def delete_all(self):
self.index.delete(delete_all=True) self.index.delete(delete_all=True)
...@@ -124,11 +149,14 @@ class PineconeIndex(BaseIndex): ...@@ -124,11 +149,14 @@ class PineconeIndex(BaseIndex):
results = self.index.query( results = self.index.query(
vector=[query_vector_list], vector=[query_vector_list],
top_k=top_k, top_k=top_k,
include_metadata=True include_metadata=True,
) )
scores = [result["score"] for result in results["matches"]] scores = [result["score"] for result in results["matches"]]
route_names = [result["metadata"]["sr_route"] for result in results["matches"]] route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
return np.array(scores), route_names return np.array(scores), route_names
def delete_index(self): def delete_index(self):
self.client.delete_index(self.index_name) self.client.delete_index(self.index_name)
\ No newline at end of file
def __len__(self):
return self.index.describe_index_stats()["total_vector_count"]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment