diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py index 170fe2557a0bb81bfd5d66ed96cf7924910a962e..8c80ed86316164632314bcc1ad93f7eababdaf95 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py @@ -126,7 +126,8 @@ class CohereEmbedding(BaseEmbedding): callback_manager: Optional[CallbackManager] = None, base_url: Optional[str] = None, timeout: Optional[float] = None, - httpx_client: Optional[httpx.AsyncClient] = None, + httpx_client: Optional[httpx.Client] = None, + httpx_async_client: Optional[httpx.AsyncClient] = None, ): """ A class representation for generating embeddings using the Cohere API. @@ -170,7 +171,7 @@ class CohereEmbedding(BaseEmbedding): client_name="llama_index", base_url=base_url, timeout=timeout, - httpx_client=httpx_client, + httpx_client=httpx_async_client, ), cohere_api_key=cohere_api_key, model_name=model_name, diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml index 39e3d005f39903145a988797a38a97cb1c45e3b1..be9b60c11f6bb52874bba53120eb477603d97e22 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-cohere" readme = "README.md" -version = "0.1.7" +version = "0.1.8" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" @@ -41,6 +41,7 @@ mypy = "0.991" pre-commit = "3.2.0" pylint = "2.15.10" pytest = "7.2.1" +pytest-asyncio = "^0.23.6" pytest-mock = "3.11.1" ruff = "0.0.292" tree-sitter-languages = "^1.8.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dabf212d7e7162849c24a733909ac4f645d75a31 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..1f8e826c777e2457a27dbe9ccf961bd39f7ae3e8 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py @@ -0,0 +1,43 @@ +import os + +import httpx +import pytest +from llama_index.core.base.embeddings.base import BaseEmbedding + +from llama_index.embeddings.cohere import CohereEmbedding + + +def test_embedding_class(): + emb = CohereEmbedding(cohere_api_key="token") + assert isinstance(emb, BaseEmbedding) + + +@pytest.mark.skipif( + os.environ.get("CO_API_KEY") is None, reason="Cohere API key required" +) +def test_sync_embedding(): + emb = CohereEmbedding( + cohere_api_key=os.environ["CO_API_KEY"], + model_name="embed-english-v3.0", + input_type="clustering", + embedding_type="float", + httpx_client=httpx.Client(), + ) + + emb.get_query_embedding("I love Cohere!") + + +@pytest.mark.skipif( + os.environ.get("CO_API_KEY") is None, reason="Cohere API key required" +) +@pytest.mark.asyncio() +async def test_async_embedding(): + emb = CohereEmbedding( + cohere_api_key=os.environ["CO_API_KEY"], + model_name="embed-english-v3.0", + input_type="clustering", + embedding_type="float", + httpx_async_client=httpx.AsyncClient(), + ) + + await emb.aget_query_embedding("I love Cohere!")