From d838683383140ddcd2f42e61346970366ccd1121 Mon Sep 17 00:00:00 2001
From: Farzad Sunavala <40604067+farzad528@users.noreply.github.com>
Date: Sun, 10 Nov 2024 21:14:29 -0600
Subject: [PATCH] Add UserAgent header "llamaindex-python" for azure search
 (#16895)

---
 .../AzureAISearchIndexDemo.ipynb              |  4 +-
 .../vector_stores/azureaisearch/base.py       | 26 +++++++-
 .../pyproject.toml                            |  2 +-
 .../tests/test_azureaisearch.py               | 61 ++++++++++++++++---
 4 files changed, 79 insertions(+), 14 deletions(-)

diff --git a/docs/docs/examples/vector_stores/AzureAISearchIndexDemo.ipynb b/docs/docs/examples/vector_stores/AzureAISearchIndexDemo.ipynb
index 97ea0fa21c..3422f3e12f 100644
--- a/docs/docs/examples/vector_stores/AzureAISearchIndexDemo.ipynb
+++ b/docs/docs/examples/vector_stores/AzureAISearchIndexDemo.ipynb
@@ -88,7 +88,7 @@
    "source": [
     "aoai_api_key = \"YOUR_AZURE_OPENAI_API_KEY\"\n",
     "aoai_endpoint = \"YOUR_AZURE_OPENAI_ENDPOINT\"\n",
-    "aoai_api_version = \"2024-02-01\"\n",
+    "aoai_api_version = \"2024-10-21\"\n",
     "\n",
     "llm = AzureOpenAI(\n",
     "    model=\"YOUR_AZURE_OPENAI_COMPLETION_MODEL_NAME\",\n",
@@ -697,7 +697,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "llamaindextest01",
+   "display_name": "llama-index-R8MZM3d9-py3.11",
    "language": "python",
    "name": "python3"
   },
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py
index c0da5d1812..0601a408fc 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py
@@ -140,6 +140,7 @@ class AzureAISearchVectorStore(BasePydanticVectorStore):
     ] = PrivateAttr()
     _vector_profile_name: str = PrivateAttr()
     _compression_type: str = PrivateAttr()
+    _user_agent: str = PrivateAttr()
 
     def _normalise_metadata_to_index_fields(
         self,
@@ -547,6 +548,7 @@ class AzureAISearchVectorStore(BasePydanticVectorStore):
         # https://learn.microsoft.com/en-us/azure/search/index-add-language-analyzers
         language_analyzer: str = "en.lucene",
         compression_type: str = "none",
+        user_agent: Optional[str] = None,
         **kwargs: Any,
     ) -> None:
         # ruff: noqa: E501
@@ -611,6 +613,10 @@ class AzureAISearchVectorStore(BasePydanticVectorStore):
             raise ImportError(import_err_msg)
 
         super().__init__()
+        base_user_agent = "llamaindex-python"
+        self._user_agent = (
+            f"{base_user_agent} {user_agent}" if user_agent else base_user_agent
+        )
 
         self._index_client: SearchIndexClient = cast(SearchIndexClient, None)
         self._async_index_client: AsyncSearchIndexClient = cast(
@@ -638,7 +644,9 @@ class AzureAISearchVectorStore(BasePydanticVectorStore):
             if isinstance(search_or_index_client, SearchIndexClient):
                 # If SearchIndexClient is supplied so must index_name
                 self._index_client = cast(SearchIndexClient, search_or_index_client)
-
+                self._index_client._client._config.user_agent_policy.add_user_agent(
+                    self._user_agent
+                )
                 if not index_name:
                     raise ValueError(
                         "index_name must be supplied if search_or_index_client is of "
@@ -648,12 +656,18 @@ class AzureAISearchVectorStore(BasePydanticVectorStore):
                 self._search_client = self._index_client.get_search_client(
                     index_name=index_name
                 )
+                self._search_client._client._config.user_agent_policy.add_user_agent(
+                    self._user_agent
+                )
 
             elif isinstance(search_or_index_client, AsyncSearchIndexClient):
                 # If SearchIndexClient is supplied so must index_name
                 self._async_index_client = cast(
                     AsyncSearchIndexClient, search_or_index_client
                 )
+                self._async_index_client._client._config.user_agent_policy.add_user_agent(
+                    self._user_agent
+                )
 
                 if not index_name:
                     raise ValueError(
@@ -664,10 +678,15 @@ class AzureAISearchVectorStore(BasePydanticVectorStore):
                 self._async_search_client = self._async_index_client.get_search_client(
                     index_name=index_name
                 )
+                self._async_search_client._client._config.user_agent_policy.add_user_agent(
+                    self._user_agent
+                )
 
             elif isinstance(search_or_index_client, SearchClient):
                 self._search_client = cast(SearchClient, search_or_index_client)
-
+                self._search_client._client._config.user_agent_policy.add_user_agent(
+                    self._user_agent
+                )
                 # Validate index_name
                 if index_name:
                     raise ValueError(
@@ -679,6 +698,9 @@ class AzureAISearchVectorStore(BasePydanticVectorStore):
                 self._async_search_client = cast(
                     AsyncSearchClient, search_or_index_client
                 )
+                self._async_search_client._client._config.user_agent_policy.add_user_agent(
+                    self._user_agent
+                )
 
                 # Validate index_name
                 if index_name:
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml
index fac7b28b40..aa24513187 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml
@@ -28,7 +28,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-vector-stores-azureaisearch"
 readme = "README.md"
-version = "0.2.8"
+version = "0.2.9"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/tests/test_azureaisearch.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/tests/test_azureaisearch.py
index 8ec813c227..8bfa66ce77 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/tests/test_azureaisearch.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/tests/test_azureaisearch.py
@@ -21,6 +21,22 @@ except ImportError:
     search_client = None  # type: ignore
 
 
+def mock_client_with_user_agent(client_type: str) -> Any:
+    """Helper function to create a mock client with user agent configuration."""
+    if client_type == "search":
+        client = MagicMock(spec=SearchClient)
+    else:
+        client = MagicMock(spec=SearchIndexClient)
+
+    # Mock the configuration chain
+    client._client = MagicMock()
+    client._client._config = MagicMock()
+    client._client._config.user_agent_policy = MagicMock()
+    client._client._config.user_agent_policy.add_user_agent = MagicMock()
+
+    return client
+
+
 def create_mock_vector_store(
     search_client: Any,
     index_name: Optional[str] = None,
@@ -33,7 +49,7 @@ def create_mock_vector_store(
         embedding_field_key="embedding",
         metadata_string_field_key="metadata",
         doc_id_field_key="doc_id",
-        filterable_metadata_field_keys=[],  # Added to match the updated constructor
+        filterable_metadata_field_keys=[],
         hidden_field_keys=["embedding"],
         index_name=index_name,
         index_management=index_management,
@@ -58,11 +74,39 @@ def create_sample_documents(n: int) -> List[TextNode]:
     return nodes
 
 
+@pytest.mark.skipif(
+    not azureaisearch_installed, reason="azure-search-documents package not installed"
+)
+def test_user_agent_configuration() -> None:
+    """Test that user agent is properly configured."""
+    # Test with SearchClient
+    search_client = mock_client_with_user_agent("search")
+    vector_store = create_mock_vector_store(search_client)
+
+    # Verify user agent was added with the correct base agent string
+    search_client._client._config.user_agent_policy.add_user_agent.assert_called_with(
+        "llamaindex-python"
+    )
+
+    # Test with SearchIndexClient
+    index_client = mock_client_with_user_agent("index")
+    vector_store = create_mock_vector_store(
+        index_client,
+        index_name="test-index",
+        index_management=IndexManagement.NO_VALIDATION,
+    )
+
+    # Verify user agent was added with the correct base agent string
+    index_client._client._config.user_agent_policy.add_user_agent.assert_called_with(
+        "llamaindex-python"
+    )
+
+
 @pytest.mark.skipif(
     not azureaisearch_installed, reason="azure-search-documents package not installed"
 )
 def test_azureaisearch_add_two_batches() -> None:
-    search_client = MagicMock(spec=SearchClient)
+    search_client = mock_client_with_user_agent("search")
 
     with patch("azure.search.documents.IndexDocumentsBatch") as MockIndexDocumentsBatch:
         index_documents_batch_instance = MockIndexDocumentsBatch.return_value
@@ -75,7 +119,7 @@ def test_azureaisearch_add_two_batches() -> None:
 
         assert ids is not None
         assert len(ids) == 11
-        assert call_count == 11  # Adjust this value based on your logic
+        assert call_count == 11
         assert search_client.index_documents.call_count == 1
 
 
@@ -83,7 +127,7 @@ def test_azureaisearch_add_two_batches() -> None:
     not azureaisearch_installed, reason="azure-search-documents package not installed"
 )
 def test_azureaisearch_add_one_batch() -> None:
-    search_client = MagicMock(spec=SearchClient)
+    search_client = mock_client_with_user_agent("search")
 
     with patch("azure.search.documents.IndexDocumentsBatch") as MockIndexDocumentsBatch:
         index_documents_batch_instance = MockIndexDocumentsBatch.return_value
@@ -96,7 +140,7 @@ def test_azureaisearch_add_one_batch() -> None:
 
         assert ids is not None
         assert len(ids) == 11
-        assert call_count == 11  # Adjust this value based on your logic
+        assert call_count == 11
         assert search_client.index_documents.call_count == 1
 
 
@@ -104,7 +148,7 @@ def test_azureaisearch_add_one_batch() -> None:
     not azureaisearch_installed, reason="azure-search-documents package not installed"
 )
 def test_invalid_index_management_for_searchclient() -> None:
-    search_client = MagicMock(spec=SearchClient)
+    search_client = mock_client_with_user_agent("search")
 
     # No error
     create_mock_vector_store(
@@ -112,7 +156,6 @@ def test_invalid_index_management_for_searchclient() -> None:
     )
 
     # Cannot supply index name
-    # ruff: noqa: E501
     with pytest.raises(
         ValueError,
         match="index_name cannot be supplied if search_or_index_client is of type azure.search.documents.SearchClient",
@@ -131,7 +174,7 @@ def test_invalid_index_management_for_searchclient() -> None:
     not azureaisearch_installed, reason="azure-search-documents package not installed"
 )
 def test_invalid_index_management_for_searchindexclient() -> None:
-    search_client = MagicMock(spec=SearchIndexClient)
+    search_client = mock_client_with_user_agent("index")
 
     # Index name must be supplied
     with pytest.raises(
@@ -154,7 +197,7 @@ def test_invalid_index_management_for_searchindexclient() -> None:
     not azureaisearch_installed, reason="azure-search-documents package not installed"
 )
 def test_azureaisearch_query() -> None:
-    search_client = MagicMock(spec=SearchClient)
+    search_client = mock_client_with_user_agent("search")
 
     # Mock the search method of the search client
     mock_search_results = [
-- 
GitLab