From 13c68c9a9a84d5c07a9e8f2bb65397e03f2be93e Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Fri, 13 Oct 2023 09:57:46 -0600 Subject: [PATCH] let huggingface embeddings load (#8119) --- llama_index/embeddings/huggingface.py | 3 +++ llama_index/embeddings/loading.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/llama_index/embeddings/huggingface.py b/llama_index/embeddings/huggingface.py index 0f4ed7fdc3..5ca9856437 100644 --- a/llama_index/embeddings/huggingface.py +++ b/llama_index/embeddings/huggingface.py @@ -8,6 +8,7 @@ from llama_index.embeddings.huggingface_utils import ( get_query_instruct_for_model_name, get_text_instruct_for_model_name, ) +from llama_index.utils import get_cache_dir class HuggingFaceEmbedding(BaseEmbedding): @@ -60,6 +61,8 @@ class HuggingFaceEmbedding(BaseEmbedding): device = "cpu" self._device = device + cache_folder = cache_folder or get_cache_dir() + if model is None: model_name = model_name or DEFAULT_HUGGINGFACE_EMBEDDING_MODEL self._model = AutoModel.from_pretrained( diff --git a/llama_index/embeddings/loading.py b/llama_index/embeddings/loading.py index 7543d3256e..891cc84797 100644 --- a/llama_index/embeddings/loading.py +++ b/llama_index/embeddings/loading.py @@ -2,6 +2,8 @@ from typing import Dict, Type from llama_index.embeddings.base import BaseEmbedding from llama_index.embeddings.google import GoogleUnivSentEncoderEmbedding +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.embeddings.huggingface_optimum import OptimumEmbedding from llama_index.embeddings.langchain import LangchainEmbedding from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.embeddings.utils import resolve_embed_model @@ -12,6 +14,8 @@ RECOGNIZED_EMBEDDINGS: Dict[str, Type[BaseEmbedding]] = { OpenAIEmbedding.class_name(): OpenAIEmbedding, LangchainEmbedding.class_name(): LangchainEmbedding, MockEmbedding.class_name(): MockEmbedding, + HuggingFaceEmbedding.class_name(): HuggingFaceEmbedding, + OpenAIEmbedding.class_name(): OpenAIEmbedding, } -- GitLab