From d47481a30eadde901a90c196cd9b60c09a187541 Mon Sep 17 00:00:00 2001
From: Paulus Schoutsen <balloob@gmail.com>
Date: Thu, 6 Mar 2025 22:52:29 -0500
Subject: [PATCH] Track when an LLM expects to continue a conversation
 (#139810)

* Track when an LLM expects to continue a conversation

* Strip content

* Address comments
---
 .../components/anthropic/conversation.py      |  4 ++-
 .../components/conversation/chat_log.py       | 19 +++++++++++++
 .../conversation.py                           |  4 ++-
 .../components/ollama/conversation.py         |  4 ++-
 .../openai_conversation/conversation.py       |  4 ++-
 .../components/conversation/test_chat_log.py  | 28 +++++++++++++++++++
 6 files changed, 59 insertions(+), 4 deletions(-)

diff --git a/homeassistant/components/anthropic/conversation.py b/homeassistant/components/anthropic/conversation.py
index 5511119d377..8d3ba5085ee 100644
--- a/homeassistant/components/anthropic/conversation.py
+++ b/homeassistant/components/anthropic/conversation.py
@@ -305,7 +305,9 @@ class AnthropicConversationEntity(
         intent_response = intent.IntentResponse(language=user_input.language)
         intent_response.async_set_speech(response_content.content or "")
         return conversation.ConversationResult(
-            response=intent_response, conversation_id=chat_log.conversation_id
+            response=intent_response,
+            conversation_id=chat_log.conversation_id,
+            continue_conversation=chat_log.continue_conversation,
         )
 
     async def _async_entry_update_listener(
diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py
index 19482af1983..355f423dbb6 100644
--- a/homeassistant/components/conversation/chat_log.py
+++ b/homeassistant/components/conversation/chat_log.py
@@ -183,6 +183,25 @@ class ChatLog:
     llm_api: llm.APIInstance | None = None
     delta_listener: Callable[[ChatLog, dict], None] | None = None
 
+    @property
+    def continue_conversation(self) -> bool:
+        """Return whether the conversation should continue."""
+        if not self.content:
+            return False
+
+        last_msg = self.content[-1]
+
+        return (
+            last_msg.role == "assistant"
+            and last_msg.content is not None  # type: ignore[union-attr]
+            and last_msg.content.strip().endswith(  # type: ignore[union-attr]
+                (
+                    "?",
+                    ";",  # Greek question mark
+                )
+            )
+        )
+
     @property
     def unresponded_tool_results(self) -> bool:
         """Return if there are unresponded tool results."""
diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py
index 168e867d857..b43558c6768 100644
--- a/homeassistant/components/google_generative_ai_conversation/conversation.py
+++ b/homeassistant/components/google_generative_ai_conversation/conversation.py
@@ -459,7 +459,9 @@ class GoogleGenerativeAIConversationEntity(
             " ".join([part.text.strip() for part in response_parts if part.text])
         )
         return conversation.ConversationResult(
-            response=response, conversation_id=chat_log.conversation_id
+            response=response,
+            conversation_id=chat_log.conversation_id,
+            continue_conversation=chat_log.continue_conversation,
         )
 
     async def _async_entry_update_listener(
diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py
index 90e81544f66..85daf742035 100644
--- a/homeassistant/components/ollama/conversation.py
+++ b/homeassistant/components/ollama/conversation.py
@@ -292,7 +292,9 @@ class OllamaConversationEntity(
             )
         intent_response.async_set_speech(chat_log.content[-1].content or "")
         return conversation.ConversationResult(
-            response=intent_response, conversation_id=chat_log.conversation_id
+            response=intent_response,
+            conversation_id=chat_log.conversation_id,
+            continue_conversation=chat_log.continue_conversation,
         )
 
     def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py
index cc09ec77c0e..37be41947f7 100644
--- a/homeassistant/components/openai_conversation/conversation.py
+++ b/homeassistant/components/openai_conversation/conversation.py
@@ -310,7 +310,9 @@ class OpenAIConversationEntity(
         assert type(chat_log.content[-1]) is conversation.AssistantContent
         intent_response.async_set_speech(chat_log.content[-1].content or "")
         return conversation.ConversationResult(
-            response=intent_response, conversation_id=chat_log.conversation_id
+            response=intent_response,
+            conversation_id=chat_log.conversation_id,
+            continue_conversation=chat_log.continue_conversation,
         )
 
     async def _async_entry_update_listener(
diff --git a/tests/components/conversation/test_chat_log.py b/tests/components/conversation/test_chat_log.py
index c0687ebecfb..97094740af0 100644
--- a/tests/components/conversation/test_chat_log.py
+++ b/tests/components/conversation/test_chat_log.py
@@ -14,6 +14,7 @@ from homeassistant.components.conversation import (
     ConversationInput,
     ConverseError,
     ToolResultContent,
+    UserContent,
     async_get_chat_log,
 )
 from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS
@@ -643,3 +644,30 @@ async def test_chat_log_reuse(
             assert len(chat_log.content) == 2
             assert chat_log.content[1].role == "user"
             assert chat_log.content[1].content == mock_conversation_input.text
+
+
+async def test_chat_log_continue_conversation(
+    hass: HomeAssistant,
+    mock_conversation_input: ConversationInput,
+) -> None:
+    """Test continue conversation."""
+    with (
+        chat_session.async_get_chat_session(hass) as session,
+        async_get_chat_log(hass, session) as chat_log,
+    ):
+        assert chat_log.continue_conversation is False
+        chat_log.async_add_user_content(UserContent(mock_conversation_input.text))
+        assert chat_log.continue_conversation is False
+        chat_log.async_add_assistant_content_without_tools(
+            AssistantContent(
+                agent_id="mock-agent-id",
+                content="Hey? ",
+            )
+        )
+        chat_log.async_add_assistant_content_without_tools(
+            AssistantContent(
+                agent_id="mock-agent-id",
+                content="Ποιο είναι το αγαπημένο σου χρώμα στα ελληνικά;",
+            )
+        )
+        assert chat_log.continue_conversation is True
-- 
GitLab