Skip to content
Snippets Groups Projects
Unverified Commit b98552f9 authored by billytrend-cohere's avatar billytrend-cohere Committed by GitHub
Browse files

Add httpx_async_client option (#12896)

parent 19497b67
Branches
Tags
No related merge requests found
...@@ -126,7 +126,8 @@ class CohereEmbedding(BaseEmbedding): ...@@ -126,7 +126,8 @@ class CohereEmbedding(BaseEmbedding):
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
httpx_client: Optional[httpx.AsyncClient] = None, httpx_client: Optional[httpx.Client] = None,
httpx_async_client: Optional[httpx.AsyncClient] = None,
): ):
""" """
A class representation for generating embeddings using the Cohere API. A class representation for generating embeddings using the Cohere API.
...@@ -170,7 +171,7 @@ class CohereEmbedding(BaseEmbedding): ...@@ -170,7 +171,7 @@ class CohereEmbedding(BaseEmbedding):
client_name="llama_index", client_name="llama_index",
base_url=base_url, base_url=base_url,
timeout=timeout, timeout=timeout,
httpx_client=httpx_client, httpx_client=httpx_async_client,
), ),
cohere_api_key=cohere_api_key, cohere_api_key=cohere_api_key,
model_name=model_name, model_name=model_name,
......
...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"] ...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT" license = "MIT"
name = "llama-index-embeddings-cohere" name = "llama-index-embeddings-cohere"
readme = "README.md" readme = "README.md"
version = "0.1.7" version = "0.1.8"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
...@@ -41,6 +41,7 @@ mypy = "0.991" ...@@ -41,6 +41,7 @@ mypy = "0.991"
pre-commit = "3.2.0" pre-commit = "3.2.0"
pylint = "2.15.10" pylint = "2.15.10"
pytest = "7.2.1" pytest = "7.2.1"
pytest-asyncio = "^0.23.6"
pytest-mock = "3.11.1" pytest-mock = "3.11.1"
ruff = "0.0.292" ruff = "0.0.292"
tree-sitter-languages = "^1.8.0" tree-sitter-languages = "^1.8.0"
......
python_tests()
import os
import httpx
import pytest
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.cohere import CohereEmbedding
def test_embedding_class():
emb = CohereEmbedding(cohere_api_key="token")
assert isinstance(emb, BaseEmbedding)
@pytest.mark.skipif(
os.environ.get("CO_API_KEY") is None, reason="Cohere API key required"
)
def test_sync_embedding():
emb = CohereEmbedding(
cohere_api_key=os.environ["CO_API_KEY"],
model_name="embed-english-v3.0",
input_type="clustering",
embedding_type="float",
httpx_client=httpx.Client(),
)
emb.get_query_embedding("I love Cohere!")
@pytest.mark.skipif(
os.environ.get("CO_API_KEY") is None, reason="Cohere API key required"
)
@pytest.mark.asyncio()
async def test_async_embedding():
emb = CohereEmbedding(
cohere_api_key=os.environ["CO_API_KEY"],
model_name="embed-english-v3.0",
input_type="clustering",
embedding_type="float",
httpx_async_client=httpx.AsyncClient(),
)
await emb.aget_query_embedding("I love Cohere!")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment