Skip to content
Snippets Groups Projects
Unverified Commit 22ef01d5 authored by Ravi Theja's avatar Ravi Theja Committed by GitHub
Browse files

Update pooling strategy for embeding models (#10536)

Update pooling strategy for embediing models
parent c5daa1d6
No related branches found
No related tags found
No related merge requests found
import asyncio import asyncio
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, List, Optional, Sequence
from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager from llama_index.callbacks import CallbackManager
...@@ -12,6 +12,7 @@ from llama_index.embeddings.huggingface_utils import ( ...@@ -12,6 +12,7 @@ from llama_index.embeddings.huggingface_utils import (
DEFAULT_HUGGINGFACE_EMBEDDING_MODEL, DEFAULT_HUGGINGFACE_EMBEDDING_MODEL,
format_query, format_query,
format_text, format_text,
get_pooling_mode,
) )
from llama_index.embeddings.pooling import Pooling from llama_index.embeddings.pooling import Pooling
from llama_index.llms.huggingface import HuggingFaceInferenceAPI from llama_index.llms.huggingface import HuggingFaceInferenceAPI
...@@ -28,7 +29,7 @@ class HuggingFaceEmbedding(BaseEmbedding): ...@@ -28,7 +29,7 @@ class HuggingFaceEmbedding(BaseEmbedding):
max_length: int = Field( max_length: int = Field(
default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0 default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0
) )
pooling: Pooling = Field(default=Pooling.CLS, description="Pooling strategy.") pooling: Pooling = Field(default=None, description="Pooling strategy.")
normalize: bool = Field(default=True, description="Normalize embeddings or not.") normalize: bool = Field(default=True, description="Normalize embeddings or not.")
query_instruction: Optional[str] = Field( query_instruction: Optional[str] = Field(
description="Instruction to prepend to query text." description="Instruction to prepend to query text."
...@@ -48,7 +49,7 @@ class HuggingFaceEmbedding(BaseEmbedding): ...@@ -48,7 +49,7 @@ class HuggingFaceEmbedding(BaseEmbedding):
self, self,
model_name: Optional[str] = None, model_name: Optional[str] = None,
tokenizer_name: Optional[str] = None, tokenizer_name: Optional[str] = None,
pooling: Union[str, Pooling] = "cls", pooling: Optional[str] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
query_instruction: Optional[str] = None, query_instruction: Optional[str] = None,
text_instruction: Optional[str] = None, text_instruction: Optional[str] = None,
...@@ -105,14 +106,15 @@ class HuggingFaceEmbedding(BaseEmbedding): ...@@ -105,14 +106,15 @@ class HuggingFaceEmbedding(BaseEmbedding):
"Unable to find max_length from model config. Please specify max_length." "Unable to find max_length from model config. Please specify max_length."
) from exc ) from exc
if isinstance(pooling, str): if not pooling:
try: pooling = get_pooling_mode(model_name)
pooling = Pooling(pooling) try:
except ValueError as exc: pooling = Pooling(pooling)
raise NotImplementedError( except ValueError as exc:
f"Pooling {pooling} unsupported, please pick one in" raise NotImplementedError(
f" {[p.value for p in Pooling]}." f"Pooling {pooling} unsupported, please pick one in"
) from exc f" {[p.value for p in Pooling]}."
) from exc
super().__init__( super().__init__(
embed_batch_size=embed_batch_size, embed_batch_size=embed_batch_size,
......
...@@ -3,7 +3,12 @@ from typing import Any, List, Optional ...@@ -3,7 +3,12 @@ from typing import Any, List, Optional
from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager from llama_index.callbacks import CallbackManager
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
from llama_index.embeddings.huggingface_utils import format_query, format_text from llama_index.embeddings.huggingface_utils import (
format_query,
format_text,
get_pooling_mode,
)
from llama_index.embeddings.pooling import Pooling
from llama_index.utils import infer_torch_device from llama_index.utils import infer_torch_device
...@@ -29,7 +34,7 @@ class OptimumEmbedding(BaseEmbedding): ...@@ -29,7 +34,7 @@ class OptimumEmbedding(BaseEmbedding):
def __init__( def __init__(
self, self,
folder_name: str, folder_name: str,
pooling: str = "cls", pooling: Optional[str] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
normalize: bool = True, normalize: bool = True,
query_instruction: Optional[str] = None, query_instruction: Optional[str] = None,
...@@ -63,8 +68,15 @@ class OptimumEmbedding(BaseEmbedding): ...@@ -63,8 +68,15 @@ class OptimumEmbedding(BaseEmbedding):
"Please provide max_length." "Please provide max_length."
) )
if pooling not in ["cls", "mean"]: if not pooling:
raise ValueError(f"Pooling {pooling} not supported.") pooling = get_pooling_mode(model)
try:
pooling = Pooling(pooling)
except ValueError as exc:
raise NotImplementedError(
f"Pooling {pooling} unsupported, please pick one in"
f" {[p.value for p in Pooling]}."
) from exc
super().__init__( super().__init__(
embed_batch_size=embed_batch_size, embed_batch_size=embed_batch_size,
......
from typing import Optional from typing import Optional
import requests
DEFAULT_HUGGINGFACE_EMBEDDING_MODEL = "BAAI/bge-small-en" DEFAULT_HUGGINGFACE_EMBEDDING_MODEL = "BAAI/bge-small-en"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-base" DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-base"
...@@ -72,3 +74,26 @@ def format_text( ...@@ -72,3 +74,26 @@ def format_text(
# NOTE: strip() enables backdoor for defeating instruction prepend by # NOTE: strip() enables backdoor for defeating instruction prepend by
# passing empty string # passing empty string
return f"{instruction} {text}".strip() return f"{instruction} {text}".strip()
def get_pooling_mode(model_name: Optional[str]) -> str:
pooling_config_url = (
f"https://huggingface.co/{model_name}/raw/main/1_Pooling/config.json"
)
try:
response = requests.get(pooling_config_url)
config_data = response.json()
cls_token = config_data.get("pooling_mode_cls_token", False)
mean_tokens = config_data.get("pooling_mode_mean_tokens", False)
if mean_tokens:
return "mean"
elif cls_token:
return "cls"
except requests.exceptions.RequestException:
print(
"Warning: Pooling config file not found; pooling mode is defaulted to 'cls'."
)
return "cls"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment