Skip to content
Snippets Groups Projects
Commit 9bfba6b5 authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

move torch import to init

parent f17cc139
No related branches found
No related tags found
No related merge requests found
from typing import Any
import torch
from pydantic import PrivateAttr
from semantic_router.encoders import BaseEncoder
......@@ -13,6 +12,7 @@ class HuggingFaceEncoder(BaseEncoder):
device: str | None = None
_tokenizer: Any = PrivateAttr()
_model: Any = PrivateAttr()
_torch: Any = PrivateAttr()
def __init__(self, **data):
super().__init__(**data)
......@@ -25,9 +25,20 @@ class HuggingFaceEncoder(BaseEncoder):
raise ImportError(
"Please install transformers to use HuggingFaceEncoder. "
"You can install it with: "
"`pip install semantic-router[transformers]`"
"`pip install semantic-router[local]`"
)
try:
import torch
except ImportError:
raise ImportError(
"Please install Pytorch to use HuggingFaceEncoder. "
"You can install it with: "
"`pip install semantic-router[local]`"
)
self._torch = torch
tokenizer = AutoTokenizer.from_pretrained(
self.name,
**self.tokenizer_kwargs,
......@@ -39,7 +50,7 @@ class HuggingFaceEncoder(BaseEncoder):
model.to(self.device)
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda" if self._torch.cuda.is_available() else "cpu"
model.to(device)
self.device = device
......@@ -60,7 +71,7 @@ class HuggingFaceEncoder(BaseEncoder):
batch_docs, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
with torch.no_grad():
with self._torch.no_grad():
model_output = self._model(**encoded_input)
if pooling_strategy == "mean":
......@@ -71,9 +82,13 @@ class HuggingFaceEncoder(BaseEncoder):
embeddings = self._max_pooling(
model_output, encoded_input["attention_mask"]
)
else:
raise ValueError(
"Invalid pooling_strategy. Please use 'mean' or 'max'."
)
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
embeddings = embeddings.tolist()
all_embeddings.extend(embeddings)
......@@ -84,9 +99,9 @@ class HuggingFaceEncoder(BaseEncoder):
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
return self._torch.sum(
token_embeddings * input_mask_expanded, 1
) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def _max_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
......@@ -94,4 +109,4 @@ class HuggingFaceEncoder(BaseEncoder):
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
token_embeddings[input_mask_expanded == 0] = -1e9
return torch.max(token_embeddings, 1)[0]
return self._torch.max(token_embeddings, 1)[0]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment