diff --git a/server/__tests__/utils/chats/openaiCompatible.test.js b/server/__tests__/utils/chats/openaiCompatible.test.js new file mode 100644 index 00000000000..d17fcb0496b --- /dev/null +++ b/server/__tests__/utils/chats/openaiCompatible.test.js @@ -0,0 +1,249 @@ +/* eslint-env jest, node */ +const { OpenAICompatibleChat } = require('../../../utils/chats/openaiCompatible'); +const { WorkspaceChats } = require('../../../models/workspaceChats'); +const { getVectorDbClass, getLLMProvider } = require('../../../utils/helpers'); +const { extractTextContent, extractAttachments } = require('../../../endpoints/api/openai/helpers'); + +// Mock dependencies +jest.mock('../../../models/workspaceChats'); +jest.mock('../../../utils/helpers'); +jest.mock('../../../utils/DocumentManager', () => ({ + DocumentManager: class { + constructor() { + this.pinnedDocs = jest.fn().mockResolvedValue([]); + } + } +})); + +describe('OpenAICompatibleChat', () => { + let mockWorkspace; + let mockVectorDb; + let mockLLMConnector; + let mockResponse; + + beforeEach(() => { + // Reset all mocks + jest.clearAllMocks(); + + // Setup mock workspace + mockWorkspace = { + id: 1, + slug: 'test-workspace', + chatMode: 'chat', + chatProvider: 'openai', + chatModel: 'gpt-4', + }; + + // Setup mock VectorDb + mockVectorDb = { + hasNamespace: jest.fn().mockResolvedValue(true), + namespaceCount: jest.fn().mockResolvedValue(1), + performSimilaritySearch: jest.fn().mockResolvedValue({ + contextTexts: [], + sources: [], + message: null, + }), + }; + getVectorDbClass.mockReturnValue(mockVectorDb); + + // Setup mock LLM connector + mockLLMConnector = { + promptWindowLimit: jest.fn().mockReturnValue(4000), + compressMessages: jest.fn().mockResolvedValue([]), + getChatCompletion: jest.fn().mockResolvedValue({ + textResponse: 'Mock response', + metrics: {}, + }), + streamingEnabled: jest.fn().mockReturnValue(true), + streamGetChatCompletion: jest.fn().mockResolvedValue({ + metrics: {}, + }), + handleStream: jest.fn().mockResolvedValue('Mock streamed response'), + defaultTemp: 0.7, + }; + getLLMProvider.mockReturnValue(mockLLMConnector); + + // Setup WorkspaceChats mock + WorkspaceChats.new.mockResolvedValue({ chat: { id: 'mock-chat-id' } }); + + // Setup mock response object for streaming + mockResponse = { + write: jest.fn(), + }; + }); + + describe('chatSync', () => { + test('should handle OpenAI vision multimodal messages', async () => { + const multiModalPrompt = [ + { + type: 'text', + text: 'What do you see in this image?' + }, + { + type: 'image_url', + image_url: { + url: 'data:image/png;base64,abc123', + detail: 'low' + } + } + ]; + + const prompt = extractTextContent(multiModalPrompt); + const attachments = extractAttachments(multiModalPrompt); + const result = await OpenAICompatibleChat.chatSync({ + workspace: mockWorkspace, + prompt, + attachments, + systemPrompt: 'You are a helpful assistant', + history: [ + { role: 'user', content: 'Previous message' }, + { role: 'assistant', content: 'Previous response' } + ], + temperature: 0.7 + }); + + // Verify chat was saved with correct format + expect(WorkspaceChats.new).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceId: mockWorkspace.id, + prompt: multiModalPrompt[0].text, + response: expect.objectContaining({ + text: 'Mock response', + attachments: [{ + name: 'uploaded_image_0', + mime: 'image/png', + contentString: multiModalPrompt[1].image_url.url + }] + }) + }) + ); + + // Verify response format + expect(result).toEqual( + expect.objectContaining({ + object: 'chat.completion', + choices: expect.arrayContaining([ + expect.objectContaining({ + message: expect.objectContaining({ + role: 'assistant', + content: 'Mock response', + }), + }), + ]), + }) + ); + }); + + test('should handle regular text messages in OpenAI format', async () => { + const promptString = 'Hello world'; + const result = await OpenAICompatibleChat.chatSync({ + workspace: mockWorkspace, + prompt: promptString, + systemPrompt: 'You are a helpful assistant', + history: [ + { role: 'user', content: 'Previous message' }, + { role: 'assistant', content: 'Previous response' } + ], + temperature: 0.7 + }); + + // Verify chat was saved without attachments + expect(WorkspaceChats.new).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceId: mockWorkspace.id, + prompt: promptString, + response: expect.objectContaining({ + text: 'Mock response', + attachments: [] + }) + }) + ); + + expect(result).toBeTruthy(); + }); + }); + + describe('streamChat', () => { + test('should handle OpenAI vision multimodal messages in streaming mode', async () => { + const multiModalPrompt = [ + { + type: 'text', + text: 'What do you see in this image?' + }, + { + type: 'image_url', + image_url: { + url: 'data:image/png;base64,abc123', + detail: 'low' + } + } + ]; + + const prompt = extractTextContent(multiModalPrompt); + const attachments = extractAttachments(multiModalPrompt); + await OpenAICompatibleChat.streamChat({ + workspace: mockWorkspace, + response: mockResponse, + prompt, + attachments, + systemPrompt: 'You are a helpful assistant', + history: [ + { role: 'user', content: 'Previous message' }, + { role: 'assistant', content: 'Previous response' } + ], + temperature: 0.7 + }); + + // Verify streaming was handled + expect(mockLLMConnector.streamGetChatCompletion).toHaveBeenCalled(); + expect(mockLLMConnector.handleStream).toHaveBeenCalled(); + + // Verify chat was saved with attachments + expect(WorkspaceChats.new).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceId: mockWorkspace.id, + prompt: multiModalPrompt[0].text, + response: expect.objectContaining({ + text: 'Mock streamed response', + attachments: [{ + name: 'uploaded_image_0', + mime: 'image/png', + contentString: multiModalPrompt[1].image_url.url + }] + }) + }) + ); + }); + + test('should handle regular text messages in streaming mode', async () => { + const promptString = 'Hello world'; + await OpenAICompatibleChat.streamChat({ + workspace: mockWorkspace, + response: mockResponse, + prompt: promptString, + systemPrompt: 'You are a helpful assistant', + history: [ + { role: 'user', content: 'Previous message' }, + { role: 'assistant', content: 'Previous response' } + ], + temperature: 0.7 + }); + + // Verify streaming was handled + expect(mockLLMConnector.streamGetChatCompletion).toHaveBeenCalled(); + expect(mockLLMConnector.handleStream).toHaveBeenCalled(); + + // Verify chat was saved without attachments + expect(WorkspaceChats.new).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceId: mockWorkspace.id, + prompt: promptString, + response: expect.objectContaining({ + text: 'Mock streamed response', + attachments: [] + }) + }) + ); + }); + }); +}); \ No newline at end of file diff --git a/server/__tests__/utils/chats/openaiHelpers.test.js b/server/__tests__/utils/chats/openaiHelpers.test.js new file mode 100644 index 00000000000..5eba2e11264 --- /dev/null +++ b/server/__tests__/utils/chats/openaiHelpers.test.js @@ -0,0 +1,128 @@ +/* eslint-env jest, node */ +const { extractTextContent, extractAttachments } = require('../../../endpoints/api/openai/helpers'); + +describe('OpenAI Helper Functions', () => { + describe('extractTextContent', () => { + test('should return string content as-is when not an array', () => { + const content = 'Hello world'; + expect(extractTextContent(content)).toBe('Hello world'); + }); + + test('should extract text from multi-modal content array', () => { + const content = [ + { + type: 'text', + text: 'What do you see in this image?' + }, + { + type: 'image_url', + image_url: { + url: 'data:image/png;base64,abc123', + detail: 'low' + } + }, + { + type: 'text', + text: 'And what about this part?' + } + ]; + expect(extractTextContent(content)).toBe('What do you see in this image?\nAnd what about this part?'); + }); + + test('should handle empty array', () => { + expect(extractTextContent([])).toBe(''); + }); + + test('should handle array with no text content', () => { + const content = [ + { + type: 'image_url', + image_url: { + url: 'data:image/png;base64,abc123', + detail: 'low' + } + } + ]; + expect(extractTextContent(content)).toBe(''); + }); + }); + + describe('extractAttachments', () => { + test('should return empty array for string content', () => { + const content = 'Hello world'; + expect(extractAttachments(content)).toEqual([]); + }); + + test('should extract image attachments with correct mime types', () => { + const content = [ + { + type: 'image_url', + image_url: { + url: 'data:image/png;base64,abc123', + detail: 'low' + } + }, + { + type: 'text', + text: 'Between images' + }, + { + type: 'image_url', + image_url: { + url: 'data:image/jpeg;base64,def456', + detail: 'high' + } + } + ]; + expect(extractAttachments(content)).toEqual([ + { + name: 'uploaded_image_0', + mime: 'image/png', + contentString: 'data:image/png;base64,abc123' + }, + { + name: 'uploaded_image_1', + mime: 'image/jpeg', + contentString: 'data:image/jpeg;base64,def456' + } + ]); + }); + + test('should handle invalid data URLs with PNG fallback', () => { + const content = [ + { + type: 'image_url', + image_url: { + url: 'invalid-data-url', + detail: 'low' + } + } + ]; + expect(extractAttachments(content)).toEqual([ + { + name: 'uploaded_image_0', + mime: 'image/png', + contentString: 'invalid-data-url' + } + ]); + }); + + test('should handle empty array', () => { + expect(extractAttachments([])).toEqual([]); + }); + + test('should handle array with no image content', () => { + const content = [ + { + type: 'text', + text: 'Just some text' + }, + { + type: 'text', + text: 'More text' + } + ]; + expect(extractAttachments(content)).toEqual([]); + }); + }); +}); \ No newline at end of file diff --git a/server/endpoints/api/openai/helpers.js b/server/endpoints/api/openai/helpers.js new file mode 100644 index 00000000000..6d54f770494 --- /dev/null +++ b/server/endpoints/api/openai/helpers.js @@ -0,0 +1,50 @@ +/** + * Extracts text content from a multimodal message + * If the content has multiple text items, it will join them together with a newline. + * @param {string|Array} content - Message content that could be string or array of content objects + * @returns {string} - The text content + */ +function extractTextContent(content) { + if (!Array.isArray(content)) return content; + return content + .filter((item) => item.type === "text") + .map((item) => item.text) + .join("\n"); +} + +/** + * Detects mime type from a base64 data URL string, defaults to PNG if not detected + * @param {string} dataUrl - The data URL string (e.g. data:image/jpeg;base64,...) + * @returns {string} - The mime type or 'image/png' if not detected + */ +function getMimeTypeFromDataUrl(dataUrl) { + try { + const matches = dataUrl.match(/^data:([^;]+);base64,/); + return matches ? matches[1].toLowerCase() : "image/png"; + } catch (e) { + return "image/png"; + } +} + +/** + * Extracts attachments from a multimodal message + * The attachments provided are in OpenAI format since this util is used in the OpenAI compatible chat. + * However, our backend internal chat uses the Attachment type we use elsewhere in the app so we have to convert it. + * @param {Array} content - Message content that could be string or array of content objects + * @returns {import("../../../utils/helpers").Attachment[]} - The attachments + */ +function extractAttachments(content) { + if (!Array.isArray(content)) return []; + return content + .filter((item) => item.type === "image_url") + .map((item, index) => ({ + name: `uploaded_image_${index}`, + mime: getMimeTypeFromDataUrl(item.image_url.url), + contentString: item.image_url.url, + })); +} + +module.exports = { + extractTextContent, + extractAttachments, +}; diff --git a/server/endpoints/api/openai/index.js b/server/endpoints/api/openai/index.js index e3c70171cf1..7962eaefb65 100644 --- a/server/endpoints/api/openai/index.js +++ b/server/endpoints/api/openai/index.js @@ -13,6 +13,7 @@ const { OpenAICompatibleChat, } = require("../../../utils/chats/openaiCompatible"); const { getModelTag } = require("../../utils"); +const { extractTextContent, extractAttachments } = require("./helpers"); function apiOpenAICompatibleEndpoints(app) { if (!app) return; @@ -145,7 +146,8 @@ function apiOpenAICompatibleEndpoints(app) { workspace, systemPrompt, history, - prompt: userMessage.content, + prompt: extractTextContent(userMessage.content), + attachments: extractAttachments(userMessage.content), temperature: Number(temperature), }); @@ -173,7 +175,8 @@ function apiOpenAICompatibleEndpoints(app) { workspace, systemPrompt, history, - prompt: userMessage.content, + prompt: extractTextContent(userMessage.content), + attachments: extractAttachments(userMessage.content), temperature: Number(temperature), response, }); diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js index ee33a12bac0..c371a1d47d1 100644 --- a/server/utils/AiProviders/openAi/index.js +++ b/server/utils/AiProviders/openAi/index.js @@ -102,7 +102,7 @@ class OpenAiLLM { type: "image_url", image_url: { url: attachment.contentString, - detail: "high", + detail: "auto", }, }); } diff --git a/server/utils/chats/openaiCompatible.js b/server/utils/chats/openaiCompatible.js index 97825b2eb1a..0767dd44cf9 100644 --- a/server/utils/chats/openaiCompatible.js +++ b/server/utils/chats/openaiCompatible.js @@ -12,6 +12,7 @@ async function chatSync({ systemPrompt = null, history = [], prompt = null, + attachments = [], temperature = null, }) { const uuid = uuidv4(); @@ -38,6 +39,7 @@ async function chatSync({ text: textResponse, sources: [], type: chatMode, + attachments, }, include: false, }); @@ -84,7 +86,7 @@ async function chatSync({ embeddingsCount !== 0 ? await VectorDb.performSimilaritySearch({ namespace: workspace.slug, - input: prompt, + input: String(prompt), LLMConnector, similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, @@ -125,11 +127,12 @@ async function chatSync({ await WorkspaceChats.new({ workspaceId: workspace.id, - prompt: prompt, + prompt: String(prompt), response: { text: textResponse, sources: [], type: chatMode, + attachments, }, include: false, }); @@ -151,9 +154,10 @@ async function chatSync({ // and build system messages based on inputs and history. const messages = await LLMConnector.compressMessages({ systemPrompt: systemPrompt ?? (await chatPrompt(workspace)), - userPrompt: prompt, + userPrompt: String(prompt), contextTexts, chatHistory: history, + attachments, }); // Send the text completion. @@ -181,8 +185,14 @@ async function chatSync({ const { chat } = await WorkspaceChats.new({ workspaceId: workspace.id, - prompt: prompt, - response: { text: textResponse, sources, type: chatMode, metrics }, + prompt: String(prompt), + response: { + text: textResponse, + sources, + type: chatMode, + metrics, + attachments, + }, }); return formatJSON( @@ -205,6 +215,7 @@ async function streamChat({ systemPrompt = null, history = [], prompt = null, + attachments = [], temperature = null, }) { const uuid = uuidv4(); @@ -250,6 +261,7 @@ async function streamChat({ text: textResponse, sources: [], type: chatMode, + attachments, }, include: false, }); @@ -300,7 +312,7 @@ async function streamChat({ embeddingsCount !== 0 ? await VectorDb.performSimilaritySearch({ namespace: workspace.slug, - input: prompt, + input: String(prompt), LLMConnector, similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, @@ -345,11 +357,12 @@ async function streamChat({ await WorkspaceChats.new({ workspaceId: workspace.id, - prompt: prompt, + prompt: String(prompt), response: { text: textResponse, sources: [], type: chatMode, + attachments, }, include: false, }); @@ -375,9 +388,10 @@ async function streamChat({ // and build system messages based on inputs and history. const messages = await LLMConnector.compressMessages({ systemPrompt: systemPrompt ?? (await chatPrompt(workspace)), - userPrompt: prompt, + userPrompt: String(prompt), contextTexts, chatHistory: history, + attachments, }); if (!LLMConnector.streamingEnabled()) { @@ -418,12 +432,13 @@ async function streamChat({ if (completeText?.length > 0) { const { chat } = await WorkspaceChats.new({ workspaceId: workspace.id, - prompt: prompt, + prompt: String(prompt), response: { text: completeText, sources, type: chatMode, metrics: stream.metrics, + attachments, }, });