Skip to content
Snippets Groups Projects
Unverified Commit cfb76b9c authored by Alen Joses R's avatar Alen Joses R Committed by GitHub
Browse files

feat: add client_kwargs Parameter to OllamaEmbedding Class (#18012)

parent e486157f
No related branches found
No related tags found
No related merge requests found
...@@ -34,6 +34,7 @@ class OllamaEmbedding(BaseEmbedding): ...@@ -34,6 +34,7 @@ class OllamaEmbedding(BaseEmbedding):
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
ollama_additional_kwargs: Optional[Dict[str, Any]] = None, ollama_additional_kwargs: Optional[Dict[str, Any]] = None,
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
client_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -45,8 +46,9 @@ class OllamaEmbedding(BaseEmbedding): ...@@ -45,8 +46,9 @@ class OllamaEmbedding(BaseEmbedding):
**kwargs, **kwargs,
) )
self._client = Client(host=self.base_url) client_kwargs = client_kwargs or {}
self._async_client = AsyncClient(host=self.base_url) self._client = Client(host=self.base_url, **client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **client_kwargs)
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
......
...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"] ...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT" license = "MIT"
name = "llama-index-embeddings-ollama" name = "llama-index-embeddings-ollama"
readme = "README.md" readme = "README.md"
version = "0.5.0" version = "0.6.0"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
......
...@@ -3,5 +3,7 @@ from llama_index.embeddings.ollama import OllamaEmbedding ...@@ -3,5 +3,7 @@ from llama_index.embeddings.ollama import OllamaEmbedding
def test_embedding_class(): def test_embedding_class():
emb = OllamaEmbedding(model_name="") emb = OllamaEmbedding(
model_name="", client_kwargs={"headers": {"Authorization": "Bearer token"}}
)
assert isinstance(emb, BaseEmbedding) assert isinstance(emb, BaseEmbedding)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment