From 649fda11f86ad7f39f459b05b00985e3c4d3a866 Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Wed, 8 May 2024 15:15:57 -0600 Subject: [PATCH] add sync httpx client support for openai/azure (#13370) --- .../llama_index/llms/azure_openai/base.py | 8 ++++++-- .../llms/llama-index-llms-azure-openai/pyproject.toml | 2 +- .../llama_index/llms/openai/base.py | 11 +++++++---- .../llms/llama-index-llms-openai/pyproject.toml | 2 +- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py index 481e80a34..5338e6c8c 100644 --- a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py @@ -102,6 +102,7 @@ class AzureOpenAI(OpenAI): deployment: Optional[str] = None, # custom httpx client http_client: Optional[httpx.Client] = None, + async_http_client: Optional[httpx.AsyncClient] = None, # base class system_prompt: Optional[str] = None, messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, @@ -138,6 +139,7 @@ class AzureOpenAI(OpenAI): api_version=api_version, callback_manager=callback_manager, http_client=http_client, + async_http_client=async_http_client, system_prompt=system_prompt, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, @@ -182,7 +184,9 @@ class AzureOpenAI(OpenAI): ) return self._aclient - def _get_credential_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + def _get_credential_kwargs( + self, is_async: bool = False, **kwargs: Any + ) -> Dict[str, Any]: if self.use_azure_ad: self._azure_ad_token = refresh_openai_azuread_token(self._azure_ad_token) self.api_key = self._azure_ad_token.token @@ -206,7 +210,7 @@ class AzureOpenAI(OpenAI): "azure_ad_token_provider": self.azure_ad_token_provider, "api_version": self.api_version, "default_headers": self.default_headers, - "http_client": self._http_client, + "http_client": self._async_http_client if is_async else self._http_client, **kwargs, } diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml index 0924c6a43..7a211c4b8 100644 --- a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml @@ -29,7 +29,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-azure-openai" readme = "README.md" -version = "0.1.7" +version = "0.1.8" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py index f522e68a0..49eeb1b7b 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py @@ -178,6 +178,7 @@ class OpenAI(FunctionCallingLLM): _client: Optional[SyncOpenAI] = PrivateAttr() _aclient: Optional[AsyncOpenAI] = PrivateAttr() _http_client: Optional[httpx.Client] = PrivateAttr() + _async_http_client: Optional[httpx.AsyncClient] = PrivateAttr() def __init__( self, @@ -194,6 +195,7 @@ class OpenAI(FunctionCallingLLM): callback_manager: Optional[CallbackManager] = None, default_headers: Optional[Dict[str, str]] = None, http_client: Optional[httpx.Client] = None, + async_http_client: Optional[httpx.AsyncClient] = None, # base class system_prompt: Optional[str] = None, messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, @@ -234,6 +236,7 @@ class OpenAI(FunctionCallingLLM): self._client = None self._aclient = None self._http_client = http_client + self._async_http_client = async_http_client def _get_client(self) -> SyncOpenAI: if not self.reuse_client: @@ -245,10 +248,10 @@ class OpenAI(FunctionCallingLLM): def _get_aclient(self) -> AsyncOpenAI: if not self.reuse_client: - return AsyncOpenAI(**self._get_credential_kwargs()) + return AsyncOpenAI(**self._get_credential_kwargs(is_async=True)) if self._aclient is None: - self._aclient = AsyncOpenAI(**self._get_credential_kwargs()) + self._aclient = AsyncOpenAI(**self._get_credential_kwargs(is_async=True)) return self._aclient def _get_model_name(self) -> str: @@ -331,14 +334,14 @@ class OpenAI(FunctionCallingLLM): return kwargs["use_chat_completions"] return self.metadata.is_chat_model - def _get_credential_kwargs(self) -> Dict[str, Any]: + def _get_credential_kwargs(self, is_async: bool = False) -> Dict[str, Any]: return { "api_key": self.api_key, "base_url": self.api_base, "max_retries": self.max_retries, "timeout": self.timeout, "default_headers": self.default_headers, - "http_client": self._http_client, + "http_client": self._async_http_client if is_async else self._http_client, } def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: diff --git a/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml index bdbd0f82a..30716b32e 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml @@ -29,7 +29,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-openai" readme = "README.md" -version = "0.1.17" +version = "0.1.18" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -- GitLab