diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py index 481e80a348b4359598b89aeb6bb4922a1eca83e5..5338e6c8ce8477f4caf25adb262afb321675887a 100644 --- a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py @@ -102,6 +102,7 @@ class AzureOpenAI(OpenAI): deployment: Optional[str] = None, # custom httpx client http_client: Optional[httpx.Client] = None, + async_http_client: Optional[httpx.AsyncClient] = None, # base class system_prompt: Optional[str] = None, messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, @@ -138,6 +139,7 @@ class AzureOpenAI(OpenAI): api_version=api_version, callback_manager=callback_manager, http_client=http_client, + async_http_client=async_http_client, system_prompt=system_prompt, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, @@ -182,7 +184,9 @@ class AzureOpenAI(OpenAI): ) return self._aclient - def _get_credential_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + def _get_credential_kwargs( + self, is_async: bool = False, **kwargs: Any + ) -> Dict[str, Any]: if self.use_azure_ad: self._azure_ad_token = refresh_openai_azuread_token(self._azure_ad_token) self.api_key = self._azure_ad_token.token @@ -206,7 +210,7 @@ class AzureOpenAI(OpenAI): "azure_ad_token_provider": self.azure_ad_token_provider, "api_version": self.api_version, "default_headers": self.default_headers, - "http_client": self._http_client, + "http_client": self._async_http_client if is_async else self._http_client, **kwargs, } diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml index 0924c6a43c3fc36b5d57f069f9d960507c598c04..7a211c4b89eea6297a7378dbc919db3c097697e1 100644 --- a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml @@ -29,7 +29,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-azure-openai" readme = "README.md" -version = "0.1.7" +version = "0.1.8" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py index f522e68a01c2fac4255149632f758f9272b25ac1..49eeb1b7b16b6a1aeb5dabb272ff4a9df6ee0e7b 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py @@ -178,6 +178,7 @@ class OpenAI(FunctionCallingLLM): _client: Optional[SyncOpenAI] = PrivateAttr() _aclient: Optional[AsyncOpenAI] = PrivateAttr() _http_client: Optional[httpx.Client] = PrivateAttr() + _async_http_client: Optional[httpx.AsyncClient] = PrivateAttr() def __init__( self, @@ -194,6 +195,7 @@ class OpenAI(FunctionCallingLLM): callback_manager: Optional[CallbackManager] = None, default_headers: Optional[Dict[str, str]] = None, http_client: Optional[httpx.Client] = None, + async_http_client: Optional[httpx.AsyncClient] = None, # base class system_prompt: Optional[str] = None, messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, @@ -234,6 +236,7 @@ class OpenAI(FunctionCallingLLM): self._client = None self._aclient = None self._http_client = http_client + self._async_http_client = async_http_client def _get_client(self) -> SyncOpenAI: if not self.reuse_client: @@ -245,10 +248,10 @@ class OpenAI(FunctionCallingLLM): def _get_aclient(self) -> AsyncOpenAI: if not self.reuse_client: - return AsyncOpenAI(**self._get_credential_kwargs()) + return AsyncOpenAI(**self._get_credential_kwargs(is_async=True)) if self._aclient is None: - self._aclient = AsyncOpenAI(**self._get_credential_kwargs()) + self._aclient = AsyncOpenAI(**self._get_credential_kwargs(is_async=True)) return self._aclient def _get_model_name(self) -> str: @@ -331,14 +334,14 @@ class OpenAI(FunctionCallingLLM): return kwargs["use_chat_completions"] return self.metadata.is_chat_model - def _get_credential_kwargs(self) -> Dict[str, Any]: + def _get_credential_kwargs(self, is_async: bool = False) -> Dict[str, Any]: return { "api_key": self.api_key, "base_url": self.api_base, "max_retries": self.max_retries, "timeout": self.timeout, "default_headers": self.default_headers, - "http_client": self._http_client, + "http_client": self._async_http_client if is_async else self._http_client, } def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: diff --git a/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml index bdbd0f82a9c28aa183db2fc8fc459fade69641c8..30716b32ec0836a53ad88f7abbeafbe1d2696483 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml @@ -29,7 +29,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-openai" readme = "README.md" -version = "0.1.17" +version = "0.1.18" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"