From b98552f9d20b9546b8b11e87a9d54895ea9bb8a7 Mon Sep 17 00:00:00 2001 From: billytrend-cohere <144115527+billytrend-cohere@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:23:04 -0500 Subject: [PATCH] Add httpx_async_client option (#12896) --- .../llama_index/embeddings/cohere/base.py | 5 ++- .../pyproject.toml | 3 +- .../llama-index-embeddings-cohere/tests/BUILD | 1 + .../tests/__init__.py | 0 .../tests/test_embeddings.py | 43 +++++++++++++++++++ 5 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/BUILD create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/__init__.py create mode 100644 llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py index 170fe2557a..8c80ed8631 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py @@ -126,7 +126,8 @@ class CohereEmbedding(BaseEmbedding): callback_manager: Optional[CallbackManager] = None, base_url: Optional[str] = None, timeout: Optional[float] = None, - httpx_client: Optional[httpx.AsyncClient] = None, + httpx_client: Optional[httpx.Client] = None, + httpx_async_client: Optional[httpx.AsyncClient] = None, ): """ A class representation for generating embeddings using the Cohere API. @@ -170,7 +171,7 @@ class CohereEmbedding(BaseEmbedding): client_name="llama_index", base_url=base_url, timeout=timeout, - httpx_client=httpx_client, + httpx_client=httpx_async_client, ), cohere_api_key=cohere_api_key, model_name=model_name, diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml index 39e3d005f3..be9b60c11f 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-cohere" readme = "README.md" -version = "0.1.7" +version = "0.1.8" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" @@ -41,6 +41,7 @@ mypy = "0.991" pre-commit = "3.2.0" pylint = "2.15.10" pytest = "7.2.1" +pytest-asyncio = "^0.23.6" pytest-mock = "3.11.1" ruff = "0.0.292" tree-sitter-languages = "^1.8.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/BUILD b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/BUILD new file mode 100644 index 0000000000..dabf212d7e --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/__init__.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py new file mode 100644 index 0000000000..1f8e826c77 --- /dev/null +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py @@ -0,0 +1,43 @@ +import os + +import httpx +import pytest +from llama_index.core.base.embeddings.base import BaseEmbedding + +from llama_index.embeddings.cohere import CohereEmbedding + + +def test_embedding_class(): + emb = CohereEmbedding(cohere_api_key="token") + assert isinstance(emb, BaseEmbedding) + + +@pytest.mark.skipif( + os.environ.get("CO_API_KEY") is None, reason="Cohere API key required" +) +def test_sync_embedding(): + emb = CohereEmbedding( + cohere_api_key=os.environ["CO_API_KEY"], + model_name="embed-english-v3.0", + input_type="clustering", + embedding_type="float", + httpx_client=httpx.Client(), + ) + + emb.get_query_embedding("I love Cohere!") + + +@pytest.mark.skipif( + os.environ.get("CO_API_KEY") is None, reason="Cohere API key required" +) +@pytest.mark.asyncio() +async def test_async_embedding(): + emb = CohereEmbedding( + cohere_api_key=os.environ["CO_API_KEY"], + model_name="embed-english-v3.0", + input_type="clustering", + embedding_type="float", + httpx_async_client=httpx.AsyncClient(), + ) + + await emb.aget_query_embedding("I love Cohere!") -- GitLab