From 0ea9581cfc3a7c151540dd7e29cc0a421828d9e5 Mon Sep 17 00:00:00 2001
From: Paulus Schoutsen <balloob@gmail.com>
Date: Tue, 11 Jun 2024 01:49:14 -0400
Subject: [PATCH] OpenAI to respect custom conversation IDs (#119307)

---
 .../openai_conversation/conversation.py       | 18 ++++++++-
 .../openai_conversation/test_conversation.py  | 39 +++++++++++++++++++
 2 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py
index d5e566678f1..d0b3ef8f895 100644
--- a/homeassistant/components/openai_conversation/conversation.py
+++ b/homeassistant/components/openai_conversation/conversation.py
@@ -141,11 +141,25 @@ class OpenAIConversationEntity(
                 )
             tools = [_format_tool(tool) for tool in llm_api.tools]
 
-        if user_input.conversation_id in self.history:
+        if user_input.conversation_id is None:
+            conversation_id = ulid.ulid_now()
+            messages = []
+
+        elif user_input.conversation_id in self.history:
             conversation_id = user_input.conversation_id
             messages = self.history[conversation_id]
+
         else:
-            conversation_id = ulid.ulid_now()
+            # Conversation IDs are ULIDs. We generate a new one if not provided.
+            # If an old OLID is passed in, we will generate a new one to indicate
+            # a new conversation was started. If the user picks their own, they
+            # want to track a conversation and we respect it.
+            try:
+                ulid.ulid_to_bytes(user_input.conversation_id)
+                conversation_id = ulid.ulid_now()
+            except ValueError:
+                conversation_id = user_input.conversation_id
+
             messages = []
 
         if (
diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py
index 002b2df186b..5ca54611c91 100644
--- a/tests/components/openai_conversation/test_conversation.py
+++ b/tests/components/openai_conversation/test_conversation.py
@@ -22,6 +22,7 @@ from homeassistant.core import Context, HomeAssistant
 from homeassistant.exceptions import HomeAssistantError
 from homeassistant.helpers import intent, llm
 from homeassistant.setup import async_setup_component
+from homeassistant.util import ulid
 
 from tests.common import MockConfigEntry
 
@@ -497,3 +498,41 @@ async def test_unknown_hass_api(
     )
 
     assert result == snapshot
+
+
+@patch(
+    "openai.resources.chat.completions.AsyncCompletions.create",
+    new_callable=AsyncMock,
+)
+async def test_conversation_id(
+    mock_create,
+    hass: HomeAssistant,
+    mock_config_entry: MockConfigEntry,
+    mock_init_component,
+) -> None:
+    """Test conversation ID is honored."""
+    result = await conversation.async_converse(
+        hass, "hello", None, None, agent_id=mock_config_entry.entry_id
+    )
+
+    conversation_id = result.conversation_id
+
+    result = await conversation.async_converse(
+        hass, "hello", conversation_id, None, agent_id=mock_config_entry.entry_id
+    )
+
+    assert result.conversation_id == conversation_id
+
+    unknown_id = ulid.ulid()
+
+    result = await conversation.async_converse(
+        hass, "hello", unknown_id, None, agent_id=mock_config_entry.entry_id
+    )
+
+    assert result.conversation_id != unknown_id
+
+    result = await conversation.async_converse(
+        hass, "hello", "koala", None, agent_id=mock_config_entry.entry_id
+    )
+
+    assert result.conversation_id == "koala"
-- 
GitLab