Skip to content
Snippets Groups Projects
Commit d96c5522 authored by “Daniel Griffiths”'s avatar “Daniel Griffiths”
Browse files

fix: created embedding helper functions

parent d6c421e3
Branches
Tags
No related merge requests found
...@@ -56,45 +56,33 @@ class HybridRouteLayer: ...@@ -56,45 +56,33 @@ class HybridRouteLayer:
return None return None
def add(self, route: Route): def add(self, route: Route):
self._add_route(route=route)
def _add_route(self, route: Route):
self.routes += [route]
self.update_dense_embeddings_index(route.utterances)
if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(
self.sparse_encoder, "fit" self.sparse_encoder, "fit"
): ):
self.sparse_encoder.fit(self.routes + [route]) self.sparse_encoder.fit(self.routes)
# re-build index
self.sparse_index = None self.sparse_index = None
for r in self.routes: all_utterances = [
self.compute_and_store_sparse_embeddings(r) utterance for route in self.routes for utterance in route.utterances
self.routes.append(route) ]
self._add_route(route=route) self.update_sparse_embeddings_index(all_utterances)
else:
self.update_sparse_embeddings_index(route.utterances)
def _add_route(self, route: Route):
# create embeddings
dense_embeds = np.array(self.dense_encoder(route.utterances)) # * self.alpha
self.compute_and_store_sparse_embeddings(route)
# create route array # create route array
if self.categories is None: if self.categories is None:
self.categories = np.array([route.name] * len(route.utterances)) self.categories = np.array([route.name] * len(route.utterances))
self.utterances = np.array(route.utterances)
else: else:
str_arr = np.array([route.name] * len(route.utterances)) str_arr = np.array([route.name] * len(route.utterances))
self.categories = np.concatenate([self.categories, str_arr]) self.categories = np.concatenate([self.categories, str_arr])
self.utterances = np.concatenate( self.routes.append(route)
[self.utterances, np.array(route.utterances)]
)
# create utterance array (the dense index)
if self.index is None:
self.index = dense_embeds
else:
self.index = np.concatenate([self.index, dense_embeds])
def compute_and_store_sparse_embeddings(self, route: Route):
sparse_embeds = np.array(
self.sparse_encoder(route.utterances)
) # * (1 - self.alpha)
# create sparse utterance array
if self.sparse_index is None:
self.sparse_index = sparse_embeds
else:
self.sparse_index = np.concatenate([self.sparse_index, sparse_embeds])
def _add_routes(self, routes: list[Route]): def _add_routes(self, routes: list[Route]):
# create embeddings for all routes # create embeddings for all routes
...@@ -102,8 +90,8 @@ class HybridRouteLayer: ...@@ -102,8 +90,8 @@ class HybridRouteLayer:
all_utterances = [ all_utterances = [
utterance for route in routes for utterance in route.utterances utterance for route in routes for utterance in route.utterances
] ]
dense_embeds = np.array(self.dense_encoder(all_utterances)) self.update_dense_embeddings_index(all_utterances)
sparse_embeds = np.array(self.sparse_encoder(all_utterances)) self.update_sparse_embeddings_index(all_utterances)
# create route array # create route array
route_names = [route.name for route in routes for _ in route.utterances] route_names = [route.name for route in routes for _ in route.utterances]
...@@ -114,6 +102,8 @@ class HybridRouteLayer: ...@@ -114,6 +102,8 @@ class HybridRouteLayer:
else route_array else route_array
) )
def update_dense_embeddings_index(self, utterances: list):
dense_embeds = np.array(self.dense_encoder(utterances))
# create utterance array (the dense index) # create utterance array (the dense index)
self.index = ( self.index = (
np.concatenate([self.index, dense_embeds]) np.concatenate([self.index, dense_embeds])
...@@ -121,6 +111,8 @@ class HybridRouteLayer: ...@@ -121,6 +111,8 @@ class HybridRouteLayer:
else dense_embeds else dense_embeds
) )
def update_sparse_embeddings_index(self, utterances: list):
sparse_embeds = np.array(self.sparse_encoder(utterances))
# create sparse utterance array # create sparse utterance array
self.sparse_index = ( self.sparse_index = (
np.concatenate([self.sparse_index, sparse_embeds]) np.concatenate([self.sparse_index, sparse_embeds])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment