From 85ecdaed797725123bc85130e7c51e202225a550 Mon Sep 17 00:00:00 2001
From: Haotian Zhang <socool.king@gmail.com>
Date: Tue, 19 Mar 2024 14:17:05 -0400
Subject: [PATCH] CLIP embedding with more models (#12063)

* CLIP embedding with more models

* cr

* cr

* cr
---
 llama-index-core/llama_index/core/__init__.py     |  2 +-
 .../llama_index/core/embeddings/utils.py          | 12 +++++++-----
 .../llama_index/core/indices/multi_modal/base.py  | 15 +++++++++------
 llama-index-core/pyproject.toml                   |  2 +-
 4 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/llama-index-core/llama_index/core/__init__.py b/llama-index-core/llama_index/core/__init__.py
index 3d55dfab89..4981cfc768 100644
--- a/llama-index-core/llama_index/core/__init__.py
+++ b/llama-index-core/llama_index/core/__init__.py
@@ -1,6 +1,6 @@
 """Init file of LlamaIndex."""
 
-__version__ = "0.10.20.post2"
+__version__ = "0.10.20.post3"
 
 import logging
 from logging import NullHandler
diff --git a/llama-index-core/llama_index/core/embeddings/utils.py b/llama-index-core/llama_index/core/embeddings/utils.py
index e24d6adf2c..abc8af4645 100644
--- a/llama-index-core/llama_index/core/embeddings/utils.py
+++ b/llama-index-core/llama_index/core/embeddings/utils.py
@@ -75,17 +75,19 @@ def resolve_embed_model(
                 "embeddings.html#modules"
                 "\n******"
             )
-
-    # for image embeddings
-    if embed_model == "clip":
+    # for image multi-modal embeddings
+    elif isinstance(embed_model, str) and embed_model.startswith("clip"):
         try:
             from llama_index.embeddings.clip import ClipEmbedding  # pants: no-infer-dep
 
-            embed_model = ClipEmbedding()
+            clip_model_name = (
+                embed_model.split(":")[1] if ":" in embed_model else "ViT-B/32"
+            )
+            embed_model = ClipEmbedding(model_name=clip_model_name)
         except ImportError as e:
             raise ImportError(
                 "`llama-index-embeddings-clip` package not found, "
-                "please run `pip install llama-index-embeddings-clip`"
+                "please run `pip install llama-index-embeddings-clip` and `pip install git+https://github.com/openai/CLIP.git`"
             )
 
     if isinstance(embed_model, str):
diff --git a/llama-index-core/llama_index/core/indices/multi_modal/base.py b/llama-index-core/llama_index/core/indices/multi_modal/base.py
index a80231231d..0f2669c49d 100644
--- a/llama-index-core/llama_index/core/indices/multi_modal/base.py
+++ b/llama-index-core/llama_index/core/indices/multi_modal/base.py
@@ -3,6 +3,7 @@
 An index that is built on top of multiple vector stores for different modalities.
 
 """
+
 import logging
 from typing import Any, List, Optional, Sequence, cast
 
@@ -63,7 +64,7 @@ class MultiModalVectorStoreIndex(VectorStoreIndex):
         # image_vector_store going to be deprecated. image_store can be passed from storage_context
         # keep image_vector_store here for backward compatibility
         image_vector_store: Optional[VectorStore] = None,
-        image_embed_model: EmbedType = "clip",
+        image_embed_model: EmbedType = "clip:ViT-B/32",
         is_image_to_text: bool = False,
         # is_image_vector_store_empty is used to indicate whether image_vector_store is empty
         # those flags are used for cases when only one vector store is used
@@ -184,11 +185,13 @@ class MultiModalVectorStoreIndex(VectorStoreIndex):
             storage_context=storage_context,
             image_vector_store=image_vector_store,
             image_embed_model=image_embed_model,
-            embed_model=resolve_embed_model(
-                embed_model, callback_manager=kwargs.get("callback_manager", None)
-            )
-            if embed_model
-            else Settings.embed_model,
+            embed_model=(
+                resolve_embed_model(
+                    embed_model, callback_manager=kwargs.get("callback_manager", None)
+                )
+                if embed_model
+                else Settings.embed_model
+            ),
             **kwargs,
         )
 
diff --git a/llama-index-core/pyproject.toml b/llama-index-core/pyproject.toml
index 4e51eccd39..154d08227a 100644
--- a/llama-index-core/pyproject.toml
+++ b/llama-index-core/pyproject.toml
@@ -43,7 +43,7 @@ name = "llama-index-core"
 packages = [{include = "llama_index"}]
 readme = "README.md"
 repository = "https://github.com/run-llama/llama_index"
-version = "0.10.20.post2"
+version = "0.10.20.post3"
 
 [tool.poetry.dependencies]
 SQLAlchemy = {extras = ["asyncio"], version = ">=1.4.49"}
-- 
GitLab