Skip to content
Snippets Groups Projects
Unverified Commit cd6700a2 authored by hkristof03's avatar hkristof03 Committed by GitHub
Browse files

Expose "safe_serialization" parameter from AutoModel to HuggingFaceEm… (#11939)


* Expose "safe_serialization" parameter from AutoModel to HuggingFaceEmbedding

* Change "safe_serialization" parameter to Optional.

* Add comma for linter check

* cr

---------

Co-authored-by: default avatarHaotian Zhang <socool.king@gmail.com>
parent 1bd378ca
No related branches found
No related tags found
No related merge requests found
......@@ -68,6 +68,7 @@ class HuggingFaceEmbedding(BaseEmbedding):
trust_remote_code: bool = False,
device: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
safe_serialization: Optional[bool] = None,
):
self._device = device or infer_torch_device()
......@@ -80,7 +81,10 @@ class HuggingFaceEmbedding(BaseEmbedding):
else DEFAULT_HUGGINGFACE_EMBEDDING_MODEL
)
model = AutoModel.from_pretrained(
model_name, cache_dir=cache_folder, trust_remote_code=trust_remote_code
model_name,
cache_dir=cache_folder,
trust_remote_code=trust_remote_code,
safe_serialization=safe_serialization,
)
elif model_name is None: # Extract model_name from model
model_name = model.name_or_path
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment