Skip to content
Snippets Groups Projects
Unverified Commit 4913a447 authored by James Briggs's avatar James Briggs
Browse files

change id logic for no-overwrites and easy route deletion

parent f3b2535d
No related branches found
No related tags found
No related merge requests found
from pydantic.v1 import BaseModel, Field
import requests
import time
import hashlib
import os
from typing import Any, List, Tuple, Optional, Union
from semantic_router.index.base import BaseIndex
from semantic_router.utils.logger import logger
import numpy as np
import uuid
def clean_route_name(route_name: str) -> str:
return route_name.strip().replace(" ", "-")
class PineconeRecord(BaseModel):
id: str = Field(default_factory=lambda: f"utt_{uuid.uuid4().hex}")
id: str = ""
values: List[float]
route: str
utterance: str
def __init__(self, **data):
super().__init__(**data)
# generate ID based on route name and utterances to prevent duplicates
clean_route = clean_route_name(self.route)
utterance_id = hashlib.md5(self.utterance.encode()).hexdigest()
self.id = f"{clean_route}#{utterance_id}"
def to_dict(self):
return {
"id": f"{self.route}#{self.id}",
"id": self.id,
"values": self.values,
"metadata": {
"sr_route": self.route,
......@@ -32,8 +43,10 @@ class PineconeIndex(BaseIndex):
metric: str = "cosine"
cloud: str = "aws"
region: str = "us-west-2"
host: str = ""
client: Any = Field(default=None, exclude=True)
index: Optional[Any] = Field(default=None, exclude=True)
ServerlessSpec: Any = Field(default=None, exclude=True)
def __init__(self, **data):
super().__init__(**data)
......@@ -91,7 +104,10 @@ class PineconeIndex(BaseIndex):
else:
# if the index doesn't exist and we don't have the dimensions
# we return None
logger.warning("Index could not be initialized.")
index = None
if index is not None:
self.host = self.client.describe_index(self.index_name)["host"]
return index
def add(self, embeddings: List[List[float]], routes: List[str], utterances: List[str]):
......@@ -105,8 +121,17 @@ class PineconeIndex(BaseIndex):
vectors_to_upsert.append(record.to_dict())
self.index.upsert(vectors=vectors_to_upsert)
def _get_route_vecs(self, route_name: str):
clean_route = clean_route_name(route_name)
res = requests.get(
f"https://{self.host}/vectors/list?prefix={clean_route}#",
headers={"Api-Key": os.environ["PINECONE_API_KEY"]}
)
return [vec["id"] for vec in res.json()["vectors"]]
def delete(self, route_name: str):
self.index.delete(ids=ids_to_remove)
route_vec_ids = self._get_route_vecs(route_name=route_name)
self.index.delete(ids=route_vec_ids)
def delete_all(self):
self.index.delete(delete_all=True)
......@@ -124,11 +149,14 @@ class PineconeIndex(BaseIndex):
results = self.index.query(
vector=[query_vector_list],
top_k=top_k,
include_metadata=True
include_metadata=True,
)
scores = [result["score"] for result in results["matches"]]
route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
return np.array(scores), route_names
def delete_index(self):
self.client.delete_index(self.index_name)
\ No newline at end of file
self.client.delete_index(self.index_name)
def __len__(self):
return self.index.describe_index_stats()["total_vector_count"]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment