diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/AddPresetModal.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/AddPresetModal.jsx new file mode 100644 index 00000000000..e5154580bfd --- /dev/null +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/AddPresetModal.jsx @@ -0,0 +1,111 @@ +import { useState } from "react"; +import { X } from "@phosphor-icons/react"; +import ModalWrapper from "@/components/ModalWrapper"; +import { CMD_REGEX } from "."; + +export default function AddPresetModal({ isOpen, onClose, onSave }) { + const [command, setCommand] = useState(""); + + const handleSubmit = async (e) => { + e.preventDefault(); + const form = new FormData(e.target); + const sanitizedCommand = command.replace(CMD_REGEX, ""); + const saved = await onSave({ + command: `/${sanitizedCommand}`, + prompt: form.get("prompt"), + description: form.get("description"), + }); + if (saved) setCommand(""); + }; + + const handleCommandChange = (e) => { + const value = e.target.value.replace(CMD_REGEX, ""); + setCommand(value); + }; + + return ( + +
+
+
+

Add New Preset

+ +
+
+
+
+ +
+ / + +
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+
+
+ ); +} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/EditPresetModal.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/EditPresetModal.jsx new file mode 100644 index 00000000000..fdffbe609c7 --- /dev/null +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/EditPresetModal.jsx @@ -0,0 +1,148 @@ +import { useState } from "react"; +import { X } from "@phosphor-icons/react"; +import ModalWrapper from "@/components/ModalWrapper"; +import { CMD_REGEX } from "."; + +export default function EditPresetModal({ + isOpen, + onClose, + onSave, + onDelete, + preset, +}) { + const [command, setCommand] = useState(preset?.command?.slice(1) || ""); + const [deleting, setDeleting] = useState(false); + + const handleSubmit = (e) => { + e.preventDefault(); + const form = new FormData(e.target); + const sanitizedCommand = command.replace(CMD_REGEX, ""); + onSave({ + id: preset.id, + command: `/${sanitizedCommand}`, + prompt: form.get("prompt"), + description: form.get("description"), + }); + }; + + const handleCommandChange = (e) => { + const value = e.target.value.replace(CMD_REGEX, ""); + setCommand(value); + }; + + const handleDelete = async () => { + const confirmDelete = window.confirm( + "Are you sure you want to delete this preset?" + ); + if (!confirmDelete) return; + + setDeleting(true); + await onDelete(preset.id); + setDeleting(false); + onClose(); + }; + + return ( + +
+
+
+

Edit Preset

+ +
+
+
+
+ +
+ / + +
+
+
+ + +
+
+ + +
+
+
+
+
+ +
+
+ + +
+
+
+
+
+ ); +} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/index.jsx new file mode 100644 index 00000000000..ca39b68a823 --- /dev/null +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/SlashPresets/index.jsx @@ -0,0 +1,127 @@ +import { useEffect, useState } from "react"; +import { useIsAgentSessionActive } from "@/utils/chat/agent"; +import AddPresetModal from "./AddPresetModal"; +import EditPresetModal from "./EditPresetModal"; +import { useModal } from "@/hooks/useModal"; +import System from "@/models/system"; +import { DotsThree, Plus } from "@phosphor-icons/react"; +import showToast from "@/utils/toast"; + +export const CMD_REGEX = new RegExp(/[^a-zA-Z0-9_-]/g); +export default function SlashPresets({ setShowing, sendCommand }) { + const isActiveAgentSession = useIsAgentSessionActive(); + const { + isOpen: isAddModalOpen, + openModal: openAddModal, + closeModal: closeAddModal, + } = useModal(); + const { + isOpen: isEditModalOpen, + openModal: openEditModal, + closeModal: closeEditModal, + } = useModal(); + const [presets, setPresets] = useState([]); + const [selectedPreset, setSelectedPreset] = useState(null); + + useEffect(() => { + fetchPresets(); + }, []); + if (isActiveAgentSession) return null; + + const fetchPresets = async () => { + const presets = await System.getSlashCommandPresets(); + setPresets(presets); + }; + + const handleSavePreset = async (preset) => { + const { error } = await System.createSlashCommandPreset(preset); + if (!!error) { + showToast(error, "error"); + return false; + } + + fetchPresets(); + closeAddModal(); + return true; + }; + + const handleEditPreset = (preset) => { + setSelectedPreset(preset); + openEditModal(); + }; + + const handleUpdatePreset = async (updatedPreset) => { + const { error } = await System.updateSlashCommandPreset( + updatedPreset.id, + updatedPreset + ); + + if (!!error) { + showToast(error, "error"); + return; + } + + fetchPresets(); + closeEditModal(); + }; + + const handleDeletePreset = async (presetId) => { + await System.deleteSlashCommandPreset(presetId); + fetchPresets(); + closeEditModal(); + }; + + return ( + <> + {presets.map((preset) => ( + + + ))} + + + {selectedPreset && ( + + )} + + ); +} diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/index.jsx index 5a606af6da6..9b626372c5f 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/PromptInput/SlashCommands/index.jsx @@ -3,6 +3,7 @@ import SlashCommandIcon from "./icons/slash-commands-icon.svg"; import { Tooltip } from "react-tooltip"; import ResetCommand from "./reset"; import EndAgentSession from "./endAgentSession"; +import SlashPresets from "./SlashPresets"; export default function SlashCommandsButton({ showing, setShowSlashCommand }) { return ( @@ -52,10 +53,11 @@ export function SlashCommands({ showing, setShowing, sendCommand }) {
+
diff --git a/frontend/src/models/system.js b/frontend/src/models/system.js index af532a0474d..e64b0119986 100644 --- a/frontend/src/models/system.js +++ b/frontend/src/models/system.js @@ -567,6 +567,74 @@ const System = { }); }, dataConnectors: DataConnector, + + getSlashCommandPresets: async function () { + return await fetch(`${API_BASE}/system/slash-command-presets`, { + method: "GET", + headers: baseHeaders(), + }) + .then((res) => { + if (!res.ok) throw new Error("Could not fetch slash command presets."); + return res.json(); + }) + .then((res) => res.presets) + .catch((e) => { + console.error(e); + return []; + }); + }, + + createSlashCommandPreset: async function (presetData) { + return await fetch(`${API_BASE}/system/slash-command-presets`, { + method: "POST", + headers: baseHeaders(), + body: JSON.stringify(presetData), + }) + .then((res) => { + if (!res.ok) throw new Error("Could not create slash command preset."); + return res.json(); + }) + .then((res) => { + return { preset: res.preset, error: null }; + }) + .catch((e) => { + console.error(e); + return { preset: null, error: e.message }; + }); + }, + + updateSlashCommandPreset: async function (presetId, presetData) { + return await fetch(`${API_BASE}/system/slash-command-presets/${presetId}`, { + method: "POST", + headers: baseHeaders(), + body: JSON.stringify(presetData), + }) + .then((res) => { + if (!res.ok) throw new Error("Could not update slash command preset."); + return res.json(); + }) + .then((res) => { + return { preset: res.preset, error: null }; + }) + .catch((e) => { + return { preset: null, error: "Failed to update this command." }; + }); + }, + + deleteSlashCommandPreset: async function (presetId) { + return await fetch(`${API_BASE}/system/slash-command-presets/${presetId}`, { + method: "DELETE", + headers: baseHeaders(), + }) + .then((res) => { + if (!res.ok) throw new Error("Could not delete slash command preset."); + return true; + }) + .catch((e) => { + console.error(e); + return false; + }); + }, }; export default System; diff --git a/server/endpoints/system.js b/server/endpoints/system.js index 60d51e35fd7..4538ee0601f 100644 --- a/server/endpoints/system.js +++ b/server/endpoints/system.js @@ -50,6 +50,7 @@ const { resetPassword, generateRecoveryCodes, } = require("../utils/PasswordRecovery"); +const { SlashCommandPresets } = require("../models/slashCommandsPresets"); function systemEndpoints(app) { if (!app) return; @@ -1044,6 +1045,111 @@ function systemEndpoints(app) { response.sendStatus(500).end(); } }); + + app.get( + "/system/slash-command-presets", + [validatedRequest, flexUserRoleValid([ROLES.all])], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const userPresets = await SlashCommandPresets.getUserPresets(user?.id); + response.status(200).json({ presets: userPresets }); + } catch (error) { + console.error("Error fetching slash command presets:", error); + response.status(500).json({ message: "Internal server error" }); + } + } + ); + + app.post( + "/system/slash-command-presets", + [validatedRequest, flexUserRoleValid([ROLES.all])], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const { command, prompt, description } = reqBody(request); + const presetData = { + command: SlashCommandPresets.formatCommand(String(command)), + prompt: String(prompt), + description: String(description), + }; + + const preset = await SlashCommandPresets.create(user?.id, presetData); + if (!preset) { + return response + .status(500) + .json({ message: "Failed to create preset" }); + } + response.status(201).json({ preset }); + } catch (error) { + console.error("Error creating slash command preset:", error); + response.status(500).json({ message: "Internal server error" }); + } + } + ); + + app.post( + "/system/slash-command-presets/:slashCommandId", + [validatedRequest, flexUserRoleValid([ROLES.all])], + async (request, response) => { + try { + const user = await userFromSession(request, response); + const { slashCommandId } = request.params; + const { command, prompt, description } = reqBody(request); + + // Valid user running owns the preset if user session is valid. + const ownsPreset = await SlashCommandPresets.get({ + userId: user?.id ?? null, + id: Number(slashCommandId), + }); + if (!ownsPreset) + return response.status(404).json({ message: "Preset not found" }); + + const updates = { + command: SlashCommandPresets.formatCommand(String(command)), + prompt: String(prompt), + description: String(description), + }; + + const preset = await SlashCommandPresets.update( + Number(slashCommandId), + updates + ); + if (!preset) return response.sendStatus(422); + response.status(200).json({ preset: { ...ownsPreset, ...updates } }); + } catch (error) { + console.error("Error updating slash command preset:", error); + response.status(500).json({ message: "Internal server error" }); + } + } + ); + + app.delete( + "/system/slash-command-presets/:slashCommandId", + [validatedRequest, flexUserRoleValid([ROLES.all])], + async (request, response) => { + try { + const { slashCommandId } = request.params; + const user = await userFromSession(request, response); + + // Valid user running owns the preset if user session is valid. + const ownsPreset = await SlashCommandPresets.get({ + userId: user?.id ?? null, + id: Number(slashCommandId), + }); + if (!ownsPreset) + return response + .status(403) + .json({ message: "Failed to delete preset" }); + + await SlashCommandPresets.delete(Number(slashCommandId)); + response.sendStatus(204); + } catch (error) { + console.error("Error deleting slash command preset:", error); + response.status(500).json({ message: "Internal server error" }); + } + } + ); } module.exports = { systemEndpoints }; diff --git a/server/models/slashCommandsPresets.js b/server/models/slashCommandsPresets.js new file mode 100644 index 00000000000..4828c77d59d --- /dev/null +++ b/server/models/slashCommandsPresets.js @@ -0,0 +1,105 @@ +const { v4 } = require("uuid"); +const prisma = require("../utils/prisma"); +const CMD_REGEX = new RegExp(/[^a-zA-Z0-9_-]/g); + +const SlashCommandPresets = { + formatCommand: function (command = "") { + if (!command || command.length < 2) return `/${v4().split("-")[0]}`; + + let adjustedCmd = command.toLowerCase(); // force lowercase + if (!adjustedCmd.startsWith("/")) adjustedCmd = `/${adjustedCmd}`; // Fix if no preceding / is found. + return `/${adjustedCmd.slice(1).toLowerCase().replace(CMD_REGEX, "-")}`; // replace any invalid chars with '-' + }, + + get: async function (clause = {}) { + try { + const preset = await prisma.slash_command_presets.findFirst({ + where: clause, + }); + return preset || null; + } catch (error) { + console.error(error.message); + return null; + } + }, + + where: async function (clause = {}, limit) { + try { + const presets = await prisma.slash_command_presets.findMany({ + where: clause, + take: limit || undefined, + }); + return presets; + } catch (error) { + console.error(error.message); + return []; + } + }, + + // Command + userId must be unique combination. + create: async function (userId = null, presetData = {}) { + try { + const preset = await prisma.slash_command_presets.create({ + data: { + ...presetData, + // This field (uid) is either the user_id or 0 (for non-multi-user mode). + // the UID field enforces the @@unique(userId, command) constraint since + // the real relational field (userId) cannot be non-null so this 'dummy' field gives us something + // to constrain against within the context of prisma and sqlite that works. + uid: userId ? Number(userId) : 0, + userId: userId ? Number(userId) : null, + }, + }); + return preset; + } catch (error) { + console.error("Failed to create preset", error.message); + return null; + } + }, + + getUserPresets: async function (userId = null) { + try { + return ( + await prisma.slash_command_presets.findMany({ + where: { userId: !!userId ? Number(userId) : null }, + orderBy: { createdAt: "asc" }, + }) + )?.map((preset) => ({ + id: preset.id, + command: preset.command, + prompt: preset.prompt, + description: preset.description, + })); + } catch (error) { + console.error("Failed to get user presets", error.message); + return []; + } + }, + + update: async function (presetId = null, presetData = {}) { + try { + const preset = await prisma.slash_command_presets.update({ + where: { id: Number(presetId) }, + data: presetData, + }); + return preset; + } catch (error) { + console.error("Failed to update preset", error.message); + return null; + } + }, + + delete: async function (presetId = null) { + try { + await prisma.slash_command_presets.delete({ + where: { id: Number(presetId) }, + }); + return true; + } catch (error) { + console.error("Failed to delete preset", error.message); + return false; + } + }, +}; + +module.exports.SlashCommandPresets = SlashCommandPresets; diff --git a/server/prisma/migrations/20240510032311_init/migration.sql b/server/prisma/migrations/20240510032311_init/migration.sql new file mode 100644 index 00000000000..3b82efb885a --- /dev/null +++ b/server/prisma/migrations/20240510032311_init/migration.sql @@ -0,0 +1,15 @@ +-- CreateTable +CREATE TABLE "slash_command_presets" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "command" TEXT NOT NULL, + "prompt" TEXT NOT NULL, + "description" TEXT NOT NULL, + "uid" INTEGER NOT NULL DEFAULT 0, + "userId" INTEGER, + "createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + "lastUpdatedAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT "slash_command_presets_userId_fkey" FOREIGN KEY ("userId") REFERENCES "users" ("id") ON DELETE CASCADE ON UPDATE CASCADE +); + +-- CreateIndex +CREATE UNIQUE INDEX "slash_command_presets_uid_command_key" ON "slash_command_presets"("uid", "command"); diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index b830de9b778..0ded65be634 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -73,6 +73,7 @@ model users { recovery_codes recovery_codes[] password_reset_tokens password_reset_tokens[] workspace_agent_invocations workspace_agent_invocations[] + slash_command_presets slash_command_presets[] } model recovery_codes { @@ -260,3 +261,17 @@ model event_logs { @@index([event]) } + +model slash_command_presets { + id Int @id @default(autoincrement()) + command String + prompt String + description String + uid Int @default(0) // 0 is null user + userId Int? + createdAt DateTime @default(now()) + lastUpdatedAt DateTime @default(now()) + user users? @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([uid, command]) +} diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js index 76f98e0df86..55e8fbe5fd6 100644 --- a/server/utils/chats/index.js +++ b/server/utils/chats/index.js @@ -4,14 +4,28 @@ const { resetMemory } = require("./commands/reset"); const { getVectorDbClass, getLLMProvider } = require("../helpers"); const { convertToPromptHistory } = require("../helpers/chat/responses"); const { DocumentManager } = require("../DocumentManager"); +const { SlashCommandPresets } = require("../../models/slashCommandsPresets"); const VALID_COMMANDS = { "/reset": resetMemory, }; -function grepCommand(message) { +async function grepCommand(message, user = null) { + const userPresets = await SlashCommandPresets.getUserPresets(user?.id); const availableCommands = Object.keys(VALID_COMMANDS); + // Check if the message starts with any preset command + const foundPreset = userPresets.find((p) => message.startsWith(p.command)); + if (!!foundPreset) { + // Replace the preset command with the corresponding prompt + const updatedMessage = message.replace( + foundPreset.command, + foundPreset.prompt + ); + return updatedMessage; + } + + // Check if the message starts with any built-in command for (let i = 0; i < availableCommands.length; i++) { const cmd = availableCommands[i]; const re = new RegExp(`^(${cmd})`, "i"); @@ -20,7 +34,7 @@ function grepCommand(message) { } } - return null; + return message; } async function chatWithWorkspace( @@ -31,10 +45,10 @@ async function chatWithWorkspace( thread = null ) { const uuid = uuidv4(); - const command = grepCommand(message); + const updatedMessage = await grepCommand(message, user); - if (!!command && Object.keys(VALID_COMMANDS).includes(command)) { - return await VALID_COMMANDS[command](workspace, message, uuid, user); + if (Object.keys(VALID_COMMANDS).includes(updatedMessage)) { + return await VALID_COMMANDS[updatedMessage](workspace, message, uuid, user); } const LLMConnector = getLLMProvider({ @@ -164,7 +178,7 @@ async function chatWithWorkspace( const messages = await LLMConnector.compressMessages( { systemPrompt: chatPrompt(workspace), - userPrompt: message, + userPrompt: updatedMessage, contextTexts, chatHistory, }, diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index ba4dea163f7..ec8fdbfac14 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -23,10 +23,10 @@ async function streamChatWithWorkspace( thread = null ) { const uuid = uuidv4(); - const command = grepCommand(message); + const updatedMessage = await grepCommand(message, user); - if (!!command && Object.keys(VALID_COMMANDS).includes(command)) { - const data = await VALID_COMMANDS[command]( + if (Object.keys(VALID_COMMANDS).includes(updatedMessage)) { + const data = await VALID_COMMANDS[updatedMessage]( workspace, message, uuid, @@ -185,7 +185,7 @@ async function streamChatWithWorkspace( const messages = await LLMConnector.compressMessages( { systemPrompt: chatPrompt(workspace), - userPrompt: message, + userPrompt: updatedMessage, contextTexts, chatHistory, },