diff --git a/.strict-typing b/.strict-typing index 348385f04d590e023cd46a60b0fe6e4e3eb65f04..0d160982a07b5ae4192ef0fda413afa86b129a28 100644 --- a/.strict-typing +++ b/.strict-typing @@ -311,6 +311,7 @@ homeassistant.components.manual.* homeassistant.components.mastodon.* homeassistant.components.matrix.* homeassistant.components.matter.* +homeassistant.components.mcp_server.* homeassistant.components.mealie.* homeassistant.components.media_extractor.* homeassistant.components.media_player.* diff --git a/CODEOWNERS b/CODEOWNERS index d531e1ccebd53eff5c354a0e171fbbb7dee12770..f43cdf457c87bea1485358974d532a1ef165c5dd 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -889,6 +889,8 @@ build.json @home-assistant/supervisor /tests/components/matrix/ @PaarthShah /homeassistant/components/matter/ @home-assistant/matter /tests/components/matter/ @home-assistant/matter +/homeassistant/components/mcp_server/ @allenporter +/tests/components/mcp_server/ @allenporter /homeassistant/components/mealie/ @joostlek @andrew-codechimp /tests/components/mealie/ @joostlek @andrew-codechimp /homeassistant/components/meater/ @Sotolotl @emontnemery diff --git a/homeassistant/components/mcp_server/__init__.py b/homeassistant/components/mcp_server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e523f46228fd66b49892a6e1cb05139f7ad37889 --- /dev/null +++ b/homeassistant/components/mcp_server/__init__.py @@ -0,0 +1,43 @@ +"""The Model Context Protocol Server integration.""" + +from __future__ import annotations + +from homeassistant.core import HomeAssistant +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.typing import ConfigType + +from . import http +from .const import DOMAIN +from .session import SessionManager +from .types import MCPServerConfigEntry + +__all__ = [ + "CONFIG_SCHEMA", + "DOMAIN", + "async_setup", + "async_setup_entry", + "async_unload_entry", +] + +CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up the Model Context Protocol component.""" + http.async_register(hass) + return True + + +async def async_setup_entry(hass: HomeAssistant, entry: MCPServerConfigEntry) -> bool: + """Set up Model Context Protocol Server from a config entry.""" + + entry.runtime_data = SessionManager() + + return True + + +async def async_unload_entry(hass: HomeAssistant, entry: MCPServerConfigEntry) -> bool: + """Unload a config entry.""" + session_manager = entry.runtime_data + session_manager.close() + return True diff --git a/homeassistant/components/mcp_server/config_flow.py b/homeassistant/components/mcp_server/config_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..8d68c6a868a5c89a81640876a882a7cb5c8d89b5 --- /dev/null +++ b/homeassistant/components/mcp_server/config_flow.py @@ -0,0 +1,63 @@ +"""Config flow for the Model Context Protocol Server integration.""" + +from __future__ import annotations + +import logging +from typing import Any + +import voluptuous as vol + +from homeassistant.config_entries import ConfigFlow, ConfigFlowResult +from homeassistant.const import CONF_LLM_HASS_API +from homeassistant.helpers import llm +from homeassistant.helpers.selector import ( + SelectOptionDict, + SelectSelector, + SelectSelectorConfig, +) + +from .const import DOMAIN + +_LOGGER = logging.getLogger(__name__) + +MORE_INFO_URL = "https://www.home-assistant.io/integrations/mcp_server/#configuration" + + +class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN): + """Handle a config flow for Model Context Protocol Server.""" + + VERSION = 1 + + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Handle the initial step.""" + llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)} + + if user_input is not None: + return self.async_create_entry( + title=llm_apis[user_input[CONF_LLM_HASS_API]], data=user_input + ) + + return self.async_show_form( + step_id="user", + data_schema=vol.Schema( + { + vol.Optional( + CONF_LLM_HASS_API, + default=llm.LLM_API_ASSIST, + ): SelectSelector( + SelectSelectorConfig( + options=[ + SelectOptionDict( + label=name, + value=llm_api_id, + ) + for llm_api_id, name in llm_apis.items() + ] + ) + ), + } + ), + description_placeholders={"more_info_url": MORE_INFO_URL}, + ) diff --git a/homeassistant/components/mcp_server/const.py b/homeassistant/components/mcp_server/const.py new file mode 100644 index 0000000000000000000000000000000000000000..1aa81f445a1a289738bc1decce7fe227ed8b0773 --- /dev/null +++ b/homeassistant/components/mcp_server/const.py @@ -0,0 +1,4 @@ +"""Constants for the Model Context Protocol Server integration.""" + +DOMAIN = "mcp_server" +TITLE = "Model Context Protocol Server" diff --git a/homeassistant/components/mcp_server/http.py b/homeassistant/components/mcp_server/http.py new file mode 100644 index 0000000000000000000000000000000000000000..da706d4a73b1018847607533c67b53a3913cf0da --- /dev/null +++ b/homeassistant/components/mcp_server/http.py @@ -0,0 +1,170 @@ +"""Model Context Protocol transport portocol for Server Sent Events (SSE). + +This registers HTTP endpoints that supports SSE as a transport layer +for the Model Context Protocol. There are two HTTP endpoints: + +- /mcp_server/sse: The SSE endpoint that is used to establish a session + with the client and glue to the MCP server. This is used to push responses + to the client. +- /mcp_server/messages: The endpoint that is used by the client to send + POST requests with new requests for the MCP server. The request contains + a session identifier. The response to the client is passed over the SSE + session started on the other endpoint. + +See https://modelcontextprotocol.io/docs/concepts/transports +""" + +import logging + +from aiohttp import web +from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound +from aiohttp_sse import sse_response +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import types + +from homeassistant.components import conversation +from homeassistant.components.http import KEY_HASS, HomeAssistantView +from homeassistant.config_entries import ConfigEntryState +from homeassistant.const import CONF_LLM_HASS_API +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import llm + +from .const import DOMAIN +from .server import create_server +from .session import Session +from .types import MCPServerConfigEntry + +_LOGGER = logging.getLogger(__name__) + +SSE_API = f"/{DOMAIN}/sse" +MESSAGES_API = f"/{DOMAIN}/messages/{{session_id}}" + + +@callback +def async_register(hass: HomeAssistant) -> None: + """Register the websocket API.""" + hass.http.register_view(ModelContextProtocolSSEView()) + hass.http.register_view(ModelContextProtocolMessagesView()) + + +def async_get_config_entry(hass: HomeAssistant) -> MCPServerConfigEntry: + """Get the first enabled MCP server config entry. + + The ConfigEntry contains a reference to the actual MCP server used to + serve the Model Context Protocol. + + Will raise an HTTP error if the expected configuration is not present. + """ + config_entries: list[MCPServerConfigEntry] = [ + config_entry + for config_entry in hass.config_entries.async_entries(DOMAIN) + if config_entry.state == ConfigEntryState.LOADED + ] + if not config_entries: + raise HTTPNotFound(body="Model Context Protocol server is not configured") + if len(config_entries) > 1: + raise HTTPNotFound(body="Found multiple Model Context Protocol configurations") + return config_entries[0] + + +class ModelContextProtocolSSEView(HomeAssistantView): + """Model Context Protocol SSE endpoint.""" + + name = f"{DOMAIN}:sse" + url = SSE_API + + async def get(self, request: web.Request) -> web.StreamResponse: + """Process SSE messages for the Model Context Protocol. + + This is a long running request for the lifetime of the client session + and is the primary transport layer between the client and server. + + Pairs of buffered streams act as a bridge between the transport protocol + (SSE over HTTP views) and the Model Context Protocol. The MCP SDK + manages all protocol details and invokes commands on our MCP server. + """ + hass = request.app[KEY_HASS] + entry = async_get_config_entry(hass) + session_manager = entry.runtime_data + + context = llm.LLMContext( + platform=DOMAIN, + context=self.context(request), + user_prompt=None, + language="*", + assistant=conversation.DOMAIN, + device_id=None, + ) + llm_api_id = entry.data[CONF_LLM_HASS_API] + server = await create_server(hass, llm_api_id, context) + options = await hass.async_add_executor_job( + server.create_initialization_options # Reads package for version info + ) + + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async with ( + sse_response(request) as response, + session_manager.create(Session(read_stream_writer)) as session_id, + ): + session_uri = MESSAGES_API.format(session_id=session_id) + _LOGGER.debug("Sending SSE endpoint: %s", session_uri) + await response.send(session_uri, event="endpoint") + + async def sse_reader() -> None: + """Forward MCP server responses to the client.""" + async for message in write_stream_reader: + _LOGGER.debug("Sending SSE message: %s", message) + await response.send( + message.model_dump_json(by_alias=True, exclude_none=True), + event="message", + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(sse_reader) + await server.run(read_stream, write_stream, options) + return response + + +class ModelContextProtocolMessagesView(HomeAssistantView): + """Model Context Protocol messages endpoint.""" + + name = f"{DOMAIN}:messages" + url = MESSAGES_API + + async def post( + self, + request: web.Request, + session_id: str, + ) -> web.StreamResponse: + """Process incoming messages for the Model Context Protocol. + + The request passes a session ID which is used to identify the original + SSE connection. This view parses incoming messagess from the transport + layer then writes them to the MCP server stream for the session. + """ + hass = request.app[KEY_HASS] + config_entry = async_get_config_entry(hass) + + session_manager = config_entry.runtime_data + if (session := session_manager.get(session_id)) is None: + _LOGGER.info("Could not find session ID: '%s'", session_id) + raise HTTPNotFound(body=f"Could not find session ID '{session_id}'") + + json_data = await request.json() + try: + message = types.JSONRPCMessage.model_validate(json_data) + except ValueError as err: + _LOGGER.info("Failed to parse message: %s", err) + raise HTTPBadRequest(body="Could not parse message") from err + + _LOGGER.debug("Received client message: %s", message) + await session.read_stream_writer.send(message) + return web.Response(status=200) diff --git a/homeassistant/components/mcp_server/manifest.json b/homeassistant/components/mcp_server/manifest.json new file mode 100644 index 0000000000000000000000000000000000000000..755d2c39065bd46ea08ad33d7ab7069dd5d79daf --- /dev/null +++ b/homeassistant/components/mcp_server/manifest.json @@ -0,0 +1,13 @@ +{ + "domain": "mcp_server", + "name": "Model Context Protocol Server", + "codeowners": ["@allenporter"], + "config_flow": true, + "dependencies": ["homeassistant", "http", "conversation"], + "documentation": "https://www.home-assistant.io/integrations/mcp_server", + "integration_type": "service", + "iot_class": "local_push", + "quality_scale": "silver", + "requirements": ["mcp==1.1.2", "aiohttp_sse==2.2.0", "anyio==4.7.0"], + "single_config_entry": true +} diff --git a/homeassistant/components/mcp_server/quality_scale.yaml b/homeassistant/components/mcp_server/quality_scale.yaml new file mode 100644 index 0000000000000000000000000000000000000000..546b4147285a28d5a08e614dca89a4c9c4d2b3fe --- /dev/null +++ b/homeassistant/components/mcp_server/quality_scale.yaml @@ -0,0 +1,118 @@ +rules: + # Bronze + action-setup: + status: exempt + comment: Service does not register actions + appropriate-polling: + status: exempt + comment: Service is not polling + brands: done + common-modules: + status: exempt + comment: Service does not have entities or coordinators + config-flow-test-coverage: done + config-flow: done + dependency-transparency: done + docs-actions: + status: exempt + comment: Service does not register actions + docs-high-level-description: done + docs-installation-instructions: done + docs-removal-instructions: done + entity-event-setup: + status: exempt + comment: Service does not subscribe to events + entity-unique-id: + status: exempt + comment: Service does not have entities + has-entity-name: + status: exempt + comment: Service does not have entities + runtime-data: + status: exempt + comment: No configuration state is used by the integration + test-before-configure: + status: exempt + comment: Service does not a connection + test-before-setup: + status: exempt + comment: Service does not a connection + unique-config-entry: + status: done + comment: Integration requires a single config entry. + + # Silver + action-exceptions: + status: exempt + comment: Service does not register actions + config-entry-unloading: done + docs-configuration-parameters: done + docs-installation-parameters: done + entity-unavailable: + status: exempt + comment: Service does not have entities + integration-owner: done + log-when-unavailable: + status: exempt + comment: Service does not have entities + parallel-updates: + status: exempt + comment: Service does not have entities + reauthentication-flow: + status: exempt + comment: Service does not require authentication + test-coverage: done + + # Gold + devices: + status: exempt + comment: Service does not have entities + diagnostics: todo + discovery-update-info: + status: exempt + comment: Service does not support discovery + discovery: + status: exempt + comment: Service does not support discovery + docs-data-update: done + docs-examples: done + docs-known-limitations: done + docs-supported-devices: done + docs-supported-functions: done + docs-troubleshooting: todo + docs-use-cases: done + dynamic-devices: + status: exempt + comment: Service does not support devices + entity-category: + status: exempt + comment: Service does not have entities + entity-device-class: + status: exempt + comment: Service does not have entities + entity-disabled-by-default: + status: exempt + comment: Service does not have entities + entity-translations: + status: exempt + comment: Service does not have entities + exception-translations: todo + icon-translations: + status: exempt + comment: Service does not have entities + reconfiguration-flow: todo + repair-issues: + status: exempt + comment: Service does not have anything to repair + stale-devices: + status: exempt + comment: Service does not have devices + + # Platinum + async-dependency: + status: exempt + comment: Service does not communicate with devices + inject-websession: + status: exempt + comment: Service does not communicate with devices + strict-typing: done diff --git a/homeassistant/components/mcp_server/server.py b/homeassistant/components/mcp_server/server.py new file mode 100644 index 0000000000000000000000000000000000000000..a52a0f92c0befc0ccad974030019aca93ea5df5e --- /dev/null +++ b/homeassistant/components/mcp_server/server.py @@ -0,0 +1,77 @@ +"""The Model Context Protocol Server implementation. + +The Model Context Protocol python sdk defines a Server API that provides the +MCP message handling logic and error handling. The server implementation provided +here is independent of the lower level transport protocol. + +See https://modelcontextprotocol.io/docs/concepts/architecture#implementation-example +""" + +from collections.abc import Callable, Sequence +import json +import logging +from typing import Any + +from mcp import types +from mcp.server import Server +import voluptuous as vol +from voluptuous_openapi import convert + +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import llm + +_LOGGER = logging.getLogger(__name__) + + +def _format_tool( + tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None +) -> types.Tool: + """Format tool specification.""" + input_schema = convert(tool.parameters, custom_serializer=custom_serializer) + return types.Tool( + name=tool.name, + description=tool.description or "", + inputSchema={ + "type": "object", + "properties": input_schema["properties"], + }, + ) + + +async def create_server( + hass: HomeAssistant, llm_api_id: str, llm_context: llm.LLMContext +) -> Server: + """Create a new Model Context Protocol Server. + + A Model Context Protocol Server object is associated with a single session. + The MCP SDK handles the details of the protocol. + """ + + server = Server("home-assistant") + + @server.list_tools() # type: ignore[no-untyped-call, misc] + async def list_tools() -> list[types.Tool]: + """List available time tools.""" + llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) + return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools] + + @server.call_tool() # type: ignore[no-untyped-call, misc] + async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]: + """Handle calling tools.""" + llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) + tool_input = llm.ToolInput(tool_name=name, tool_args=arguments) + _LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args) + + try: + tool_response = await llm_api.async_call_tool(tool_input) + except (HomeAssistantError, vol.Invalid) as e: + raise HomeAssistantError(f"Error calling tool: {e}") from e + return [ + types.TextContent( + type="text", + text=json.dumps(tool_response), + ) + ] + + return server diff --git a/homeassistant/components/mcp_server/session.py b/homeassistant/components/mcp_server/session.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6622de9f7e544436a43baf5b196b8f2ad1e738 --- /dev/null +++ b/homeassistant/components/mcp_server/session.py @@ -0,0 +1,60 @@ +"""Model Context Protocol sessions. + +A session is a long-lived connection between the client and server that is used +to exchange messages. The server pushes messages to the client over the session +and the client sends messages to the server over the session. +""" + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from dataclasses import dataclass +import logging + +from anyio.streams.memory import MemoryObjectSendStream +from mcp import types + +from homeassistant.util import ulid + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class Session: + """A session for the Model Context Protocol.""" + + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + + +class SessionManager: + """Manage SSE sessions for the MCP transport layer. + + This class is used to manage the lifecycle of SSE sessions. It is responsible for + creating new sessions, resuming existing sessions, and closing sessions. + """ + + def __init__(self) -> None: + """Initialize the SSE server transport.""" + self._sessions: dict[str, Session] = {} + + @asynccontextmanager + async def create(self, session: Session) -> AsyncGenerator[str]: + """Context manager to create a new session ID and close when done.""" + session_id = ulid.ulid_now() + _LOGGER.debug("Creating session: %s", session_id) + self._sessions[session_id] = session + try: + yield session_id + finally: + _LOGGER.debug("Closing session: %s", session_id) + if session_id in self._sessions: # close() may have already been called + self._sessions.pop(session_id) + + def get(self, session_id: str) -> Session | None: + """Get an existing session.""" + return self._sessions.get(session_id) + + def close(self) -> None: + """Close any open sessions.""" + for session in self._sessions.values(): + session.read_stream_writer.close() + self._sessions.clear() diff --git a/homeassistant/components/mcp_server/strings.json b/homeassistant/components/mcp_server/strings.json new file mode 100644 index 0000000000000000000000000000000000000000..fbd14038ddc34400784c5578d1bbf30cd467503e --- /dev/null +++ b/homeassistant/components/mcp_server/strings.json @@ -0,0 +1,18 @@ +{ + "config": { + "step": { + "user": { + "description": "See the [integration documentation]({more_info_url}) for setup instructions.", + "data": { + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" + }, + "data_description": { + "llm_hass_api": "The method for controling Home Assistant to expose with the Model Context Protocol." + } + } + }, + "abort": { + "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" + } + } +} diff --git a/homeassistant/components/mcp_server/types.py b/homeassistant/components/mcp_server/types.py new file mode 100644 index 0000000000000000000000000000000000000000..56ce0469e255380e01caf6733a20f9dc1a3f80b5 --- /dev/null +++ b/homeassistant/components/mcp_server/types.py @@ -0,0 +1,7 @@ +"""Types for the MCP server integration.""" + +from homeassistant.config_entries import ConfigEntry + +from .session import SessionManager + +type MCPServerConfigEntry = ConfigEntry[SessionManager] diff --git a/homeassistant/generated/config_flows.py b/homeassistant/generated/config_flows.py index d33198cb023d47b9bb87cba30d5de27642fdc6d4..f3e82d4d085adce3457f543edc6f5c58c6f7d274 100644 --- a/homeassistant/generated/config_flows.py +++ b/homeassistant/generated/config_flows.py @@ -356,6 +356,7 @@ FLOWS = { "mailgun", "mastodon", "matter", + "mcp_server", "mealie", "meater", "medcom_ble", diff --git a/homeassistant/generated/integrations.json b/homeassistant/generated/integrations.json index 5f0fdc0618f7e7e341c6a5247212a1737f7cc989..8343b7fde9d7da2f4cfb77e8df97742a2c1df125 100644 --- a/homeassistant/generated/integrations.json +++ b/homeassistant/generated/integrations.json @@ -3590,6 +3590,13 @@ "config_flow": true, "iot_class": "local_push" }, + "mcp_server": { + "name": "Model Context Protocol Server", + "integration_type": "service", + "config_flow": true, + "iot_class": "local_push", + "single_config_entry": true + }, "mealie": { "name": "Mealie", "integration_type": "service", diff --git a/mypy.ini b/mypy.ini index f4a0a67a6c7b696bb54fed875d003078eeff1fe3..8600e5ba165eb371d2a0a83c95574c91b72a0980 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2866,6 +2866,16 @@ disallow_untyped_defs = true warn_return_any = true warn_unreachable = true +[mypy-homeassistant.components.mcp_server.*] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +warn_return_any = true +warn_unreachable = true + [mypy-homeassistant.components.mealie.*] check_untyped_defs = true disallow_incomplete_defs = true diff --git a/requirements_all.txt b/requirements_all.txt index 30127795e9df8eca3df133c0601d8d22c820ca12..6c8ba57e7dc347e6ea84a91c06535c3458fd294d 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -266,6 +266,9 @@ aiohasupervisor==0.2.2b5 # homeassistant.components.homekit_controller aiohomekit==3.2.7 +# homeassistant.components.mcp_server +aiohttp_sse==2.2.0 + # homeassistant.components.hue aiohue==4.7.3 @@ -466,6 +469,9 @@ anthemav==1.4.1 # homeassistant.components.anthropic anthropic==0.31.2 +# homeassistant.components.mcp_server +anyio==4.7.0 + # homeassistant.components.weatherkit apple_weatherkit==1.1.3 @@ -1355,6 +1361,9 @@ maxcube-api==0.4.3 # homeassistant.components.mythicbeastsdns mbddns==0.1.2 +# homeassistant.components.mcp_server +mcp==1.1.2 + # homeassistant.components.minecraft_server mcstatus==11.1.1 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index dc84da16f7a7614a36d1efdf255a955ad4b0e63e..c1af9f75787f43d47471f1b042d780243d8fbc69 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -251,6 +251,9 @@ aiohasupervisor==0.2.2b5 # homeassistant.components.homekit_controller aiohomekit==3.2.7 +# homeassistant.components.mcp_server +aiohttp_sse==2.2.0 + # homeassistant.components.hue aiohue==4.7.3 @@ -439,6 +442,9 @@ anthemav==1.4.1 # homeassistant.components.anthropic anthropic==0.31.2 +# homeassistant.components.mcp_server +anyio==4.7.0 + # homeassistant.components.weatherkit apple_weatherkit==1.1.3 @@ -1133,6 +1139,9 @@ maxcube-api==0.4.3 # homeassistant.components.mythicbeastsdns mbddns==0.1.2 +# homeassistant.components.mcp_server +mcp==1.1.2 + # homeassistant.components.minecraft_server mcstatus==11.1.1 diff --git a/tests/components/mcp_server/__init__.py b/tests/components/mcp_server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6bdcc66de17a72c26e63044974bbe9c9bbe5848e --- /dev/null +++ b/tests/components/mcp_server/__init__.py @@ -0,0 +1 @@ +"""Tests for the Model Context Protocol Server integration.""" diff --git a/tests/components/mcp_server/conftest.py b/tests/components/mcp_server/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..149073f3645077db2cf8ff6bfea215ac64807440 --- /dev/null +++ b/tests/components/mcp_server/conftest.py @@ -0,0 +1,35 @@ +"""Common fixtures for the Model Context Protocol Server tests.""" + +from collections.abc import Generator +from unittest.mock import AsyncMock, patch + +import pytest + +from homeassistant.components.mcp_server.const import DOMAIN +from homeassistant.const import CONF_LLM_HASS_API +from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm + +from tests.common import MockConfigEntry + + +@pytest.fixture +def mock_setup_entry() -> Generator[AsyncMock]: + """Override async_setup_entry.""" + with patch( + "homeassistant.components.mcp_server.async_setup_entry", return_value=True + ) as mock_setup_entry: + yield mock_setup_entry + + +@pytest.fixture(name="config_entry") +def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: + """Fixture to load the integration.""" + config_entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + }, + ) + config_entry.add_to_hass(hass) + return config_entry diff --git a/tests/components/mcp_server/test_config_flow.py b/tests/components/mcp_server/test_config_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9f5bee663b11d3c34d23193c9ee27d41f30163 --- /dev/null +++ b/tests/components/mcp_server/test_config_flow.py @@ -0,0 +1,41 @@ +"""Test the Model Context Protocol Server config flow.""" + +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from homeassistant import config_entries +from homeassistant.components.mcp_server.const import DOMAIN +from homeassistant.const import CONF_LLM_HASS_API +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResultType + + +@pytest.mark.parametrize( + "params", + [ + {}, + {CONF_LLM_HASS_API: "assist"}, + ], +) +async def test_form( + hass: HomeAssistant, mock_setup_entry: AsyncMock, params: dict[str, Any] +) -> None: + """Test we get the form.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert not result["errors"] + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + params, + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == "Assist" + assert len(mock_setup_entry.mock_calls) == 1 + assert result["data"] == {CONF_LLM_HASS_API: "assist"} diff --git a/tests/components/mcp_server/test_http.py b/tests/components/mcp_server/test_http.py new file mode 100644 index 0000000000000000000000000000000000000000..78f1364502dd369b1476ad53b164a3565096a91c --- /dev/null +++ b/tests/components/mcp_server/test_http.py @@ -0,0 +1,356 @@ +"""Test the Model Context Protocol Server init module.""" + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from http import HTTPStatus +import json +import logging + +import aiohttp +import mcp +import mcp.client.session +import mcp.client.sse +import pytest + +from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN +from homeassistant.components.homeassistant.exposed_entities import async_expose_entity +from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN +from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API +from homeassistant.config_entries import ConfigEntryState +from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON +from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.setup import async_setup_component + +from tests.common import MockConfigEntry, setup_test_component_platform +from tests.components.light.common import MockLight +from tests.typing import ClientSessionGenerator + +_LOGGER = logging.getLogger(__name__) + +TEST_ENTITY = "light.kitchen" +INITIALIZE_MESSAGE = { + "jsonrpc": "2.0", + "id": "request-id-1", + "method": "initialize", + "params": { + "protocolVersion": "1.0", + "capabilities": {}, + "clientInfo": { + "name": "test", + "version": "1", + }, + }, +} +EVENT_PREFIX = "event: " +DATA_PREFIX = "data: " + + +@pytest.fixture +async def setup_integration(hass: HomeAssistant, config_entry: MockConfigEntry) -> None: + """Set up the config entry.""" + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.LOADED + + +@pytest.fixture(autouse=True) +async def mock_entities( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, + setup_integration: None, +) -> None: + """Fixture to expose entities to the conversation agent.""" + entity = MockLight("kitchen", STATE_OFF) + entity.entity_id = TEST_ENTITY + setup_test_component_platform(hass, LIGHT_DOMAIN, [entity]) + + assert await async_setup_component( + hass, + LIGHT_DOMAIN, + {LIGHT_DOMAIN: [{"platform": "test"}]}, + ) + + async_expose_entity(hass, CONVERSATION_DOMAIN, TEST_ENTITY, True) + + +async def sse_response_reader( + response: aiohttp.ClientResponse, +) -> AsyncGenerator[tuple[str, str]]: + """Read SSE responses from the server and emit event messages. + + SSE responses are formatted as: + event: event-name + data: event-data + and this function emits each event-name and event-data as a tuple. + """ + it = aiter(response.content) + while True: + line = (await anext(it)).decode() + if not line.startswith(EVENT_PREFIX): + raise ValueError("Expected event") + event = line[len(EVENT_PREFIX) :].strip() + line = (await anext(it)).decode() + if not line.startswith(DATA_PREFIX): + raise ValueError("Expected data") + data = line[len(DATA_PREFIX) :].strip() + line = (await anext(it)).decode() + assert line == "\r\n" + yield event, data + + +async def test_http_sse( + hass: HomeAssistant, + setup_integration: None, + hass_client: ClientSessionGenerator, +) -> None: + """Test SSE endpoint can be used to receive MCP messages.""" + + client = await hass_client() + + # Start an SSE session + response = await client.get(SSE_API) + assert response.status == HTTPStatus.OK + + # Decode a single SSE response that sends the messages endpoint + reader = sse_response_reader(response) + event, endpoint_url = await anext(reader) + assert event == "endpoint" + + # Send an initialize message on the messages endpoint + response = await client.post(endpoint_url, json=INITIALIZE_MESSAGE) + assert response.status == HTTPStatus.OK + + # Decode the initialize response event message from the SSE stream + event, data = await anext(reader) + assert event == "message" + message = json.loads(data) + assert message.get("jsonrpc") == "2.0" + assert message.get("id") == "request-id-1" + assert "serverInfo" in message.get("result", {}) + assert "protocolVersion" in message.get("result", {}) + + +async def test_http_messages_missing_session_id( + hass: HomeAssistant, + setup_integration: None, + hass_client: ClientSessionGenerator, +) -> None: + """Test the tools list endpoint.""" + + client = await hass_client() + response = await client.post(MESSAGES_API.format(session_id="invalid-session-id")) + assert response.status == HTTPStatus.NOT_FOUND + response_data = await response.text() + assert response_data == "Could not find session ID 'invalid-session-id'" + + +async def test_http_messages_invalid_message_format( + hass: HomeAssistant, + setup_integration: None, + hass_client: ClientSessionGenerator, +) -> None: + """Test the tools list endpoint.""" + + client = await hass_client() + response = await client.get(SSE_API) + assert response.status == HTTPStatus.OK + reader = sse_response_reader(response) + event, endpoint_url = await anext(reader) + assert event == "endpoint" + + response = await client.post(endpoint_url, json={"invalid": "message"}) + assert response.status == HTTPStatus.BAD_REQUEST + response_data = await response.text() + assert response_data == "Could not parse message" + + +async def test_http_sse_multiple_config_entries( + hass: HomeAssistant, + setup_integration: None, + hass_client: ClientSessionGenerator, +) -> None: + """Test the SSE endpoint will fail with multiple config entries. + + This cannot happen in practice as the integration only supports a single + config entry, but this is added for test coverage. + """ + + config_entry = MockConfigEntry( + domain="mcp_server", data={CONF_LLM_HASS_API: "llm-api-id"} + ) + config_entry.add_to_hass(hass) + await hass.config_entries.async_setup(config_entry.entry_id) + + client = await hass_client() + + # Attempt to start an SSE session will fail + response = await client.get(SSE_API) + assert response.status == HTTPStatus.NOT_FOUND + response_data = await response.text() + assert "Found multiple Model Context Protocol" in response_data + + +async def test_http_sse_no_config_entry( + hass: HomeAssistant, + setup_integration: None, + config_entry: MockConfigEntry, + hass_client: ClientSessionGenerator, +) -> None: + """Test the SSE endpoint fails with a missing config entry.""" + + await hass.config_entries.async_unload(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.NOT_LOADED + + client = await hass_client() + + # Start an SSE session + response = await client.get(SSE_API) + assert response.status == HTTPStatus.NOT_FOUND + response_data = await response.text() + assert "Model Context Protocol server is not configured" in response_data + + +async def test_http_messages_no_config_entry( + hass: HomeAssistant, + setup_integration: None, + config_entry: MockConfigEntry, + hass_client: ClientSessionGenerator, +) -> None: + """Test the message endpoint will fail if the config entry is unloaded.""" + + client = await hass_client() + + # Start an SSE session + response = await client.get(SSE_API) + assert response.status == HTTPStatus.OK + reader = sse_response_reader(response) + event, endpoint_url = await anext(reader) + assert event == "endpoint" + + # Invalidate the session by unloading the config entry + await hass.config_entries.async_unload(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.NOT_LOADED + + # Reload the config entry and ensure the session is not found + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.LOADED + + response = await client.post(endpoint_url, json=INITIALIZE_MESSAGE) + assert response.status == HTTPStatus.NOT_FOUND + response_data = await response.text() + assert "Could not find session ID" in response_data + + +async def test_http_requires_authentication( + hass: HomeAssistant, + setup_integration: None, + hass_client_no_auth: ClientSessionGenerator, +) -> None: + """Test the SSE endpoint requires authentication.""" + + client = await hass_client_no_auth() + + response = await client.get(SSE_API) + assert response.status == HTTPStatus.UNAUTHORIZED + + response = await client.post(MESSAGES_API.format(session_id="session-id")) + assert response.status == HTTPStatus.UNAUTHORIZED + + +@pytest.fixture +async def mcp_sse_url(hass_client: ClientSessionGenerator) -> str: + """Fixture to get the MCP integration SSE URL.""" + client = await hass_client() + return str(client.make_url(SSE_API)) + + +@asynccontextmanager +async def mcp_session( + mcp_sse_url: str, + hass_supervisor_access_token: str, +) -> AsyncGenerator[mcp.client.session.ClientSession]: + """Create an MCP session.""" + + headers = {"Authorization": f"Bearer {hass_supervisor_access_token}"} + + async with ( + mcp.client.sse.sse_client(mcp_sse_url, headers=headers) as streams, + mcp.client.session.ClientSession(*streams) as session, + ): + await session.initialize() + yield session + + +async def test_mcp_tools_list( + hass: HomeAssistant, + setup_integration: None, + mcp_sse_url: str, + hass_supervisor_access_token: str, +) -> None: + """Test the tools list endpoint.""" + + async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: + result = await session.list_tools() + + # Pick a single arbitrary tool and test that description and parameters + # are converted correctly. + tool = next(iter(tool for tool in result.tools if tool.name == "HassTurnOn")) + assert tool.name == "HassTurnOn" + assert tool.description == "Turns on/opens a device or entity" + assert tool.inputSchema + assert tool.inputSchema.get("type") == "object" + properties = tool.inputSchema.get("properties") + assert properties.get("name") == {"type": "string"} + + +async def test_mcp_tool_call( + hass: HomeAssistant, + setup_integration: None, + mcp_sse_url: str, + hass_supervisor_access_token: str, +) -> None: + """Test the tool call endpoint.""" + + state = hass.states.get("light.kitchen") + assert state + assert state.state == STATE_OFF + + async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: + result = await session.call_tool( + name="HassTurnOn", + arguments={"name": "kitchen"}, + ) + + assert not result.isError + assert len(result.content) == 1 + assert result.content[0].type == "text" + # The content is the raw tool call payload + content = json.loads(result.content[0].text) + assert content.get("data", {}).get("success") + assert not content.get("data", {}).get("failed") + + # Verify tool call invocation + state = hass.states.get("light.kitchen") + assert state + assert state.state == STATE_ON + + +async def test_mcp_tool_call_failed( + hass: HomeAssistant, + setup_integration: None, + mcp_sse_url: str, + hass_supervisor_access_token: str, +) -> None: + """Test the tool call endpoint with a failure.""" + + async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: + result = await session.call_tool( + name="HassTurnOn", + arguments={"name": "backyard"}, + ) + + assert result.isError + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert "Error calling tool" in result.content[0].text diff --git a/tests/components/mcp_server/test_init.py b/tests/components/mcp_server/test_init.py new file mode 100644 index 0000000000000000000000000000000000000000..af6a8a55e595f47180ceaa83efddf62ab7d133e3 --- /dev/null +++ b/tests/components/mcp_server/test_init.py @@ -0,0 +1,15 @@ +"""Test the Model Context Protocol Server init module.""" + +from homeassistant.config_entries import ConfigEntryState +from homeassistant.core import HomeAssistant + +from tests.common import MockConfigEntry + + +async def test_init(hass: HomeAssistant, config_entry: MockConfigEntry) -> None: + """Test the integration is initialized and can be unloaded cleanly.""" + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.LOADED + + await hass.config_entries.async_unload(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.NOT_LOADED