From cd6700a274e3c77e41215d39e4b5f8799276a3ef Mon Sep 17 00:00:00 2001
From: hkristof03 <hkristof03@gmail.com>
Date: Fri, 15 Mar 2024 17:24:33 +0100
Subject: [PATCH] =?UTF-8?q?Expose=20"safe=5Fserialization"=20parameter=20f?=
 =?UTF-8?q?rom=20AutoModel=20to=20HuggingFaceEm=E2=80=A6=20(#11939)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Expose "safe_serialization" parameter from AutoModel to HuggingFaceEmbedding

* Change "safe_serialization" parameter to Optional.

* Add comma for linter check

* cr

---------

Co-authored-by: Haotian Zhang <socool.king@gmail.com>
---
 .../llama_index/embeddings/huggingface/base.py              | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py
index cca7da667c..4c1a81f2a2 100644
--- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py
@@ -68,6 +68,7 @@ class HuggingFaceEmbedding(BaseEmbedding):
         trust_remote_code: bool = False,
         device: Optional[str] = None,
         callback_manager: Optional[CallbackManager] = None,
+        safe_serialization: Optional[bool] = None,
     ):
         self._device = device or infer_torch_device()
 
@@ -80,7 +81,10 @@ class HuggingFaceEmbedding(BaseEmbedding):
                 else DEFAULT_HUGGINGFACE_EMBEDDING_MODEL
             )
             model = AutoModel.from_pretrained(
-                model_name, cache_dir=cache_folder, trust_remote_code=trust_remote_code
+                model_name,
+                cache_dir=cache_folder,
+                trust_remote_code=trust_remote_code,
+                safe_serialization=safe_serialization,
             )
         elif model_name is None:  # Extract model_name from model
             model_name = model.name_or_path
-- 
GitLab