diff --git a/server/endpoints/api/workspace/index.js b/server/endpoints/api/workspace/index.js index f022dd57787..82cc8b5a895 100644 --- a/server/endpoints/api/workspace/index.js +++ b/server/endpoints/api/workspace/index.js @@ -610,7 +610,8 @@ function apiWorkspaceEndpoints(app) { mime: "image/png", contentString: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA..." } - ] + ], + reset: false } } } @@ -645,6 +646,7 @@ function apiWorkspaceEndpoints(app) { mode = "query", sessionId = null, attachments = [], + reset = false, } = reqBody(request); const workspace = await Workspace.get({ slug: String(slug) }); @@ -660,7 +662,7 @@ function apiWorkspaceEndpoints(app) { return; } - if (!message?.length || !VALID_CHAT_MODE.includes(mode)) { + if ((!message?.length || !VALID_CHAT_MODE.includes(mode)) && !reset) { response.status(400).json({ id: uuidv4(), type: "abort", @@ -668,7 +670,7 @@ function apiWorkspaceEndpoints(app) { sources: [], close: true, error: !message?.length - ? "message parameter cannot be empty." + ? "Message is empty" : `${mode} is not a valid mode.`, }); return; @@ -682,6 +684,7 @@ function apiWorkspaceEndpoints(app) { thread: null, sessionId: !!sessionId ? String(sessionId) : null, attachments, + reset, }); await Telemetry.sendTelemetry("sent_chat", { @@ -732,7 +735,8 @@ function apiWorkspaceEndpoints(app) { mime: "image/png", contentString: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA..." } - ] + ], + reset: false } } } @@ -788,6 +792,7 @@ function apiWorkspaceEndpoints(app) { mode = "query", sessionId = null, attachments = [], + reset = false, } = reqBody(request); const workspace = await Workspace.get({ slug: String(slug) }); @@ -803,7 +808,7 @@ function apiWorkspaceEndpoints(app) { return; } - if (!message?.length || !VALID_CHAT_MODE.includes(mode)) { + if ((!message?.length || !VALID_CHAT_MODE.includes(mode)) && !reset) { response.status(400).json({ id: uuidv4(), type: "abort", @@ -832,6 +837,7 @@ function apiWorkspaceEndpoints(app) { thread: null, sessionId: !!sessionId ? String(sessionId) : null, attachments, + reset, }); await Telemetry.sendTelemetry("sent_chat", { LLMSelection: diff --git a/server/endpoints/api/workspaceThread/index.js b/server/endpoints/api/workspaceThread/index.js index d85cf739b13..3ce65df97b4 100644 --- a/server/endpoints/api/workspaceThread/index.js +++ b/server/endpoints/api/workspaceThread/index.js @@ -351,7 +351,8 @@ function apiWorkspaceThreadEndpoints(app) { mime: "image/png", contentString: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA..." } - ] + ], + reset: false } } } @@ -386,6 +387,7 @@ function apiWorkspaceThreadEndpoints(app) { mode = "query", userId, attachments = [], + reset = false, } = reqBody(request); const workspace = await Workspace.get({ slug }); const thread = await WorkspaceThread.get({ @@ -405,7 +407,7 @@ function apiWorkspaceThreadEndpoints(app) { return; } - if (!message?.length || !VALID_CHAT_MODE.includes(mode)) { + if ((!message?.length || !VALID_CHAT_MODE.includes(mode)) && !reset) { response.status(400).json({ id: uuidv4(), type: "abort", @@ -413,7 +415,7 @@ function apiWorkspaceThreadEndpoints(app) { sources: [], close: true, error: !message?.length - ? "message parameter cannot be empty." + ? "Message is empty" : `${mode} is not a valid mode.`, }); return; @@ -427,6 +429,7 @@ function apiWorkspaceThreadEndpoints(app) { user, thread, attachments, + reset, }); await Telemetry.sendTelemetry("sent_chat", { LLMSelection: process.env.LLM_PROVIDER || "openai", @@ -489,7 +492,8 @@ function apiWorkspaceThreadEndpoints(app) { mime: "image/png", contentString: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA..." } - ] + ], + reset: false } } } @@ -545,6 +549,7 @@ function apiWorkspaceThreadEndpoints(app) { mode = "query", userId, attachments = [], + reset = false, } = reqBody(request); const workspace = await Workspace.get({ slug }); const thread = await WorkspaceThread.get({ @@ -564,7 +569,7 @@ function apiWorkspaceThreadEndpoints(app) { return; } - if (!message?.length || !VALID_CHAT_MODE.includes(mode)) { + if ((!message?.length || !VALID_CHAT_MODE.includes(mode)) && !reset) { response.status(400).json({ id: uuidv4(), type: "abort", @@ -594,6 +599,7 @@ function apiWorkspaceThreadEndpoints(app) { user, thread, attachments, + reset, }); await Telemetry.sendTelemetry("sent_chat", { LLMSelection: process.env.LLM_PROVIDER || "openai", diff --git a/server/models/workspaceChats.js b/server/models/workspaceChats.js index 4a2b884f8b3..dde98d1e899 100644 --- a/server/models/workspaceChats.js +++ b/server/models/workspaceChats.js @@ -104,6 +104,9 @@ const WorkspaceChats = { } }, + /** + * @deprecated Use markThreadHistoryInvalidV2 instead. + */ markHistoryInvalid: async function (workspaceId = null, user = null) { if (!workspaceId) return; try { @@ -123,6 +126,9 @@ const WorkspaceChats = { } }, + /** + * @deprecated Use markThreadHistoryInvalidV2 instead. + */ markThreadHistoryInvalid: async function ( workspaceId = null, user = null, @@ -146,6 +152,28 @@ const WorkspaceChats = { } }, + /** + * @description This function is used to mark a thread's history as invalid. + * and works with an arbitrary where clause. + * @param {Object} whereClause - The where clause to update the chats. + * @param {Object} data - The data to update the chats with. + * @returns {Promise} + */ + markThreadHistoryInvalidV2: async function (whereClause = {}) { + if (!whereClause) return; + try { + await prisma.workspace_chats.updateMany({ + where: whereClause, + data: { + include: false, + }, + }); + return; + } catch (error) { + console.error(error.message); + } + }, + get: async function (clause = {}, limit = null, orderBy = null) { try { const chat = await prisma.workspace_chats.findFirst({ diff --git a/server/swagger/openapi.json b/server/swagger/openapi.json index f848bbe2139..94a1d71e39a 100644 --- a/server/swagger/openapi.json +++ b/server/swagger/openapi.json @@ -2280,7 +2280,8 @@ "mime": "image/png", "contentString": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA..." } - ] + ], + "reset": false } } } @@ -2382,7 +2383,8 @@ "mime": "image/png", "contentString": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA..." } - ] + ], + "reset": false } } } @@ -3143,7 +3145,8 @@ "mime": "image/png", "contentString": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA..." } - ] + ], + "reset": false } } } @@ -3255,7 +3258,8 @@ "mime": "image/png", "contentString": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA..." } - ] + ], + "reset": false } } } diff --git a/server/utils/chats/apiChatHandler.js b/server/utils/chats/apiChatHandler.js index 11421ea128e..6bb82a2dd38 100644 --- a/server/utils/chats/apiChatHandler.js +++ b/server/utils/chats/apiChatHandler.js @@ -3,7 +3,12 @@ const { DocumentManager } = require("../DocumentManager"); const { WorkspaceChats } = require("../../models/workspaceChats"); const { getVectorDbClass, getLLMProvider } = require("../helpers"); const { writeResponseChunk } = require("../helpers/chat/responses"); -const { chatPrompt, sourceIdentifier, recentChatHistory } = require("./index"); +const { + chatPrompt, + sourceIdentifier, + recentChatHistory, + grepAllSlashCommands, +} = require("./index"); const { EphemeralAgentHandler, EphemeralEventListener, @@ -31,6 +36,7 @@ const { Telemetry } = require("../../models/telemetry"); * thread: import("@prisma/client").workspace_threads|null, * sessionId: string|null, * attachments: { name: string; mime: string; contentString: string }[], + * reset: boolean, * }} parameters * @returns {Promise} */ @@ -42,10 +48,39 @@ async function chatSync({ thread = null, sessionId = null, attachments = [], + reset = false, }) { const uuid = uuidv4(); const chatMode = mode ?? "chat"; + // If the user wants to reset the chat history we do so pre-flight + // and continue execution. If no message is provided then the user intended + // to reset the chat history only and we can exit early with a confirmation. + if (reset) { + await WorkspaceChats.markThreadHistoryInvalidV2({ + workspaceId: workspace.id, + user_id: user?.id, + thread_id: thread?.id, + api_session_id: sessionId, + }); + if (!message?.length) { + return { + id: uuid, + type: "textResponse", + textResponse: "Chat history was reset!", + sources: [], + close: true, + error: null, + metrics: {}, + }; + } + } + + // Process slash commands + // Since preset commands are not supported in API calls, we can just process the message here + const processedMessage = await grepAllSlashCommands(message); + message = processedMessage; + if (EphemeralAgentHandler.isAgentInvocation({ message })) { await Telemetry.sendTelemetry("agent_chat_started"); @@ -320,6 +355,7 @@ async function chatSync({ * thread: import("@prisma/client").workspace_threads|null, * sessionId: string|null, * attachments: { name: string; mime: string; contentString: string }[], + * reset: boolean, * }} parameters * @returns {Promise} */ @@ -332,10 +368,41 @@ async function streamChat({ thread = null, sessionId = null, attachments = [], + reset = false, }) { const uuid = uuidv4(); const chatMode = mode ?? "chat"; + // If the user wants to reset the chat history we do so pre-flight + // and continue execution. If no message is provided then the user intended + // to reset the chat history only and we can exit early with a confirmation. + if (reset) { + await WorkspaceChats.markThreadHistoryInvalidV2({ + workspaceId: workspace.id, + user_id: user?.id, + thread_id: thread?.id, + api_session_id: sessionId, + }); + if (!message?.length) { + writeResponseChunk(response, { + id: uuid, + type: "textResponse", + textResponse: "Chat history was reset!", + sources: [], + attachments: [], + close: true, + error: null, + metrics: {}, + }); + return; + } + } + + // Check for and process slash commands + // Since preset commands are not supported in API calls, we can just process the message here + const processedMessage = await grepAllSlashCommands(message); + message = processedMessage; + if (EphemeralAgentHandler.isAgentInvocation({ message })) { await Telemetry.sendTelemetry("agent_chat_started"); diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js index 067b968ce57..28f3f3d75a6 100644 --- a/server/utils/chats/index.js +++ b/server/utils/chats/index.js @@ -35,6 +35,28 @@ async function grepCommand(message, user = null) { return updatedMessage; } +/** + * @description This function will do recursive replacement of all slash commands with their corresponding prompts. + * @notice This function is used for API calls and is not user-scoped. THIS FUNCTION DOES NOT SUPPORT PRESET COMMANDS. + * @returns {Promise} + */ +async function grepAllSlashCommands(message) { + const allPresets = await SlashCommandPresets.where({}); + + // Replace all preset commands with their corresponding prompts + // Allows multiple commands in one message + let updatedMessage = message; + for (const preset of allPresets) { + const regex = new RegExp( + `(?:\\b\\s|^)(${preset.command})(?:\\b\\s|$)`, + "g" + ); + updatedMessage = updatedMessage.replace(regex, preset.prompt); + } + + return updatedMessage; +} + async function recentChatHistory({ user = null, workspace, @@ -80,5 +102,6 @@ module.exports = { recentChatHistory, chatPrompt, grepCommand, + grepAllSlashCommands, VALID_COMMANDS, };