Skip to content
Snippets Groups Projects
Unverified Commit 85ecdaed authored by Haotian Zhang's avatar Haotian Zhang Committed by GitHub
Browse files

CLIP embedding with more models (#12063)

* CLIP embedding with more models

* cr

* cr

* cr
parent 05f4329a
No related branches found
No related tags found
No related merge requests found
"""Init file of LlamaIndex.""" """Init file of LlamaIndex."""
__version__ = "0.10.20.post2" __version__ = "0.10.20.post3"
import logging import logging
from logging import NullHandler from logging import NullHandler
......
...@@ -75,17 +75,19 @@ def resolve_embed_model( ...@@ -75,17 +75,19 @@ def resolve_embed_model(
"embeddings.html#modules" "embeddings.html#modules"
"\n******" "\n******"
) )
# for image multi-modal embeddings
# for image embeddings elif isinstance(embed_model, str) and embed_model.startswith("clip"):
if embed_model == "clip":
try: try:
from llama_index.embeddings.clip import ClipEmbedding # pants: no-infer-dep from llama_index.embeddings.clip import ClipEmbedding # pants: no-infer-dep
embed_model = ClipEmbedding() clip_model_name = (
embed_model.split(":")[1] if ":" in embed_model else "ViT-B/32"
)
embed_model = ClipEmbedding(model_name=clip_model_name)
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"`llama-index-embeddings-clip` package not found, " "`llama-index-embeddings-clip` package not found, "
"please run `pip install llama-index-embeddings-clip`" "please run `pip install llama-index-embeddings-clip` and `pip install git+https://github.com/openai/CLIP.git`"
) )
if isinstance(embed_model, str): if isinstance(embed_model, str):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
An index that is built on top of multiple vector stores for different modalities. An index that is built on top of multiple vector stores for different modalities.
""" """
import logging import logging
from typing import Any, List, Optional, Sequence, cast from typing import Any, List, Optional, Sequence, cast
...@@ -63,7 +64,7 @@ class MultiModalVectorStoreIndex(VectorStoreIndex): ...@@ -63,7 +64,7 @@ class MultiModalVectorStoreIndex(VectorStoreIndex):
# image_vector_store going to be deprecated. image_store can be passed from storage_context # image_vector_store going to be deprecated. image_store can be passed from storage_context
# keep image_vector_store here for backward compatibility # keep image_vector_store here for backward compatibility
image_vector_store: Optional[VectorStore] = None, image_vector_store: Optional[VectorStore] = None,
image_embed_model: EmbedType = "clip", image_embed_model: EmbedType = "clip:ViT-B/32",
is_image_to_text: bool = False, is_image_to_text: bool = False,
# is_image_vector_store_empty is used to indicate whether image_vector_store is empty # is_image_vector_store_empty is used to indicate whether image_vector_store is empty
# those flags are used for cases when only one vector store is used # those flags are used for cases when only one vector store is used
...@@ -184,11 +185,13 @@ class MultiModalVectorStoreIndex(VectorStoreIndex): ...@@ -184,11 +185,13 @@ class MultiModalVectorStoreIndex(VectorStoreIndex):
storage_context=storage_context, storage_context=storage_context,
image_vector_store=image_vector_store, image_vector_store=image_vector_store,
image_embed_model=image_embed_model, image_embed_model=image_embed_model,
embed_model=resolve_embed_model( embed_model=(
embed_model, callback_manager=kwargs.get("callback_manager", None) resolve_embed_model(
) embed_model, callback_manager=kwargs.get("callback_manager", None)
if embed_model )
else Settings.embed_model, if embed_model
else Settings.embed_model
),
**kwargs, **kwargs,
) )
......
...@@ -43,7 +43,7 @@ name = "llama-index-core" ...@@ -43,7 +43,7 @@ name = "llama-index-core"
packages = [{include = "llama_index"}] packages = [{include = "llama_index"}]
readme = "README.md" readme = "README.md"
repository = "https://github.com/run-llama/llama_index" repository = "https://github.com/run-llama/llama_index"
version = "0.10.20.post2" version = "0.10.20.post3"
[tool.poetry.dependencies] [tool.poetry.dependencies]
SQLAlchemy = {extras = ["asyncio"], version = ">=1.4.49"} SQLAlchemy = {extras = ["asyncio"], version = ">=1.4.49"}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment