diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py index c4112f2c7d27a3987d2e95bd52ba56178d0a786e..83f93644989cbe7e1d17f9b71126ba349eaead70 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-ollama/llama_index/embeddings/ollama/base.py @@ -34,6 +34,7 @@ class OllamaEmbedding(BaseEmbedding): embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, ollama_additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, + client_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -45,8 +46,9 @@ class OllamaEmbedding(BaseEmbedding): **kwargs, ) - self._client = Client(host=self.base_url) - self._async_client = AsyncClient(host=self.base_url) + client_kwargs = client_kwargs or {} + self._client = Client(host=self.base_url, **client_kwargs) + self._async_client = AsyncClient(host=self.base_url, **client_kwargs) @classmethod def class_name(cls) -> str: diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-ollama/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-ollama/pyproject.toml index f0ea73fb70c373aa2907b7dfec073586cec7211d..7cad77517ab4e386bb77648bfb76fecf8f52c3ab 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-ollama/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-ollama/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-ollama" readme = "README.md" -version = "0.5.0" +version = "0.6.0" [tool.poetry.dependencies] python = ">=3.9,<4.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-ollama/tests/test_embeddings_ollama.py b/llama-index-integrations/embeddings/llama-index-embeddings-ollama/tests/test_embeddings_ollama.py index 74ec40ecea635bf50b9d1523a5e96f029145fd8c..8c9ac39533b068e0b5780d250d1b0a5a16015ec6 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-ollama/tests/test_embeddings_ollama.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-ollama/tests/test_embeddings_ollama.py @@ -3,5 +3,7 @@ from llama_index.embeddings.ollama import OllamaEmbedding def test_embedding_class(): - emb = OllamaEmbedding(model_name="") + emb = OllamaEmbedding( + model_name="", client_kwargs={"headers": {"Authorization": "Bearer token"}} + ) assert isinstance(emb, BaseEmbedding)