diff --git a/docker/.env.example b/docker/.env.example index 5fd93ab9e90..117a46133b2 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -105,6 +105,9 @@ GID='1000' # AWS_BEDROCK_LLM_REGION=us-west-2 # AWS_BEDROCK_LLM_MODEL_PREFERENCE=meta.llama3-1-8b-instruct-v1:0 # AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT=8191 +# AWS_BEDROCK_LLM_CONNECTION_METHOD=iam +# AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS=4096 +# AWS_BEDROCK_LLM_SESSION_TOKEN= # Only required if CONNECTION_METHOD is 'sessionToken' # LLM_PROVIDER='fireworksai' # FIREWORKS_AI_LLM_API_KEY='my-fireworks-ai-key' diff --git a/frontend/src/components/LLMSelection/AwsBedrockLLMOptions/index.jsx b/frontend/src/components/LLMSelection/AwsBedrockLLMOptions/index.jsx index 779d487b929..ec4a3c6f747 100644 --- a/frontend/src/components/LLMSelection/AwsBedrockLLMOptions/index.jsx +++ b/frontend/src/components/LLMSelection/AwsBedrockLLMOptions/index.jsx @@ -175,7 +175,7 @@ export default function AwsBedrockLLMOptions({ settings }) { type="number" name="AwsBedrockLLMTokenLimit" className="border-none bg-theme-settings-input-bg text-white placeholder:text-theme-settings-input-placeholder text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5" - placeholder="Content window limit (eg: 4096)" + placeholder="Content window limit (eg: 8192)" min={1} onScroll={(e) => e.target.blur()} defaultValue={settings?.AwsBedrockLLMTokenLimit} @@ -183,6 +183,22 @@ export default function AwsBedrockLLMOptions({ settings }) { autoComplete="off" /> +
+ + e.target.blur()} + defaultValue={settings?.AwsBedrockLLMMaxOutputTokens} + required={true} + autoComplete="off" + /> +
)} diff --git a/server/.env.example b/server/.env.example index 0c219e9324d..51f8c05d4ef 100644 --- a/server/.env.example +++ b/server/.env.example @@ -105,6 +105,16 @@ SIG_SALT='salt' # Please generate random string at least 32 chars long. # COHERE_API_KEY= # COHERE_MODEL_PREF='command-r' +# LLM_PROVIDER='bedrock' +# AWS_BEDROCK_LLM_ACCESS_KEY_ID= +# AWS_BEDROCK_LLM_ACCESS_KEY= +# AWS_BEDROCK_LLM_REGION=us-west-2 +# AWS_BEDROCK_LLM_MODEL_PREFERENCE=meta.llama3-1-8b-instruct-v1:0 +# AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT=8191 +# AWS_BEDROCK_LLM_CONNECTION_METHOD=iam +# AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS=4096 +# AWS_BEDROCK_LLM_SESSION_TOKEN= # Only required if CONNECTION_METHOD is 'sessionToken' + # LLM_PROVIDER='apipie' # APIPIE_LLM_API_KEY='sk-123abc' # APIPIE_LLM_MODEL_PREF='openrouter/llama-3.1-8b-instruct' diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js index 0ac29f19564..566632aaf83 100644 --- a/server/models/systemSettings.js +++ b/server/models/systemSettings.js @@ -540,7 +540,10 @@ const SystemSettings = { AwsBedrockLLMSessionToken: !!process.env.AWS_BEDROCK_LLM_SESSION_TOKEN, AwsBedrockLLMRegion: process.env.AWS_BEDROCK_LLM_REGION, AwsBedrockLLMModel: process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE, - AwsBedrockLLMTokenLimit: process.env.AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT, + AwsBedrockLLMTokenLimit: + process.env.AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT || 8192, + AwsBedrockLLMMaxOutputTokens: + process.env.AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS || 4096, // Cohere API Keys CohereApiKey: !!process.env.COHERE_API_KEY, diff --git a/server/utils/AiProviders/bedrock/index.js b/server/utils/AiProviders/bedrock/index.js index c5a9b8dbb68..0fd60c7f72a 100644 --- a/server/utils/AiProviders/bedrock/index.js +++ b/server/utils/AiProviders/bedrock/index.js @@ -12,14 +12,18 @@ const { LLMPerformanceMonitor, } = require("../../helpers/chat/LLMPerformanceMonitor"); const { v4: uuidv4 } = require("uuid"); +const { + DEFAULT_MAX_OUTPUT_TOKENS, + DEFAULT_CONTEXT_WINDOW_TOKENS, + SUPPORTED_CONNECTION_METHODS, + getImageFormatFromMime, + base64ToUint8Array, +} = require("./utils"); class AWSBedrockLLM { /** - * These models do not support system prompts - * It is not explicitly stated but it is observed that they do not use the system prompt - * in their responses and will crash when a system prompt is provided. - * We can add more models to this list as we discover them or new models are added. - * We may want to extend this list or make a user-config if using custom bedrock models. + * List of Bedrock models observed to not support system prompts when using the Converse API. + * @type {string[]} */ noSystemPromptModels = [ "amazon.titan-text-express-v1", @@ -27,32 +31,46 @@ class AWSBedrockLLM { "cohere.command-text-v14", "cohere.command-light-text-v14", "us.deepseek.r1-v1:0", + // Add other models here if identified ]; + /** + * Initializes the AWS Bedrock LLM connector. + * @param {object | null} [embedder=null] - An optional embedder instance. Defaults to NativeEmbedder. + * @param {string | null} [modelPreference=null] - Optional model ID override. Defaults to environment variable. + * @throws {Error} If required environment variables are missing or invalid. + */ constructor(embedder = null, modelPreference = null) { - if (!process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID) - throw new Error("No AWS Bedrock LLM profile id was set."); - - if (!process.env.AWS_BEDROCK_LLM_ACCESS_KEY) - throw new Error("No AWS Bedrock LLM access key was set."); + const requiredEnvVars = [ + "AWS_BEDROCK_LLM_ACCESS_KEY_ID", + "AWS_BEDROCK_LLM_ACCESS_KEY", + "AWS_BEDROCK_LLM_REGION", + "AWS_BEDROCK_LLM_MODEL_PREFERENCE", + ]; - if (!process.env.AWS_BEDROCK_LLM_REGION) - throw new Error("No AWS Bedrock LLM region was set."); + // Validate required environment variables + for (const envVar of requiredEnvVars) { + if (!process.env[envVar]) + throw new Error(`Required environment variable ${envVar} is not set.`); + } if ( process.env.AWS_BEDROCK_LLM_CONNECTION_METHOD === "sessionToken" && !process.env.AWS_BEDROCK_LLM_SESSION_TOKEN - ) + ) { throw new Error( - "No AWS Bedrock LLM session token was set while using session token as the authentication method." + "AWS_BEDROCK_LLM_SESSION_TOKEN is not set for sessionToken authentication method." ); + } this.model = modelPreference || process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE; + + const contextWindowLimit = this.promptWindowLimit(); this.limits = { - history: this.promptWindowLimit() * 0.15, - system: this.promptWindowLimit() * 0.15, - user: this.promptWindowLimit() * 0.7, + history: Math.floor(contextWindowLimit * 0.15), + system: Math.floor(contextWindowLimit * 0.15), + user: Math.floor(contextWindowLimit * 0.7), }; this.bedrockClient = new BedrockRuntimeClient({ @@ -69,156 +87,304 @@ class AWSBedrockLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.7; this.#log( - `Loaded with model: ${this.model}. Will communicate with AWS Bedrock using ${this.authMethod} authentication.` + `Initialized with model: ${this.model}. Auth: ${this.authMethod}. Context Window: ${contextWindowLimit}.` ); } /** - * Get the authentication method for the AWS Bedrock LLM. - * There are only two valid values for this setting - anything else will default to "iam". - * @returns {"iam"|"sessionToken"} + * Gets the configured AWS authentication method ('iam' or 'sessionToken'). + * Defaults to 'iam' if the environment variable is invalid. + * @returns {"iam" | "sessionToken"} The authentication method. */ get authMethod() { const method = process.env.AWS_BEDROCK_LLM_CONNECTION_METHOD || "iam"; - if (!["iam", "sessionToken"].includes(method)) return "iam"; - return method; + return SUPPORTED_CONNECTION_METHODS.includes(method) ? method : "iam"; } + /** + * Appends context texts to a string with standard formatting. + * @param {string[]} contextTexts - An array of context text snippets. + * @returns {string} Formatted context string or empty string if no context provided. + * @private + */ #appendContext(contextTexts = []) { - if (!contextTexts || !contextTexts.length) return ""; + if (!contextTexts?.length) return ""; return ( "\nContext:\n" + contextTexts - .map((text, i) => { - return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; - }) + .map((text, i) => `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`) .join("") ); } + /** + * Internal logging helper with provider prefix. + * @param {string} text - The log message. + * @param {...any} args - Additional arguments to log. + * @private + */ #log(text, ...args) { console.log(`\x1b[32m[AWSBedrock]\x1b[0m ${text}`, ...args); } + /** + * Internal logging helper with provider prefix for static methods. + * @private + */ + static #slog(text, ...args) { + console.log(`\x1b[32m[AWSBedrock]\x1b[0m ${text}`, ...args); + } + + /** + * Indicates if the provider supports streaming responses. + * @returns {boolean} True. + */ streamingEnabled() { return "streamGetChatCompletion" in this; } - static promptWindowLimit(_modelName) { - const limit = process.env.AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT || 8191; - if (!limit || isNaN(Number(limit))) - throw new Error("No valid token context limit was set."); - return Number(limit); + /** + * @static + * Gets the total prompt window limit (total context window: input + output) from the environment variable. + * This value is used for calculating input limits, NOT for setting the max output tokens in API calls. + * @returns {number} The total context window token limit. Defaults to 8191. + */ + static promptWindowLimit() { + const limit = + process.env.AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT ?? + DEFAULT_CONTEXT_WINDOW_TOKENS; + const numericLimit = Number(limit); + if (isNaN(numericLimit) || numericLimit <= 0) { + this.#slog( + `[AWSBedrock ERROR] Invalid AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT found: "${limitSourceValue}". Must be a positive number - returning default ${DEFAULT_CONTEXT_WINDOW_TOKENS}.` + ); + return DEFAULT_CONTEXT_WINDOW_TOKENS; + } + return numericLimit; } - // Ensure the user set a value for the token limit - // and if undefined - assume 4096 window. + /** + * Gets the total prompt window limit (total context window) for the current model instance. + * @returns {number} The token limit. + */ promptWindowLimit() { - const limit = process.env.AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT || 8191; - if (!limit || isNaN(Number(limit))) - throw new Error("No valid token context limit was set."); - return Number(limit); + return AWSBedrockLLM.promptWindowLimit(); } - async isValidChatCompletionModel(_ = "") { + /** + * Gets the maximum number of tokens the model should generate in its response. + * Reads from the AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS environment variable or uses a default. + * This is distinct from the total context window limit. + * @returns {number} The maximum output tokens limit for API calls. + */ + getMaxOutputTokens() { + const outputLimitSource = process.env.AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS; + if (isNaN(Number(outputLimitSource))) { + this.#log( + `[AWSBedrock ERROR] Invalid AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS found: "${outputLimitSource}". Must be a positive number - returning default ${DEFAULT_MAX_OUTPUT_TOKENS}.` + ); + return DEFAULT_MAX_OUTPUT_TOKENS; + } + + const numericOutputLimit = Number(outputLimitSource); + if (numericOutputLimit <= 0) { + this.#log( + `[AWSBedrock ERROR] Invalid AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS found: "${outputLimitSource}". Must be a greater than 0 - returning default ${DEFAULT_MAX_OUTPUT_TOKENS}.` + ); + return DEFAULT_MAX_OUTPUT_TOKENS; + } + + return numericOutputLimit; + } + + /** Stubbed method for compatibility with LLM interface. */ + async isValidChatCompletionModel(_modelName = "") { return true; } /** - * Generates appropriate content array for a message + attachments. - * TODO: Implement this - attachments are not supported yet for Bedrock. - * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}} - * @returns {string|object[]} + * Validates attachments array and returns a new array with valid attachments. + * @param {Array<{contentString: string, mime: string}>} attachments - Array of attachments. + * @returns {Array<{image: {format: string, source: {bytes: Uint8Array}}>} Array of valid attachments. + * @private + */ + #validateAttachments(attachments = []) { + if (!Array.isArray(attachments) || !attachments?.length) return []; + const validAttachments = []; + for (const attachment of attachments) { + if ( + !attachment || + typeof attachment.mime !== "string" || + typeof attachment.contentString !== "string" + ) { + this.#log("Skipping invalid attachment object.", attachment); + continue; + } + + // Strip data URI prefix (e.g., "data:image/png;base64,") + const base64Data = attachment.contentString.replace( + /^data:image\/\w+;base64,/, + "" + ); + + const format = getImageFormatFromMime(attachment.mime); + const attachmentInfo = { + valid: format !== null, + format, + imageBytes: base64ToUint8Array(base64Data), + }; + + if (!attachmentInfo.valid) { + this.#log( + `Skipping attachment with unsupported/invalid MIME type: ${attachment.mime}` + ); + continue; + } + + validAttachments.push({ + image: { + format: format, + source: { bytes: attachmentInfo.imageBytes }, + }, + }); + } + + return validAttachments; + } + + /** + * Generates the Bedrock Converse API content array for a message, + * processing text and formatting valid image attachments. + * @param {object} params + * @param {string} params.userPrompt - The text part of the message. + * @param {Array<{contentString: string, mime: string}>} params.attachments - Array of attachments for the message. + * @returns {Array} Array of content blocks (e.g., [{text: "..."}, {image: {...}}]). + * @private */ - #generateContent({ userPrompt, attachments = [] }) { - if (!attachments.length) return [{ text: userPrompt }]; - - // const content = [{ type: "text", text: userPrompt }]; - // for (let attachment of attachments) { - // content.push({ - // type: "image_url", - // image_url: attachment.contentString, - // }); - // } - // return { content: content.flat() }; + #generateContent({ userPrompt = "", attachments = [] }) { + const content = []; + // Add text block if prompt is not empty + if (!!userPrompt?.trim()?.length) content.push({ text: userPrompt }); + + // Validate attachments and add valid attachments to content + const validAttachments = this.#validateAttachments(attachments); + if (validAttachments?.length) content.push(...validAttachments); + + // Ensure content array is never empty (Bedrock requires at least one block) + if (content.length === 0) content.push({ text: "" }); + return content; } /** - * Construct the user prompt for this model. - * @param {{attachments: import("../../helpers").Attachment[]}} param0 - * @returns + * Constructs the complete message array in the format expected by the Bedrock Converse API. + * @param {object} params + * @param {string} params.systemPrompt - The system prompt text. + * @param {string[]} params.contextTexts - Array of context text snippets. + * @param {Array<{role: 'user' | 'assistant', content: string, attachments?: Array<{contentString: string, mime: string}>}>} params.chatHistory - Previous messages. + * @param {string} params.userPrompt - The latest user prompt text. + * @param {Array<{contentString: string, mime: string}>} params.attachments - Attachments for the latest user prompt. + * @returns {Array} The formatted message array for the API call. */ constructPrompt({ systemPrompt = "", contextTexts = [], chatHistory = [], userPrompt = "", - _attachments = [], + attachments = [], }) { - let prompt = [ - { - role: "system", - content: [ - { text: `${systemPrompt}${this.#appendContext(contextTexts)}` }, - ], - }, - ]; + const systemMessageContent = `${systemPrompt}${this.#appendContext(contextTexts)}`; + let messages = []; - // If the model does not support system prompts, we need to add a user message and assistant message + // Handle system prompt (either real or simulated) if (this.noSystemPromptModels.includes(this.model)) { - prompt = [ - { - role: "user", - content: [ - { text: `${systemPrompt}${this.#appendContext(contextTexts)}` }, - ], - }, - { - role: "assistant", - content: [{ text: "Okay." }], - }, - ]; + if (systemMessageContent.trim().length > 0) { + this.#log( + `Model ${this.model} doesn't support system prompts; simulating.` + ); + messages.push( + { + role: "user", + content: this.#generateContent({ + userPrompt: systemMessageContent, + }), + }, + { role: "assistant", content: [{ text: "Okay." }] } + ); + } + } else if (systemMessageContent.trim().length > 0) { + messages.push({ + role: "system", + content: this.#generateContent({ userPrompt: systemMessageContent }), + }); } - return [ - ...prompt, - ...chatHistory.map((msg) => ({ + // Add chat history + messages = messages.concat( + chatHistory.map((msg) => ({ role: msg.role, content: this.#generateContent({ userPrompt: msg.content, - attachments: msg.attachments, + attachments: Array.isArray(msg.attachments) ? msg.attachments : [], }), - })), - { - role: "user", - content: this.#generateContent({ - userPrompt: userPrompt, - attachments: [], - }), - }, - ]; + })) + ); + + // Add final user prompt + messages.push({ + role: "user", + content: this.#generateContent({ + userPrompt: userPrompt, + attachments: Array.isArray(attachments) ? attachments : [], + }), + }); + + return messages; } /** - * Parses and prepends reasoning from the response and returns the full text response. - * @param {Object} response - * @returns {string} + * Parses reasoning steps from the response and prepends them in tags. + * @param {object} message - The message object from the Bedrock response. + * @returns {string} The text response, potentially with reasoning prepended. + * @private */ #parseReasoningFromResponse({ content = [] }) { - let textResponse = content[0]?.text; + if (!content?.length) return ""; - if ( - !!content?.[1]?.reasoningContent && - content?.[1]?.reasoningContent?.reasoningText?.text?.trim().length > 0 - ) - textResponse = `${content?.[1]?.reasoningContent?.reasoningText?.text}${textResponse}`; + // Find the text block and grab the text + const textBlock = content.find((block) => block.text !== undefined); + let textResponse = textBlock?.text || ""; + // Find the reasoning block and grab the reasoning text + const reasoningBlock = content.find( + (block) => block.reasoningContent?.reasoningText?.text + ); + if (reasoningBlock) { + const reasoningText = + reasoningBlock.reasoningContent.reasoningText.text.trim(); + if (!!reasoningText?.length) + textResponse = `${reasoningText}${textResponse}`; + } return textResponse; } - async getChatCompletion(messages = null, { temperature = 0.7 }) { + /** + * Sends a request for chat completion (non-streaming). + * @param {Array | null} messages - Formatted message array from constructPrompt. + * @param {object} options - Request options. + * @param {number} options.temperature - Sampling temperature. + * @returns {Promise} Response object with textResponse and metrics, or null. + * @throws {Error} If the API call fails or validation errors occur. + */ + async getChatCompletion(messages = null, { temperature }) { + if (!messages?.length) + throw new Error( + "AWSBedrock::getChatCompletion requires a non-empty messages array." + ); + const hasSystem = messages[0]?.role === "system"; - const [system, ...history] = hasSystem ? messages : [null, ...messages]; + const systemBlock = hasSystem ? messages[0].content : undefined; + const history = hasSystem ? messages.slice(1) : messages; + const maxTokensToSend = this.getMaxOutputTokens(); const result = await LLMPerformanceMonitor.measureAsyncFunction( this.bedrockClient @@ -227,188 +393,318 @@ class AWSBedrockLLM { modelId: this.model, messages: history, inferenceConfig: { - maxTokens: this.promptWindowLimit(), - temperature, + maxTokens: maxTokensToSend, + temperature: temperature ?? this.defaultTemp, }, - system: !!system ? system.content : undefined, + system: systemBlock, }) ) .catch((e) => { - throw new Error( - `AWSBedrock::getChatCompletion failed to communicate with Bedrock client. ${e.message}` + this.#log( + `Bedrock Converse API Error (getChatCompletion): ${e.message}`, + e ); + if ( + e.name === "ValidationException" && + e.message.includes("maximum tokens") + ) { + throw new Error( + `AWSBedrock::getChatCompletion failed. Model ${this.model} rejected maxTokens value of ${maxTokensToSend}. Check model documentation for its maximum output token limit and set AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS if needed. Original error: ${e.message}` + ); + } + throw new Error(`AWSBedrock::getChatCompletion failed. ${e.message}`); }), messages, false ); const response = result.output; - if (!response || !response?.output) return null; + if (!response?.output?.message) { + this.#log( + "Bedrock response missing expected output.message structure.", + response + ); + return null; + } + + const latencyMs = response?.metrics?.latencyMs; + const outputTokens = response?.usage?.outputTokens; + const outputTps = + latencyMs > 0 && outputTokens ? outputTokens / (latencyMs / 1000) : 0; + return { - textResponse: this.#parseReasoningFromResponse(response.output?.message), + textResponse: this.#parseReasoningFromResponse(response.output.message), metrics: { - prompt_tokens: response?.usage?.inputTokens, - completion_tokens: response?.usage?.outputTokens, - total_tokens: response?.usage?.totalTokens, - outputTps: - response?.usage?.outputTokens / (response?.metrics?.latencyMs / 1000), + prompt_tokens: response?.usage?.inputTokens ?? 0, + completion_tokens: outputTokens ?? 0, + total_tokens: response?.usage?.totalTokens ?? 0, + outputTps: outputTps, duration: result.duration, }, }; } - async streamGetChatCompletion(messages = null, { temperature = 0.7 }) { + /** + * Sends a request for streaming chat completion. + * @param {Array | null} messages - Formatted message array from constructPrompt. + * @param {object} options - Request options. + * @param {number} [options.temperature] - Sampling temperature. + * @returns {Promise} The monitored stream object. + * @throws {Error} If the API call setup fails or validation errors occur. + */ + async streamGetChatCompletion(messages = null, { temperature }) { + if (!Array.isArray(messages) || messages.length === 0) { + throw new Error( + "AWSBedrock::streamGetChatCompletion requires a non-empty messages array." + ); + } + const hasSystem = messages[0]?.role === "system"; - const [system, ...history] = hasSystem ? messages : [null, ...messages]; + const systemBlock = hasSystem ? messages[0].content : undefined; + const history = hasSystem ? messages.slice(1) : messages; + const maxTokensToSend = this.getMaxOutputTokens(); - const measuredStreamRequest = await LLMPerformanceMonitor.measureStream( - this.bedrockClient.send( + try { + // Attempt to initiate the stream + const stream = await this.bedrockClient.send( new ConverseStreamCommand({ modelId: this.model, messages: history, - inferenceConfig: { maxTokens: this.promptWindowLimit(), temperature }, - system: !!system ? system.content : undefined, + inferenceConfig: { + maxTokens: maxTokensToSend, + temperature: temperature ?? this.defaultTemp, + }, + system: systemBlock, }) - ), - messages, - false - ); - return measuredStreamRequest; + ); + + // If successful, wrap the stream with performance monitoring + const measuredStreamRequest = await LLMPerformanceMonitor.measureStream( + stream, + messages, + false // Indicate it's not a function call measurement + ); + return measuredStreamRequest; + } catch (e) { + // Catch errors during the initial .send() call (e.g., validation errors) + this.#log( + `Bedrock Converse API Error (streamGetChatCompletion setup): ${e.message}`, + e + ); + if ( + e.name === "ValidationException" && + e.message.includes("maximum tokens") + ) { + throw new Error( + `AWSBedrock::streamGetChatCompletion failed during setup. Model ${this.model} rejected maxTokens value of ${maxTokensToSend}. Check model documentation for its maximum output token limit and set AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS if needed. Original error: ${e.message}` + ); + } + + throw new Error( + `AWSBedrock::streamGetChatCompletion failed during setup. ${e.message}` + ); + } } /** - * Handles the stream response from the AWS Bedrock API. - * Bedrock does not support usage metrics in the stream response so we need to estimate them. - * @param {Object} response - the response object - * @param {import('../../helpers/chat/LLMPerformanceMonitor').MonitoredStream} stream - the stream response from the AWS Bedrock API w/tracking - * @param {Object} responseProps - the response properties - * @returns {Promise} + * Handles the stream response from the AWS Bedrock API ConverseStreamCommand. + * Parses chunks, handles reasoning tags, and estimates token usage if not provided. + * @param {object} response - The HTTP response object to write chunks to. + * @param {import('../../helpers/chat/LLMPerformanceMonitor').MonitoredStream} stream - The monitored stream object from streamGetChatCompletion. + * @param {object} responseProps - Additional properties for the response chunks. + * @param {string} responseProps.uuid - Unique ID for the response. + * @param {Array} responseProps.sources - Source documents used (if any). + * @returns {Promise} A promise that resolves with the complete text response when the stream ends. */ handleStream(response, stream, responseProps) { const { uuid = uuidv4(), sources = [] } = responseProps; let hasUsageMetrics = false; - let usage = { - prompt_tokens: 0, - completion_tokens: 0, - }; + let usage = { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }; return new Promise(async (resolve) => { let fullText = ""; let reasoningText = ""; - // Establish listener to early-abort a streaming response - // in case things go sideways or the user does not like the response. - // We preserve the generated text but continue as if chat was completed - // to preserve previously generated content. + // Abort handler for client closing connection const handleAbort = () => { - stream?.endMeasurement(usage); - clientAbortedHandler(resolve, fullText); + this.#log(`Client closed connection for stream ${uuid}. Aborting.`); + stream?.endMeasurement(usage); // Finalize metrics + clientAbortedHandler(resolve, fullText); // Resolve with partial text }; response.on("close", handleAbort); try { + // Process stream chunks for await (const chunk of stream.stream) { - if (chunk === undefined) - throw new Error( - "Stream returned undefined chunk. Aborting reply - check model provider logs." - ); - - const action = Object.keys(chunk)[0]; - if (action === "metadata") { - hasUsageMetrics = true; - usage.prompt_tokens = chunk.metadata?.usage?.inputTokens ?? 0; - usage.completion_tokens = chunk.metadata?.usage?.outputTokens ?? 0; - usage.total_tokens = chunk.metadata?.usage?.totalTokens ?? 0; + if (!chunk) { + this.#log("Stream returned null/undefined chunk."); + continue; } + const action = Object.keys(chunk)[0]; - if (action === "contentBlockDelta") { - const token = chunk.contentBlockDelta?.delta?.text; - const reasoningToken = - chunk.contentBlockDelta?.delta?.reasoningContent?.text; - - // Reasoning models will always return the reasoning text before the token text. - if (reasoningToken) { - // If the reasoning text is empty (''), we need to initialize it - // and send the first chunk of reasoning text. - if (reasoningText.length === 0) { + switch (action) { + case "metadata": // Contains usage metrics at the end + if (chunk.metadata?.usage) { + hasUsageMetrics = true; + usage = { + // Overwrite with final metrics + prompt_tokens: chunk.metadata.usage.inputTokens ?? 0, + completion_tokens: chunk.metadata.usage.outputTokens ?? 0, + total_tokens: chunk.metadata.usage.totalTokens ?? 0, + }; + } + break; + case "contentBlockDelta": { + // Contains text or reasoning deltas + const delta = chunk.contentBlockDelta?.delta; + if (!delta) break; + const token = delta.text; + const reasoningToken = delta.reasoningContent?.text; + + if (reasoningToken) { + // Handle reasoning text + if (reasoningText.length === 0) { + // Start of reasoning block + const startTag = ""; + writeResponseChunk(response, { + uuid, + sources, + type: "textResponseChunk", + textResponse: startTag + reasoningToken, + close: false, + error: false, + }); + reasoningText += startTag + reasoningToken; + } else { + // Continuation of reasoning block + writeResponseChunk(response, { + uuid, + sources, + type: "textResponseChunk", + textResponse: reasoningToken, + close: false, + error: false, + }); + reasoningText += reasoningToken; + } + } else if (token) { + // Handle regular text + if (reasoningText.length > 0) { + // If reasoning was just output, close the tag + const endTag = ""; + writeResponseChunk(response, { + uuid, + sources, + type: "textResponseChunk", + textResponse: endTag, + close: false, + error: false, + }); + fullText += reasoningText + endTag; // Add completed reasoning to final text + reasoningText = ""; // Reset reasoning buffer + } + fullText += token; // Append regular text + if (!hasUsageMetrics) usage.completion_tokens++; // Estimate usage if no metrics yet writeResponseChunk(response, { uuid, - sources: [], + sources, type: "textResponseChunk", - textResponse: `${reasoningToken}`, + textResponse: token, close: false, error: false, }); - reasoningText += `${reasoningToken}`; - continue; - } else { + } + break; + } + case "messageStop": // End of message event + if (chunk.messageStop?.usage) { + // Check for final metrics here too + hasUsageMetrics = true; + usage = { + // Overwrite with final metrics if available + prompt_tokens: + chunk.messageStop.usage.inputTokens ?? usage.prompt_tokens, + completion_tokens: + chunk.messageStop.usage.outputTokens ?? + usage.completion_tokens, + total_tokens: + chunk.messageStop.usage.totalTokens ?? usage.total_tokens, + }; + } + // Ensure reasoning tag is closed if message stops mid-reasoning + if (reasoningText.length > 0) { + const endTag = ""; writeResponseChunk(response, { uuid, - sources: [], + sources, type: "textResponseChunk", - textResponse: reasoningToken, + textResponse: endTag, close: false, error: false, }); - reasoningText += reasoningToken; + fullText += reasoningText + endTag; + reasoningText = ""; } - } - - // If the reasoning text is not empty, but the reasoning token is empty - // and the token text is not empty we need to close the reasoning text and begin sending the token text. - if (!!reasoningText && !reasoningToken && token) { - writeResponseChunk(response, { - uuid, - sources: [], - type: "textResponseChunk", - textResponse: ``, - close: false, - error: false, - }); - fullText += `${reasoningText}`; - reasoningText = ""; - } - - if (token) { - fullText += token; - // If we never saw a usage metric, we can estimate them by number of completion chunks - if (!hasUsageMetrics) usage.completion_tokens++; - writeResponseChunk(response, { - uuid, - sources: [], - type: "textResponseChunk", - textResponse: token, - close: false, - error: false, - }); - } + break; + // Ignore other event types for now + case "messageStart": + case "contentBlockStart": + case "contentBlockStop": + break; + default: + this.#log(`Unhandled stream action: ${action}`, chunk); } + } // End for await loop + + // Final cleanup for reasoning tag in case stream ended abruptly + if (reasoningText.length > 0 && !fullText.endsWith("")) { + const endTag = ""; + if (!response.writableEnded) { + writeResponseChunk(response, { + uuid, + sources, + type: "textResponseChunk", + textResponse: endTag, + close: false, + error: false, + }); + } + fullText += reasoningText + endTag; } - writeResponseChunk(response, { - uuid, - sources, - type: "textResponseChunk", - textResponse: "", - close: true, - error: false, - }); - response.removeListener("close", handleAbort); - stream?.endMeasurement(usage); - resolve(fullText); + // Send final closing chunk to signal end of stream + if (!response.writableEnded) { + writeResponseChunk(response, { + uuid, + sources, + type: "textResponseChunk", + textResponse: "", + close: true, + error: false, + }); + } } catch (error) { - console.log(`\x1b[43m\x1b[34m[STREAMING ERROR]\x1b[0m ${e.message}`); - writeResponseChunk(response, { - uuid, - type: "abort", - textResponse: null, - sources: [], - close: true, - error: `AWSBedrock:streaming - could not stream chat. ${error?.cause ?? error.message}`, - }); + // Handle errors during stream processing + this.#log( + `\x1b[43m\x1b[34m[STREAMING ERROR]\x1b[0m ${error.message}`, + error + ); + if (response && !response.writableEnded) { + writeResponseChunk(response, { + uuid, + type: "abort", + textResponse: null, + sources, + close: true, + error: `AWSBedrock:streaming - error. ${ + error?.message ?? "Unknown error" + }`, + }); + } + } finally { response.removeListener("close", handleAbort); stream?.endMeasurement(usage); - resolve(fullText); // Return what we currently have - if anything. + resolve(fullText); // Resolve with the accumulated text } }); } diff --git a/server/utils/AiProviders/bedrock/utils.js b/server/utils/AiProviders/bedrock/utils.js new file mode 100644 index 00000000000..f1529fb94ca --- /dev/null +++ b/server/utils/AiProviders/bedrock/utils.js @@ -0,0 +1,68 @@ +/** @typedef {'jpeg' | 'png' | 'gif' | 'webp'} */ +const SUPPORTED_BEDROCK_IMAGE_FORMATS = ["jpeg", "png", "gif", "webp"]; + +/** @type {number} */ +const DEFAULT_MAX_OUTPUT_TOKENS = 4096; + +/** @type {number} */ +const DEFAULT_CONTEXT_WINDOW_TOKENS = 8191; + +/** @type {'iam' | 'sessionToken'} */ +const SUPPORTED_CONNECTION_METHODS = ["iam", "sessionToken"]; + +/** + * Parses a MIME type string (e.g., "image/jpeg") to extract and validate the image format + * supported by Bedrock Converse. Handles 'image/jpg' as 'jpeg'. + * @param {string | null | undefined} mimeType - The MIME type string. + * @returns {string | null} The validated image format (e.g., "jpeg") or null if invalid/unsupported. + */ +function getImageFormatFromMime(mimeType = "") { + if (!mimeType) return null; + const parts = mimeType.toLowerCase().split("/"); + if (parts?.[0] !== "image") return null; + const format = parts?.[1]; + + if (!format) return null; + + // Remap jpg to jpeg + switch (format) { + case "jpg": + format = "jpeg"; + break; + default: + break; + } + + if (!SUPPORTED_BEDROCK_IMAGE_FORMATS.includes(format)) return null; + return format; +} + +/** + * Decodes a pure base64 string (without data URI prefix) into a Uint8Array using the atob method. + * This approach matches the technique previously used by Langchain's implementation. + * @param {string} base64String - The pure base64 encoded data. + * @returns {Uint8Array | null} The resulting byte array or null on decoding error. + */ +function base64ToUint8Array(base64String) { + try { + const binaryString = atob(base64String); + const len = binaryString.length; + const bytes = new Uint8Array(len); + for (let i = 0; i < len; i++) bytes[i] = binaryString.charCodeAt(i); + return bytes; + } catch (e) { + console.error( + `[AWSBedrock] Error decoding base64 string with atob: ${e.message}` + ); + return null; + } +} + +module.exports = { + SUPPORTED_CONNECTION_METHODS, + SUPPORTED_BEDROCK_IMAGE_FORMATS, + DEFAULT_MAX_OUTPUT_TOKENS, + DEFAULT_CONTEXT_WINDOW_TOKENS, + getImageFormatFromMime, + base64ToUint8Array, +}; diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index bffab102609..32dca2d0bbc 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -254,6 +254,10 @@ const KEY_MAPPING = { envKey: "AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT", checks: [nonZero], }, + AwsBedrockLLMMaxOutputTokens: { + envKey: "AWS_BEDROCK_LLM_MAX_OUTPUT_TOKENS", + checks: [nonZero], + }, EmbeddingEngine: { envKey: "EMBEDDING_ENGINE",