diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index c0fbfae64443da76bbabfab9d9083d78e1820314..2c83720f930c898697717f93c51323b3bffb6c9b 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -5,22 +5,18 @@ from __future__ import annotations from collections.abc import Callable import json import logging -import time from typing import Any, Literal import ollama -import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation -from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError, TemplateError -from homeassistant.helpers import intent, llm, template +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import chat_session, intent, llm from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.util import ulid as ulid_util from .const import ( CONF_KEEP_ALIVE, @@ -32,7 +28,6 @@ from .const import ( DEFAULT_MAX_HISTORY, DEFAULT_NUM_CTX, DOMAIN, - MAX_HISTORY_SECONDS, ) from .models import MessageHistory, MessageRole @@ -93,6 +88,44 @@ def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]: return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v} +def _convert_content( + chat_content: conversation.Content + | conversation.ToolResultContent + | conversation.AssistantContent, +) -> ollama.Message: + """Create tool response content.""" + if isinstance(chat_content, conversation.ToolResultContent): + return ollama.Message( + role=MessageRole.TOOL.value, + content=json.dumps(chat_content.tool_result), + ) + if isinstance(chat_content, conversation.AssistantContent): + return ollama.Message( + role=MessageRole.ASSISTANT.value, + content=chat_content.content, + tool_calls=[ + ollama.Message.ToolCall( + function=ollama.Message.ToolCall.Function( + name=tool_call.tool_name, + arguments=tool_call.tool_args, + ) + ) + for tool_call in chat_content.tool_calls or () + ], + ) + if isinstance(chat_content, conversation.UserContent): + return ollama.Message( + role=MessageRole.USER.value, + content=chat_content.content, + ) + if isinstance(chat_content, conversation.SystemContent): + return ollama.Message( + role=MessageRole.SYSTEM.value, + content=chat_content.content, + ) + raise ValueError(f"Unexpected content type: {type(chat_content)}") + + class OllamaConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -105,7 +138,6 @@ class OllamaConversationEntity( self.entry = entry # conversation id -> message history - self._history: dict[str, MessageHistory] = {} self._attr_name = entry.title self._attr_unique_id = entry.entry_id if self.entry.options.get(CONF_LLM_HASS_API): @@ -138,121 +170,48 @@ class OllamaConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" + with ( + chat_session.async_get_chat_session( + self.hass, user_input.conversation_id + ) as session, + conversation.async_get_chat_log(self.hass, session, user_input) as chat_log, + ): + return await self._async_handle_message(user_input, chat_log) + + async def _async_handle_message( + self, + user_input: conversation.ConversationInput, + chat_log: conversation.ChatLog, + ) -> conversation.ConversationResult: + """Call the API.""" settings = {**self.entry.data, **self.entry.options} client = self.hass.data[DOMAIN][self.entry.entry_id] - conversation_id = user_input.conversation_id or ulid_util.ulid_now() model = settings[CONF_MODEL] - intent_response = intent.IntentResponse(language=user_input.language) - llm_api: llm.APIInstance | None = None - tools: list[dict[str, Any]] | None = None - user_name: str | None = None - llm_context = llm.LLMContext( - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, - ) - - if settings.get(CONF_LLM_HASS_API): - try: - llm_api = await llm.async_get_api( - self.hass, - settings[CONF_LLM_HASS_API], - llm_context, - ) - except HomeAssistantError as err: - _LOGGER.error("Error getting LLM API: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Error preparing LLM API: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=user_input.conversation_id - ) - tools = [ - _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools - ] - - if ( - user_input.context - and user_input.context.user_id - and ( - user := await self.hass.auth.async_get_user(user_input.context.user_id) - ) - ): - user_name = user.name - - # Look up message history - message_history: MessageHistory | None = None - message_history = self._history.get(conversation_id) - if message_history is None: - # New history - # - # Render prompt and error out early if there's a problem - try: - prompt_parts = [ - template.Template( - llm.BASE_PROMPT - + settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), - self.hass, - ).async_render( - { - "ha_name": self.hass.config.location_name, - "user_name": user_name, - "llm_context": llm_context, - }, - parse_result=False, - ) - ] - - except TemplateError as err: - _LOGGER.error("Error rendering prompt: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem generating my prompt: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - if llm_api: - prompt_parts.append(llm_api.api_prompt) - prompt = "\n".join(prompt_parts) - _LOGGER.debug("Prompt: %s", prompt) - _LOGGER.debug("Tools: %s", tools) - - message_history = MessageHistory( - timestamp=time.monotonic(), - messages=[ - ollama.Message(role=MessageRole.SYSTEM.value, content=prompt) - ], + try: + await chat_log.async_update_llm_data( + DOMAIN, + user_input, + settings.get(CONF_LLM_HASS_API), + settings.get(CONF_PROMPT), ) - self._history[conversation_id] = message_history - else: - # Bump timestamp so this conversation won't get cleaned up - message_history.timestamp = time.monotonic() + except conversation.ConverseError as err: + return err.as_conversation_result() - # Clean up old histories - self._prune_old_histories() + tools: list[dict[str, Any]] | None = None + if chat_log.llm_api: + tools = [ + _format_tool(tool, chat_log.llm_api.custom_serializer) + for tool in chat_log.llm_api.tools + ] - # Trim this message history to keep a maximum number of *user* messages + message_history: MessageHistory = MessageHistory( + [_convert_content(content) for content in chat_log.content] + ) max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) self._trim_history(message_history, max_messages) - # Add new user message - message_history.messages.append( - ollama.Message(role=MessageRole.USER.value, content=user_input.text) - ) - - trace.async_conversation_trace_append( - trace.ConversationTraceEventType.AGENT_DETAIL, - {"messages": message_history.messages}, - ) - # Get response # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): @@ -269,77 +228,75 @@ class OllamaConversationEntity( ) except (ollama.RequestError, ollama.ResponseError) as err: _LOGGER.error("Unexpected error talking to Ollama server: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to the Ollama server: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) + raise HomeAssistantError( + f"Sorry, I had a problem talking to the Ollama server: {err}" + ) from err response_message = response["message"] + content = response_message.get("content") + tool_calls = response_message.get("tool_calls") message_history.messages.append( ollama.Message( role=response_message["role"], - content=response_message.get("content"), - tool_calls=response_message.get("tool_calls"), + content=content, + tool_calls=tool_calls, ) ) - - tool_calls = response_message.get("tool_calls") - if not tool_calls or not llm_api: - break - - for tool_call in tool_calls: - tool_input = llm.ToolInput( + tool_inputs = [ + llm.ToolInput( tool_name=tool_call["function"]["name"], tool_args=_parse_tool_args(tool_call["function"]["arguments"]), ) - _LOGGER.debug( - "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args - ) - - try: - tool_response = await llm_api.async_call_tool(tool_input) - except (HomeAssistantError, vol.Invalid) as e: - tool_response = {"error": type(e).__name__} - if str(e): - tool_response["error_text"] = str(e) + for tool_call in tool_calls or () + ] - _LOGGER.debug("Tool response: %s", tool_response) - message_history.messages.append( + message_history.messages.extend( + [ ollama.Message( role=MessageRole.TOOL.value, - content=json.dumps(tool_response), + content=json.dumps(tool_response.tool_result), ) - ) + async for tool_response in chat_log.async_add_assistant_content( + conversation.AssistantContent( + agent_id=user_input.agent_id, + content=content, + tool_calls=tool_inputs or None, + ) + ) + ] + ) + + if not tool_calls: + break # Create intent response + intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(response_message["content"]) return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + response=intent_response, conversation_id=chat_log.conversation_id ) - def _prune_old_histories(self) -> None: - """Remove old message histories.""" - now = time.monotonic() - self._history = { - conversation_id: message_history - for conversation_id, message_history in self._history.items() - if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS - } - def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None: - """Trims excess messages from a single history.""" + """Trims excess messages from a single history. + + This sets the max history to allow a configurable size history may take + up in the context window. + + Note that some messages in the history may not be from ollama only, and + may come from other anents, so the assumptions here may not strictly hold, + but generally should be effective. + """ if max_messages < 1: # Keep all messages return - if message_history.num_user_messages >= max_messages: + # Ignore the in progress user message + num_previous_rounds = message_history.num_user_messages - 1 + if num_previous_rounds >= max_messages: # Trim history but keep system prompt (first message). # Every other message should be an assistant message, so keep 2x - # message objects. - num_keep = 2 * max_messages + # message objects. Also keep the last in progress user message + num_keep = 2 * max_messages + 1 drop_index = len(message_history.messages) - num_keep message_history.messages = [ message_history.messages[0] diff --git a/homeassistant/components/ollama/models.py b/homeassistant/components/ollama/models.py index 3b6fc958587c89efbc2ec23761b570a20e062acb..fd268664919f5781c3c4d5e2d429ef8131fdca4f 100644 --- a/homeassistant/components/ollama/models.py +++ b/homeassistant/components/ollama/models.py @@ -19,9 +19,6 @@ class MessageRole(StrEnum): class MessageHistory: """Chat message history.""" - timestamp: float - """Timestamp of last use in seconds.""" - messages: list[ollama.Message] """List of message history, including system prompt and assistant responses.""" diff --git a/tests/components/ollama/snapshots/test_conversation.ambr b/tests/components/ollama/snapshots/test_conversation.ambr index e4dd7cd00bb809c07057ef145d2c86c9df9e1f85..93f3b03d9afd494b5481ca28de71ebe90bd39f30 100644 --- a/tests/components/ollama/snapshots/test_conversation.ambr +++ b/tests/components/ollama/snapshots/test_conversation.ambr @@ -1,7 +1,7 @@ # serializer version: 1 # name: test_unknown_hass_api dict({ - 'conversation_id': None, + 'conversation_id': '1234', 'response': IntentResponse( card=dict({ }), @@ -20,7 +20,7 @@ speech=dict({ 'plain': dict({ 'extra_data': None, - 'speech': 'Error preparing LLM API: API non-existing not found', + 'speech': 'Error preparing LLM API', }), }), speech_slots=dict({ diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index b8e299f5e776e0408ecf1f1cc908e9e87654340b..df7c6beca72cdb235112b63275d797b3f77c3607 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -325,7 +325,11 @@ async def test_unknown_hass_api( await hass.async_block_till_done() result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + hass, + "hello", + "1234", + Context(), + agent_id=mock_config_entry.entry_id, ) assert result == snapshot @@ -428,70 +432,17 @@ async def test_message_history_trimming( assert args[4].kwargs["messages"][5]["content"] == "message 5" -async def test_message_history_pruning( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component -) -> None: - """Test that old message histories are pruned.""" - with patch( - "ollama.AsyncClient.chat", - return_value={"message": {"role": "assistant", "content": "test response"}}, - ): - # Create 3 different message histories - conversation_ids: list[str] = [] - for i in range(3): - result = await conversation.async_converse( - hass, - f"message {i + 1}", - conversation_id=None, - context=Context(), - agent_id=mock_config_entry.entry_id, - ) - assert ( - result.response.response_type == intent.IntentResponseType.ACTION_DONE - ), result - assert isinstance(result.conversation_id, str) - conversation_ids.append(result.conversation_id) - - agent = conversation.get_agent_manager(hass).async_get_agent( - mock_config_entry.entry_id - ) - assert len(agent._history) == 3 - assert agent._history.keys() == set(conversation_ids) - - # Modify the timestamps of the first 2 histories so they will be pruned - # on the next cycle. - for conversation_id in conversation_ids[:2]: - # Move back 2 hours - agent._history[conversation_id].timestamp -= 2 * 60 * 60 - - # Next cycle - result = await conversation.async_converse( - hass, - "test message", - conversation_id=None, - context=Context(), - agent_id=mock_config_entry.entry_id, - ) - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, ( - result - ) - - # Only the most recent histories should remain - assert len(agent._history) == 2 - assert conversation_ids[-1] in agent._history - assert result.conversation_id in agent._history - - async def test_message_history_unlimited( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component ) -> None: """Test that message history is not trimmed when max_history = 0.""" conversation_id = "1234" + with ( patch( "ollama.AsyncClient.chat", return_value={"message": {"role": "assistant", "content": "test response"}}, - ), + ) as mock_chat, ): hass.config_entries.async_update_entry( mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0} @@ -508,13 +459,13 @@ async def test_message_history_unlimited( result.response.response_type == intent.IntentResponseType.ACTION_DONE ), result - agent = conversation.get_agent_manager(hass).async_get_agent( - mock_config_entry.entry_id + args = mock_chat.call_args_list + assert len(args) == 100 + recorded_messages = args[-1].kwargs["messages"] + message_count = sum( + (message["role"] == "user") for message in recorded_messages ) - - assert len(agent._history) == 1 - assert conversation_id in agent._history - assert agent._history[conversation_id].num_user_messages == 100 + assert message_count == 100 async def test_error_handling(