From 22ef01d5353055ca349b39158c6714ee323b0f2d Mon Sep 17 00:00:00 2001
From: Ravi Theja <ravi03071991@gmail.com>
Date: Fri, 9 Feb 2024 02:21:26 +0530
Subject: [PATCH] Update pooling strategy for embeding models (#10536)

Update pooling strategy for embediing models
---
 llama_index/embeddings/huggingface.py         | 24 ++++++++++--------
 llama_index/embeddings/huggingface_optimum.py | 20 ++++++++++++---
 llama_index/embeddings/huggingface_utils.py   | 25 +++++++++++++++++++
 3 files changed, 54 insertions(+), 15 deletions(-)

diff --git a/llama_index/embeddings/huggingface.py b/llama_index/embeddings/huggingface.py
index 56da5bb27f..20842d6fdb 100644
--- a/llama_index/embeddings/huggingface.py
+++ b/llama_index/embeddings/huggingface.py
@@ -1,5 +1,5 @@
 import asyncio
-from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, List, Optional, Sequence
 
 from llama_index.bridge.pydantic import Field, PrivateAttr
 from llama_index.callbacks import CallbackManager
@@ -12,6 +12,7 @@ from llama_index.embeddings.huggingface_utils import (
     DEFAULT_HUGGINGFACE_EMBEDDING_MODEL,
     format_query,
     format_text,
+    get_pooling_mode,
 )
 from llama_index.embeddings.pooling import Pooling
 from llama_index.llms.huggingface import HuggingFaceInferenceAPI
@@ -28,7 +29,7 @@ class HuggingFaceEmbedding(BaseEmbedding):
     max_length: int = Field(
         default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0
     )
-    pooling: Pooling = Field(default=Pooling.CLS, description="Pooling strategy.")
+    pooling: Pooling = Field(default=None, description="Pooling strategy.")
     normalize: bool = Field(default=True, description="Normalize embeddings or not.")
     query_instruction: Optional[str] = Field(
         description="Instruction to prepend to query text."
@@ -48,7 +49,7 @@ class HuggingFaceEmbedding(BaseEmbedding):
         self,
         model_name: Optional[str] = None,
         tokenizer_name: Optional[str] = None,
-        pooling: Union[str, Pooling] = "cls",
+        pooling: Optional[str] = None,
         max_length: Optional[int] = None,
         query_instruction: Optional[str] = None,
         text_instruction: Optional[str] = None,
@@ -105,14 +106,15 @@ class HuggingFaceEmbedding(BaseEmbedding):
                     "Unable to find max_length from model config. Please specify max_length."
                 ) from exc
 
-        if isinstance(pooling, str):
-            try:
-                pooling = Pooling(pooling)
-            except ValueError as exc:
-                raise NotImplementedError(
-                    f"Pooling {pooling} unsupported, please pick one in"
-                    f" {[p.value for p in Pooling]}."
-                ) from exc
+        if not pooling:
+            pooling = get_pooling_mode(model_name)
+        try:
+            pooling = Pooling(pooling)
+        except ValueError as exc:
+            raise NotImplementedError(
+                f"Pooling {pooling} unsupported, please pick one in"
+                f" {[p.value for p in Pooling]}."
+            ) from exc
 
         super().__init__(
             embed_batch_size=embed_batch_size,
diff --git a/llama_index/embeddings/huggingface_optimum.py b/llama_index/embeddings/huggingface_optimum.py
index 668dd58386..6c69e37f97 100644
--- a/llama_index/embeddings/huggingface_optimum.py
+++ b/llama_index/embeddings/huggingface_optimum.py
@@ -3,7 +3,12 @@ from typing import Any, List, Optional
 from llama_index.bridge.pydantic import Field, PrivateAttr
 from llama_index.callbacks import CallbackManager
 from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
-from llama_index.embeddings.huggingface_utils import format_query, format_text
+from llama_index.embeddings.huggingface_utils import (
+    format_query,
+    format_text,
+    get_pooling_mode,
+)
+from llama_index.embeddings.pooling import Pooling
 from llama_index.utils import infer_torch_device
 
 
@@ -29,7 +34,7 @@ class OptimumEmbedding(BaseEmbedding):
     def __init__(
         self,
         folder_name: str,
-        pooling: str = "cls",
+        pooling: Optional[str] = None,
         max_length: Optional[int] = None,
         normalize: bool = True,
         query_instruction: Optional[str] = None,
@@ -63,8 +68,15 @@ class OptimumEmbedding(BaseEmbedding):
                     "Please provide max_length."
                 )
 
-        if pooling not in ["cls", "mean"]:
-            raise ValueError(f"Pooling {pooling} not supported.")
+        if not pooling:
+            pooling = get_pooling_mode(model)
+        try:
+            pooling = Pooling(pooling)
+        except ValueError as exc:
+            raise NotImplementedError(
+                f"Pooling {pooling} unsupported, please pick one in"
+                f" {[p.value for p in Pooling]}."
+            ) from exc
 
         super().__init__(
             embed_batch_size=embed_batch_size,
diff --git a/llama_index/embeddings/huggingface_utils.py b/llama_index/embeddings/huggingface_utils.py
index 606bced13b..009aaab764 100644
--- a/llama_index/embeddings/huggingface_utils.py
+++ b/llama_index/embeddings/huggingface_utils.py
@@ -1,5 +1,7 @@
 from typing import Optional
 
+import requests
+
 DEFAULT_HUGGINGFACE_EMBEDDING_MODEL = "BAAI/bge-small-en"
 DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-base"
 
@@ -72,3 +74,26 @@ def format_text(
     # NOTE: strip() enables backdoor for defeating instruction prepend by
     # passing empty string
     return f"{instruction} {text}".strip()
+
+
+def get_pooling_mode(model_name: Optional[str]) -> str:
+    pooling_config_url = (
+        f"https://huggingface.co/{model_name}/raw/main/1_Pooling/config.json"
+    )
+
+    try:
+        response = requests.get(pooling_config_url)
+        config_data = response.json()
+
+        cls_token = config_data.get("pooling_mode_cls_token", False)
+        mean_tokens = config_data.get("pooling_mode_mean_tokens", False)
+
+        if mean_tokens:
+            return "mean"
+        elif cls_token:
+            return "cls"
+    except requests.exceptions.RequestException:
+        print(
+            "Warning: Pooling config file not found; pooling mode is defaulted to 'cls'."
+        )
+    return "cls"
-- 
GitLab