From 156948c4960b34e579a2047eb7b570960767c2fa Mon Sep 17 00:00:00 2001
From: Paulus Schoutsen <balloob@gmail.com>
Date: Mon, 26 Aug 2024 19:39:09 +0200
Subject: [PATCH] Fix defaults for cloud STT/TTS (#121229)

* Fix defaults for cloud STT/TTS

* Prefer entity over legacy provider

* Remove unrealistic tests

* Add tests which show cloud stt/tts entity is preferred

---------

Co-authored-by: Erik <erik@montnemery.com>
---
 homeassistant/components/stt/__init__.py | 15 ++++-
 homeassistant/components/tts/__init__.py | 13 +++--
 tests/components/stt/test_init.py        | 72 ++++++++++++++++++++----
 tests/components/tts/common.py           | 10 ++--
 tests/components/tts/conftest.py         | 20 +++++--
 tests/components/tts/test_init.py        | 44 +++++++++++++--
 6 files changed, 142 insertions(+), 32 deletions(-)

diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py
index 227c92f2b98..f6c38c1e0b7 100644
--- a/homeassistant/components/stt/__init__.py
+++ b/homeassistant/components/stt/__init__.py
@@ -72,9 +72,18 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
 @callback
 def async_default_engine(hass: HomeAssistant) -> str | None:
     """Return the domain or entity id of the default engine."""
-    return next(
-        iter(hass.states.async_entity_ids(DOMAIN)), None
-    ) or async_default_provider(hass)
+    component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
+
+    default_entity_id: str | None = None
+
+    for entity in component.entities:
+        if entity.platform and entity.platform.platform_name == "cloud":
+            return entity.entity_id
+
+        if default_entity_id is None:
+            default_entity_id = entity.entity_id
+
+    return default_entity_id or async_default_provider(hass)
 
 
 @callback
diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py
index 5286b01f67f..583db4472d4 100644
--- a/homeassistant/components/tts/__init__.py
+++ b/homeassistant/components/tts/__init__.py
@@ -137,15 +137,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None:
     component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
     manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
 
-    if "cloud" in manager.providers:
-        return "cloud"
+    default_entity_id: str | None = None
 
-    entity = next(iter(component.entities), None)
+    for entity in component.entities:
+        if entity.platform and entity.platform.platform_name == "cloud":
+            return entity.entity_id
 
-    if entity is not None:
-        return entity.entity_id
+        if default_entity_id is None:
+            default_entity_id = entity.entity_id
 
-    return next(iter(manager.providers), None)
+    return default_entity_id or next(iter(manager.providers), None)
 
 
 @callback
diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py
index a42ac44112e..e5d75d3c4a5 100644
--- a/tests/components/stt/test_init.py
+++ b/tests/components/stt/test_init.py
@@ -1,6 +1,7 @@
 """Test STT component setup."""
 
-from collections.abc import AsyncIterable, Generator
+from collections.abc import AsyncIterable, Generator, Iterable
+from contextlib import ExitStack
 from http import HTTPStatus
 from pathlib import Path
 from unittest.mock import AsyncMock
@@ -122,20 +123,23 @@ class STTFlow(ConfigFlow):
     """Test flow."""
 
 
-@pytest.fixture(name="config_flow_test_domain")
-def config_flow_test_domain_fixture() -> str:
+@pytest.fixture(name="config_flow_test_domains")
+def config_flow_test_domain_fixture() -> Iterable[str]:
     """Test domain fixture."""
-    return TEST_DOMAIN
+    return (TEST_DOMAIN,)
 
 
 @pytest.fixture(autouse=True)
 def config_flow_fixture(
-    hass: HomeAssistant, config_flow_test_domain: str
+    hass: HomeAssistant, config_flow_test_domains: Iterable[str]
 ) -> Generator[None]:
     """Mock config flow."""
-    mock_platform(hass, f"{config_flow_test_domain}.config_flow")
+    for domain in config_flow_test_domains:
+        mock_platform(hass, f"{domain}.config_flow")
 
-    with mock_config_flow(config_flow_test_domain, STTFlow):
+    with ExitStack() as stack:
+        for domain in config_flow_test_domains:
+            stack.enter_context(mock_config_flow(domain, STTFlow))
         yield
 
 
@@ -496,21 +500,25 @@ async def test_default_engine_entity(
     assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
 
 
-@pytest.mark.parametrize("config_flow_test_domain", ["new_test"])
+@pytest.mark.parametrize("config_flow_test_domains", [("new_test",)])
 async def test_default_engine_prefer_entity(
     hass: HomeAssistant,
     tmp_path: Path,
     mock_provider_entity: MockProviderEntity,
     mock_provider: MockProvider,
-    config_flow_test_domain: str,
+    config_flow_test_domains: str,
 ) -> None:
-    """Test async_default_engine."""
+    """Test async_default_engine.
+
+    In this tests there's an entity and a legacy provider.
+    The test asserts async_default_engine returns the entity.
+    """
     mock_provider_entity.url_path = "stt.new_test"
     mock_provider_entity._attr_name = "New test"
 
     await mock_setup(hass, tmp_path, mock_provider)
     await mock_config_entry_setup(
-        hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domain
+        hass, tmp_path, mock_provider_entity, test_domain=config_flow_test_domains[0]
     )
     await hass.async_block_till_done()
 
@@ -523,6 +531,48 @@ async def test_default_engine_prefer_entity(
     assert async_default_engine(hass) == "stt.new_test"
 
 
+@pytest.mark.parametrize(
+    "config_flow_test_domains",
+    [
+        # Test different setup order to ensure the default is not influenced
+        # by setup order.
+        ("cloud", "new_test"),
+        ("new_test", "cloud"),
+    ],
+)
+async def test_default_engine_prefer_cloud_entity(
+    hass: HomeAssistant,
+    tmp_path: Path,
+    mock_provider: MockProvider,
+    config_flow_test_domains: str,
+) -> None:
+    """Test async_default_engine.
+
+    In this tests there's an entity from domain cloud, an entity from domain new_test
+    and a legacy provider.
+    The test asserts async_default_engine returns the entity from domain cloud.
+    """
+    await mock_setup(hass, tmp_path, mock_provider)
+    for domain in config_flow_test_domains:
+        entity = MockProviderEntity()
+        entity.url_path = f"stt.{domain}"
+        entity._attr_name = f"{domain} STT entity"
+        await mock_config_entry_setup(hass, tmp_path, entity, test_domain=domain)
+    await hass.async_block_till_done()
+
+    for domain in config_flow_test_domains:
+        entity_engine = async_get_speech_to_text_engine(
+            hass, f"stt.{domain}_stt_entity"
+        )
+        assert entity_engine is not None
+        assert entity_engine.name == f"{domain} STT entity"
+
+    provider_engine = async_get_speech_to_text_engine(hass, "test")
+    assert provider_engine is not None
+    assert provider_engine.name == "test"
+    assert async_default_engine(hass) == "stt.cloud_stt_entity"
+
+
 async def test_get_engine_legacy(
     hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
 ) -> None:
diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py
index 71edf29721f..4acba401fad 100644
--- a/tests/components/tts/common.py
+++ b/tests/components/tts/common.py
@@ -215,7 +215,9 @@ async def mock_setup(
 
 
 async def mock_config_entry_setup(
-    hass: HomeAssistant, tts_entity: MockTTSEntity
+    hass: HomeAssistant,
+    tts_entity: MockTTSEntity,
+    test_domain: str = TEST_DOMAIN,
 ) -> MockConfigEntry:
     """Set up a test tts platform via config entry."""
 
@@ -236,7 +238,7 @@ async def mock_config_entry_setup(
     mock_integration(
         hass,
         MockModule(
-            TEST_DOMAIN,
+            test_domain,
             async_setup_entry=async_setup_entry_init,
             async_unload_entry=async_unload_entry_init,
         ),
@@ -251,9 +253,9 @@ async def mock_config_entry_setup(
         async_add_entities([tts_entity])
 
     loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
-    mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", loaded_platform)
+    mock_platform(hass, f"{test_domain}.{TTS_DOMAIN}", loaded_platform)
 
-    config_entry = MockConfigEntry(domain=TEST_DOMAIN)
+    config_entry = MockConfigEntry(domain=test_domain)
     config_entry.add_to_hass(hass)
     assert await hass.config_entries.async_setup(config_entry.entry_id)
     await hass.async_block_till_done()
diff --git a/tests/components/tts/conftest.py b/tests/components/tts/conftest.py
index d9a4499f544..91ddd7742af 100644
--- a/tests/components/tts/conftest.py
+++ b/tests/components/tts/conftest.py
@@ -3,7 +3,8 @@
 From http://doc.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
 """
 
-from collections.abc import Generator
+from collections.abc import Generator, Iterable
+from contextlib import ExitStack
 from pathlib import Path
 from unittest.mock import MagicMock
 
@@ -81,12 +82,23 @@ class TTSFlow(ConfigFlow):
     """Test flow."""
 
 
+@pytest.fixture(name="config_flow_test_domains")
+def config_flow_test_domain_fixture() -> Iterable[str]:
+    """Test domain fixture."""
+    return (TEST_DOMAIN,)
+
+
 @pytest.fixture(autouse=True)
-def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
+def config_flow_fixture(
+    hass: HomeAssistant, config_flow_test_domains: Iterable[str]
+) -> Generator[None]:
     """Mock config flow."""
-    mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
+    for domain in config_flow_test_domains:
+        mock_platform(hass, f"{domain}.config_flow")
 
-    with mock_config_flow(TEST_DOMAIN, TTSFlow):
+    with ExitStack() as stack:
+        for domain in config_flow_test_domains:
+            stack.enter_context(mock_config_flow(domain, TTSFlow))
         yield
 
 
diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py
index 55ff4492e80..05c19622e84 100644
--- a/tests/components/tts/test_init.py
+++ b/tests/components/tts/test_init.py
@@ -1389,9 +1389,6 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
     ):
         assert tts.async_resolve_engine(hass, None) is None
 
-    with patch.dict(hass.data[tts.DATA_TTS_MANAGER].providers, {"cloud": object()}):
-        assert tts.async_resolve_engine(hass, None) == "cloud"
-
 
 @pytest.mark.parametrize(
     ("setup", "engine_id"),
@@ -1845,7 +1842,11 @@ async def test_default_engine_prefer_entity(
     mock_tts_entity: MockTTSEntity,
     mock_provider: MockProvider,
 ) -> None:
-    """Test async_default_engine."""
+    """Test async_default_engine.
+
+    In this tests there's an entity and a legacy provider.
+    The test asserts async_default_engine returns the entity.
+    """
     mock_tts_entity._attr_name = "New test"
 
     await mock_setup(hass, mock_provider)
@@ -1857,3 +1858,38 @@ async def test_default_engine_prefer_entity(
     provider_engine = tts.async_resolve_engine(hass, "test")
     assert provider_engine == "test"
     assert tts.async_default_engine(hass) == "tts.new_test"
+
+
+@pytest.mark.parametrize(
+    "config_flow_test_domains",
+    [
+        # Test different setup order to ensure the default is not influenced
+        # by setup order.
+        ("cloud", "new_test"),
+        ("new_test", "cloud"),
+    ],
+)
+async def test_default_engine_prefer_cloud_entity(
+    hass: HomeAssistant,
+    mock_provider: MockProvider,
+    config_flow_test_domains: str,
+) -> None:
+    """Test async_default_engine.
+
+    In this tests there's an entity from domain cloud, an entity from domain new_test
+    and a legacy provider.
+    The test asserts async_default_engine returns the entity from domain cloud.
+    """
+    await mock_setup(hass, mock_provider)
+    for domain in config_flow_test_domains:
+        entity = MockTTSEntity(DEFAULT_LANG)
+        entity._attr_name = f"{domain} TTS entity"
+        await mock_config_entry_setup(hass, entity, test_domain=domain)
+    await hass.async_block_till_done()
+
+    for domain in config_flow_test_domains:
+        entity_engine = tts.async_resolve_engine(hass, f"tts.{domain}_tts_entity")
+        assert entity_engine == f"tts.{domain}_tts_entity"
+    provider_engine = tts.async_resolve_engine(hass, "test")
+    assert provider_engine == "test"
+    assert tts.async_default_engine(hass) == "tts.cloud_tts_entity"
-- 
GitLab