diff --git a/CODEOWNERS b/CODEOWNERS index 5aef36e2b950e64891e0b85210e72e0bc37cfa24..c2eba386420498c9246c91a985164cd527318712 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -451,6 +451,8 @@ build.json @home-assistant/supervisor /homeassistant/components/google_assistant_sdk/ @tronikos /tests/components/google_assistant_sdk/ @tronikos /homeassistant/components/google_cloud/ @lufton +/homeassistant/components/google_generative_ai_conversation/ @tronikos +/tests/components/google_generative_ai_conversation/ @tronikos /homeassistant/components/google_mail/ @tkdrob /tests/components/google_mail/ @tkdrob /homeassistant/components/google_sheets/ @tkdrob diff --git a/homeassistant/brands/google.json b/homeassistant/brands/google.json index 0d396ca05ed37ee721de622333c7d96455d758f5..3eb2e9e64f014aa6a0dc0e0e847cdc926a424e6e 100644 --- a/homeassistant/brands/google.json +++ b/homeassistant/brands/google.json @@ -6,6 +6,7 @@ "google_assistant_sdk", "google_cloud", "google_domains", + "google_generative_ai_conversation", "google_mail", "google_maps", "google_pubsub", diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0fac634207d7263a516b366444eda8e27a88f9 --- /dev/null +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -0,0 +1,157 @@ +"""The Google Generative AI Conversation integration.""" +from __future__ import annotations + +from functools import partial +import logging +from typing import Literal + +from google.api_core.exceptions import ClientError +import google.generativeai as palm +from google.generativeai.types.discuss_types import ChatResponse + +from homeassistant.components import conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_API_KEY, MATCH_ALL +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryNotReady, TemplateError +from homeassistant.helpers import intent, template +from homeassistant.util import ulid + +from .const import ( + CONF_CHAT_MODEL, + CONF_PROMPT, + CONF_TEMPERATURE, + CONF_TOP_K, + CONF_TOP_P, + DEFAULT_CHAT_MODEL, + DEFAULT_PROMPT, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_K, + DEFAULT_TOP_P, +) + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up Google Generative AI Conversation from a config entry.""" + palm.configure(api_key=entry.data[CONF_API_KEY]) + + try: + await hass.async_add_executor_job( + partial( + palm.get_model, entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) + ) + ) + except ClientError as err: + if err.reason == "API_KEY_INVALID": + _LOGGER.error("Invalid API key: %s", err) + return False + raise ConfigEntryNotReady(err) from err + + conversation.async_set_agent(hass, entry, GoogleGenerativeAIAgent(hass, entry)) + return True + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload GoogleGenerativeAI.""" + palm.configure(api_key=None) + conversation.async_unset_agent(hass, entry) + return True + + +class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent): + """Google Generative AI conversation agent.""" + + def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + """Initialize the agent.""" + self.hass = hass + self.entry = entry + self.history: dict[str, list[dict]] = {} + + @property + def attribution(self): + """Return the attribution.""" + return { + "name": "Powered by Google Generative AI", + "url": "https://developers.generativeai.google/", + } + + @property + def supported_languages(self) -> list[str] | Literal["*"]: + """Return a list of supported languages.""" + return MATCH_ALL + + async def async_process( + self, user_input: conversation.ConversationInput + ) -> conversation.ConversationResult: + """Process a sentence.""" + raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) + model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) + temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) + top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) + top_k = self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K) + + if user_input.conversation_id in self.history: + conversation_id = user_input.conversation_id + messages = self.history[conversation_id] + else: + conversation_id = ulid.ulid() + messages = [] + + try: + prompt = self._async_generate_prompt(raw_prompt) + except TemplateError as err: + _LOGGER.error("Error rendering prompt: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem with my template: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + messages.append({"author": "0", "content": user_input.text}) + + _LOGGER.debug("Prompt for %s: %s", model, messages) + + try: + chat_response: ChatResponse = await palm.chat_async( + model=model, + context=prompt, + messages=messages, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) + except ClientError as err: + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to Google Generative AI: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + _LOGGER.debug("Response %s", chat_response) + # For some queries the response is empty. In that case don't update history to avoid + # "google.generativeai.types.discuss_types.AuthorError: Authors are not strictly alternating" + if chat_response.last: + self.history[conversation_id] = chat_response.messages + + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_speech(chat_response.last) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + def _async_generate_prompt(self, raw_prompt: str) -> str: + """Generate a prompt for the user.""" + return template.Template(raw_prompt, self.hass).async_render( + { + "ha_name": self.hass.config.location_name, + }, + parse_result=False, + ) diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..94639177a42f64bb28f8f7cc666cd98d1e8ff23a --- /dev/null +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -0,0 +1,165 @@ +"""Config flow for Google Generative AI Conversation integration.""" +from __future__ import annotations + +from functools import partial +import logging +import types +from types import MappingProxyType +from typing import Any + +from google.api_core.exceptions import ClientError +import google.generativeai as palm +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.const import CONF_API_KEY +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResult +from homeassistant.helpers.selector import ( + NumberSelector, + NumberSelectorConfig, + TemplateSelector, +) + +from .const import ( + CONF_CHAT_MODEL, + CONF_PROMPT, + CONF_TEMPERATURE, + CONF_TOP_K, + CONF_TOP_P, + DEFAULT_CHAT_MODEL, + DEFAULT_PROMPT, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_K, + DEFAULT_TOP_P, + DOMAIN, +) + +_LOGGER = logging.getLogger(__name__) + +STEP_USER_DATA_SCHEMA = vol.Schema( + { + vol.Required(CONF_API_KEY): str, + } +) + +DEFAULT_OPTIONS = types.MappingProxyType( + { + CONF_PROMPT: DEFAULT_PROMPT, + CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL, + CONF_TEMPERATURE: DEFAULT_TEMPERATURE, + CONF_TOP_P: DEFAULT_TOP_P, + CONF_TOP_K: DEFAULT_TOP_K, + } +) + + +async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: + """Validate the user input allows us to connect. + + Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user. + """ + palm.configure(api_key=data[CONF_API_KEY]) + await hass.async_add_executor_job(partial(palm.list_models)) + + +class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): + """Handle a config flow for Google Generative AI Conversation.""" + + VERSION = 1 + + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Handle the initial step.""" + if user_input is None: + return self.async_show_form( + step_id="user", data_schema=STEP_USER_DATA_SCHEMA + ) + + errors = {} + + try: + await validate_input(self.hass, user_input) + except ClientError as err: + if err.reason == "API_KEY_INVALID": + errors["base"] = "invalid_auth" + else: + errors["base"] = "cannot_connect" + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Unexpected exception") + errors["base"] = "unknown" + else: + return self.async_create_entry( + title="Google Generative AI Conversation", data=user_input + ) + + return self.async_show_form( + step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors + ) + + @staticmethod + def async_get_options_flow( + config_entry: config_entries.ConfigEntry, + ) -> config_entries.OptionsFlow: + """Create the options flow.""" + return OptionsFlow(config_entry) + + +class OptionsFlow(config_entries.OptionsFlow): + """Google Generative AI config flow options handler.""" + + def __init__(self, config_entry: config_entries.ConfigEntry) -> None: + """Initialize options flow.""" + self.config_entry = config_entry + + async def async_step_init( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Manage the options.""" + if user_input is not None: + return self.async_create_entry( + title="Google Generative AI Conversation", data=user_input + ) + schema = google_generative_ai_config_option_schema(self.config_entry.options) + return self.async_show_form( + step_id="init", + data_schema=vol.Schema(schema), + ) + + +def google_generative_ai_config_option_schema( + options: MappingProxyType[str, Any] +) -> dict: + """Return a schema for Google Generative AI completion options.""" + if not options: + options = DEFAULT_OPTIONS + return { + vol.Optional( + CONF_PROMPT, + description={"suggested_value": options[CONF_PROMPT]}, + default=DEFAULT_PROMPT, + ): TemplateSelector(), + vol.Optional( + CONF_CHAT_MODEL, + description={ + "suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) + }, + default=DEFAULT_CHAT_MODEL, + ): str, + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options[CONF_TEMPERATURE]}, + default=DEFAULT_TEMPERATURE, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TOP_P, + description={"suggested_value": options[CONF_TOP_P]}, + default=DEFAULT_TOP_P, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TOP_K, + description={"suggested_value": options[CONF_TOP_K]}, + default=DEFAULT_TOP_K, + ): int, + } diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py new file mode 100644 index 0000000000000000000000000000000000000000..9664552e436236d7cb0f8190ff301492b446b954 --- /dev/null +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -0,0 +1,33 @@ +"""Constants for the Google Generative AI Conversation integration.""" + +DOMAIN = "google_generative_ai_conversation" +CONF_PROMPT = "prompt" +DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. + +An overview of the areas and the devices in this smart home: +{%- for area in areas() %} + {%- set area_info = namespace(printed=false) %} + {%- for device in area_devices(area) -%} + {%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %} + {%- if not area_info.printed %} + +{{ area_name(area) }}: + {%- set area_info.printed = true %} + {%- endif %} +- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %} + {%- endif %} + {%- endfor %} +{%- endfor %} + +Answer the user's questions about the world truthfully. + +If the user wants to control a device, reject the request and suggest using the Home Assistant app. +""" +CONF_CHAT_MODEL = "chat_model" +DEFAULT_CHAT_MODEL = "models/chat-bison-001" +CONF_TEMPERATURE = "temperature" +DEFAULT_TEMPERATURE = 0.25 +CONF_TOP_P = "top_p" +DEFAULT_TOP_P = 0.95 +CONF_TOP_K = "top_k" +DEFAULT_TOP_K = 40 diff --git a/homeassistant/components/google_generative_ai_conversation/manifest.json b/homeassistant/components/google_generative_ai_conversation/manifest.json new file mode 100644 index 0000000000000000000000000000000000000000..52de921553544e186e36d2bb86b08e2e696d02e2 --- /dev/null +++ b/homeassistant/components/google_generative_ai_conversation/manifest.json @@ -0,0 +1,11 @@ +{ + "domain": "google_generative_ai_conversation", + "name": "Google Generative AI Conversation", + "codeowners": ["@tronikos"], + "config_flow": true, + "dependencies": ["conversation"], + "documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation", + "integration_type": "service", + "iot_class": "cloud_polling", + "requirements": ["google-generativeai==0.1.0rc2"] +} diff --git a/homeassistant/components/google_generative_ai_conversation/strings.json b/homeassistant/components/google_generative_ai_conversation/strings.json new file mode 100644 index 0000000000000000000000000000000000000000..2df5398222c877fbb5fc1abf25d6f444c536c189 --- /dev/null +++ b/homeassistant/components/google_generative_ai_conversation/strings.json @@ -0,0 +1,29 @@ +{ + "config": { + "step": { + "user": { + "data": { + "api_key": "[%key:common::config_flow::data::api_key%]" + } + } + }, + "error": { + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", + "unknown": "[%key:common::config_flow::error::unknown%]" + } + }, + "options": { + "step": { + "init": { + "data": { + "prompt": "Prompt Template", + "model": "Model", + "temperature": "Temperature", + "top_p": "Top P", + "top_k": "Top K" + } + } + } + } +} diff --git a/homeassistant/generated/config_flows.py b/homeassistant/generated/config_flows.py index 82401badcd2fceadbafe96518f458a97a4219013..48c4051bb84143827860f11a2968d65974b1bd21 100644 --- a/homeassistant/generated/config_flows.py +++ b/homeassistant/generated/config_flows.py @@ -166,6 +166,7 @@ FLOWS = { "goodwe", "google", "google_assistant_sdk", + "google_generative_ai_conversation", "google_mail", "google_sheets", "google_travel_time", diff --git a/homeassistant/generated/integrations.json b/homeassistant/generated/integrations.json index 490f3cf9a077676286ecd5504c1a6c99a5b17592..7be4a9d5a1ee693d85cce68403e8a3a6beb7c38c 100644 --- a/homeassistant/generated/integrations.json +++ b/homeassistant/generated/integrations.json @@ -2029,6 +2029,12 @@ "iot_class": "cloud_polling", "name": "Google Domains" }, + "google_generative_ai_conversation": { + "integration_type": "service", + "config_flow": true, + "iot_class": "cloud_polling", + "name": "Google Generative AI Conversation" + }, "google_mail": { "integration_type": "service", "config_flow": true, diff --git a/requirements_all.txt b/requirements_all.txt index ce8f8ae767a4dd56571ddf5ac312e0ca213ea3e1..df86675a7c92c51b01bc3e58162584db1852d64a 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -816,6 +816,9 @@ google-cloud-pubsub==2.13.11 # homeassistant.components.google_cloud google-cloud-texttospeech==2.12.3 +# homeassistant.components.google_generative_ai_conversation +google-generativeai==0.1.0rc2 + # homeassistant.components.nest google-nest-sdm==2.2.4 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index fe1e3107c124ba89df33916b05b3cc52dbc31bc7..1239b3010c9ab3dc2f153f92bbdf36eef4b1b0bb 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -635,6 +635,9 @@ google-api-python-client==2.71.0 # homeassistant.components.google_pubsub google-cloud-pubsub==2.13.11 +# homeassistant.components.google_generative_ai_conversation +google-generativeai==0.1.0rc2 + # homeassistant.components.nest google-nest-sdm==2.2.4 diff --git a/tests/components/google_generative_ai_conversation/__init__.py b/tests/components/google_generative_ai_conversation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f789d9737e30a499d1baba0ff666542c6d8ab46 --- /dev/null +++ b/tests/components/google_generative_ai_conversation/__init__.py @@ -0,0 +1 @@ +"""Tests for the Google Generative AI Conversation integration.""" diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..0a45a991bf8cb1e57a0486ea6b57238c0a6c910e --- /dev/null +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -0,0 +1,31 @@ +"""Tests helpers.""" +from unittest.mock import patch + +import pytest + +from homeassistant.setup import async_setup_component + +from tests.common import MockConfigEntry + + +@pytest.fixture +def mock_config_entry(hass): + """Mock a config entry.""" + entry = MockConfigEntry( + domain="google_generative_ai_conversation", + data={ + "api_key": "bla", + }, + ) + entry.add_to_hass(hass) + return entry + + +@pytest.fixture +async def mock_init_component(hass, mock_config_entry): + """Initialize integration.""" + with patch("google.generativeai.get_model"): + assert await async_setup_component( + hass, "google_generative_ai_conversation", {} + ) + await hass.async_block_till_done() diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr new file mode 100644 index 0000000000000000000000000000000000000000..636a46e42f56cad566b3f2a00378827a3d60838a --- /dev/null +++ b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr @@ -0,0 +1,33 @@ +# serializer version: 1 +# name: test_default_prompt + dict({ + 'context': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Answer the user's questions about the world truthfully. + + If the user wants to control a device, reject the request and suggest using the Home Assistant app. + ''', + 'messages': list([ + dict({ + 'author': '0', + 'content': 'hello', + }), + ]), + 'model': 'models/chat-bison-001', + 'temperature': 0.25, + 'top_k': 40, + 'top_p': 0.95, + }) +# --- diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..0b7072f4ef08c5f83d9b73c4ff7b0077535253b1 --- /dev/null +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -0,0 +1,118 @@ +"""Test the Google Generative AI Conversation config flow.""" +from unittest.mock import patch + +from google.api_core.exceptions import ClientError +from google.rpc.error_details_pb2 import ErrorInfo +import pytest + +from homeassistant import config_entries +from homeassistant.components.google_generative_ai_conversation.const import ( + CONF_CHAT_MODEL, + CONF_TOP_K, + CONF_TOP_P, + DEFAULT_CHAT_MODEL, + DEFAULT_TOP_K, + DEFAULT_TOP_P, + DOMAIN, +) +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResultType + +from tests.common import MockConfigEntry + + +async def test_form(hass: HomeAssistant) -> None: + """Test we get the form.""" + # Pretend we already set up a config entry. + hass.config.components.add("google_generative_ai_conversation") + MockConfigEntry( + domain=DOMAIN, + state=config_entries.ConfigEntryState.LOADED, + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert result["errors"] is None + + with patch( + "homeassistant.components.google_generative_ai_conversation.config_flow.palm.list_models", + ), patch( + "homeassistant.components.google_generative_ai_conversation.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "api_key": "bla", + }, + ) + await hass.async_block_till_done() + + assert result2["type"] == FlowResultType.CREATE_ENTRY + assert result2["data"] == { + "api_key": "bla", + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_options( + hass: HomeAssistant, mock_config_entry, mock_init_component +) -> None: + """Test the options form.""" + options_flow = await hass.config_entries.options.async_init( + mock_config_entry.entry_id + ) + options = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + { + "prompt": "Speak like a pirate", + "temperature": 0.3, + }, + ) + await hass.async_block_till_done() + assert options["type"] == FlowResultType.CREATE_ENTRY + assert options["data"]["prompt"] == "Speak like a pirate" + assert options["data"]["temperature"] == 0.3 + assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL + assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P + assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K + + +@pytest.mark.parametrize( + ("side_effect", "error"), + [ + ( + ClientError(message="some error"), + "cannot_connect", + ), + ( + ClientError( + message="invalid api key", + error_info=ErrorInfo(reason="API_KEY_INVALID"), + ), + "invalid_auth", + ), + (Exception, "unknown"), + ], +) +async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None: + """Test we handle errors.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + with patch( + "homeassistant.components.google_generative_ai_conversation.config_flow.palm.list_models", + side_effect=side_effect, + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "api_key": "bla", + }, + ) + + assert result2["type"] == FlowResultType.FORM + assert result2["errors"] == {"base": error} diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py new file mode 100644 index 0000000000000000000000000000000000000000..7335903b43b4c234d45f8770410facab883989dc --- /dev/null +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -0,0 +1,137 @@ +"""Tests for the Google Generative AI Conversation integration.""" +from unittest.mock import patch + +from google.api_core.exceptions import ClientError +from syrupy.assertion import SnapshotAssertion + +from homeassistant.components import conversation +from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers import area_registry as ar, device_registry as dr, intent + +from tests.common import MockConfigEntry + + +async def test_default_prompt( + hass: HomeAssistant, + mock_init_component, + area_registry: ar.AreaRegistry, + device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, +) -> None: + """Test that the default prompt works.""" + for i in range(3): + area_registry.async_create(f"{i}Empty Area") + + device_registry.async_get_or_create( + config_entry_id="1234", + connections={("test", "1234")}, + name="Test Device", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + ) + for i in range(3): + device_registry.async_get_or_create( + config_entry_id="1234", + connections={("test", f"{i}abcd")}, + name="Test Service", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + entry_type=dr.DeviceEntryType.SERVICE, + ) + device_registry.async_get_or_create( + config_entry_id="1234", + connections={("test", "5678")}, + name="Test Device 2", + manufacturer="Test Manufacturer 2", + model="Device 2", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id="1234", + connections={("test", "9876")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id="1234", + connections={("test", "qwer")}, + name="Test Device 4", + suggested_area="Test Area 2", + ) + device = device_registry.async_get_or_create( + config_entry_id="1234", + connections={("test", "9876-disabled")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + device_registry.async_update_device( + device.id, disabled_by=dr.DeviceEntryDisabler.USER + ) + device_registry.async_get_or_create( + config_entry_id="1234", + connections={("test", "9876-no-name")}, + manufacturer="Test Manufacturer NoName", + model="Test Model NoName", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id="1234", + connections={("test", "9876-integer-values")}, + name=1, + manufacturer=2, + model=3, + suggested_area="Test Area 2", + ) + with patch("google.generativeai.chat_async") as mock_chat: + result = await conversation.async_converse(hass, "hello", None, Context()) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert mock_chat.mock_calls[0][2] == snapshot + + +async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None: + """Test that the default prompt works.""" + with patch("google.generativeai.chat_async", side_effect=ClientError("")): + result = await conversation.async_converse(hass, "hello", None, Context()) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_template_error( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template error handling works.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", + }, + ) + with patch( + "google.generativeai.get_model", + ), patch("google.generativeai.chat_async"): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse(hass, "hello", None, Context()) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_conversation_agent( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test GoogleGenerativeAIAgent.""" + agent = await conversation._get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + assert agent.supported_languages == "*"