From d47481a30eadde901a90c196cd9b60c09a187541 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen <balloob@gmail.com> Date: Thu, 6 Mar 2025 22:52:29 -0500 Subject: [PATCH] Track when an LLM expects to continue a conversation (#139810) * Track when an LLM expects to continue a conversation * Strip content * Address comments --- .../components/anthropic/conversation.py | 4 ++- .../components/conversation/chat_log.py | 19 +++++++++++++ .../conversation.py | 4 ++- .../components/ollama/conversation.py | 4 ++- .../openai_conversation/conversation.py | 4 ++- .../components/conversation/test_chat_log.py | 28 +++++++++++++++++++ 6 files changed, 59 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/anthropic/conversation.py b/homeassistant/components/anthropic/conversation.py index 5511119d377..8d3ba5085ee 100644 --- a/homeassistant/components/anthropic/conversation.py +++ b/homeassistant/components/anthropic/conversation.py @@ -305,7 +305,9 @@ class AnthropicConversationEntity( intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(response_content.content or "") return conversation.ConversationResult( - response=intent_response, conversation_id=chat_log.conversation_id + response=intent_response, + conversation_id=chat_log.conversation_id, + continue_conversation=chat_log.continue_conversation, ) async def _async_entry_update_listener( diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index 19482af1983..355f423dbb6 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -183,6 +183,25 @@ class ChatLog: llm_api: llm.APIInstance | None = None delta_listener: Callable[[ChatLog, dict], None] | None = None + @property + def continue_conversation(self) -> bool: + """Return whether the conversation should continue.""" + if not self.content: + return False + + last_msg = self.content[-1] + + return ( + last_msg.role == "assistant" + and last_msg.content is not None # type: ignore[union-attr] + and last_msg.content.strip().endswith( # type: ignore[union-attr] + ( + "?", + ";", # Greek question mark + ) + ) + ) + @property def unresponded_tool_results(self) -> bool: """Return if there are unresponded tool results.""" diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 168e867d857..b43558c6768 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -459,7 +459,9 @@ class GoogleGenerativeAIConversationEntity( " ".join([part.text.strip() for part in response_parts if part.text]) ) return conversation.ConversationResult( - response=response, conversation_id=chat_log.conversation_id + response=response, + conversation_id=chat_log.conversation_id, + continue_conversation=chat_log.continue_conversation, ) async def _async_entry_update_listener( diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index 90e81544f66..85daf742035 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -292,7 +292,9 @@ class OllamaConversationEntity( ) intent_response.async_set_speech(chat_log.content[-1].content or "") return conversation.ConversationResult( - response=intent_response, conversation_id=chat_log.conversation_id + response=intent_response, + conversation_id=chat_log.conversation_id, + continue_conversation=chat_log.continue_conversation, ) def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None: diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index cc09ec77c0e..37be41947f7 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -310,7 +310,9 @@ class OpenAIConversationEntity( assert type(chat_log.content[-1]) is conversation.AssistantContent intent_response.async_set_speech(chat_log.content[-1].content or "") return conversation.ConversationResult( - response=intent_response, conversation_id=chat_log.conversation_id + response=intent_response, + conversation_id=chat_log.conversation_id, + continue_conversation=chat_log.continue_conversation, ) async def _async_entry_update_listener( diff --git a/tests/components/conversation/test_chat_log.py b/tests/components/conversation/test_chat_log.py index c0687ebecfb..97094740af0 100644 --- a/tests/components/conversation/test_chat_log.py +++ b/tests/components/conversation/test_chat_log.py @@ -14,6 +14,7 @@ from homeassistant.components.conversation import ( ConversationInput, ConverseError, ToolResultContent, + UserContent, async_get_chat_log, ) from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS @@ -643,3 +644,30 @@ async def test_chat_log_reuse( assert len(chat_log.content) == 2 assert chat_log.content[1].role == "user" assert chat_log.content[1].content == mock_conversation_input.text + + +async def test_chat_log_continue_conversation( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, +) -> None: + """Test continue conversation.""" + with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session) as chat_log, + ): + assert chat_log.continue_conversation is False + chat_log.async_add_user_content(UserContent(mock_conversation_input.text)) + assert chat_log.continue_conversation is False + chat_log.async_add_assistant_content_without_tools( + AssistantContent( + agent_id="mock-agent-id", + content="Hey? ", + ) + ) + chat_log.async_add_assistant_content_without_tools( + AssistantContent( + agent_id="mock-agent-id", + content="Ποιο είναι το αγαπημένο σου χρώμα στα ελληνικά;", + ) + ) + assert chat_log.continue_conversation is True -- GitLab