From 6e111d18ec71c750f9d6d59c49a7abbe2f55c0b8 Mon Sep 17 00:00:00 2001
From: Allen Porter <allen@thebends.org>
Date: Thu, 9 Jan 2025 08:18:25 -0800
Subject: [PATCH] Allow unregistering LLM APIs (#135162)

---
 homeassistant/helpers/llm.py |  9 ++++-
 tests/helpers/test_llm.py    | 66 +++++++++++++++++++++++++++++++++---
 2 files changed, 69 insertions(+), 6 deletions(-)

diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py
index 38d80d5649d..cb303f4aa65 100644
--- a/homeassistant/helpers/llm.py
+++ b/homeassistant/helpers/llm.py
@@ -85,7 +85,7 @@ def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
 
 
 @callback
-def async_register_api(hass: HomeAssistant, api: API) -> None:
+def async_register_api(hass: HomeAssistant, api: API) -> Callable[[], None]:
     """Register an API to be exposed to LLMs."""
     apis = _async_get_apis(hass)
 
@@ -94,6 +94,13 @@ def async_register_api(hass: HomeAssistant, api: API) -> None:
 
     apis[api.id] = api
 
+    @callback
+    def unregister() -> None:
+        """Unregister the API."""
+        apis.pop(api.id)
+
+    return unregister
+
 
 async def async_get_api(
     hass: HomeAssistant, api_id: str, llm_context: LLMContext
diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py
index 3787526c433..5348348bb0d 100644
--- a/tests/helpers/test_llm.py
+++ b/tests/helpers/test_llm.py
@@ -39,6 +39,14 @@ def llm_context() -> llm.LLMContext:
     )
 
 
+class MyAPI(llm.API):
+    """Test API."""
+
+    async def async_get_api_instance(self, _: llm.ToolInput) -> llm.APIInstance:
+        """Return a list of tools."""
+        return llm.APIInstance(self, "", [], llm_context)
+
+
 async def test_get_api_no_existing(
     hass: HomeAssistant, llm_context: llm.LLMContext
 ) -> None:
@@ -50,11 +58,6 @@ async def test_get_api_no_existing(
 async def test_register_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
     """Test registering an llm api."""
 
-    class MyAPI(llm.API):
-        async def async_get_api_instance(self, _: llm.ToolInput) -> llm.APIInstance:
-            """Return a list of tools."""
-            return llm.APIInstance(self, "", [], llm_context)
-
     api = MyAPI(hass=hass, id="test", name="Test")
     llm.async_register_api(hass, api)
 
@@ -66,6 +69,59 @@ async def test_register_api(hass: HomeAssistant, llm_context: llm.LLMContext) ->
         llm.async_register_api(hass, api)
 
 
+async def test_unregister_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
+    """Test unregistering an llm api."""
+
+    unreg = llm.async_register_api(hass, MyAPI(hass=hass, id="test", name="Test"))
+    assert await llm.async_get_api(hass, "test", llm_context)
+    unreg()
+    with pytest.raises(HomeAssistantError):
+        assert await llm.async_get_api(hass, "test", llm_context)
+
+
+async def test_reregister_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
+    """Test unregistering an llm api then re-registering with the same id."""
+
+    unreg = llm.async_register_api(hass, MyAPI(hass=hass, id="test", name="Test"))
+    assert await llm.async_get_api(hass, "test", llm_context)
+    unreg()
+    llm.async_register_api(hass, MyAPI(hass=hass, id="test", name="Test"))
+    assert await llm.async_get_api(hass, "test", llm_context)
+
+
+async def test_unregister_twice(
+    hass: HomeAssistant, llm_context: llm.LLMContext
+) -> None:
+    """Test unregistering an llm api twice."""
+
+    unreg = llm.async_register_api(hass, MyAPI(hass=hass, id="test", name="Test"))
+    assert await llm.async_get_api(hass, "test", llm_context)
+    unreg()
+
+    # Unregistering twice is a bug that should not happen
+    with pytest.raises(KeyError):
+        unreg()
+
+
+async def test_multiple_apis(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
+    """Test registering multiple APIs."""
+
+    unreg1 = llm.async_register_api(hass, MyAPI(hass=hass, id="test-1", name="Test 1"))
+    llm.async_register_api(hass, MyAPI(hass=hass, id="test-2", name="Test 2"))
+
+    # Verify both Apis are registered
+    assert await llm.async_get_api(hass, "test-1", llm_context)
+    assert await llm.async_get_api(hass, "test-2", llm_context)
+
+    # Unregister and verify only one is left
+    unreg1()
+
+    with pytest.raises(HomeAssistantError):
+        assert await llm.async_get_api(hass, "test-1", llm_context)
+
+    assert await llm.async_get_api(hass, "test-2", llm_context)
+
+
 async def test_call_tool_no_existing(
     hass: HomeAssistant, llm_context: llm.LLMContext
 ) -> None:
-- 
GitLab