Skip to content
Snippets Groups Projects
Unverified Commit 649fda11 authored by Logan's avatar Logan Committed by GitHub
Browse files

add sync httpx client support for openai/azure (#13370)

parent da1732ef
Branches
Tags
No related merge requests found
...@@ -102,6 +102,7 @@ class AzureOpenAI(OpenAI): ...@@ -102,6 +102,7 @@ class AzureOpenAI(OpenAI):
deployment: Optional[str] = None, deployment: Optional[str] = None,
# custom httpx client # custom httpx client
http_client: Optional[httpx.Client] = None, http_client: Optional[httpx.Client] = None,
async_http_client: Optional[httpx.AsyncClient] = None,
# base class # base class
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
...@@ -138,6 +139,7 @@ class AzureOpenAI(OpenAI): ...@@ -138,6 +139,7 @@ class AzureOpenAI(OpenAI):
api_version=api_version, api_version=api_version,
callback_manager=callback_manager, callback_manager=callback_manager,
http_client=http_client, http_client=http_client,
async_http_client=async_http_client,
system_prompt=system_prompt, system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt, messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt, completion_to_prompt=completion_to_prompt,
...@@ -182,7 +184,9 @@ class AzureOpenAI(OpenAI): ...@@ -182,7 +184,9 @@ class AzureOpenAI(OpenAI):
) )
return self._aclient return self._aclient
def _get_credential_kwargs(self, **kwargs: Any) -> Dict[str, Any]: def _get_credential_kwargs(
self, is_async: bool = False, **kwargs: Any
) -> Dict[str, Any]:
if self.use_azure_ad: if self.use_azure_ad:
self._azure_ad_token = refresh_openai_azuread_token(self._azure_ad_token) self._azure_ad_token = refresh_openai_azuread_token(self._azure_ad_token)
self.api_key = self._azure_ad_token.token self.api_key = self._azure_ad_token.token
...@@ -206,7 +210,7 @@ class AzureOpenAI(OpenAI): ...@@ -206,7 +210,7 @@ class AzureOpenAI(OpenAI):
"azure_ad_token_provider": self.azure_ad_token_provider, "azure_ad_token_provider": self.azure_ad_token_provider,
"api_version": self.api_version, "api_version": self.api_version,
"default_headers": self.default_headers, "default_headers": self.default_headers,
"http_client": self._http_client, "http_client": self._async_http_client if is_async else self._http_client,
**kwargs, **kwargs,
} }
......
...@@ -29,7 +29,7 @@ exclude = ["**/BUILD"] ...@@ -29,7 +29,7 @@ exclude = ["**/BUILD"]
license = "MIT" license = "MIT"
name = "llama-index-llms-azure-openai" name = "llama-index-llms-azure-openai"
readme = "README.md" readme = "README.md"
version = "0.1.7" version = "0.1.8"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
......
...@@ -178,6 +178,7 @@ class OpenAI(FunctionCallingLLM): ...@@ -178,6 +178,7 @@ class OpenAI(FunctionCallingLLM):
_client: Optional[SyncOpenAI] = PrivateAttr() _client: Optional[SyncOpenAI] = PrivateAttr()
_aclient: Optional[AsyncOpenAI] = PrivateAttr() _aclient: Optional[AsyncOpenAI] = PrivateAttr()
_http_client: Optional[httpx.Client] = PrivateAttr() _http_client: Optional[httpx.Client] = PrivateAttr()
_async_http_client: Optional[httpx.AsyncClient] = PrivateAttr()
def __init__( def __init__(
self, self,
...@@ -194,6 +195,7 @@ class OpenAI(FunctionCallingLLM): ...@@ -194,6 +195,7 @@ class OpenAI(FunctionCallingLLM):
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
default_headers: Optional[Dict[str, str]] = None, default_headers: Optional[Dict[str, str]] = None,
http_client: Optional[httpx.Client] = None, http_client: Optional[httpx.Client] = None,
async_http_client: Optional[httpx.AsyncClient] = None,
# base class # base class
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
...@@ -234,6 +236,7 @@ class OpenAI(FunctionCallingLLM): ...@@ -234,6 +236,7 @@ class OpenAI(FunctionCallingLLM):
self._client = None self._client = None
self._aclient = None self._aclient = None
self._http_client = http_client self._http_client = http_client
self._async_http_client = async_http_client
def _get_client(self) -> SyncOpenAI: def _get_client(self) -> SyncOpenAI:
if not self.reuse_client: if not self.reuse_client:
...@@ -245,10 +248,10 @@ class OpenAI(FunctionCallingLLM): ...@@ -245,10 +248,10 @@ class OpenAI(FunctionCallingLLM):
def _get_aclient(self) -> AsyncOpenAI: def _get_aclient(self) -> AsyncOpenAI:
if not self.reuse_client: if not self.reuse_client:
return AsyncOpenAI(**self._get_credential_kwargs()) return AsyncOpenAI(**self._get_credential_kwargs(is_async=True))
if self._aclient is None: if self._aclient is None:
self._aclient = AsyncOpenAI(**self._get_credential_kwargs()) self._aclient = AsyncOpenAI(**self._get_credential_kwargs(is_async=True))
return self._aclient return self._aclient
def _get_model_name(self) -> str: def _get_model_name(self) -> str:
...@@ -331,14 +334,14 @@ class OpenAI(FunctionCallingLLM): ...@@ -331,14 +334,14 @@ class OpenAI(FunctionCallingLLM):
return kwargs["use_chat_completions"] return kwargs["use_chat_completions"]
return self.metadata.is_chat_model return self.metadata.is_chat_model
def _get_credential_kwargs(self) -> Dict[str, Any]: def _get_credential_kwargs(self, is_async: bool = False) -> Dict[str, Any]:
return { return {
"api_key": self.api_key, "api_key": self.api_key,
"base_url": self.api_base, "base_url": self.api_base,
"max_retries": self.max_retries, "max_retries": self.max_retries,
"timeout": self.timeout, "timeout": self.timeout,
"default_headers": self.default_headers, "default_headers": self.default_headers,
"http_client": self._http_client, "http_client": self._async_http_client if is_async else self._http_client,
} }
def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
......
...@@ -29,7 +29,7 @@ exclude = ["**/BUILD"] ...@@ -29,7 +29,7 @@ exclude = ["**/BUILD"]
license = "MIT" license = "MIT"
name = "llama-index-llms-openai" name = "llama-index-llms-openai"
readme = "README.md" readme = "README.md"
version = "0.1.17" version = "0.1.18"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment