diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/Actions/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/Actions/index.jsx index 85590e7f310ea7772bf248722a6a6bfbfcadfdf8..abe1f00e04daebadb489a451a7f97e946af56486 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/Actions/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/Actions/index.jsx @@ -6,6 +6,7 @@ import { ThumbsDown, ArrowsClockwise, Copy, + GitMerge, } from "@phosphor-icons/react"; import { Tooltip } from "react-tooltip"; import Workspace from "@/models/workspace"; @@ -19,6 +20,7 @@ const Actions = ({ slug, isLastMessage, regenerateMessage, + forkThread, isEditing, role, }) => { @@ -32,8 +34,14 @@ const Actions = ({ return ( <div className="flex w-full justify-between items-center"> - <div className="flex justify-start items-center gap-x-4"> + <div className="flex justify-start items-center gap-x-4 group"> <CopyMessage message={message} /> + <ForkThread + chatId={chatId} + forkThread={forkThread} + isEditing={isEditing} + role={role} + /> <EditMessageAction chatId={chatId} role={role} isEditing={isEditing} /> {isLastMessage && !isEditing && ( <RegenerateMessage @@ -150,5 +158,27 @@ function RegenerateMessage({ regenerateMessage, chatId }) { </div> ); } +function ForkThread({ chatId, forkThread, isEditing, role }) { + if (!chatId || isEditing || role === "user") return null; + return ( + <div className="mt-3 relative"> + <button + onClick={() => forkThread(chatId)} + data-tooltip-id="fork-thread" + data-tooltip-content="Fork chat to new thread" + className="border-none text-zinc-300" + aria-label="Fork" + > + <GitMerge size={18} className="mb-1" weight="fill" /> + </button> + <Tooltip + id="fork-thread" + place="bottom" + delayShow={300} + className="tooltip !text-xs" + /> + </div> + ); +} export default memo(Actions); diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx index 7b509e86325eaf8af95123361c325ec4b2d06d38..d88a75f3ff47f36ffe372a3963c9ce27271a7f75 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx @@ -23,6 +23,7 @@ const HistoricalMessage = ({ isLastMessage = false, regenerateMessage, saveEditedMessage, + forkThread, }) => { const { isEditing } = useEditMessage({ chatId, role }); const adjustTextArea = (event) => { @@ -95,6 +96,7 @@ const HistoricalMessage = ({ regenerateMessage={regenerateMessage} isEditing={isEditing} role={role} + forkThread={forkThread} /> </div> {role === "assistant" && <Citations sources={sources} />} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx index fef556b143d8dcd8f07017a2ddfd815e3d5270f8..53cbeb64f63e3f0fe17f7bf83e61162da3691464 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx @@ -9,6 +9,7 @@ import useUser from "@/hooks/useUser"; import Chartable from "./Chartable"; import Workspace from "@/models/workspace"; import { useParams } from "react-router-dom"; +import paths from "@/utils/paths"; export default function ChatHistory({ history = [], @@ -131,6 +132,18 @@ export default function ChatHistory({ } }; + const forkThread = async (chatId) => { + const newThreadSlug = await Workspace.forkThread( + workspace.slug, + threadSlug, + chatId + ); + window.location.href = paths.workspace.thread( + workspace.slug, + newThreadSlug + ); + }; + if (history.length === 0) { return ( <div className="flex flex-col h-full md:mt-0 pb-44 md:pb-40 w-full justify-end items-center"> @@ -217,6 +230,7 @@ export default function ChatHistory({ regenerateMessage={regenerateAssistantMessage} isLastMessage={isLastBotReply} saveEditedMessage={saveEditedMessage} + forkThread={forkThread} /> ); })} diff --git a/frontend/src/models/workspace.js b/frontend/src/models/workspace.js index cfbde704a1b4d44242553fc999251fb79847af89..43c723f79017b9e934a87a332f7dfa91c344428a 100644 --- a/frontend/src/models/workspace.js +++ b/frontend/src/models/workspace.js @@ -384,6 +384,22 @@ const Workspace = { return false; }); }, + forkThread: async function (slug = "", threadSlug = null, chatId = null) { + return await fetch(`${API_BASE}/workspace/${slug}/thread/fork`, { + method: "POST", + headers: baseHeaders(), + body: JSON.stringify({ threadSlug, chatId }), + }) + .then((res) => { + if (!res.ok) throw new Error("Failed to fork thread."); + return res.json(); + }) + .then((data) => data.newThreadSlug) + .catch((e) => { + console.error("Error forking thread:", e); + return null; + }); + }, threads: WorkspaceThread, }; diff --git a/server/endpoints/workspaces.js b/server/endpoints/workspaces.js index 6d6f29bbd51147f0d7116da6d4c04ee15485e6bf..e013a4305062c2b355e8666da493f9b97506910a 100644 --- a/server/endpoints/workspaces.js +++ b/server/endpoints/workspaces.js @@ -31,6 +31,8 @@ const { fetchPfp, } = require("../utils/files/pfp"); const { getTTSProvider } = require("../utils/TextToSpeech"); +const { WorkspaceThread } = require("../models/workspaceThread"); +const truncate = require("truncate"); function workspaceEndpoints(app) { if (!app) return; @@ -761,6 +763,81 @@ function workspaceEndpoints(app) { } } ); + + app.post( + "/workspace/:slug/thread/fork", + [validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const workspace = response.locals.workspace; + const { chatId, threadSlug } = reqBody(request); + if (!chatId) + return response.status(400).json({ message: "chatId is required" }); + + // Get threadId we are branching from if that request body is sent + // and is a valid thread slug. + const threadId = !!threadSlug + ? ( + await WorkspaceThread.get({ + slug: String(threadSlug), + workspace_id: workspace.id, + }) + )?.id ?? null + : null; + const chatsToFork = await WorkspaceChats.where( + { + workspaceId: workspace.id, + user_id: user?.id, + include: true, // only duplicate visible chats + thread_id: threadId, + id: { lte: Number(chatId) }, + }, + null, + { id: "asc" } + ); + + const { thread: newThread, message: threadError } = + await WorkspaceThread.new(workspace, user?.id); + if (threadError) + return response.status(500).json({ error: threadError }); + + let lastMessageText = ""; + const chatsData = chatsToFork.map((chat) => { + const chatResponse = safeJsonParse(chat.response, {}); + if (chatResponse?.text) lastMessageText = chatResponse.text; + + return { + workspaceId: workspace.id, + prompt: chat.prompt, + response: JSON.stringify(chatResponse), + user_id: user?.id, + thread_id: newThread.id, + }; + }); + await WorkspaceChats.bulkCreate(chatsData); + await WorkspaceThread.update(newThread, { + name: !!lastMessageText + ? truncate(lastMessageText, 22) + : "Forked Thread", + }); + + await Telemetry.sendTelemetry("thread_forked"); + await EventLogs.logEvent( + "thread_forked", + { + workspaceName: workspace?.name || "Unknown Workspace", + threadName: newThread.name, + }, + user?.id + ); + response.status(200).json({ newThreadSlug: newThread.slug }); + } catch (e) { + console.log(e.message, e); + response.status(500).json({ message: "Internal server error" }); + } + } + ); } module.exports = { workspaceEndpoints }; diff --git a/server/models/workspaceChats.js b/server/models/workspaceChats.js index 951245204fe42d8bc9ea5d3e6ede5e34d176ac6d..52d96c400e63e65e997363c341b6598942b34769 100644 --- a/server/models/workspaceChats.js +++ b/server/models/workspaceChats.js @@ -240,6 +240,23 @@ const WorkspaceChats = { return false; } }, + bulkCreate: async function (chatsData) { + // TODO: Replace with createMany when we update prisma to latest version + // The version of prisma that we are currently using does not support createMany with SQLite + try { + const createdChats = []; + for (const chatData of chatsData) { + const chat = await prisma.workspace_chats.create({ + data: chatData, + }); + createdChats.push(chat); + } + return { chats: createdChats, message: null }; + } catch (error) { + console.error(error.message); + return { chats: null, message: error.message }; + } + }, }; module.exports = { WorkspaceChats };