diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 507ea3494248c83b6e7feef458371107560aebd2..9596af5061a7bf2aef3b902d1540fcd604645e82 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -412,6 +412,7 @@ class SparseEmbedding(BaseModel): """Sparse embedding interface. Primarily uses numpy operations for faster operations. """ + embedding: np.ndarray class Config: @@ -425,36 +426,31 @@ class SparseEmbedding(BaseModel): "Column 0 should contain index positions, and column 1 should contain respective values." ) return cls(embedding=array) - + @classmethod def from_aurelio(cls, embedding: BM25Embedding): arr = np.array([embedding.indices, embedding.values]).T return cls.from_array(arr) - + @classmethod def from_dict(cls, sparse_dict: dict): arr = np.array([list(sparse_dict.keys()), list(sparse_dict.values())]).T return cls.from_array(arr) - + def to_dict(self): return { - i: v for i, v in zip( - self.embedding[:,0].astype(int), - self.embedding[:,1] - ) + i: v for i, v in zip(self.embedding[:, 0].astype(int), self.embedding[:, 1]) } - + def to_pinecone(self): return { "indices": self.embedding[:, 0].astype(int).tolist(), "values": self.embedding[:, 1].tolist(), } - + # dictionary interface def items(self): return [ - (i, v) for i, v in zip( - self.embedding[:,0].astype(int), - self.embedding[:,1] - ) + (i, v) + for i, v in zip(self.embedding[:, 0].astype(int), self.embedding[:, 1]) ] diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index ea4b8d41b5571918b5c01c80200fc14eeb6a9443..35093a6b481fafe0a74373e59b7acaf065491c29 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -7,7 +7,7 @@ from typing import Optional from semantic_router.encoders import DenseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.index.pinecone import PineconeIndex from semantic_router.schema import Utterance -from semantic_router.routers.base import SemanticRouter +from semantic_router.routers import SemanticRouter from semantic_router.route import Route from platform import python_version