From c18b32348583e2ec79500830125103030734b16d Mon Sep 17 00:00:00 2001
From: abhiram1809 <53875874+abhiram1809@users.noreply.github.com>
Date: Thu, 15 Feb 2024 05:48:59 +0530
Subject: [PATCH] Added Embeddings Option on custom Triplets (#10629)

---
 .../core/indices/knowledge_graph/base.py      | 28 +++++++++++++++----
 .../legacy/indices/knowledge_graph/base.py    | 28 +++++++++++++++----
 2 files changed, 46 insertions(+), 10 deletions(-)

diff --git a/llama-index-core/llama_index/core/indices/knowledge_graph/base.py b/llama-index-core/llama_index/core/indices/knowledge_graph/base.py
index 594b1d577..86f29a5de 100644
--- a/llama-index-core/llama_index/core/indices/knowledge_graph/base.py
+++ b/llama-index-core/llama_index/core/indices/knowledge_graph/base.py
@@ -244,17 +244,25 @@ class KnowledgeGraphIndex(BaseIndex[KG]):
         # Update the storage context's index_store
         self._storage_context.index_store.add_index_struct(self._index_struct)
 
-    def upsert_triplet(self, triplet: Tuple[str, str, str]) -> None:
-        """Insert triplets.
+    def upsert_triplet(
+        self, triplet: Tuple[str, str, str], include_embeddings: bool = False
+    ) -> None:
+        """Insert triplets and optionally embeddings.
 
         Used for manual insertion of KG triplets (in the form
         of (subject, relationship, object)).
 
         Args:
-            triplet (str): Knowledge triplet
-
+            triplet (tuple): Knowledge triplet
+            embedding (Any, optional): Embedding option for the triplet. Defaults to None.
         """
         self._graph_store.upsert_triplet(*triplet)
+        triplet_str = str(triplet)
+        if include_embeddings:
+            set_embedding = self._service_context.embed_model.get_text_embedding(
+                triplet_str
+            )
+            self._index_struct.add_to_embedding_dict(str(triplet), set_embedding)
 
     def add_node(self, keywords: List[str], node: BaseNode) -> None:
         """Add node.
@@ -270,7 +278,10 @@ class KnowledgeGraphIndex(BaseIndex[KG]):
         self._docstore.add_documents([node], allow_update=True)
 
     def upsert_triplet_and_node(
-        self, triplet: Tuple[str, str, str], node: BaseNode
+        self,
+        triplet: Tuple[str, str, str],
+        node: BaseNode,
+        include_embeddings: bool = False,
     ) -> None:
         """Upsert KG triplet and node.
 
@@ -281,11 +292,18 @@ class KnowledgeGraphIndex(BaseIndex[KG]):
         Args:
             keywords (List[str]): Keywords to index the node.
             node (Node): Node to be indexed.
+            include_embeddings (bool): Option to add embeddings for triplets. Defaults to False
 
         """
         subj, _, obj = triplet
         self.upsert_triplet(triplet)
         self.add_node([subj, obj], node)
+        triplet_str = str(triplet)
+        if include_embeddings:
+            set_embedding = self._service_context.embed_model.get_text_embedding(
+                triplet_str
+            )
+            self._index_struct.add_to_embedding_dict(str(triplet), set_embedding)
 
     def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None:
         """Delete a node."""
diff --git a/llama-index-legacy/llama_index/legacy/indices/knowledge_graph/base.py b/llama-index-legacy/llama_index/legacy/indices/knowledge_graph/base.py
index a09fd5b93..6d94aad25 100644
--- a/llama-index-legacy/llama_index/legacy/indices/knowledge_graph/base.py
+++ b/llama-index-legacy/llama_index/legacy/indices/knowledge_graph/base.py
@@ -217,17 +217,25 @@ class KnowledgeGraphIndex(BaseIndex[KG]):
                     )
                     self._index_struct.add_to_embedding_dict(triplet_str, rel_embedding)
 
-    def upsert_triplet(self, triplet: Tuple[str, str, str]) -> None:
-        """Insert triplets.
+    def upsert_triplet(
+        self, triplet: Tuple[str, str, str], include_embeddings: bool = False
+    ) -> None:
+        """Insert triplets and optionally embeddings.
 
         Used for manual insertion of KG triplets (in the form
         of (subject, relationship, object)).
 
         Args:
-            triplet (str): Knowledge triplet
-
+            triplet (tuple): Knowledge triplet
+            embedding (Any, optional): Embedding option for the triplet. Defaults to None.
         """
         self._graph_store.upsert_triplet(*triplet)
+        triplet_str = str(triplet)
+        if include_embeddings:
+            set_embedding = self._service_context.embed_model.get_text_embedding(
+                triplet_str
+            )
+            self._index_struct.add_to_embedding_dict(str(triplet), set_embedding)
 
     def add_node(self, keywords: List[str], node: BaseNode) -> None:
         """Add node.
@@ -243,7 +251,10 @@ class KnowledgeGraphIndex(BaseIndex[KG]):
         self._docstore.add_documents([node], allow_update=True)
 
     def upsert_triplet_and_node(
-        self, triplet: Tuple[str, str, str], node: BaseNode
+        self,
+        triplet: Tuple[str, str, str],
+        node: BaseNode,
+        include_embeddings: bool = False,
     ) -> None:
         """Upsert KG triplet and node.
 
@@ -254,11 +265,18 @@ class KnowledgeGraphIndex(BaseIndex[KG]):
         Args:
             keywords (List[str]): Keywords to index the node.
             node (Node): Node to be indexed.
+            include_embeddings (bool): Option to add embeddings for triplets. Defaults to False
 
         """
         subj, _, obj = triplet
         self.upsert_triplet(triplet)
         self.add_node([subj, obj], node)
+        triplet_str = str(triplet)
+        if include_embeddings:
+            set_embedding = self._service_context.embed_model.get_text_embedding(
+                triplet_str
+            )
+            self._index_struct.add_to_embedding_dict(str(triplet), set_embedding)
 
     def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None:
         """Delete a node."""
-- 
GitLab