+
+
{isLastMessage && !isEditing && (
);
}
+function ForkThread({ chatId, forkThread, isEditing, role }) {
+ if (!chatId || isEditing || role === "user") return null;
+ return (
+
+
+
+
+ );
+}
export default memo(Actions);
diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx
index 7b509e86325..d88a75f3ff4 100644
--- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx
+++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx
@@ -23,6 +23,7 @@ const HistoricalMessage = ({
isLastMessage = false,
regenerateMessage,
saveEditedMessage,
+ forkThread,
}) => {
const { isEditing } = useEditMessage({ chatId, role });
const adjustTextArea = (event) => {
@@ -95,6 +96,7 @@ const HistoricalMessage = ({
regenerateMessage={regenerateMessage}
isEditing={isEditing}
role={role}
+ forkThread={forkThread}
/>
{role === "assistant" &&
}
diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx
index fef556b143d..53cbeb64f63 100644
--- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx
+++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx
@@ -9,6 +9,7 @@ import useUser from "@/hooks/useUser";
import Chartable from "./Chartable";
import Workspace from "@/models/workspace";
import { useParams } from "react-router-dom";
+import paths from "@/utils/paths";
export default function ChatHistory({
history = [],
@@ -131,6 +132,18 @@ export default function ChatHistory({
}
};
+ const forkThread = async (chatId) => {
+ const newThreadSlug = await Workspace.forkThread(
+ workspace.slug,
+ threadSlug,
+ chatId
+ );
+ window.location.href = paths.workspace.thread(
+ workspace.slug,
+ newThreadSlug
+ );
+ };
+
if (history.length === 0) {
return (
@@ -217,6 +230,7 @@ export default function ChatHistory({
regenerateMessage={regenerateAssistantMessage}
isLastMessage={isLastBotReply}
saveEditedMessage={saveEditedMessage}
+ forkThread={forkThread}
/>
);
})}
diff --git a/frontend/src/models/workspace.js b/frontend/src/models/workspace.js
index cfbde704a1b..43c723f7901 100644
--- a/frontend/src/models/workspace.js
+++ b/frontend/src/models/workspace.js
@@ -384,6 +384,22 @@ const Workspace = {
return false;
});
},
+ forkThread: async function (slug = "", threadSlug = null, chatId = null) {
+ return await fetch(`${API_BASE}/workspace/${slug}/thread/fork`, {
+ method: "POST",
+ headers: baseHeaders(),
+ body: JSON.stringify({ threadSlug, chatId }),
+ })
+ .then((res) => {
+ if (!res.ok) throw new Error("Failed to fork thread.");
+ return res.json();
+ })
+ .then((data) => data.newThreadSlug)
+ .catch((e) => {
+ console.error("Error forking thread:", e);
+ return null;
+ });
+ },
threads: WorkspaceThread,
};
diff --git a/server/endpoints/workspaces.js b/server/endpoints/workspaces.js
index 6d6f29bbd51..e013a430506 100644
--- a/server/endpoints/workspaces.js
+++ b/server/endpoints/workspaces.js
@@ -31,6 +31,8 @@ const {
fetchPfp,
} = require("../utils/files/pfp");
const { getTTSProvider } = require("../utils/TextToSpeech");
+const { WorkspaceThread } = require("../models/workspaceThread");
+const truncate = require("truncate");
function workspaceEndpoints(app) {
if (!app) return;
@@ -761,6 +763,81 @@ function workspaceEndpoints(app) {
}
}
);
+
+ app.post(
+ "/workspace/:slug/thread/fork",
+ [validatedRequest, flexUserRoleValid([ROLES.all]), validWorkspaceSlug],
+ async (request, response) => {
+ try {
+ const user = await userFromSession(request, response);
+ const workspace = response.locals.workspace;
+ const { chatId, threadSlug } = reqBody(request);
+ if (!chatId)
+ return response.status(400).json({ message: "chatId is required" });
+
+ // Get threadId we are branching from if that request body is sent
+ // and is a valid thread slug.
+ const threadId = !!threadSlug
+ ? (
+ await WorkspaceThread.get({
+ slug: String(threadSlug),
+ workspace_id: workspace.id,
+ })
+ )?.id ?? null
+ : null;
+ const chatsToFork = await WorkspaceChats.where(
+ {
+ workspaceId: workspace.id,
+ user_id: user?.id,
+ include: true, // only duplicate visible chats
+ thread_id: threadId,
+ id: { lte: Number(chatId) },
+ },
+ null,
+ { id: "asc" }
+ );
+
+ const { thread: newThread, message: threadError } =
+ await WorkspaceThread.new(workspace, user?.id);
+ if (threadError)
+ return response.status(500).json({ error: threadError });
+
+ let lastMessageText = "";
+ const chatsData = chatsToFork.map((chat) => {
+ const chatResponse = safeJsonParse(chat.response, {});
+ if (chatResponse?.text) lastMessageText = chatResponse.text;
+
+ return {
+ workspaceId: workspace.id,
+ prompt: chat.prompt,
+ response: JSON.stringify(chatResponse),
+ user_id: user?.id,
+ thread_id: newThread.id,
+ };
+ });
+ await WorkspaceChats.bulkCreate(chatsData);
+ await WorkspaceThread.update(newThread, {
+ name: !!lastMessageText
+ ? truncate(lastMessageText, 22)
+ : "Forked Thread",
+ });
+
+ await Telemetry.sendTelemetry("thread_forked");
+ await EventLogs.logEvent(
+ "thread_forked",
+ {
+ workspaceName: workspace?.name || "Unknown Workspace",
+ threadName: newThread.name,
+ },
+ user?.id
+ );
+ response.status(200).json({ newThreadSlug: newThread.slug });
+ } catch (e) {
+ console.log(e.message, e);
+ response.status(500).json({ message: "Internal server error" });
+ }
+ }
+ );
}
module.exports = { workspaceEndpoints };
diff --git a/server/models/workspaceChats.js b/server/models/workspaceChats.js
index 951245204fe..52d96c400e6 100644
--- a/server/models/workspaceChats.js
+++ b/server/models/workspaceChats.js
@@ -240,6 +240,23 @@ const WorkspaceChats = {
return false;
}
},
+ bulkCreate: async function (chatsData) {
+ // TODO: Replace with createMany when we update prisma to latest version
+ // The version of prisma that we are currently using does not support createMany with SQLite
+ try {
+ const createdChats = [];
+ for (const chatData of chatsData) {
+ const chat = await prisma.workspace_chats.create({
+ data: chatData,
+ });
+ createdChats.push(chat);
+ }
+ return { chats: createdChats, message: null };
+ } catch (error) {
+ console.error(error.message);
+ return { chats: null, message: error.message };
+ }
+ },
};
module.exports = { WorkspaceChats };