diff --git a/.github/workflows/gemini-automated-issue-triage.yml b/.github/workflows/gemini-automated-issue-triage.yml index 8e033f197d9..2d0603497e6 100644 --- a/.github/workflows/gemini-automated-issue-triage.yml +++ b/.github/workflows/gemini-automated-issue-triage.yml @@ -153,9 +153,6 @@ jobs: settings: |- { "maxSessionTurns": 25, - "coreTools": [ - "run_shell_command(echo)" - ], "telemetry": { "enabled": true, "target": "gcp" @@ -167,8 +164,8 @@ jobs: You are an issue triage assistant. Your role is to analyze a GitHub issue and determine the single most appropriate area/ label and the single most appropriate priority/ label based on the definitions provided. ## Steps - 1. Review the issue title and body provided in the environment variables: ${ISSUE_TITLE} and ${ISSUE_BODY}. - 2. Review the available labels provided in the environment variable: ${AVAILABLE_LABELS}. + 1. Review the issue title and body: ${{ env.ISSUE_TITLE }} and ${{ env.ISSUE_BODY }}. + 2. Review the available labels: ${{ env.AVAILABLE_LABELS }}. 3. Select exactly one area/ label that best matches the issue based on Reference 1: Area Definitions. 4. Select exactly one priority/ label that best matches the issue based on Reference 2: Priority Definitions. 5. Fallback Logic: @@ -180,7 +177,6 @@ jobs: ## Guidelines - Your output must contain exactly one area/ label and exactly one priority/ label. - Triage only the current issue based on its title and body. - - Reference all shell variables as "${VAR}" (with quotes and braces). - Output only valid JSON format. - Do not include any explanation or additional text, just the JSON. diff --git a/docs/changelogs/index.md b/docs/changelogs/index.md index b4f9ad640fd..baa4ec0c7ce 100644 --- a/docs/changelogs/index.md +++ b/docs/changelogs/index.md @@ -3,6 +3,69 @@ Wondering what's new in Gemini CLI? This document provides key highlights and notable changes to Gemini CLI. +## v0.12.0 - Gemini CLI weekly update - 2025-10-27 + +![Codebase investigator subagent in Gemini CLI.](https://i.imgur.com/4J1njsx.png) + +- **🎉 New partner extensions:** + - **🤗 Hugging Face extension:** Access the Hugging Face hub. + ([gif](https://drive.google.com/file/d/1LEzIuSH6_igFXq96_tWev11svBNyPJEB/view?usp=sharing&resourcekey=0-LtPTzR1woh-rxGtfPzjjfg)) + + `gemini extensions install https://github.com/huggingface/hf-mcp-server` + + - **Monday.com extension**: Analyze your sprints, update your task boards, + etc. + ([gif](https://drive.google.com/file/d/1cO0g6kY1odiBIrZTaqu5ZakaGZaZgpQv/view?usp=sharing&resourcekey=0-xEr67SIjXmAXRe1PKy7Jlw)) + + `gemini extensions install https://github.com/mondaycom/mcp` + + - **Data Commons extension:** Query public datasets or ground responses on + data from Data Commons + ([gif](https://drive.google.com/file/d/1cuj-B-vmUkeJnoBXrO_Y1CuqphYc6p-O/view?usp=sharing&resourcekey=0-0adXCXDQEd91ZZW63HbW-Q)). + + `gemini extensions install https://github.com/gemini-cli-extensions/datacommons` + +- **Model selection:** Choose the Gemini model for your session with `/model`. + ([pic](https://imgur.com/a/ABFcWWw), + [pr](https://github.com/google-gemini/gemini-cli/pull/8940) by + [@abhipatel12](https://github.com/abhipatel12)). +- **Model routing:** Gemini CLI will now intelligently pick the best model for + the task. Simple queries will be sent to Flash while complex analytical or + creative tasks will still use the power of Pro. This ensures your quota will + last for a longer period of time. You can always opt-out of this via `/model`. + ([pr](https://github.com/google-gemini/gemini-cli/pull/9262) by + [@abhipatel12](https://github.com/abhipatel12)). + - Discussion: + [https://github.com/google-gemini/gemini-cli/discussions/12375](https://github.com/google-gemini/gemini-cli/discussions/12375) +- **Codebase investigator subagent:** We now have a new built-in subagent that + will explore your workspace and resolve relevant information to improve + overall performance. + ([pr](https://github.com/google-gemini/gemini-cli/pull/9988) by + [@abhipatel12](https://github.com/abhipatel12), + [pr](https://github.com/google-gemini/gemini-cli/pull/10282) by + [@silviojr](https://github.com/silviojr)). + - Enable, disable, or limit turns in `/settings`, plus advanced configs in + `settings.json` ([pic](https://imgur.com/a/yJiggNO), + [pr](https://github.com/google-gemini/gemini-cli/pull/10844) by + [@silviojr](https://github.com/silviojr)). +- **Explore extensions with `/extension`:** Users can now open the extensions + page in their default browser directly from the CLI using the `/extension` + explore command. ([pr](https://github.com/google-gemini/gemini-cli/pull/11846) + by [@JayadityaGit](https://github.com/JayadityaGit)). +- **Configurable compression:** Users can modify the compression threshold in + `/settings`. The default has been made more proactive + ([pr](https://github.com/google-gemini/gemini-cli/pull/12317) by + [@scidomino](https://github.com/scidomino)). +- **API key authentication:** Users can now securely enter and store their + Gemini API key via a new dialog, eliminating the need for environment + variables and repeated entry. + ([pr](https://github.com/google-gemini/gemini-cli/pull/11760) by + [@galz10](https://github.com/galz10)). +- **Sequential approval:** Users can now approve multiple tool calls + sequentially during execution. + ([pr](https://github.com/google-gemini/gemini-cli/pull/11593) by + [@joshualitt](https://github.com/joshualitt)). + ## v0.11.0 - Gemini CLI weekly update - 2025-10-20 ![Gemini CLI and Jules](https://storage.googleapis.com/gweb-developer-goog-blog-assets/images/Jules_Extension_-_Blog_Header_O346JNt.original.png) diff --git a/docs/cli/commands.md b/docs/cli/commands.md index c0dee22245e..1dae60c424a 100644 --- a/docs/cli/commands.md +++ b/docs/cli/commands.md @@ -318,10 +318,20 @@ Gemini CLI. - When exited, the UI reverts to its standard appearance and normal Gemini CLI behavior resumes. -- **Caution for all `!` usage:** Commands you execute in shell mode have the - same permissions and impact as if you ran them directly in your terminal. +### Quick Command Execution (`Ctrl+Enter` / `Meta+Enter`) -- **Environment Variable:** When a command is executed via `!` or in shell mode, - the `GEMINI_CLI=1` environment variable is set in the subprocess's - environment. This allows scripts or tools to detect if they are being run from - within the Gemini CLI. +While in **prompt mode**, you can quickly execute shell commands without switching to shell mode: + +- **`Ctrl+Enter` / `Meta+Enter`** + - **Description:** Execute the current input as a shell command directly. After execution, you remain in prompt mode. + - **Behavior:** + - Works only when you're in prompt mode (not shell mode) + - The command is executed using the same shell processor as shell mode + - After execution, the input is cleared and you stay in prompt mode + - Useful for quick commands without mode switching + - **Examples:** + - Type `ls -la` and press `Ctrl+Enter`/`Meta+Enter` to execute without entering shell mode + - Type `git status` and press `Ctrl+Enter`/`Meta+Enter` to check git status quickly + +- **Caution for all `!` usage and quick command execution:** Commands you execute have the same permissions and impact as if you ran them directly in your terminal. +- **Environment Variable:** When a command is executed via `!`, in shell mode, or with `Ctrl+Enter`/`Meta+Enter`, the `GEMINI_CLI=1` environment variable is set in the subprocess's environment. This allows scripts or tools to detect if they are being run from within the Gemini CLI. \ No newline at end of file diff --git a/docs/cli/keyboard-shortcuts.md b/docs/cli/keyboard-shortcuts.md index 05a5683ba28..fb32e80d456 100644 --- a/docs/cli/keyboard-shortcuts.md +++ b/docs/cli/keyboard-shortcuts.md @@ -24,6 +24,7 @@ This document lists the available keyboard shortcuts within Gemini CLI. | -------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | | `!` | Toggle shell mode when the input is empty. | | `\` (at end of line) + `Enter` | Insert a newline. | +| `Ctrl+Enter` / `Meta+Enter` | Execute the current input as a shell command directly without switching to shell mode (works in prompt mode). | | `Down Arrow` | Navigate down through the input history. | | `Enter` | Submit the current prompt. | | `Meta+Delete` / `Ctrl+Delete` | Delete the word to the right of the cursor. | diff --git a/docs/get-started/configuration.md b/docs/get-started/configuration.md index d65ec509a99..8d17f10c788 100644 --- a/docs/get-started/configuration.md +++ b/docs/get-started/configuration.md @@ -296,6 +296,20 @@ their corresponding top-level category object in your `settings.json` file. - **Description:** Skip the next speaker check. - **Default:** `true` +#### `modelConfigs` + +- **`modelConfigs.aliases`** (object): + - **Description:** Named presets for model configs. Can be used in place of a + model name and can inherit from other aliases using an `extends` property. + - **Default:** + `{"base":{"modelConfig":{"generateContentConfig":{"temperature":0,"topP":1}}},"chat-base":{"extends":"base","modelConfig":{"generateContentConfig":{"thinkingConfig":{"includeThoughts":true,"thinkingBudget":-1}}}},"gemini-2.5-pro":{"extends":"chat-base","modelConfig":{"model":"gemini-2.5-pro"}},"gemini-2.5-flash":{"extends":"chat-base","modelConfig":{"model":"gemini-2.5-flash"}},"gemini-2.5-flash-lite":{"extends":"chat-base","modelConfig":{"model":"gemini-2.5-flash-lite"}},"classifier":{"extends":"base","modelConfig":{"model":"gemini-2.5-flash-lite","generateContentConfig":{"maxOutputTokens":1024,"thinkingConfig":{"thinkingBudget":512}}}},"prompt-completion":{"extends":"base","modelConfig":{"model":"gemini-2.5-flash-lite","generateContentConfig":{"temperature":0.3,"maxOutputTokens":16000,"thinkingConfig":{"thinkingBudget":0}}}},"edit-corrector":{"extends":"base","modelConfig":{"model":"gemini-2.5-flash-lite","generateContentConfig":{"thinkingConfig":{"thinkingBudget":0}}}},"summarizer-default":{"extends":"base","modelConfig":{"model":"gemini-2.5-flash-lite","generateContentConfig":{"maxOutputTokens":2000}}},"summarizer-shell":{"extends":"base","modelConfig":{"model":"gemini-2.5-flash-lite","generateContentConfig":{"maxOutputTokens":2000}}},"web-search-tool":{"extends":"base","modelConfig":{"model":"gemini-2.5-flash","generateContentConfig":{"tools":[{"googleSearch":{}}]}}},"web-fetch-tool":{"extends":"base","modelConfig":{"model":"gemini-2.5-flash","generateContentConfig":{"tools":[{"urlContext":{}}]}}}}` + +- **`modelConfigs.overrides`** (array): + - **Description:** Apply specific configuration overrides based on matches, + with a primary key of model (or alias). The most specific match will be + used. + - **Default:** `[]` + #### `context` - **`context.fileName`** (string | string[]): diff --git a/integration-tests/list_directory.test.ts b/integration-tests/list_directory.test.ts index eadd9d434b8..1da05b9db65 100644 --- a/integration-tests/list_directory.test.ts +++ b/integration-tests/list_directory.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect } from 'vitest'; +import { describe, it } from 'vitest'; import { TestRig, poll, @@ -17,7 +17,7 @@ import { join } from 'node:path'; describe('list_directory', () => { it('should be able to list a directory', async () => { const rig = new TestRig(); - await rig.setup('should be able to list a directory'); + rig.setup('should be able to list a directory'); rig.createFile('file1.txt', 'file 1 content'); rig.mkdir('subdir'); rig.sync(); @@ -38,33 +38,27 @@ describe('list_directory', () => { const result = await rig.run(prompt); - const foundToolCall = await rig.waitForToolCall('list_directory'); + try { + await rig.expectToolCallSuccess(['list_directory']); + } catch (e) { + // Add debugging information + if (!result.includes('file1.txt') || !result.includes('subdir')) { + const allTools = printDebugInfo(rig, result, { + 'Found tool call': false, + 'Contains file1.txt': result.includes('file1.txt'), + 'Contains subdir': result.includes('subdir'), + }); - // Add debugging information - if ( - !foundToolCall || - !result.includes('file1.txt') || - !result.includes('subdir') - ) { - const allTools = printDebugInfo(rig, result, { - 'Found tool call': foundToolCall, - 'Contains file1.txt': result.includes('file1.txt'), - 'Contains subdir': result.includes('subdir'), - }); - - console.error( - 'List directory calls:', - allTools - .filter((t) => t.toolRequest.name === 'list_directory') - .map((t) => t.toolRequest.args), - ); + console.error( + 'List directory calls:', + allTools + .filter((t) => t.toolRequest.name === 'list_directory') + .map((t) => t.toolRequest.args), + ); + } + throw e; } - expect( - foundToolCall, - 'Expected to find a list_directory tool call', - ).toBeTruthy(); - // Validate model output - will throw if no output, warn if missing expected content validateModelOutput(result, ['file1.txt', 'subdir'], 'List directory test'); }); diff --git a/package-lock.json b/package-lock.json index 74949dc76f1..69ec9df2890 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@google/gemini-cli", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@google/gemini-cli", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "workspaces": [ "packages/*" ], @@ -16959,7 +16959,7 @@ }, "packages/a2a-server": { "name": "@google/gemini-cli-a2a-server", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "dependencies": { "@a2a-js/sdk": "^0.3.2", "@google-cloud/storage": "^7.16.0", @@ -17249,7 +17249,7 @@ }, "packages/cli": { "name": "@google/gemini-cli", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "dependencies": { "@google/gemini-cli-core": "file:../core", "@google/genai": "1.16.0", @@ -17349,7 +17349,7 @@ }, "packages/core": { "name": "@google/gemini-cli-core", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "dependencies": { "@google-cloud/logging": "^11.2.1", "@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.21.0", @@ -17493,7 +17493,7 @@ }, "packages/test-utils": { "name": "@google/gemini-cli-test-utils", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "license": "Apache-2.0", "devDependencies": { "typescript": "^5.3.3" @@ -17504,7 +17504,7 @@ }, "packages/vscode-ide-companion": { "name": "gemini-cli-vscode-ide-companion", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "license": "LICENSE", "dependencies": { "@modelcontextprotocol/sdk": "^1.15.1", diff --git a/package.json b/package.json index b8e724ee503..46f84e9b2ed 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@google/gemini-cli", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "engines": { "node": ">=20.0.0" }, @@ -14,7 +14,7 @@ "url": "git+https://github.com/google-gemini/gemini-cli.git" }, "config": { - "sandboxImageUri": "us-docker.pkg.dev/gemini-code-dev/gemini-cli/sandbox:0.13.0-nightly.20251031.c89bc30d" + "sandboxImageUri": "us-docker.pkg.dev/gemini-code-dev/gemini-cli/sandbox:0.14.0-nightly.20251104.da3da198" }, "scripts": { "start": "cross-env NODE_ENV=development node scripts/start.js", diff --git a/packages/a2a-server/package.json b/packages/a2a-server/package.json index bfbc56d8f8f..7b5ff6fc676 100644 --- a/packages/a2a-server/package.json +++ b/packages/a2a-server/package.json @@ -1,6 +1,6 @@ { "name": "@google/gemini-cli-a2a-server", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "description": "Gemini CLI A2A Server", "repository": { "type": "git", diff --git a/packages/cli/package.json b/packages/cli/package.json index 26ac93fc84d..f3d3c79dd1b 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -1,6 +1,6 @@ { "name": "@google/gemini-cli", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "description": "Gemini CLI", "repository": { "type": "git", @@ -25,7 +25,7 @@ "dist" ], "config": { - "sandboxImageUri": "us-docker.pkg.dev/gemini-code-dev/gemini-cli/sandbox:0.13.0-nightly.20251031.c89bc30d" + "sandboxImageUri": "us-docker.pkg.dev/gemini-code-dev/gemini-cli/sandbox:0.14.0-nightly.20251104.da3da198" }, "dependencies": { "@google/gemini-cli-core": "file:../core", diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 553a1ce760a..a241a27d517 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -669,6 +669,7 @@ export async function loadCliConfig( recordResponses: argv.recordResponses, retryFetchErrors: settings.general?.retryFetchErrors ?? false, ptyInfo: ptyInfo?.name, + modelConfigServiceConfig: settings.modelConfigs, // TODO: loading of hooks based on workspace trust enableHooks: settings.tools?.enableHooks ?? false, hooks: settings.hooks || {}, diff --git a/packages/cli/src/config/extension-manager.ts b/packages/cli/src/config/extension-manager.ts index dda9b25c6ec..5111c28c0bb 100644 --- a/packages/cli/src/config/extension-manager.ts +++ b/packages/cli/src/config/extension-manager.ts @@ -616,15 +616,21 @@ export class ExtensionManager extends ExtensionLoader { throw new Error(`Extension with name ${name} does not exist.`); } - const scopePath = - scope === SettingScope.Workspace ? this.workspaceDir : os.homedir(); - this.extensionEnablementManager.disable(name, true, scopePath); - extension.isActive = false; - await this.maybeStopExtension(extension); + if (scope !== SettingScope.Session) { + const scopePath = + scope === SettingScope.Workspace ? this.workspaceDir : os.homedir(); + this.extensionEnablementManager.disable(name, true, scopePath); + } logExtensionDisable( this.telemetryConfig, new ExtensionDisableEvent(hashValue(name), extension.id, scope), ); + if (!this.config || this.config.getEnableExtensionReloading()) { + // Only toggle the isActive state if we are actually going to disable it + // in the current session, or we haven't been initialized yet. + extension.isActive = false; + } + await this.maybeStopExtension(extension); } /** @@ -644,14 +650,21 @@ export class ExtensionManager extends ExtensionLoader { if (!extension) { throw new Error(`Extension with name ${name} does not exist.`); } - const scopePath = - scope === SettingScope.Workspace ? this.workspaceDir : os.homedir(); - this.extensionEnablementManager.enable(name, true, scopePath); + + if (scope !== SettingScope.Session) { + const scopePath = + scope === SettingScope.Workspace ? this.workspaceDir : os.homedir(); + this.extensionEnablementManager.enable(name, true, scopePath); + } logExtensionEnable( this.telemetryConfig, new ExtensionEnableEvent(hashValue(name), extension.id, scope), ); - extension.isActive = true; + if (!this.config || this.config.getEnableExtensionReloading()) { + // Only toggle the isActive state if we are actually going to disable it + // in the current session, or we haven't been initialized yet. + extension.isActive = true; + } await this.maybeStartExtension(extension); } } diff --git a/packages/cli/src/config/keyBindings.ts b/packages/cli/src/config/keyBindings.ts index 62d672a5207..d8aaafbd1b7 100644 --- a/packages/cli/src/config/keyBindings.ts +++ b/packages/cli/src/config/keyBindings.ts @@ -63,6 +63,7 @@ export enum Command { SUBMIT_REVERSE_SEARCH = 'submitReverseSearch', ACCEPT_SUGGESTION_REVERSE_SEARCH = 'acceptSuggestionReverseSearch', TOGGLE_SHELL_INPUT_FOCUS = 'toggleShellInputFocus', + EXECUTE_PROMPT_COMMAND = 'executePromptCommand', // Suggestion expansion EXPAND_SUGGESTION = 'expandSuggestion', @@ -164,8 +165,6 @@ export const defaultKeyBindings: KeyBindingConfig = { // Split into multiple data-driven bindings // Now also includes shift+enter for multi-line input [Command.NEWLINE]: [ - { key: 'return', ctrl: true }, - { key: 'return', command: true }, { key: 'return', paste: true }, { key: 'return', shift: true }, { key: 'j', ctrl: true }, @@ -194,6 +193,10 @@ export const defaultKeyBindings: KeyBindingConfig = { [Command.SUBMIT_REVERSE_SEARCH]: [{ key: 'return', ctrl: false }], [Command.ACCEPT_SUGGESTION_REVERSE_SEARCH]: [{ key: 'tab' }], [Command.TOGGLE_SHELL_INPUT_FOCUS]: [{ key: 'f', ctrl: true }], + [Command.EXECUTE_PROMPT_COMMAND]: [ + { key: 'return', ctrl: true }, + { key: 'return', command: true }, + ], // Suggestion expansion [Command.EXPAND_SUGGESTION]: [{ key: 'right' }], diff --git a/packages/cli/src/config/policies/read-only.toml b/packages/cli/src/config/policies/read-only.toml deleted file mode 100644 index 0c36faf003f..00000000000 --- a/packages/cli/src/config/policies/read-only.toml +++ /dev/null @@ -1,56 +0,0 @@ -# Priority system for policy rules: -# - Higher priority numbers win over lower priority numbers -# - When multiple rules match, the highest priority rule is applied -# - Rules are evaluated in order of priority (highest first) -# -# Priority bands (tiers): -# - Default policies (TOML): 1 + priority/1000 (e.g., priority 100 → 1.100) -# - User policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100) -# - Admin policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100) -# -# This ensures Admin > User > Default hierarchy is always preserved, -# while allowing user-specified priorities to work within each tier. -# -# Settings-based and dynamic rules (all in user tier 2.x): -# 2.95: Tools that the user has selected as "Always Allow" in the interactive UI -# 2.9: MCP servers excluded list (security: persistent server blocks) -# 2.4: Command line flag --exclude-tools (explicit temporary blocks) -# 2.3: Command line flag --allowed-tools (explicit temporary allows) -# 2.2: MCP servers with trust=true (persistent trusted servers) -# 2.1: MCP servers allowed list (persistent general server allows) -# -# TOML policy priorities (before transformation): -# 10: Write tools default to ASK_USER (becomes 1.010 in default tier) -# 15: Auto-edit tool override (becomes 1.015 in default tier) -# 50: Read-only tools (becomes 1.050 in default tier) -# 999: YOLO mode allow-all (becomes 1.999 in default tier) - -[[rule]] -toolName = "glob" -decision = "allow" -priority = 50 - -[[rule]] -toolName = "search_file_content" -decision = "allow" -priority = 50 - -[[rule]] -toolName = "list_directory" -decision = "allow" -priority = 50 - -[[rule]] -toolName = "read_file" -decision = "allow" -priority = 50 - -[[rule]] -toolName = "read_many_files" -decision = "allow" -priority = 50 - -[[rule]] -toolName = "google_web_search" -decision = "allow" -priority = 50 diff --git a/packages/cli/src/config/policies/write.toml b/packages/cli/src/config/policies/write.toml deleted file mode 100644 index 8e4c1ae70eb..00000000000 --- a/packages/cli/src/config/policies/write.toml +++ /dev/null @@ -1,63 +0,0 @@ -# Priority system for policy rules: -# - Higher priority numbers win over lower priority numbers -# - When multiple rules match, the highest priority rule is applied -# - Rules are evaluated in order of priority (highest first) -# -# Priority bands (tiers): -# - Default policies (TOML): 1 + priority/1000 (e.g., priority 100 → 1.100) -# - User policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100) -# - Admin policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100) -# -# This ensures Admin > User > Default hierarchy is always preserved, -# while allowing user-specified priorities to work within each tier. -# -# Settings-based and dynamic rules (all in user tier 2.x): -# 2.95: Tools that the user has selected as "Always Allow" in the interactive UI -# 2.9: MCP servers excluded list (security: persistent server blocks) -# 2.4: Command line flag --exclude-tools (explicit temporary blocks) -# 2.3: Command line flag --allowed-tools (explicit temporary allows) -# 2.2: MCP servers with trust=true (persistent trusted servers) -# 2.1: MCP servers allowed list (persistent general server allows) -# -# TOML policy priorities (before transformation): -# 10: Write tools default to ASK_USER (becomes 1.010 in default tier) -# 15: Auto-edit tool override (becomes 1.015 in default tier) -# 50: Read-only tools (becomes 1.050 in default tier) -# 999: YOLO mode allow-all (becomes 1.999 in default tier) - -[[rule]] -toolName = "replace" -decision = "ask_user" -priority = 10 - -[[rule]] -toolName = "replace" -decision = "allow" -priority = 15 -modes = ["autoEdit"] - -[[rule]] -toolName = "save_memory" -decision = "ask_user" -priority = 10 - -[[rule]] -toolName = "run_shell_command" -decision = "ask_user" -priority = 10 - -[[rule]] -toolName = "write_file" -decision = "ask_user" -priority = 10 - -[[rule]] -toolName = "write_file" -decision = "allow" -priority = 15 -modes = ["autoEdit"] - -[[rule]] -toolName = "web_fetch" -decision = "ask_user" -priority = 10 diff --git a/packages/cli/src/config/policies/yolo.toml b/packages/cli/src/config/policies/yolo.toml deleted file mode 100644 index 0c5f9e9221b..00000000000 --- a/packages/cli/src/config/policies/yolo.toml +++ /dev/null @@ -1,31 +0,0 @@ -# Priority system for policy rules: -# - Higher priority numbers win over lower priority numbers -# - When multiple rules match, the highest priority rule is applied -# - Rules are evaluated in order of priority (highest first) -# -# Priority bands (tiers): -# - Default policies (TOML): 1 + priority/1000 (e.g., priority 100 → 1.100) -# - User policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100) -# - Admin policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100) -# -# This ensures Admin > User > Default hierarchy is always preserved, -# while allowing user-specified priorities to work within each tier. -# -# Settings-based and dynamic rules (all in user tier 2.x): -# 2.95: Tools that the user has selected as "Always Allow" in the interactive UI -# 2.9: MCP servers excluded list (security: persistent server blocks) -# 2.4: Command line flag --exclude-tools (explicit temporary blocks) -# 2.3: Command line flag --allowed-tools (explicit temporary allows) -# 2.2: MCP servers with trust=true (persistent trusted servers) -# 2.1: MCP servers allowed list (persistent general server allows) -# -# TOML policy priorities (before transformation): -# 10: Write tools default to ASK_USER (becomes 1.010 in default tier) -# 15: Auto-edit tool override (becomes 1.015 in default tier) -# 50: Read-only tools (becomes 1.050 in default tier) -# 999: YOLO mode allow-all (becomes 1.999 in default tier) - -[[rule]] -decision = "allow" -priority = 999 -modes = ["yolo"] diff --git a/packages/cli/src/config/policy-engine.integration.test.ts b/packages/cli/src/config/policy-engine.integration.test.ts index 9b8457bc332..0c22cfeba98 100644 --- a/packages/cli/src/config/policy-engine.integration.test.ts +++ b/packages/cli/src/config/policy-engine.integration.test.ts @@ -30,18 +30,22 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Allowed tool should be allowed - expect(engine.check({ name: 'run_shell_command' })).toBe( + expect(engine.check({ name: 'run_shell_command' }, undefined)).toBe( PolicyDecision.ALLOW, ); // Excluded tool should be denied - expect(engine.check({ name: 'write_file' })).toBe(PolicyDecision.DENY); + expect(engine.check({ name: 'write_file' }, undefined)).toBe( + PolicyDecision.DENY, + ); // Other write tools should ask user - expect(engine.check({ name: 'replace' })).toBe(PolicyDecision.ASK_USER); + expect(engine.check({ name: 'replace' }, undefined)).toBe( + PolicyDecision.ASK_USER, + ); // Unknown tools should use default - expect(engine.check({ name: 'unknown_tool' })).toBe( + expect(engine.check({ name: 'unknown_tool' }, undefined)).toBe( PolicyDecision.ASK_USER, ); }); @@ -68,31 +72,31 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Tools from allowed server should be allowed - expect(engine.check({ name: 'allowed-server__tool1' })).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'allowed-server__another_tool' })).toBe( + expect(engine.check({ name: 'allowed-server__tool1' }, undefined)).toBe( PolicyDecision.ALLOW, ); + expect( + engine.check({ name: 'allowed-server__another_tool' }, undefined), + ).toBe(PolicyDecision.ALLOW); // Tools from trusted server should be allowed - expect(engine.check({ name: 'trusted-server__tool1' })).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'trusted-server__special_tool' })).toBe( + expect(engine.check({ name: 'trusted-server__tool1' }, undefined)).toBe( PolicyDecision.ALLOW, ); + expect( + engine.check({ name: 'trusted-server__special_tool' }, undefined), + ).toBe(PolicyDecision.ALLOW); // Tools from blocked server should be denied - expect(engine.check({ name: 'blocked-server__tool1' })).toBe( - PolicyDecision.DENY, - ); - expect(engine.check({ name: 'blocked-server__any_tool' })).toBe( + expect(engine.check({ name: 'blocked-server__tool1' }, undefined)).toBe( PolicyDecision.DENY, ); + expect( + engine.check({ name: 'blocked-server__any_tool' }, undefined), + ).toBe(PolicyDecision.DENY); // Tools from unknown servers should use default - expect(engine.check({ name: 'unknown-server__tool' })).toBe( + expect(engine.check({ name: 'unknown-server__tool' }, undefined)).toBe( PolicyDecision.ASK_USER, ); }); @@ -114,13 +118,13 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // MCP server allowed (priority 2.1) provides general allow for server - expect(engine.check({ name: 'my-server__safe-tool' })).toBe( + expect(engine.check({ name: 'my-server__safe-tool' }, undefined)).toBe( PolicyDecision.ALLOW, ); // But specific tool exclude (priority 2.4) wins over server allow - expect(engine.check({ name: 'my-server__dangerous-tool' })).toBe( - PolicyDecision.DENY, - ); + expect( + engine.check({ name: 'my-server__dangerous-tool' }, undefined), + ).toBe(PolicyDecision.DENY); }); it('should handle complex mixed configurations', async () => { @@ -150,36 +154,44 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Read-only tools should be allowed (autoAccept) - expect(engine.check({ name: 'read_file' })).toBe(PolicyDecision.ALLOW); - expect(engine.check({ name: 'list_directory' })).toBe( + expect(engine.check({ name: 'read_file' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + expect(engine.check({ name: 'list_directory' }, undefined)).toBe( PolicyDecision.ALLOW, ); // But glob is explicitly excluded, so it should be denied - expect(engine.check({ name: 'glob' })).toBe(PolicyDecision.DENY); + expect(engine.check({ name: 'glob' }, undefined)).toBe( + PolicyDecision.DENY, + ); // Replace should ask user (normal write tool behavior) - expect(engine.check({ name: 'replace' })).toBe(PolicyDecision.ASK_USER); + expect(engine.check({ name: 'replace' }, undefined)).toBe( + PolicyDecision.ASK_USER, + ); // Explicitly allowed tools - expect(engine.check({ name: 'custom-tool' })).toBe(PolicyDecision.ALLOW); - expect(engine.check({ name: 'my-server__special-tool' })).toBe( + expect(engine.check({ name: 'custom-tool' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + expect(engine.check({ name: 'my-server__special-tool' }, undefined)).toBe( PolicyDecision.ALLOW, ); // MCP server tools - expect(engine.check({ name: 'allowed-server__tool' })).toBe( + expect(engine.check({ name: 'allowed-server__tool' }, undefined)).toBe( PolicyDecision.ALLOW, ); - expect(engine.check({ name: 'trusted-server__tool' })).toBe( + expect(engine.check({ name: 'trusted-server__tool' }, undefined)).toBe( PolicyDecision.ALLOW, ); - expect(engine.check({ name: 'blocked-server__tool' })).toBe( + expect(engine.check({ name: 'blocked-server__tool' }, undefined)).toBe( PolicyDecision.DENY, ); // Write tools should ask by default - expect(engine.check({ name: 'write_file' })).toBe( + expect(engine.check({ name: 'write_file' }, undefined)).toBe( PolicyDecision.ASK_USER, ); }); @@ -198,14 +210,18 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Most tools should be allowed in YOLO mode - expect(engine.check({ name: 'run_shell_command' })).toBe( + expect(engine.check({ name: 'run_shell_command' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + expect(engine.check({ name: 'write_file' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + expect(engine.check({ name: 'unknown_tool' }, undefined)).toBe( PolicyDecision.ALLOW, ); - expect(engine.check({ name: 'write_file' })).toBe(PolicyDecision.ALLOW); - expect(engine.check({ name: 'unknown_tool' })).toBe(PolicyDecision.ALLOW); // But explicitly excluded tools should still be denied - expect(engine.check({ name: 'dangerous-tool' })).toBe( + expect(engine.check({ name: 'dangerous-tool' }, undefined)).toBe( PolicyDecision.DENY, ); }); @@ -220,11 +236,15 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Edit tools should be allowed in AUTO_EDIT mode - expect(engine.check({ name: 'replace' })).toBe(PolicyDecision.ALLOW); - expect(engine.check({ name: 'write_file' })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'replace' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + expect(engine.check({ name: 'write_file' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); // Other tools should follow normal rules - expect(engine.check({ name: 'run_shell_command' })).toBe( + expect(engine.check({ name: 'run_shell_command' }, undefined)).toBe( PolicyDecision.ASK_USER, ); }); @@ -285,20 +305,24 @@ describe('Policy Engine Integration Tests', () => { expect(readOnlyToolRule?.priority).toBeCloseTo(1.05, 5); // Verify the engine applies these priorities correctly - expect(engine.check({ name: 'blocked-tool' })).toBe(PolicyDecision.DENY); - expect(engine.check({ name: 'blocked-server__any' })).toBe( + expect(engine.check({ name: 'blocked-tool' }, undefined)).toBe( PolicyDecision.DENY, ); - expect(engine.check({ name: 'specific-tool' })).toBe( + expect(engine.check({ name: 'blocked-server__any' }, undefined)).toBe( + PolicyDecision.DENY, + ); + expect(engine.check({ name: 'specific-tool' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + expect(engine.check({ name: 'trusted-server__any' }, undefined)).toBe( PolicyDecision.ALLOW, ); - expect(engine.check({ name: 'trusted-server__any' })).toBe( + expect(engine.check({ name: 'mcp-server__any' }, undefined)).toBe( PolicyDecision.ALLOW, ); - expect(engine.check({ name: 'mcp-server__any' })).toBe( + expect(engine.check({ name: 'glob' }, undefined)).toBe( PolicyDecision.ALLOW, ); - expect(engine.check({ name: 'glob' })).toBe(PolicyDecision.ALLOW); }); it('should handle edge case: MCP server with both trust and exclusion', async () => { @@ -322,7 +346,7 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Exclusion (195) should win over trust (90) - expect(engine.check({ name: 'conflicted-server__tool' })).toBe( + expect(engine.check({ name: 'conflicted-server__tool' }, undefined)).toBe( PolicyDecision.DENY, ); }); @@ -345,10 +369,10 @@ describe('Policy Engine Integration Tests', () => { // Server exclusion (195) wins over specific tool allow (100) // This might be counterintuitive but follows the priority system - expect(engine.check({ name: 'my-server__special-tool' })).toBe( + expect(engine.check({ name: 'my-server__special-tool' }, undefined)).toBe( PolicyDecision.DENY, ); - expect(engine.check({ name: 'my-server__other-tool' })).toBe( + expect(engine.check({ name: 'my-server__other-tool' }, undefined)).toBe( PolicyDecision.DENY, ); }); @@ -365,8 +389,10 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(engineConfig); // ASK_USER should become DENY in non-interactive mode - expect(engine.check({ name: 'unknown_tool' })).toBe(PolicyDecision.DENY); - expect(engine.check({ name: 'run_shell_command' })).toBe( + expect(engine.check({ name: 'unknown_tool' }, undefined)).toBe( + PolicyDecision.DENY, + ); + expect(engine.check({ name: 'run_shell_command' }, undefined)).toBe( PolicyDecision.DENY, ); }); @@ -381,13 +407,17 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Should have default rules for write tools - expect(engine.check({ name: 'write_file' })).toBe( + expect(engine.check({ name: 'write_file' }, undefined)).toBe( + PolicyDecision.ASK_USER, + ); + expect(engine.check({ name: 'replace' }, undefined)).toBe( PolicyDecision.ASK_USER, ); - expect(engine.check({ name: 'replace' })).toBe(PolicyDecision.ASK_USER); // Unknown tools should use default - expect(engine.check({ name: 'unknown' })).toBe(PolicyDecision.ASK_USER); + expect(engine.check({ name: 'unknown' }, undefined)).toBe( + PolicyDecision.ASK_USER, + ); }); it('should verify rules are created with correct priorities', async () => { diff --git a/packages/cli/src/config/settings.ts b/packages/cli/src/config/settings.ts index cfdbf7e44e6..3a3295beccf 100644 --- a/packages/cli/src/config/settings.ts +++ b/packages/cli/src/config/settings.ts @@ -157,6 +157,38 @@ export enum SettingScope { Workspace = 'Workspace', System = 'System', SystemDefaults = 'SystemDefaults', + // Note that this scope is not supported in the settings dialog at this time, + // it is only supported for extensions. + Session = 'Session', +} + +/** + * A type representing the settings scopes that are supported for LoadedSettings. + */ +export type LoadableSettingScope = + | SettingScope.User + | SettingScope.Workspace + | SettingScope.System + | SettingScope.SystemDefaults; + +/** + * The actual values of the loadable settings scopes. + */ +const _loadableSettingScopes = [ + SettingScope.User, + SettingScope.Workspace, + SettingScope.System, + SettingScope.SystemDefaults, +]; + +/** + * A type guard function that checks if `scope` is a loadable settings scope, + * and allows promotion to the `LoadableSettingsScope` type based on the result. + */ +export function isLoadableSettingScope( + scope: SettingScope, +): scope is LoadableSettingScope { + return _loadableSettingScopes.includes(scope); } export interface CheckpointingSettings { @@ -398,14 +430,14 @@ export class LoadedSettings { user: SettingsFile, workspace: SettingsFile, isTrusted: boolean, - migratedInMemorScopes: Set, + migratedInMemoryScopes: Set, ) { this.system = system; this.systemDefaults = systemDefaults; this.user = user; this.workspace = workspace; this.isTrusted = isTrusted; - this.migratedInMemorScopes = migratedInMemorScopes; + this.migratedInMemoryScopes = migratedInMemoryScopes; this._merged = this.computeMergedSettings(); } @@ -414,7 +446,7 @@ export class LoadedSettings { readonly user: SettingsFile; readonly workspace: SettingsFile; readonly isTrusted: boolean; - readonly migratedInMemorScopes: Set; + readonly migratedInMemoryScopes: Set; private _merged: Settings; @@ -432,7 +464,7 @@ export class LoadedSettings { ); } - forScope(scope: SettingScope): SettingsFile { + forScope(scope: LoadableSettingScope): SettingsFile { switch (scope) { case SettingScope.User: return this.user; @@ -447,7 +479,7 @@ export class LoadedSettings { } } - setValue(scope: SettingScope, key: string, value: unknown): void { + setValue(scope: LoadableSettingScope, key: string, value: unknown): void { const settingsFile = this.forScope(scope); setNestedProperty(settingsFile.settings, key, value); setNestedProperty(settingsFile.originalSettings, key, value); @@ -563,7 +595,7 @@ export function loadSettings( const settingsErrors: SettingsError[] = []; const systemSettingsPath = getSystemSettingsPath(); const systemDefaultsPath = getSystemDefaultsPath(); - const migratedInMemorScopes = new Set(); + const migratedInMemoryScopes = new Set(); // Resolve paths to their canonical representation to handle symlinks const resolvedWorkspaceDir = path.resolve(workspaceDir); @@ -625,7 +657,7 @@ export function loadSettings( ); } } else { - migratedInMemorScopes.add(scope); + migratedInMemoryScopes.add(scope); } settingsObject = migratedSettings; } @@ -703,7 +735,7 @@ export function loadSettings( isTrusted, ); - // loadEnviroment depends on settings so we have to create a temp version of + // loadEnvironment depends on settings so we have to create a temp version of // the settings to avoid a cycle loadEnvironment(tempMergedSettings); @@ -744,7 +776,7 @@ export function loadSettings( rawJson: workspaceResult.rawJson, }, isTrusted, - migratedInMemorScopes, + migratedInMemoryScopes, ); } @@ -752,7 +784,7 @@ export function migrateDeprecatedSettings( loadedSettings: LoadedSettings, extensionManager: ExtensionManager, ): void { - const processScope = (scope: SettingScope) => { + const processScope = (scope: LoadableSettingScope) => { const settings = loadedSettings.forScope(scope).settings; if (settings.extensions?.disabled) { debugLogger.log( diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index ea6ca29fa2a..8d4e459f58a 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -21,6 +21,7 @@ import { DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, DEFAULT_GEMINI_MODEL, + DEFAULT_MODEL_CONFIGS, } from '@google/gemini-cli-core'; import type { CustomTheme } from '../ui/themes/theme.js'; import type { SessionRetentionSettings } from './settings.js'; @@ -680,6 +681,38 @@ const SETTINGS_SCHEMA = { }, }, + modelConfigs: { + type: 'object', + label: 'Model Configs', + category: 'Model', + requiresRestart: false, + default: DEFAULT_MODEL_CONFIGS, + description: 'Model configurations.', + showInDialog: false, + properties: { + aliases: { + type: 'object', + label: 'Model Config Aliases', + category: 'Model', + requiresRestart: false, + default: DEFAULT_MODEL_CONFIGS.aliases, + description: + 'Named presets for model configs. Can be used in place of a model name and can inherit from other aliases using an `extends` property.', + showInDialog: false, + }, + overrides: { + type: 'array', + label: 'Model Config Overrides', + category: 'Model', + requiresRestart: false, + default: [], + description: + 'Apply specific configuration overrides based on matches, with a primary key of model (or alias). The most specific match will be used.', + showInDialog: false, + }, + }, + }, + context: { type: 'object', label: 'Context', diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index d740f08e56c..eea524256c2 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -184,7 +184,7 @@ describe('runNonInteractive', () => { }, }, isTrusted: true, - migratedInMemorScopes: new Set(), + migratedInMemoryScopes: new Set(), forScope: vi.fn(), computeMergedSettings: vi.fn(), } as unknown as LoadedSettings; diff --git a/packages/cli/src/services/BuiltinCommandLoader.test.ts b/packages/cli/src/services/BuiltinCommandLoader.test.ts index 3ae6c6639a6..49792cb0815 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.test.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.test.ts @@ -66,7 +66,7 @@ vi.mock('../ui/commands/corgiCommand.js', () => ({ corgiCommand: {} })); vi.mock('../ui/commands/docsCommand.js', () => ({ docsCommand: {} })); vi.mock('../ui/commands/editorCommand.js', () => ({ editorCommand: {} })); vi.mock('../ui/commands/extensionsCommand.js', () => ({ - extensionsCommand: {}, + extensionsCommand: () => ({}), })); vi.mock('../ui/commands/helpCommand.js', () => ({ helpCommand: {} })); vi.mock('../ui/commands/memoryCommand.js', () => ({ memoryCommand: {} })); @@ -97,6 +97,7 @@ describe('BuiltinCommandLoader', () => { getFolderTrust: vi.fn().mockReturnValue(true), getUseModelRouter: () => false, getEnableMessageBusIntegration: () => false, + getEnableExtensionReloading: () => false, } as unknown as Config; restoreCommandMock.mockReturnValue({ @@ -222,6 +223,7 @@ describe('BuiltinCommandLoader profile', () => { getUseModelRouter: () => false, getCheckpointingEnabled: () => false, getEnableMessageBusIntegration: () => false, + getEnableExtensionReloading: () => false, } as unknown as Config; }); diff --git a/packages/cli/src/services/BuiltinCommandLoader.ts b/packages/cli/src/services/BuiltinCommandLoader.ts index 010cf60aa53..6e6c0a407fb 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.ts @@ -67,7 +67,7 @@ export class BuiltinCommandLoader implements ICommandLoader { docsCommand, directoryCommand, editorCommand, - extensionsCommand, + extensionsCommand(this.config?.getEnableExtensionReloading()), helpCommand, await ideCommand(), initCommand, diff --git a/packages/cli/src/ui/AppContainer.test.tsx b/packages/cli/src/ui/AppContainer.test.tsx index 63298829bf2..d6f15cbb464 100644 --- a/packages/cli/src/ui/AppContainer.test.tsx +++ b/packages/cli/src/ui/AppContainer.test.tsx @@ -1709,4 +1709,41 @@ describe('AppContainer State Management', () => { unmount(); }); }); + + describe('Shell Interaction', () => { + it('should not crash if resizing the pty fails', async () => { + const resizePtySpy = vi + .spyOn(ShellExecutionService, 'resizePty') + .mockImplementation(() => { + throw new Error('Cannot resize a pty that has already exited'); + }); + + mockedUseGeminiStream.mockReturnValue({ + streamingState: 'idle', + submitQuery: vi.fn(), + initError: null, + pendingHistoryItems: [], + thought: null, + cancelOngoingRequest: vi.fn(), + activePtyId: 'some-pty-id', // Make sure activePtyId is set + }); + + // The main assertion is that the render does not throw. + const { unmount } = render( + , + ); + + await act(async () => { + await new Promise((resolve) => setTimeout(resolve, 0)); + }); + + expect(resizePtySpy).toHaveBeenCalled(); + unmount(); + }); + }); }); diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index a4de005e509..ad47bb07bac 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -76,7 +76,11 @@ import { useTextBuffer } from './components/shared/text-buffer.js'; import { useLogger } from './hooks/useLogger.js'; import { useGeminiStream } from './hooks/useGeminiStream.js'; import { useVim } from './hooks/vim.js'; -import { type LoadedSettings, SettingScope } from '../config/settings.js'; +import { + type LoadableSettingScope, + type LoadedSettings, + SettingScope, +} from '../config/settings.js'; import { type InitializationResult } from '../core/initializer.js'; import { useFocus } from './hooks/useFocus.js'; import { useBracketedPaste } from './hooks/useBracketedPaste.js'; @@ -100,6 +104,7 @@ import { useExtensionUpdates, } from './hooks/useExtensionUpdates.js'; import { ShellFocusContext } from './contexts/ShellFocusContext.js'; +import { useShellHistory } from './hooks/useShellHistory.js'; import { type ExtensionManager } from '../config/extension-manager.js'; import { requestConsentInteractive } from '../config/extensions/consent.js'; import { disableMouseEvents, enableMouseEvents } from './utils/mouse.js'; @@ -153,6 +158,7 @@ export const AppContainer = (props: AppContainerProps) => { ); const [isProcessing, setIsProcessing] = useState(false); const [embeddedShellFocused, setEmbeddedShellFocused] = useState(false); + const quickCommandAbortControllerRef = useRef(null); const [showDebugProfiler, setShowDebugProfiler] = useState(false); const [copyModeEnabled, setCopyModeEnabled] = useState(false); @@ -323,6 +329,8 @@ export const AppContainer = (props: AppContainerProps) => { shellModeActive, }); + const shellHistory = useShellHistory(config.getProjectRoot()); + useEffect(() => { const fetchUserMessages = async () => { const pastMessagesRaw = (await logger?.getPreviousUserMessages()) || []; @@ -396,7 +404,7 @@ export const AppContainer = (props: AppContainerProps) => { // Create handleAuthSelect wrapper for backward compatibility const handleAuthSelect = useCallback( - async (authType: AuthType | undefined, scope: SettingScope) => { + async (authType: AuthType | undefined, scope: LoadableSettingScope) => { if (authType) { await clearCachedCredentialFile(); settings.setValue(scope, 'security.auth.selectedType', authType); @@ -655,6 +663,7 @@ Logging in with Google... Please restart Gemini CLI to continue. handleApprovalModeChange, activePtyId, loopDetectionConfirmationRequest, + handleShellCommand, } = useGeminiStream( config.getGeminiClient(), historyManager.history, @@ -805,11 +814,27 @@ Logging in with Google... Please restart Gemini CLI to continue. useEffect(() => { if (activePtyId) { - ShellExecutionService.resizePty( - activePtyId, - Math.floor(terminalWidth * SHELL_WIDTH_FRACTION), - Math.max(Math.floor(availableTerminalHeight - SHELL_HEIGHT_PADDING), 1), - ); + try { + ShellExecutionService.resizePty( + activePtyId, + Math.floor(terminalWidth * SHELL_WIDTH_FRACTION), + Math.max( + Math.floor(availableTerminalHeight - SHELL_HEIGHT_PADDING), + 1, + ), + ); + } catch (e) { + // This can happen in a race condition where the pty exits + // right before we try to resize it. + if ( + !( + e instanceof Error && + e.message.includes('Cannot resize a pty that has already exited') + ) + ) { + throw e; + } + } } }, [terminalWidth, availableTerminalHeight, activePtyId]); @@ -1035,6 +1060,9 @@ Logging in with Google... Please restart Gemini CLI to continue. // If the user presses Ctrl+C, we want to cancel any ongoing requests. // This should happen regardless of the count. cancelOngoingRequest?.(); + // Also cancel quick command if one is running + quickCommandAbortControllerRef.current?.abort(); + quickCommandAbortControllerRef.current = null; setCtrlCPressCount((prev) => prev + 1); return; @@ -1074,6 +1102,31 @@ Logging in with Google... Please restart Gemini CLI to continue. !enteringConstrainHeightMode ) { setConstrainHeight(false); + } else if (keyMatchers[Command.EXECUTE_PROMPT_COMMAND](key)) { + const commandToExecute = buffer.text.trim(); + if (commandToExecute) { + buffer.setText(''); + // Add command to shell history + shellHistory.addCommandToHistory(commandToExecute); + const abortController = new AbortController(); + // Store the controller so Ctrl+C can cancel it + quickCommandAbortControllerRef.current = abortController; + + // Clear the ref when the command completes or is aborted + const cleanup = () => { + if (quickCommandAbortControllerRef.current === abortController) { + quickCommandAbortControllerRef.current = null; + } + }; + abortController.signal.addEventListener('abort', cleanup, { + once: true, + }); + + // Using the same shell command processor as Shell Mode + handleShellCommand(commandToExecute, abortController.signal); + } + // Consume the key event even if buffer is empty to prevent newline + return; } else if (keyMatchers[Command.TOGGLE_SHELL_INPUT_FOCUS](key)) { if (activePtyId || embeddedShellFocused) { setEmbeddedShellFocused((prev) => !prev); @@ -1087,13 +1140,15 @@ Logging in with Google... Please restart Gemini CLI to continue. config, ideContextState, setCtrlCPressCount, - buffer.text.length, + buffer, setCtrlDPressCount, handleSlashCommand, cancelOngoingRequest, activePtyId, embeddedShellFocused, settings.merged.general?.debugKeystrokeLogging, + handleShellCommand, + shellHistory, refreshStatic, setCopyModeEnabled, copyModeEnabled, diff --git a/packages/cli/src/ui/auth/AuthDialog.tsx b/packages/cli/src/ui/auth/AuthDialog.tsx index c024dd255e9..61ea01764d7 100644 --- a/packages/cli/src/ui/auth/AuthDialog.tsx +++ b/packages/cli/src/ui/auth/AuthDialog.tsx @@ -9,7 +9,10 @@ import { useCallback } from 'react'; import { Box, Text } from 'ink'; import { theme } from '../semantic-colors.js'; import { RadioButtonSelect } from '../components/shared/RadioButtonSelect.js'; -import type { LoadedSettings } from '../../config/settings.js'; +import type { + LoadableSettingScope, + LoadedSettings, +} from '../../config/settings.js'; import { SettingScope } from '../../config/settings.js'; import { AuthType, @@ -99,7 +102,7 @@ export function AuthDialog({ } const onSelect = useCallback( - async (authType: AuthType | undefined, scope: SettingScope) => { + async (authType: AuthType | undefined, scope: LoadableSettingScope) => { if (authType) { await clearCachedCredentialFile(); diff --git a/packages/cli/src/ui/commands/extensionsCommand.test.ts b/packages/cli/src/ui/commands/extensionsCommand.test.ts index c10b3896d1e..db947b39685 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.test.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.test.ts @@ -7,10 +7,16 @@ import type { GeminiCLIExtension } from '@google/gemini-cli-core'; import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; import { MessageType } from '../types.js'; -import { extensionsCommand } from './extensionsCommand.js'; -import { type CommandContext } from './types.js'; +import { + completeExtensions, + completeExtensionsAndScopes, + extensionsCommand, +} from './extensionsCommand.js'; +import { type CommandContext, type SlashCommand } from './types.js'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { type ExtensionUpdateAction } from '../state/extensions.js'; +import { ExtensionManager } from '../../config/extension-manager.js'; +import { SettingScope } from '../../config/settings.js'; import open from 'open'; vi.mock('open', () => ({ @@ -22,20 +28,72 @@ vi.mock('../../config/extensions/update.js', () => ({ checkForAllExtensionUpdates: vi.fn(), })); +const mockDisableExtension = vi.fn(); +const mockEnableExtension = vi.fn(); const mockGetExtensions = vi.fn(); +const inactiveExt: GeminiCLIExtension = { + name: 'ext-one', + id: 'ext-one-id', + version: '1.0.0', + isActive: false, // should suggest disabled extensions + path: '/test/dir/ext-one', + contextFiles: [], + installMetadata: { + type: 'git', + autoUpdate: false, + source: 'https://github.com/some/extension.git', + }, +}; +const activeExt: GeminiCLIExtension = { + name: 'ext-two', + id: 'ext-two-id', + version: '1.0.0', + isActive: true, // should not suggest enabled extensions + path: '/test/dir/ext-two', + contextFiles: [], + installMetadata: { + type: 'git', + autoUpdate: false, + source: 'https://github.com/some/extension.git', + }, +}; +const allExt: GeminiCLIExtension = { + name: 'all-ext', + id: 'all-ext-id', + version: '1.0.0', + isActive: true, + path: '/test/dir/all-ext', + contextFiles: [], + installMetadata: { + type: 'git', + autoUpdate: false, + source: 'https://github.com/some/extension.git', + }, +}; + describe('extensionsCommand', () => { let mockContext: CommandContext; const mockDispatchExtensionState = vi.fn(); beforeEach(() => { vi.resetAllMocks(); - mockGetExtensions.mockReturnValue([]); + + mockGetExtensions.mockReturnValue([inactiveExt, activeExt, allExt]); vi.mocked(open).mockClear(); mockContext = createMockCommandContext({ services: { config: { getExtensions: mockGetExtensions, + getExtensionLoader: vi.fn().mockImplementation(() => { + const actual = Object.create(ExtensionManager.prototype); + Object.assign(actual, { + enableExtension: mockEnableExtension, + disableExtension: mockDisableExtension, + getExtensions: mockGetExtensions, + }); + return actual; + }), getWorkingDir: () => '/test/dir', }, }, @@ -52,8 +110,9 @@ describe('extensionsCommand', () => { describe('list', () => { it('should add an EXTENSIONS_LIST item to the UI', async () => { - if (!extensionsCommand.action) throw new Error('Action not defined'); - await extensionsCommand.action(mockContext, ''); + const command = extensionsCommand(); + if (!command.action) throw new Error('Action not defined'); + await command.action(mockContext, ''); expect(mockContext.ui.addItem).toHaveBeenCalledWith( { @@ -65,8 +124,68 @@ describe('extensionsCommand', () => { }); }); + describe('completeExtensions', () => { + it.each([ + { + description: 'should return matching extension names', + partialArg: 'ext', + expected: ['ext-one', 'ext-two'], + }, + { + description: 'should return --all when partialArg matches', + partialArg: '--al', + expected: ['--all'], + }, + { + description: + 'should return both extension names and --all when both match', + partialArg: 'all', + expected: ['--all', 'all-ext'], + }, + { + description: 'should return an empty array if no matches', + partialArg: 'nomatch', + expected: [], + }, + { + description: + 'should suggest only disabled extension names for the enable command', + partialArg: 'ext', + expected: ['ext-one'], + command: 'enable', + }, + { + description: + 'should suggest only enabled extension names for the disable command', + partialArg: 'ext', + expected: ['ext-two'], + command: 'disable', + }, + ])('$description', async ({ partialArg, expected, command }) => { + if (command) { + mockContext.invocation!.name = command; + } + const suggestions = completeExtensions(mockContext, partialArg); + expect(suggestions).toEqual(expected); + }); + }); + + describe('completeExtensionsAndScopes', () => { + it('expands the list of suggestions with --scope args', () => { + const suggestions = completeExtensionsAndScopes(mockContext, 'ext'); + expect(suggestions).toEqual([ + 'ext-one --scope user', + 'ext-one --scope workspace', + 'ext-one --scope session', + 'ext-two --scope user', + 'ext-two --scope workspace', + 'ext-two --scope session', + ]); + }); + }); + describe('update', () => { - const updateAction = extensionsCommand.subCommands?.find( + const updateAction = extensionsCommand().subCommands?.find( (cmd) => cmd.name === 'update', )?.action; @@ -230,92 +349,10 @@ describe('extensionsCommand', () => { expect.any(Number), ); }); - - describe('completion', () => { - const updateCompletion = extensionsCommand.subCommands?.find( - (cmd) => cmd.name === 'update', - )?.completion; - - if (!updateCompletion) { - throw new Error('Update completion not found'); - } - - const extensionOne: GeminiCLIExtension = { - name: 'ext-one', - id: 'ext-one-id', - version: '1.0.0', - isActive: true, - path: '/test/dir/ext-one', - contextFiles: [], - installMetadata: { - type: 'git', - autoUpdate: false, - source: 'https://github.com/some/extension.git', - }, - }; - const extensionTwo: GeminiCLIExtension = { - name: 'another-ext', - id: 'another-ext-id', - version: '1.0.0', - isActive: true, - path: '/test/dir/another-ext', - contextFiles: [], - installMetadata: { - type: 'git', - autoUpdate: false, - source: 'https://github.com/some/extension.git', - }, - }; - const allExt: GeminiCLIExtension = { - name: 'all-ext', - id: 'all-ext-id', - version: '1.0.0', - isActive: true, - path: '/test/dir/all-ext', - contextFiles: [], - installMetadata: { - type: 'git', - autoUpdate: false, - source: 'https://github.com/some/extension.git', - }, - }; - - it.each([ - { - description: 'should return matching extension names', - extensions: [extensionOne, extensionTwo], - partialArg: 'ext', - expected: ['ext-one'], - }, - { - description: 'should return --all when partialArg matches', - extensions: [], - partialArg: '--al', - expected: ['--all'], - }, - { - description: - 'should return both extension names and --all when both match', - extensions: [allExt], - partialArg: 'all', - expected: ['--all', 'all-ext'], - }, - { - description: 'should return an empty array if no matches', - extensions: [extensionOne], - partialArg: 'nomatch', - expected: [], - }, - ])('$description', async ({ extensions, partialArg, expected }) => { - mockGetExtensions.mockReturnValue(extensions); - const suggestions = await updateCompletion(mockContext, partialArg); - expect(suggestions).toEqual(expected); - }); - }); }); describe('explore', () => { - const exploreAction = extensionsCommand.subCommands?.find( + const exploreAction = extensionsCommand().subCommands?.find( (cmd) => cmd.name === 'explore', )?.action; @@ -398,4 +435,141 @@ describe('extensionsCommand', () => { ); }); }); + + describe('when enableExtensionReloading is true', () => { + it('should include enable and disable subcommands', () => { + const command = extensionsCommand(true); + const subCommandNames = command.subCommands?.map((cmd) => cmd.name); + expect(subCommandNames).toContain('enable'); + expect(subCommandNames).toContain('disable'); + }); + }); + + describe('when enableExtensionReloading is false', () => { + it('should not include enable and disable subcommands', () => { + const command = extensionsCommand(false); + const subCommandNames = command.subCommands?.map((cmd) => cmd.name); + expect(subCommandNames).not.toContain('enable'); + expect(subCommandNames).not.toContain('disable'); + }); + }); + + describe('when enableExtensionReloading is not provided', () => { + it('should not include enable and disable subcommands by default', () => { + const command = extensionsCommand(); + const subCommandNames = command.subCommands?.map((cmd) => cmd.name); + expect(subCommandNames).not.toContain('enable'); + expect(subCommandNames).not.toContain('disable'); + }); + }); + + describe('enable', () => { + let enableAction: SlashCommand['action']; + + beforeEach(() => { + enableAction = extensionsCommand(true).subCommands?.find( + (cmd) => cmd.name === 'enable', + )?.action; + + expect(enableAction).not.toBeNull(); + + mockContext.invocation!.name = 'enable'; + }); + + it('should show usage if no extension name is provided', async () => { + await enableAction!(mockContext, ''); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.ERROR, + text: 'Usage: /extensions enable [--scope=]', + }, + expect.any(Number), + ); + }); + + it('should call enableExtension with the provided scope', async () => { + await enableAction!(mockContext, `${inactiveExt.name} --scope=user`); + expect(mockEnableExtension).toHaveBeenCalledWith( + inactiveExt.name, + SettingScope.User, + ); + + await enableAction!(mockContext, `${inactiveExt.name} --scope workspace`); + expect(mockEnableExtension).toHaveBeenCalledWith( + inactiveExt.name, + SettingScope.Workspace, + ); + }); + + it('should support --all', async () => { + mockGetExtensions.mockReturnValue([ + inactiveExt, + { ...inactiveExt, name: 'another-inactive-ext' }, + ]); + await enableAction!(mockContext, '--all --scope session'); + expect(mockEnableExtension).toHaveBeenCalledWith( + inactiveExt.name, + SettingScope.Session, + ); + expect(mockEnableExtension).toHaveBeenCalledWith( + 'another-inactive-ext', + SettingScope.Session, + ); + }); + }); + + describe('disable', () => { + let disableAction: SlashCommand['action']; + + beforeEach(() => { + disableAction = extensionsCommand(true).subCommands?.find( + (cmd) => cmd.name === 'disable', + )?.action; + + expect(disableAction).not.toBeNull(); + + mockContext.invocation!.name = 'disable'; + }); + + it('should show usage if no extension name is provided', async () => { + await disableAction!(mockContext, ''); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.ERROR, + text: 'Usage: /extensions disable [--scope=]', + }, + expect.any(Number), + ); + }); + + it('should call disableExtension with the provided scope', async () => { + await disableAction!(mockContext, `${activeExt.name} --scope=user`); + expect(mockDisableExtension).toHaveBeenCalledWith( + activeExt.name, + SettingScope.User, + ); + + await disableAction!(mockContext, `${activeExt.name} --scope workspace`); + expect(mockDisableExtension).toHaveBeenCalledWith( + activeExt.name, + SettingScope.Workspace, + ); + }); + + it('should support --all', async () => { + mockGetExtensions.mockReturnValue([ + activeExt, + { ...activeExt, name: 'another-active-ext' }, + ]); + await disableAction!(mockContext, '--all --scope session'); + expect(mockDisableExtension).toHaveBeenCalledWith( + activeExt.name, + SettingScope.Session, + ); + expect(mockDisableExtension).toHaveBeenCalledWith( + 'another-active-ext', + SettingScope.Session, + ); + }); + }); }); diff --git a/packages/cli/src/ui/commands/extensionsCommand.ts b/packages/cli/src/ui/commands/extensionsCommand.ts index 45ea3e47b67..2cb823543a4 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { listExtensions } from '@google/gemini-cli-core'; +import { debugLogger, listExtensions } from '@google/gemini-cli-core'; import type { ExtensionUpdateInfo } from '../../config/extension.js'; import { getErrorMessage } from '../../utils/errors.js'; import { MessageType, type HistoryItemExtensionsList } from '../types.js'; @@ -15,6 +15,8 @@ import { } from './types.js'; import open from 'open'; import process from 'node:process'; +import { ExtensionManager } from '../../config/extension-manager.js'; +import { SettingScope } from '../../config/settings.js'; async function listAction(context: CommandContext) { const historyItem: HistoryItemExtensionsList = { @@ -159,6 +161,158 @@ async function exploreAction(context: CommandContext) { } } +function getEnableDisableContext( + context: CommandContext, + argumentsString: string, +): { + extensionManager: ExtensionManager; + names: string[]; + scope: SettingScope; +} | null { + const extensionLoader = context.services.config?.getExtensionLoader(); + if (!(extensionLoader instanceof ExtensionManager)) { + debugLogger.error( + `Cannot ${context.invocation?.name} extensions in this environment`, + ); + return null; + } + const parts = argumentsString.split(' '); + const name = parts[0]; + if ( + name === '' || + !( + (parts.length === 2 && parts[1].startsWith('--scope=')) || // --scope= + (parts.length === 3 && parts[1] === '--scope') // --scope + ) + ) { + context.ui.addItem( + { + type: MessageType.ERROR, + text: `Usage: /extensions ${context.invocation?.name} [--scope=]`, + }, + Date.now(), + ); + return null; + } + let scope: SettingScope; + // Transform `--scope=` to `--scope `. + if (parts.length === 2) { + parts.push(...parts[1].split('=')); + parts.splice(1, 1); + } + switch (parts[2].toLowerCase()) { + case 'workspace': + scope = SettingScope.Workspace; + break; + case 'user': + scope = SettingScope.User; + break; + case 'session': + scope = SettingScope.Session; + break; + default: + context.ui.addItem( + { + type: MessageType.ERROR, + text: `Unsupported scope ${parts[2]}, should be one of "user", "workspace", or "session"`, + }, + Date.now(), + ); + debugLogger.error(); + return null; + } + let names: string[] = []; + if (name === '--all') { + let extensions = extensionLoader.getExtensions(); + if (context.invocation?.name === 'enable') { + extensions = extensions.filter((ext) => !ext.isActive); + } + if (context.invocation?.name === 'disable') { + extensions = extensions.filter((ext) => ext.isActive); + } + names = extensions.map((ext) => ext.name); + } else { + names = [name]; + } + + return { + extensionManager: extensionLoader, + names, + scope, + }; +} + +async function disableAction(context: CommandContext, args: string) { + const enableContext = getEnableDisableContext(context, args); + if (!enableContext) return; + + const { names, scope, extensionManager } = enableContext; + for (const name of names) { + await extensionManager.disableExtension(name, scope); + context.ui.addItem( + { + type: MessageType.INFO, + text: `Extension "${name}" disabled for the scope "${scope}"`, + }, + Date.now(), + ); + } +} + +async function enableAction(context: CommandContext, args: string) { + const enableContext = getEnableDisableContext(context, args); + if (!enableContext) return; + + const { names, scope, extensionManager } = enableContext; + for (const name of names) { + await extensionManager.enableExtension(name, scope); + context.ui.addItem( + { + type: MessageType.INFO, + text: `Extension "${name}" enabled for the scope "${scope}"`, + }, + Date.now(), + ); + } +} + +/** + * Exported for testing. + */ +export function completeExtensions( + context: CommandContext, + partialArg: string, +) { + let extensions = context.services.config?.getExtensions() ?? []; + if (context.invocation?.name === 'enable') { + extensions = extensions.filter((ext) => !ext.isActive); + } + if (context.invocation?.name === 'disable') { + extensions = extensions.filter((ext) => ext.isActive); + } + const extensionNames = extensions.map((ext) => ext.name); + const suggestions = extensionNames.filter((name) => + name.startsWith(partialArg), + ); + + if ('--all'.startsWith(partialArg) || 'all'.startsWith(partialArg)) { + suggestions.unshift('--all'); + } + + return suggestions; +} + +export function completeExtensionsAndScopes( + context: CommandContext, + partialArg: string, +) { + return completeExtensions(context, partialArg).flatMap((s) => [ + `${s} --scope user`, + `${s} --scope workspace`, + `${s} --scope session`, + ]); +} + const listExtensionsCommand: SlashCommand = { name: 'list', description: 'List active extensions', @@ -171,21 +325,23 @@ const updateExtensionsCommand: SlashCommand = { description: 'Update extensions. Usage: update |--all', kind: CommandKind.BUILT_IN, action: updateAction, - completion: async (context, partialArg) => { - const extensions = context.services.config - ? listExtensions(context.services.config) - : []; - const extensionNames = extensions.map((ext) => ext.name); - const suggestions = extensionNames.filter((name) => - name.startsWith(partialArg), - ); + completion: completeExtensions, +}; - if ('--all'.startsWith(partialArg) || 'all'.startsWith(partialArg)) { - suggestions.unshift('--all'); - } +const disableCommand: SlashCommand = { + name: 'disable', + description: 'Disable an extension', + kind: CommandKind.BUILT_IN, + action: disableAction, + completion: completeExtensionsAndScopes, +}; - return suggestions; - }, +const enableCommand: SlashCommand = { + name: 'enable', + description: 'Enable an extension', + kind: CommandKind.BUILT_IN, + action: enableAction, + completion: completeExtensionsAndScopes, }; const exploreExtensionsCommand: SlashCommand = { @@ -195,16 +351,24 @@ const exploreExtensionsCommand: SlashCommand = { action: exploreAction, }; -export const extensionsCommand: SlashCommand = { - name: 'extensions', - description: 'Manage extensions', - kind: CommandKind.BUILT_IN, - subCommands: [ - listExtensionsCommand, - updateExtensionsCommand, - exploreExtensionsCommand, - ], - action: (context, args) => - // Default to list if no subcommand is provided - listExtensionsCommand.action!(context, args), -}; +export function extensionsCommand( + enableExtensionReloading?: boolean, +): SlashCommand { + const conditionalCommands = enableExtensionReloading + ? [disableCommand, enableCommand] + : []; + return { + name: 'extensions', + description: 'Manage extensions', + kind: CommandKind.BUILT_IN, + subCommands: [ + listExtensionsCommand, + updateExtensionsCommand, + exploreExtensionsCommand, + ...conditionalCommands, + ], + action: (context, args) => + // Default to list if no subcommand is provided + listExtensionsCommand.action!(context, args), + }; +} diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index 44080cbf616..99e514fbba4 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -211,7 +211,7 @@ export interface SlashCommand { completion?: ( context: CommandContext, partialArg: string, - ) => Promise; + ) => Promise | string[]; subCommands?: SlashCommand[]; } diff --git a/packages/cli/src/ui/components/CliSpinner.test.tsx b/packages/cli/src/ui/components/CliSpinner.test.tsx new file mode 100644 index 00000000000..bbea23ab5d5 --- /dev/null +++ b/packages/cli/src/ui/components/CliSpinner.test.tsx @@ -0,0 +1,24 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { render } from '../../test-utils/render.js'; +import { CliSpinner } from './CliSpinner.js'; +import { debugState } from '../debug.js'; +import { describe, it, expect, beforeEach } from 'vitest'; + +describe('', () => { + beforeEach(() => { + debugState.debugNumAnimatedComponents = 0; + }); + + it('should increment debugNumAnimatedComponents on mount and decrement on unmount', () => { + expect(debugState.debugNumAnimatedComponents).toBe(0); + const { unmount } = render(); + expect(debugState.debugNumAnimatedComponents).toBe(1); + unmount(); + expect(debugState.debugNumAnimatedComponents).toBe(0); + }); +}); diff --git a/packages/cli/src/ui/components/CliSpinner.tsx b/packages/cli/src/ui/components/CliSpinner.tsx index b194a6c4680..6795bf26706 100644 --- a/packages/cli/src/ui/components/CliSpinner.tsx +++ b/packages/cli/src/ui/components/CliSpinner.tsx @@ -6,17 +6,15 @@ import Spinner from 'ink-spinner'; import { type ComponentProps, useEffect } from 'react'; - -// A top-level field to track the total number of active spinners. -export let debugNumSpinners = 0; +import { debugState } from '../debug.js'; export type SpinnerProps = ComponentProps; export const CliSpinner = (props: SpinnerProps) => { useEffect(() => { - debugNumSpinners++; + debugState.debugNumAnimatedComponents++; return () => { - debugNumSpinners--; + debugState.debugNumAnimatedComponents--; }; }, []); diff --git a/packages/cli/src/ui/components/DebugProfiler.test.tsx b/packages/cli/src/ui/components/DebugProfiler.test.tsx index c7e63f0b262..604e54c5fd0 100644 --- a/packages/cli/src/ui/components/DebugProfiler.test.tsx +++ b/packages/cli/src/ui/components/DebugProfiler.test.tsx @@ -12,6 +12,7 @@ import { FRAME_TIMESTAMP_CAPACITY, } from './DebugProfiler.js'; import { FixedDeque } from 'mnemonist'; +import { debugState } from '../debug.js'; describe('DebugProfiler', () => { beforeEach(() => { @@ -29,12 +30,14 @@ describe('DebugProfiler', () => { Array, ACTION_TIMESTAMP_CAPACITY, ); + debugState.debugNumAnimatedComponents = 0; }); afterEach(() => { vi.restoreAllMocks(); profiler.actionTimestamps.clear(); profiler.possiblyIdleFrameTimestamps.clear(); + debugState.debugNumAnimatedComponents = 0; }); it('should not exceed action timestamp capacity', () => { @@ -193,4 +196,20 @@ describe('DebugProfiler', () => { expect(profiler.totalIdleFrames).toBe(0); }); + + it('should not report frames as idle if debugNumAnimatedComponents > 0', async () => { + const startTime = Date.now(); + vi.setSystemTime(startTime); + debugState.debugNumAnimatedComponents = 1; + + for (let i = 0; i < 5; i++) { + profiler.reportFrameRendered(); + vi.advanceTimersByTime(20); + } + + vi.advanceTimersByTime(1000); + profiler.checkForIdleFrames(); + + expect(profiler.totalIdleFrames).toBe(0); + }); }); diff --git a/packages/cli/src/ui/components/DebugProfiler.tsx b/packages/cli/src/ui/components/DebugProfiler.tsx index 5b46e332510..cbcdbe5f247 100644 --- a/packages/cli/src/ui/components/DebugProfiler.tsx +++ b/packages/cli/src/ui/components/DebugProfiler.tsx @@ -9,7 +9,7 @@ import { useEffect, useState } from 'react'; import { FixedDeque } from 'mnemonist'; import { theme } from '../semantic-colors.js'; import { useUIState } from '../contexts/UIStateContext.js'; -import { debugNumSpinners } from './CliSpinner.js'; +import { debugState } from '../debug.js'; import { appEvents, AppEvent } from '../../utils/events.js'; // Frames that render at least this far before or after an action are considered @@ -52,7 +52,7 @@ export const profiler = { if (now - this.lastFrameStartTime > 16) { this.lastFrameStartTime = now; this.numFrames++; - if (debugNumSpinners === 0) { + if (debugState.debugNumAnimatedComponents === 0) { if (this.possiblyIdleFrameTimestamps.size >= FRAME_TIMESTAMP_CAPACITY) { this.possiblyIdleFrameTimestamps.shift(); } diff --git a/packages/cli/src/ui/components/EditorSettingsDialog.tsx b/packages/cli/src/ui/components/EditorSettingsDialog.tsx index 3e70207bcb1..55434fdf9d2 100644 --- a/packages/cli/src/ui/components/EditorSettingsDialog.tsx +++ b/packages/cli/src/ui/components/EditorSettingsDialog.tsx @@ -14,14 +14,20 @@ import { type EditorDisplay, } from '../editors/editorSettingsManager.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; -import type { LoadedSettings } from '../../config/settings.js'; +import type { + LoadableSettingScope, + LoadedSettings, +} from '../../config/settings.js'; import { SettingScope } from '../../config/settings.js'; import type { EditorType } from '@google/gemini-cli-core'; import { isEditorAvailable } from '@google/gemini-cli-core'; import { useKeypress } from '../hooks/useKeypress.js'; interface EditorDialogProps { - onSelect: (editorType: EditorType | undefined, scope: SettingScope) => void; + onSelect: ( + editorType: EditorType | undefined, + scope: LoadableSettingScope, + ) => void; settings: LoadedSettings; onExit: () => void; } @@ -31,7 +37,7 @@ export function EditorSettingsDialog({ settings, onExit, }: EditorDialogProps): React.JSX.Element { - const [selectedScope, setSelectedScope] = useState( + const [selectedScope, setSelectedScope] = useState( SettingScope.User, ); const [focusedSection, setFocusedSection] = useState<'editor' | 'scope'>( @@ -64,7 +70,11 @@ export function EditorSettingsDialog({ editorIndex = 0; } - const scopeItems = [ + const scopeItems: Array<{ + label: string; + value: LoadableSettingScope; + key: string; + }> = [ { label: 'User Settings', value: SettingScope.User, @@ -85,7 +95,7 @@ export function EditorSettingsDialog({ onSelect(editorType, selectedScope); }; - const handleScopeSelect = (scope: SettingScope) => { + const handleScopeSelect = (scope: LoadableSettingScope) => { setSelectedScope(scope); setFocusedSection('editor'); }; diff --git a/packages/cli/src/ui/components/Help.tsx b/packages/cli/src/ui/components/Help.tsx index a3e124d759d..e57e34b5e1d 100644 --- a/packages/cli/src/ui/components/Help.tsx +++ b/packages/cli/src/ui/components/Help.tsx @@ -55,7 +55,13 @@ export const Help: React.FC = ({ commands }) => ( start server - ). + ). Or use{' '} + + {process.platform === 'darwin' + ? 'Meta+Enter / Ctrl+Enter' + : 'Ctrl+Enter'} + {' '} + to run commands directly without entering shell mode. @@ -138,7 +144,15 @@ export const Help: React.FC = ({ commands }) => ( - {process.platform === 'darwin' ? 'Ctrl+X / Meta+Enter' : 'Ctrl+X'} + {process.platform === 'darwin' + ? 'Meta+Enter / Ctrl+Enter' + : 'Ctrl+Enter'} + {' '} + - Execute shell command directly without switching modes + + + + Ctrl+X {' '} - Open input in external editor diff --git a/packages/cli/src/ui/components/HistoryItemDisplay.test.tsx b/packages/cli/src/ui/components/HistoryItemDisplay.test.tsx index d0c28fa11f9..f603e9616a2 100644 --- a/packages/cli/src/ui/components/HistoryItemDisplay.test.tsx +++ b/packages/cli/src/ui/components/HistoryItemDisplay.test.tsx @@ -55,6 +55,18 @@ describe('', () => { expect(lastFrame()).toContain('/theme'); }); + it('renders InfoMessage for "info" type with multi-line text', () => { + const item: HistoryItem = { + ...baseItem, + type: MessageType.INFO, + text: '⚡ Line 1\n⚡ Line 2\n⚡ Line 3', + }; + const { lastFrame } = renderWithProviders( + , + ); + expect(lastFrame()).toMatchSnapshot(); + }); + it('renders StatsDisplay for "stats" type', () => { const item: HistoryItem = { ...baseItem, diff --git a/packages/cli/src/ui/components/SettingsDialog.tsx b/packages/cli/src/ui/components/SettingsDialog.tsx index b0f107c7a58..4f74afd12eb 100644 --- a/packages/cli/src/ui/components/SettingsDialog.tsx +++ b/packages/cli/src/ui/components/SettingsDialog.tsx @@ -7,7 +7,11 @@ import React, { useState, useEffect } from 'react'; import { Box, Text } from 'ink'; import { theme } from '../semantic-colors.js'; -import type { LoadedSettings, Settings } from '../../config/settings.js'; +import type { + LoadableSettingScope, + LoadedSettings, + Settings, +} from '../../config/settings.js'; import { SettingScope } from '../../config/settings.js'; import { getScopeItems, @@ -63,7 +67,7 @@ export function SettingsDialog({ 'settings', ); // Scope selector state (User by default) - const [selectedScope, setSelectedScope] = useState( + const [selectedScope, setSelectedScope] = useState( SettingScope.User, ); // Active indices @@ -358,11 +362,11 @@ export function SettingsDialog({ key: item.value, })); - const handleScopeHighlight = (scope: SettingScope) => { + const handleScopeHighlight = (scope: LoadableSettingScope) => { setSelectedScope(scope); }; - const handleScopeSelect = (scope: SettingScope) => { + const handleScopeSelect = (scope: LoadableSettingScope) => { handleScopeHighlight(scope); setFocusSection('settings'); }; diff --git a/packages/cli/src/ui/components/ThemeDialog.tsx b/packages/cli/src/ui/components/ThemeDialog.tsx index f6f35ed8f51..611c9f9a715 100644 --- a/packages/cli/src/ui/components/ThemeDialog.tsx +++ b/packages/cli/src/ui/components/ThemeDialog.tsx @@ -12,7 +12,10 @@ import { themeManager, DEFAULT_THEME } from '../themes/theme-manager.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; import { DiffRenderer } from './messages/DiffRenderer.js'; import { colorizeCode } from '../utils/CodeColorizer.js'; -import type { LoadedSettings } from '../../config/settings.js'; +import type { + LoadableSettingScope, + LoadedSettings, +} from '../../config/settings.js'; import { SettingScope } from '../../config/settings.js'; import { getScopeMessageForSetting } from '../../utils/dialogScopeUtils.js'; import { useKeypress } from '../hooks/useKeypress.js'; @@ -20,7 +23,7 @@ import { ScopeSelector } from './shared/ScopeSelector.js'; interface ThemeDialogProps { /** Callback function when a theme is selected */ - onSelect: (themeName: string, scope: SettingScope) => void; + onSelect: (themeName: string, scope: LoadableSettingScope) => void; /** Callback function when the dialog is cancelled */ onCancel: () => void; @@ -41,7 +44,7 @@ export function ThemeDialog({ availableTerminalHeight, terminalWidth, }: ThemeDialogProps): React.JSX.Element { - const [selectedScope, setSelectedScope] = useState( + const [selectedScope, setSelectedScope] = useState( SettingScope.User, ); @@ -97,12 +100,12 @@ export function ThemeDialog({ onHighlight(themeName); }; - const handleScopeHighlight = useCallback((scope: SettingScope) => { + const handleScopeHighlight = useCallback((scope: LoadableSettingScope) => { setSelectedScope(scope); }, []); const handleScopeSelect = useCallback( - (scope: SettingScope) => { + (scope: LoadableSettingScope) => { onSelect(highlightedThemeName, scope); }, [onSelect, highlightedThemeName], diff --git a/packages/cli/src/ui/components/__snapshots__/HistoryItemDisplay.test.tsx.snap b/packages/cli/src/ui/components/__snapshots__/HistoryItemDisplay.test.tsx.snap index de25e153e7e..b9c4b5e8bbd 100644 --- a/packages/cli/src/ui/components/__snapshots__/HistoryItemDisplay.test.tsx.snap +++ b/packages/cli/src/ui/components/__snapshots__/HistoryItemDisplay.test.tsx.snap @@ -1,5 +1,12 @@ // Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html +exports[` > renders InfoMessage for "info" type with multi-line text 1`] = ` +" +ℹ ⚡ Line 1 + ⚡ Line 2 + ⚡ Line 3" +`; + exports[` > should render a full gemini item when using availableTerminalHeightGemini 1`] = ` "✦ Example code block: 1 Line 1 diff --git a/packages/cli/src/ui/components/messages/InfoMessage.tsx b/packages/cli/src/ui/components/messages/InfoMessage.tsx index b8da1a4e20f..2d8fd2564b7 100644 --- a/packages/cli/src/ui/components/messages/InfoMessage.tsx +++ b/packages/cli/src/ui/components/messages/InfoMessage.tsx @@ -22,10 +22,12 @@ export const InfoMessage: React.FC = ({ text }) => { {prefix} - - - - + + {text.split('\n').map((line, index) => ( + + + + ))} ); diff --git a/packages/cli/src/ui/components/shared/ScopeSelector.tsx b/packages/cli/src/ui/components/shared/ScopeSelector.tsx index 30aa1e403f9..6ba19ddf6f9 100644 --- a/packages/cli/src/ui/components/shared/ScopeSelector.tsx +++ b/packages/cli/src/ui/components/shared/ScopeSelector.tsx @@ -6,19 +6,19 @@ import type React from 'react'; import { Box, Text } from 'ink'; -import type { SettingScope } from '../../../config/settings.js'; +import type { LoadableSettingScope } from '../../../config/settings.js'; import { getScopeItems } from '../../../utils/dialogScopeUtils.js'; import { RadioButtonSelect } from './RadioButtonSelect.js'; interface ScopeSelectorProps { /** Callback function when a scope is selected */ - onSelect: (scope: SettingScope) => void; + onSelect: (scope: LoadableSettingScope) => void; /** Callback function when a scope is highlighted */ - onHighlight: (scope: SettingScope) => void; + onHighlight: (scope: LoadableSettingScope) => void; /** Whether the component is focused */ isFocused: boolean; /** The initial scope to select */ - initialScope: SettingScope; + initialScope: LoadableSettingScope; } export function ScopeSelector({ diff --git a/packages/cli/src/ui/contexts/UIActionsContext.tsx b/packages/cli/src/ui/contexts/UIActionsContext.tsx index 31a0ec2a346..4e2cf4a5e63 100644 --- a/packages/cli/src/ui/contexts/UIActionsContext.tsx +++ b/packages/cli/src/ui/contexts/UIActionsContext.tsx @@ -9,22 +9,22 @@ import { type Key } from '../hooks/useKeypress.js'; import { type IdeIntegrationNudgeResult } from '../IdeIntegrationNudge.js'; import { type FolderTrustChoice } from '../components/FolderTrustDialog.js'; import { type AuthType, type EditorType } from '@google/gemini-cli-core'; -import { type SettingScope } from '../../config/settings.js'; +import { type LoadableSettingScope } from '../../config/settings.js'; import type { AuthState } from '../types.js'; export interface UIActions { - handleThemeSelect: (themeName: string, scope: SettingScope) => void; + handleThemeSelect: (themeName: string, scope: LoadableSettingScope) => void; closeThemeDialog: () => void; handleThemeHighlight: (themeName: string | undefined) => void; handleAuthSelect: ( authType: AuthType | undefined, - scope: SettingScope, + scope: LoadableSettingScope, ) => void; setAuthState: (state: AuthState) => void; onAuthError: (error: string | null) => void; handleEditorSelect: ( editorType: EditorType | undefined, - scope: SettingScope, + scope: LoadableSettingScope, ) => void; exitEditorDialog: () => void; exitPrivacyNotice: () => void; diff --git a/packages/cli/src/ui/debug.ts b/packages/cli/src/ui/debug.ts new file mode 100644 index 00000000000..833dcc8b810 --- /dev/null +++ b/packages/cli/src/ui/debug.ts @@ -0,0 +1,11 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +// A top-level field to track the total number of active animated components. +// This is used for testing to ensure we wait for animations to finish. +export const debugState = { + debugNumAnimatedComponents: 0, +}; diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.tsx b/packages/cli/src/ui/hooks/slashCommandProcessor.test.tsx index b8b0081872b..7fa0db1852b 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.tsx +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.tsx @@ -26,6 +26,7 @@ import { ToolConfirmationOutcome, makeFakeConfig, } from '@google/gemini-cli-core'; +import { appEvents } from '../../utils/events.js'; const { logSlashCommand } = vi.hoisted(() => ({ logSlashCommand: vi.fn(), @@ -1076,4 +1077,26 @@ describe('useSlashCommandProcessor', () => { expect(logSlashCommand).not.toHaveBeenCalled(); }); }); + + it('should reload commands on extension events', async () => { + const result = await setupProcessorHook(); + await waitFor(() => expect(result.current.slashCommands).toEqual([])); + + // Create a new command and make that the result of the fileLoadCommands + // (which is where extension commands come from) + const newCommand = createTestCommand({ + name: 'someNewCommand', + action: vi.fn(), + }); + mockFileLoadCommands.mockResolvedValue([newCommand]); + + // We should not see a change until we fire an event. + await waitFor(() => expect(result.current.slashCommands).toEqual([])); + await act(() => { + appEvents.emit('extensionsStarting'); + }); + await waitFor(() => + expect(result.current.slashCommands).toEqual([newCommand]), + ); + }); }); diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index ec5ee1609a3..fe8be200128 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -8,7 +8,11 @@ import { useCallback, useMemo, useEffect, useState } from 'react'; import { type PartListUnion } from '@google/genai'; import process from 'node:process'; import type { UseHistoryManagerReturn } from './useHistoryManager.js'; -import type { Config } from '@google/gemini-cli-core'; +import type { + Config, + ExtensionsStartingEvent, + ExtensionsStoppingEvent, +} from '@google/gemini-cli-core'; import { GitService, Logger, @@ -39,6 +43,7 @@ import { type ExtensionUpdateAction, type ExtensionUpdateStatus, } from '../state/extensions.js'; +import { appEvents } from '../../utils/events.js'; interface SlashCommandProcessorActions { openAuthDialog: () => void; @@ -249,11 +254,27 @@ export const useSlashCommandProcessor = ( ideClient.addStatusChangeListener(listener); })(); + // TODO: Ideally this would happen more directly inside the ExtensionLoader, + // but the CommandService today is not conducive to that since it isn't a + // long lived service but instead gets fully re-created based on reload + // events within this hook. + const extensionEventListener = ( + _event: ExtensionsStartingEvent | ExtensionsStoppingEvent, + ) => { + // We only care once at least one extension has completed + // starting/stopping + reloadCommands(); + }; + appEvents.on('extensionsStarting', extensionEventListener); + appEvents.on('extensionsStopping', extensionEventListener); + return () => { (async () => { const ideClient = await IdeClient.getInstance(); ideClient.removeStatusChangeListener(listener); })(); + appEvents.off('extensionsStarting', extensionEventListener); + appEvents.off('extensionsStopping', extensionEventListener); }; }, [config, reloadCommands]); diff --git a/packages/cli/src/ui/hooks/useAnimatedScrollbar.test.tsx b/packages/cli/src/ui/hooks/useAnimatedScrollbar.test.tsx new file mode 100644 index 00000000000..3fd84ad7a5a --- /dev/null +++ b/packages/cli/src/ui/hooks/useAnimatedScrollbar.test.tsx @@ -0,0 +1,73 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { act } from 'react'; +import { render } from '../../test-utils/render.js'; +import { useAnimatedScrollbar } from './useAnimatedScrollbar.js'; +import { debugState } from '../debug.js'; +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; + +const TestComponent = ({ isFocused = false }: { isFocused?: boolean }) => { + useAnimatedScrollbar(isFocused, () => {}); + return null; +}; + +describe('useAnimatedScrollbar', () => { + beforeEach(() => { + debugState.debugNumAnimatedComponents = 0; + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('should not increment debugNumAnimatedComponents when not focused', () => { + render(); + expect(debugState.debugNumAnimatedComponents).toBe(0); + }); + + it('should not increment debugNumAnimatedComponents on initial mount even if focused', () => { + render(); + expect(debugState.debugNumAnimatedComponents).toBe(0); + }); + + it('should increment debugNumAnimatedComponents when becoming focused', () => { + const { rerender } = render(); + expect(debugState.debugNumAnimatedComponents).toBe(0); + rerender(); + expect(debugState.debugNumAnimatedComponents).toBe(1); + }); + + it('should decrement debugNumAnimatedComponents when becoming unfocused', () => { + const { rerender } = render(); + rerender(); + expect(debugState.debugNumAnimatedComponents).toBe(1); + rerender(); + expect(debugState.debugNumAnimatedComponents).toBe(0); + }); + + it('should decrement debugNumAnimatedComponents on unmount', () => { + const { rerender, unmount } = render(); + rerender(); + expect(debugState.debugNumAnimatedComponents).toBe(1); + unmount(); + expect(debugState.debugNumAnimatedComponents).toBe(0); + }); + + it('should decrement debugNumAnimatedComponents after animation finishes', async () => { + const { rerender } = render(); + rerender(); + expect(debugState.debugNumAnimatedComponents).toBe(1); + + // Advance timers by enough time for animation to complete (200 + 1000 + 300 + buffer) + await act(async () => { + await vi.advanceTimersByTimeAsync(2000); + }); + + expect(debugState.debugNumAnimatedComponents).toBe(0); + }); +}); diff --git a/packages/cli/src/ui/hooks/useAnimatedScrollbar.ts b/packages/cli/src/ui/hooks/useAnimatedScrollbar.ts index fa290f5b540..aeb1d790410 100644 --- a/packages/cli/src/ui/hooks/useAnimatedScrollbar.ts +++ b/packages/cli/src/ui/hooks/useAnimatedScrollbar.ts @@ -7,6 +7,7 @@ import { useState, useEffect, useRef, useCallback } from 'react'; import { theme } from '../semantic-colors.js'; import { interpolateColor } from '../themes/color-utils.js'; +import { debugState } from '../debug.js'; export function useAnimatedScrollbar( isFocused: boolean, @@ -18,8 +19,13 @@ export function useAnimatedScrollbar( const animationFrame = useRef(null); const timeout = useRef(null); + const isAnimatingRef = useRef(false); const cleanup = useCallback(() => { + if (isAnimatingRef.current) { + debugState.debugNumAnimatedComponents--; + isAnimatingRef.current = false; + } if (animationFrame.current) { clearInterval(animationFrame.current); animationFrame.current = null; @@ -32,6 +38,8 @@ export function useAnimatedScrollbar( const flashScrollbar = useCallback(() => { cleanup(); + debugState.debugNumAnimatedComponents++; + isAnimatingRef.current = true; const fadeInDuration = 200; const visibleDuration = 1000; @@ -67,10 +75,7 @@ export function useAnimatedScrollbar( ); if (progress === 1) { - if (animationFrame.current) { - clearInterval(animationFrame.current); - animationFrame.current = null; - } + cleanup(); } }; diff --git a/packages/cli/src/ui/hooks/useEditorSettings.test.tsx b/packages/cli/src/ui/hooks/useEditorSettings.test.tsx index 3797198a8e3..db46856c7d8 100644 --- a/packages/cli/src/ui/hooks/useEditorSettings.test.tsx +++ b/packages/cli/src/ui/hooks/useEditorSettings.test.tsx @@ -16,7 +16,10 @@ import { import { act } from 'react'; import { render } from '../../test-utils/render.js'; import { useEditorSettings } from './useEditorSettings.js'; -import type { LoadedSettings } from '../../config/settings.js'; +import type { + LoadableSettingScope, + LoadedSettings, +} from '../../config/settings.js'; import { SettingScope } from '../../config/settings.js'; import { MessageType, type HistoryItem } from '../types.js'; import { @@ -186,7 +189,10 @@ describe('useEditorSettings', () => { render(); const editorType: EditorType = 'vscode'; - const scopes = [SettingScope.User, SettingScope.Workspace]; + const scopes: LoadableSettingScope[] = [ + SettingScope.User, + SettingScope.Workspace, + ]; scopes.forEach((scope) => { act(() => { diff --git a/packages/cli/src/ui/hooks/useEditorSettings.ts b/packages/cli/src/ui/hooks/useEditorSettings.ts index 7c0e35c2b54..075de1bc713 100644 --- a/packages/cli/src/ui/hooks/useEditorSettings.ts +++ b/packages/cli/src/ui/hooks/useEditorSettings.ts @@ -5,7 +5,10 @@ */ import { useState, useCallback } from 'react'; -import type { LoadedSettings, SettingScope } from '../../config/settings.js'; +import type { + LoadableSettingScope, + LoadedSettings, +} from '../../config/settings.js'; import { type HistoryItem, MessageType } from '../types.js'; import type { EditorType } from '@google/gemini-cli-core'; import { @@ -18,7 +21,7 @@ interface UseEditorSettingsReturn { openEditorDialog: () => void; handleEditorSelect: ( editorType: EditorType | undefined, - scope: SettingScope, + scope: LoadableSettingScope, ) => void; exitEditorDialog: () => void; } @@ -35,7 +38,7 @@ export const useEditorSettings = ( }, []); const handleEditorSelect = useCallback( - (editorType: EditorType | undefined, scope: SettingScope) => { + (editorType: EditorType | undefined, scope: LoadableSettingScope) => { if ( editorType && (!checkHasEditorType(editorType) || diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index fce41127359..0c1750a3972 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -662,7 +662,7 @@ export const useGeminiStream = ( { type: 'info', text: - `IMPORTANT: This conversation approached the input token limit for ${config.getModel()}. ` + + `IMPORTANT: This conversation exceeded the compress threshold. ` + `A compressed context will be sent for future messages (compressed from: ` + `${eventValue?.originalTokenCount ?? 'unknown'} to ` + `${eventValue?.newTokenCount ?? 'unknown'} tokens).`, @@ -670,7 +670,7 @@ export const useGeminiStream = ( Date.now(), ); }, - [addItem, config, pendingHistoryItemRef, setPendingHistoryItem], + [addItem, pendingHistoryItemRef, setPendingHistoryItem], ); const handleMaxSessionTurnsEvent = useCallback( @@ -1286,5 +1286,6 @@ export const useGeminiStream = ( handleApprovalModeChange, activePtyId, loopDetectionConfirmationRequest, + handleShellCommand, }; }; diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts index 194f5f27fc9..3bdaff42955 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts @@ -66,42 +66,57 @@ export function useQuotaAndFallback({ if (error instanceof TerminalQuotaError) { // Pro Quota specific messages (Interactive) if (isPaidTier) { - message = `⚡ You have reached your daily ${failedModel} quota limit. -⚡ You can choose to authenticate with a paid API key or continue with the fallback model. -⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; + message = [ + `⚡ You have reached your daily ${failedModel} quota limit.`, + `⚡ You can choose to authenticate with a paid API key or continue with the fallback model.`, + `⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey}`, + ].join('\n'); } else { - message = `⚡ You have reached your daily ${failedModel} quota limit. -⚡ You can choose to authenticate with a paid API key or continue with the fallback model. -⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist -⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key -⚡ You can switch authentication methods by typing /auth`; + message = [ + `⚡ You have reached your daily ${failedModel} quota limit.`, + `⚡ You can choose to authenticate with a paid API key or continue with the fallback model.`, + `⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist`, + `⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key`, + `⚡ You can switch authentication methods by typing /auth`, + ].join('\n'); } } else if (error instanceof RetryableQuotaError) { // Short term quota retries exhausted (Automatic fallback) - const actionMessage = `⚡ Your requests are being throttled right now due to server being at capacity for ${failedModel}.\n⚡ Automatically switching from ${failedModel} to ${fallbackModel} for the remainder of this session.`; + const actionMessage = [ + `⚡ Your requests are being throttled right now due to server being at capacity for ${failedModel}.`, + `⚡ Automatically switching from ${failedModel} to ${fallbackModel} for the remainder of this session.`, + ].join('\n'); if (isPaidTier) { - message = `${actionMessage} -⚡ To continue accessing the ${failedModel} model, retry your request after some time or consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; + message = [ + actionMessage, + `⚡ To continue accessing the ${failedModel} model, retry your request after some time or consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey}`, + ].join('\n'); } else { - message = `${actionMessage} -⚡ Retry your requests after some time. Otherwise consider upgrading to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist -⚡ You can switch authentication methods by typing /auth`; + message = [ + actionMessage, + `⚡ Retry your requests after some time. Otherwise consider upgrading to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist`, + `⚡ You can switch authentication methods by typing /auth`, + ].join('\n'); } } else { // Other errors (Automatic fallback) const actionMessage = `⚡ Automatically switching from ${failedModel} to ${fallbackModel} for faster responses for the remainder of this session.`; if (isPaidTier) { - message = `${actionMessage} -⚡ Your requests are being throttled temporarily due to server being at capacity for ${failedModel} or there is a service outage. -⚡ To continue accessing the ${failedModel} model, you can retry your request after some time or consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; + message = [ + actionMessage, + `⚡ Your requests are being throttled temporarily due to server being at capacity for ${failedModel} or there is a service outage.`, + `⚡ To continue accessing the ${failedModel} model, you can retry your request after some time or consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey}`, + ].join('\n'); } else { - message = `${actionMessage} -⚡ Your requests are being throttled temporarily due to server being at capacity for ${failedModel} or there is a service outage. -⚡ To avoid being throttled, you can retry your request after some time or upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist -⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key -⚡ You can switch authentication methods by typing /auth`; + message = [ + actionMessage, + `⚡ Your requests are being throttled temporarily due to server being at capacity for ${failedModel} or there is a service outage.`, + `⚡ To avoid being throttled, you can retry your request after some time or upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist`, + `⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key`, + `⚡ You can switch authentication methods by typing /auth`, + ].join('\n'); } } diff --git a/packages/cli/src/ui/hooks/useThemeCommand.ts b/packages/cli/src/ui/hooks/useThemeCommand.ts index 46cf0e5d851..72133e9b11b 100644 --- a/packages/cli/src/ui/hooks/useThemeCommand.ts +++ b/packages/cli/src/ui/hooks/useThemeCommand.ts @@ -6,7 +6,10 @@ import { useState, useCallback } from 'react'; import { themeManager } from '../themes/theme-manager.js'; -import type { LoadedSettings, SettingScope } from '../../config/settings.js'; // Import LoadedSettings, AppSettings, MergedSetting +import type { + LoadableSettingScope, + LoadedSettings, +} from '../../config/settings.js'; // Import LoadedSettings, AppSettings, MergedSetting import { type HistoryItem, MessageType } from '../types.js'; import process from 'node:process'; @@ -14,7 +17,7 @@ interface UseThemeCommandReturn { isThemeDialogOpen: boolean; openThemeDialog: () => void; closeThemeDialog: () => void; - handleThemeSelect: (themeName: string, scope: SettingScope) => void; + handleThemeSelect: (themeName: string, scope: LoadableSettingScope) => void; handleThemeHighlight: (themeName: string | undefined) => void; } @@ -68,7 +71,7 @@ export const useThemeCommand = ( }, [applyTheme, loadedSettings]); const handleThemeSelect = useCallback( - (themeName: string, scope: SettingScope) => { + (themeName: string, scope: LoadableSettingScope) => { try { // Merge user and workspace custom themes (workspace takes precedence) const mergedCustomThemes = { diff --git a/packages/cli/src/ui/keyMatchers.test.ts b/packages/cli/src/ui/keyMatchers.test.ts index baaa88dbff6..483f0329d13 100644 --- a/packages/cli/src/ui/keyMatchers.test.ts +++ b/packages/cli/src/ui/keyMatchers.test.ts @@ -50,7 +50,8 @@ describe('keyMatchers', () => { [Command.SUBMIT]: (key: Key) => key.name === 'return' && !key.ctrl && !key.meta && !key.paste, [Command.NEWLINE]: (key: Key) => - key.name === 'return' && (key.ctrl || key.meta || key.paste), + (key.name === 'return' && (key.paste || key.shift)) || + (key.name === 'j' && key.ctrl), [Command.OPEN_EXTERNAL_EDITOR]: (key: Key) => key.ctrl && (key.name === 'x' || key.sequence === '\x18'), [Command.PASTE_CLIPBOARD_IMAGE]: (key: Key) => key.ctrl && key.name === 'v', @@ -72,6 +73,8 @@ describe('keyMatchers', () => { key.ctrl && key.name === 'f', [Command.EXPAND_SUGGESTION]: (key: Key) => key.name === 'right', [Command.COLLAPSE_SUGGESTION]: (key: Key) => key.name === 'left', + [Command.EXECUTE_PROMPT_COMMAND]: (key: Key) => + key.name === 'return' && (key.ctrl || key.meta), }; // Test data for each command with positive and negative test cases @@ -212,11 +215,16 @@ describe('keyMatchers', () => { { command: Command.NEWLINE, positive: [ + createKey('return', { paste: true }), + createKey('return', { shift: true }), + createKey('j', { ctrl: true }), + ], + negative: [ + createKey('return'), + createKey('n'), createKey('return', { ctrl: true }), createKey('return', { meta: true }), - createKey('return', { paste: true }), ], - negative: [createKey('return'), createKey('n')], }, // External tools @@ -297,6 +305,14 @@ describe('keyMatchers', () => { positive: [createKey('f', { ctrl: true })], negative: [createKey('f')], }, + { + command: Command.EXECUTE_PROMPT_COMMAND, + positive: [ + createKey('return', { ctrl: true }), + createKey('return', { meta: true }), + ], + negative: [createKey('return'), createKey('return', { shift: true })], + }, ]; describe('Data-driven key binding matches original logic', () => { diff --git a/packages/cli/src/utils/dialogScopeUtils.ts b/packages/cli/src/utils/dialogScopeUtils.ts index fd4cbbd4fcb..ccf93b6a68b 100644 --- a/packages/cli/src/utils/dialogScopeUtils.ts +++ b/packages/cli/src/utils/dialogScopeUtils.ts @@ -4,8 +4,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { LoadedSettings } from '../config/settings.js'; -import { SettingScope } from '../config/settings.js'; +import type { + LoadableSettingScope, + LoadedSettings, +} from '../config/settings.js'; +import { isLoadableSettingScope, SettingScope } from '../config/settings.js'; import { settingExistsInScope } from './settingsUtils.js'; /** @@ -20,7 +23,10 @@ export const SCOPE_LABELS = { /** * Helper function to get scope items for radio button selects */ -export function getScopeItems() { +export function getScopeItems(): Array<{ + label: string; + value: LoadableSettingScope; +}> { return [ { label: SCOPE_LABELS[SettingScope.User], value: SettingScope.User }, { @@ -36,12 +42,12 @@ export function getScopeItems() { */ export function getScopeMessageForSetting( settingKey: string, - selectedScope: SettingScope, + selectedScope: LoadableSettingScope, settings: LoadedSettings, ): string { - const otherScopes = Object.values(SettingScope).filter( - (scope) => scope !== selectedScope, - ); + const otherScopes = Object.values(SettingScope) + .filter(isLoadableSettingScope) + .filter((scope) => scope !== selectedScope); const modifiedInOtherScopes = otherScopes.filter((scope) => { const scopeSettings = settings.forScope(scope).settings; diff --git a/packages/cli/src/utils/settingsUtils.ts b/packages/cli/src/utils/settingsUtils.ts index a9a429370ac..7ec5fd5885d 100644 --- a/packages/cli/src/utils/settingsUtils.ts +++ b/packages/cli/src/utils/settingsUtils.ts @@ -6,8 +6,8 @@ import type { Settings, - SettingScope, LoadedSettings, + LoadableSettingScope, } from '../config/settings.js'; import type { SettingDefinition, @@ -391,7 +391,7 @@ export function saveModifiedSettings( modifiedSettings: Set, pendingSettings: Settings, loadedSettings: LoadedSettings, - scope: SettingScope, + scope: LoadableSettingScope, ): void { modifiedSettings.forEach((settingKey) => { const path = settingKey.split('.'); diff --git a/packages/cli/src/validateNonInterActiveAuth.test.ts b/packages/cli/src/validateNonInterActiveAuth.test.ts index 475e079bdf7..e9f8c7c8ae5 100644 --- a/packages/cli/src/validateNonInterActiveAuth.test.ts +++ b/packages/cli/src/validateNonInterActiveAuth.test.ts @@ -69,7 +69,7 @@ describe('validateNonInterActiveAuth', () => { }, }, isTrusted: true, - migratedInMemorScopes: new Set(), + migratedInMemoryScopes: new Set(), forScope: vi.fn(), computeMergedSettings: vi.fn(), } as unknown as LoadedSettings; diff --git a/packages/core/package.json b/packages/core/package.json index 99bdba4a48a..13c57d7dea0 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -1,6 +1,6 @@ { "name": "@google/gemini-cli-core", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "description": "Gemini CLI Core", "repository": { "type": "git", diff --git a/packages/core/src/agents/executor.test.ts b/packages/core/src/agents/executor.test.ts index 13e56c6a876..3d58df3704f 100644 --- a/packages/core/src/agents/executor.test.ts +++ b/packages/core/src/agents/executor.test.ts @@ -4,7 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; import { AgentExecutor, type ActivityCallback } from './executor.js'; import { makeFakeConfig } from '../test-utils/config.js'; import { ToolRegistry } from '../tools/tool-registry.js'; @@ -20,6 +28,7 @@ import { type Part, type GenerateContentResponse, type GenerateContentConfig, + type Content, } from '@google/genai'; import type { Config } from '../config/config.js'; import { MockTool } from '../test-utils/mock-tool.js'; @@ -44,10 +53,26 @@ import type { } from './types.js'; import { AgentTerminateMode } from './types.js'; import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js'; +import { CompressionStatus } from '../core/turn.js'; +import { ChatCompressionService } from '../services/chatCompressionService.js'; + +const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted( + () => ({ + mockSendMessageStream: vi.fn(), + mockExecuteToolCall: vi.fn(), + mockCompress: vi.fn(), + }), +); + +let mockChatHistory: Content[] = []; +const mockSetHistory = vi.fn((newHistory: Content[]) => { + mockChatHistory = newHistory; +}); -const { mockSendMessageStream, mockExecuteToolCall } = vi.hoisted(() => ({ - mockSendMessageStream: vi.fn(), - mockExecuteToolCall: vi.fn(), +vi.mock('../services/chatCompressionService.js', () => ({ + ChatCompressionService: vi.fn().mockImplementation(() => ({ + compress: mockCompress, + })), })); vi.mock('../core/geminiChat.js', async (importOriginal) => { @@ -56,6 +81,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => { ...actual, GeminiChat: vi.fn().mockImplementation(() => ({ sendMessageStream: mockSendMessageStream, + getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]), + setHistory: mockSetHistory, })), }; }); @@ -193,6 +220,8 @@ describe('AgentExecutor', () => { beforeEach(async () => { vi.resetAllMocks(); + mockCompress.mockClear(); + mockSetHistory.mockClear(); mockSendMessageStream.mockReset(); mockExecuteToolCall.mockReset(); mockedLogAgentStart.mockReset(); @@ -200,10 +229,21 @@ describe('AgentExecutor', () => { mockedPromptIdContext.getStore.mockReset(); mockedPromptIdContext.run.mockImplementation((_id, fn) => fn()); + (ChatCompressionService as Mock).mockImplementation(() => ({ + compress: mockCompress, + })); + mockCompress.mockResolvedValue({ + newHistory: null, + info: { compressionStatus: CompressionStatus.NOOP }, + }); + MockedGeminiChat.mockImplementation( () => ({ sendMessageStream: mockSendMessageStream, + getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]), + getLastPromptTokenCount: vi.fn(() => 100), + setHistory: mockSetHistory, }) as unknown as GeminiChat, ); @@ -1440,4 +1480,205 @@ describe('AgentExecutor', () => { expect(recoveryEvent.reason).toBe(AgentTerminateMode.MAX_TURNS); }); }); + describe('Chat Compression', () => { + const mockWorkResponse = (id: string) => { + mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); + mockExecuteToolCall.mockResolvedValueOnce({ + status: 'success', + request: { + callId: id, + name: LS_TOOL_NAME, + args: { path: '.' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: id, + resultDisplay: 'ok', + responseParts: [ + { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, + }); + }; + + it('should attempt to compress chat history on each turn', async () => { + const definition = createTestDefinition(); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + + // Mock compression to do nothing + mockCompress.mockResolvedValue({ + newHistory: null, + info: { compressionStatus: CompressionStatus.NOOP }, + }); + + // Turn 1 + mockWorkResponse('t1'); + + // Turn 2: Complete + mockModelResponse( + [ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'Done' }, + id: 'call2', + }, + ], + 'T2', + ); + + await executor.run({ goal: 'Compress test' }, signal); + + expect(mockCompress).toHaveBeenCalledTimes(2); + }); + + it('should update chat history when compression is successful', async () => { + const definition = createTestDefinition(); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + const compressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'compressed' }] }, + ]; + + mockCompress.mockResolvedValue({ + newHistory: compressedHistory, + info: { compressionStatus: CompressionStatus.COMPRESSED }, + }); + + // Turn 1: Complete + mockModelResponse( + [ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'Done' }, + id: 'call1', + }, + ], + 'T1', + ); + + await executor.run({ goal: 'Compress success' }, signal); + + expect(mockCompress).toHaveBeenCalledTimes(1); + expect(mockSetHistory).toHaveBeenCalledTimes(1); + expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory); + }); + + it('should pass hasFailedCompressionAttempt=true to compression after a failure', async () => { + const definition = createTestDefinition(); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + + // First call fails + mockCompress.mockResolvedValueOnce({ + newHistory: null, + info: { + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }, + }); + // Second call is neutral + mockCompress.mockResolvedValueOnce({ + newHistory: null, + info: { compressionStatus: CompressionStatus.NOOP }, + }); + + // Turn 1 + mockWorkResponse('t1'); + // Turn 2: Complete + mockModelResponse( + [ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'Done' }, + id: 't2', + }, + ], + 'T2', + ); + + await executor.run({ goal: 'Compress fail' }, signal); + + expect(mockCompress).toHaveBeenCalledTimes(2); + // First call, hasFailedCompressionAttempt is false + expect(mockCompress.mock.calls[0][5]).toBe(false); + // Second call, hasFailedCompressionAttempt is true + expect(mockCompress.mock.calls[1][5]).toBe(true); + }); + + it('should reset hasFailedCompressionAttempt flag after a successful compression', async () => { + const definition = createTestDefinition(); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + const compressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'compressed' }] }, + ]; + + // Turn 1: Fails + mockCompress.mockResolvedValueOnce({ + newHistory: null, + info: { + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }, + }); + // Turn 2: Succeeds + mockCompress.mockResolvedValueOnce({ + newHistory: compressedHistory, + info: { compressionStatus: CompressionStatus.COMPRESSED }, + }); + // Turn 3: Neutral + mockCompress.mockResolvedValueOnce({ + newHistory: null, + info: { compressionStatus: CompressionStatus.NOOP }, + }); + + // Turn 1 + mockWorkResponse('t1'); + // Turn 2 + mockWorkResponse('t2'); + // Turn 3: Complete + mockModelResponse( + [ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'Done' }, + id: 't3', + }, + ], + 'T3', + ); + + await executor.run({ goal: 'Compress reset' }, signal); + + expect(mockCompress).toHaveBeenCalledTimes(3); + // Call 1: hasFailed... is false + expect(mockCompress.mock.calls[0][5]).toBe(false); + // Call 2: hasFailed... is true + expect(mockCompress.mock.calls[1][5]).toBe(true); + // Call 3: hasFailed... is false again + expect(mockCompress.mock.calls[2][5]).toBe(false); + + expect(mockSetHistory).toHaveBeenCalledTimes(1); + expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory); + }); + }); }); diff --git a/packages/core/src/agents/executor.ts b/packages/core/src/agents/executor.ts index 8928a75e694..59828176608 100644 --- a/packages/core/src/agents/executor.ts +++ b/packages/core/src/agents/executor.ts @@ -18,7 +18,8 @@ import type { } from '@google/genai'; import { executeToolCall } from '../core/nonInteractiveToolExecutor.js'; import { ToolRegistry } from '../tools/tool-registry.js'; -import type { ToolCallRequestInfo } from '../core/turn.js'; +import { type ToolCallRequestInfo, CompressionStatus } from '../core/turn.js'; +import { ChatCompressionService } from '../services/chatCompressionService.js'; import { getDirectoryContextString } from '../utils/environmentContext.js'; import { GLOB_TOOL_NAME, @@ -84,6 +85,8 @@ export class AgentExecutor { private readonly toolRegistry: ToolRegistry; private readonly runtimeContext: Config; private readonly onActivity?: ActivityCallback; + private readonly compressionService: ChatCompressionService; + private hasFailedCompressionAttempt = false; /** * Creates and validates a new `AgentExecutor` instance. @@ -125,6 +128,7 @@ export class AgentExecutor { // registered; their schemas are passed directly to the model later. } + agentToolRegistry.sortTools(); // Validate that all registered tools are safe for non-interactive // execution. await AgentExecutor.validateTools(agentToolRegistry, definition.name); @@ -159,6 +163,7 @@ export class AgentExecutor { this.runtimeContext = runtimeContext; this.toolRegistry = toolRegistry; this.onActivity = onActivity; + this.compressionService = new ChatCompressionService(); const randomIdPart = Math.random().toString(36).slice(2, 8); // parentPromptId will be undefined if this agent is invoked directly @@ -184,6 +189,8 @@ export class AgentExecutor { ): Promise { const promptId = `${this.agentId}#${turnCounter}`; + await this.tryCompressChat(chat, promptId); + const { functionCalls } = await promptIdContext.run(promptId, async () => this.callModel(chat, currentMessage, tools, combinedSignal, promptId), ); @@ -548,6 +555,34 @@ export class AgentExecutor { } } + private async tryCompressChat( + chat: GeminiChat, + prompt_id: string, + ): Promise { + const model = this.definition.modelConfig.model; + + const { newHistory, info } = await this.compressionService.compress( + chat, + prompt_id, + false, + model, + this.runtimeContext, + this.hasFailedCompressionAttempt, + ); + + if ( + info.compressionStatus === + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT + ) { + this.hasFailedCompressionAttempt = true; + } else if (info.compressionStatus === CompressionStatus.COMPRESSED) { + if (newHistory) { + chat.setHistory(newHistory); + this.hasFailedCompressionAttempt = false; + } + } + } + /** * Calls the generative model with the current context and tools. * diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 6b4e65dc794..9f4650be072 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -31,6 +31,7 @@ import { RipGrepTool, canUseRipgrep } from '../tools/ripGrep.js'; import { logRipgrepFallback } from '../telemetry/loggers.js'; import { RipgrepFallbackEvent } from '../telemetry/types.js'; import { ToolRegistry } from '../tools/tool-registry.js'; +import { DEFAULT_MODEL_CONFIGS } from './defaultModelConfigs.js'; vi.mock('fs', async (importOriginal) => { const actual = await importOriginal(); @@ -49,6 +50,7 @@ vi.mock('../tools/tool-registry', () => { const ToolRegistryMock = vi.fn(); ToolRegistryMock.prototype.registerTool = vi.fn(); ToolRegistryMock.prototype.discoverAllTools = vi.fn(); + ToolRegistryMock.prototype.sortTools = vi.fn(); ToolRegistryMock.prototype.getAllTools = vi.fn(() => []); // Mock methods if needed ToolRegistryMock.prototype.getTool = vi.fn(); ToolRegistryMock.prototype.getFunctionDeclarations = vi.fn(() => []); @@ -1248,6 +1250,92 @@ describe('BaseLlmClient Lifecycle', () => { }); }); +describe('Generation Config Merging (HACK)', () => { + const MODEL = 'gemini-pro'; + const SANDBOX: SandboxConfig = { + command: 'docker', + image: 'gemini-cli-sandbox', + }; + const TARGET_DIR = '/path/to/target'; + const DEBUG_MODE = false; + const QUESTION = 'test question'; + const USER_MEMORY = 'Test User Memory'; + const TELEMETRY_SETTINGS = { enabled: false }; + const EMBEDDING_MODEL = 'gemini-embedding'; + const SESSION_ID = 'test-session-id'; + const baseParams: ConfigParameters = { + cwd: '/tmp', + embeddingModel: EMBEDDING_MODEL, + sandbox: SANDBOX, + targetDir: TARGET_DIR, + debugMode: DEBUG_MODE, + question: QUESTION, + userMemory: USER_MEMORY, + telemetry: TELEMETRY_SETTINGS, + sessionId: SESSION_ID, + model: MODEL, + usageStatisticsEnabled: false, + }; + + it('should merge default aliases when user provides only overrides', () => { + const userOverrides = [ + { + match: { model: 'test-model' }, + modelConfig: { generateContentConfig: { temperature: 0.1 } }, + }, + ]; + + const params: ConfigParameters = { + ...baseParams, + modelConfigServiceConfig: { + overrides: userOverrides, + }, + }; + + const config = new Config(params); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const serviceConfig = (config.modelConfigService as any).config; + + // Assert that the default aliases are present + expect(serviceConfig.aliases).toEqual(DEFAULT_MODEL_CONFIGS.aliases); + // Assert that the user's overrides are present + expect(serviceConfig.overrides).toEqual(userOverrides); + }); + + it('should use user-provided aliases if they exist', () => { + const userAliases = { + 'my-alias': { + modelConfig: { model: 'my-model' }, + }, + }; + + const params: ConfigParameters = { + ...baseParams, + modelConfigServiceConfig: { + aliases: userAliases, + }, + }; + + const config = new Config(params); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const serviceConfig = (config.modelConfigService as any).config; + + // Assert that the user's aliases are used, not the defaults + expect(serviceConfig.aliases).toEqual(userAliases); + }); + + it('should use default generation config if none is provided', () => { + const params: ConfigParameters = { ...baseParams }; + + const config = new Config(params); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const serviceConfig = (config.modelConfigService as any).config; + + // Assert that the full default config is used + expect(serviceConfig).toEqual(DEFAULT_MODEL_CONFIGS); + }); +}); + describe('Config getHooks', () => { const baseParams: ConfigParameters = { cwd: '/tmp', @@ -1354,6 +1442,21 @@ describe('Config getHooks', () => { expect(retrievedHooks).toEqual(allEventHooks); expect(Object.keys(retrievedHooks!)).toHaveLength(11); // All hook event types }); + + describe('setModel', () => { + it('should allow setting a pro (any) model and disable fallback mode', () => { + const config = new Config(baseParams); + config.setFallbackMode(true); + expect(config.isInFallbackMode()).toBe(true); + + const proModel = 'gemini-2.5-pro'; + config.setModel(proModel); + + expect(config.getModel()).toBe(proModel); + expect(config.isInFallbackMode()).toBe(false); + expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith(proModel); + }); + }); }); describe('Config getExperiments', () => { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 20fe3b578d9..5dcaf503e0a 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -62,6 +62,9 @@ import { RipgrepFallbackEvent } from '../telemetry/types.js'; import type { FallbackModelHandler } from '../fallback/types.js'; import { ModelRouterService } from '../routing/modelRouterService.js'; import { OutputFormat } from '../output/types.js'; +import type { ModelConfigServiceConfig } from '../services/modelConfigService.js'; +import { ModelConfigService } from '../services/modelConfigService.js'; +import { DEFAULT_MODEL_CONFIGS } from './defaultModelConfigs.js'; // Re-export OAuth config type export type { MCPOAuthConfig, AnyToolInvocation }; @@ -291,6 +294,7 @@ export interface ConfigParameters { recordResponses?: string; ptyInfo?: string; disableYoloMode?: boolean; + modelConfigServiceConfig?: ModelConfigServiceConfig; enableHooks?: boolean; experiments?: Experiments; hooks?: { @@ -309,6 +313,7 @@ export class Config { private fileSystemService: FileSystemService; private contentGeneratorConfig!: ContentGeneratorConfig; private contentGenerator!: ContentGenerator; + readonly modelConfigService: ModelConfigService; private readonly embeddingModel: string; private readonly sandbox: SandboxConfig | undefined; private readonly targetDir: string; @@ -560,6 +565,25 @@ export class Config { } this.geminiClient = new GeminiClient(this); this.modelRouterService = new ModelRouterService(this); + + // HACK: The settings loading logic doesn't currently merge the default + // generation config with the user's settings. This means if a user provides + // any `generation` settings (e.g., just `overrides`), the default `aliases` + // are lost. This hack manually merges the default aliases back in if they + // are missing from the user's config. + // TODO(12593): Fix the settings loading logic to properly merge defaults and + // remove this hack. + let modelConfigServiceConfig = params.modelConfigServiceConfig; + if (modelConfigServiceConfig && !modelConfigServiceConfig.aliases) { + modelConfigServiceConfig = { + ...modelConfigServiceConfig, + aliases: DEFAULT_MODEL_CONFIGS.aliases, + }; + } + + this.modelConfigService = new ModelConfigService( + modelConfigServiceConfig ?? DEFAULT_MODEL_CONFIGS, + ); } /** @@ -692,10 +716,7 @@ export class Config { } setModel(newModel: string): void { - // Do not allow Pro usage if the user is in fallback mode. - if (newModel.includes('pro') && this.isInFallbackMode()) { - return; - } + this.setFallbackMode(false); if (this.model !== newModel) { this.model = newModel; @@ -1349,6 +1370,7 @@ export class Config { } await registry.discoverAllTools(); + registry.sortTools(); return registry; } diff --git a/packages/core/src/config/defaultModelConfigs.ts b/packages/core/src/config/defaultModelConfigs.ts new file mode 100644 index 00000000000..3ee1730defb --- /dev/null +++ b/packages/core/src/config/defaultModelConfigs.ts @@ -0,0 +1,129 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { ModelConfigServiceConfig } from '../services/modelConfigService.js'; + +// The default model configs. We use `base` as the parent for all of our model +// configs, while `chat-base`, a child of `base`, is the parent of the models +// we use in the "chat" experience. +export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = { + aliases: { + base: { + modelConfig: { + generateContentConfig: { + temperature: 0, + topP: 1, + }, + }, + }, + 'chat-base': { + extends: 'base', + modelConfig: { + generateContentConfig: { + thinkingConfig: { + includeThoughts: true, + thinkingBudget: -1, + }, + }, + }, + }, + // Because `gemini-2.5-pro` and related model configs are "user-facing" + // today, i.e. they could be passed via `--model`, we have to be careful to + // ensure these model configs can be used interactively. + // TODO(joshualitt): Introduce internal base configs for the various models, + // note: we will have to think carefully about names. + 'gemini-2.5-pro': { + extends: 'chat-base', + modelConfig: { + model: 'gemini-2.5-pro', + }, + }, + 'gemini-2.5-flash': { + extends: 'chat-base', + modelConfig: { + model: 'gemini-2.5-flash', + }, + }, + 'gemini-2.5-flash-lite': { + extends: 'chat-base', + modelConfig: { + model: 'gemini-2.5-flash-lite', + }, + }, + classifier: { + extends: 'base', + modelConfig: { + model: 'gemini-2.5-flash-lite', + generateContentConfig: { + maxOutputTokens: 1024, + thinkingConfig: { + thinkingBudget: 512, + }, + }, + }, + }, + 'prompt-completion': { + extends: 'base', + modelConfig: { + model: 'gemini-2.5-flash-lite', + generateContentConfig: { + temperature: 0.3, + maxOutputTokens: 16000, + thinkingConfig: { + thinkingBudget: 0, + }, + }, + }, + }, + 'edit-corrector': { + extends: 'base', + modelConfig: { + model: 'gemini-2.5-flash-lite', + generateContentConfig: { + thinkingConfig: { + thinkingBudget: 0, + }, + }, + }, + }, + 'summarizer-default': { + extends: 'base', + modelConfig: { + model: 'gemini-2.5-flash-lite', + generateContentConfig: { + maxOutputTokens: 2000, + }, + }, + }, + 'summarizer-shell': { + extends: 'base', + modelConfig: { + model: 'gemini-2.5-flash-lite', + generateContentConfig: { + maxOutputTokens: 2000, + }, + }, + }, + 'web-search-tool': { + extends: 'base', + modelConfig: { + model: 'gemini-2.5-flash', + generateContentConfig: { + tools: [{ googleSearch: {} }], + }, + }, + }, + 'web-fetch-tool': { + extends: 'base', + modelConfig: { + model: 'gemini-2.5-flash', + generateContentConfig: { + tools: [{ urlContext: {} }], + }, + }, + }, + }, +}; diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts index b48129b4120..cd293a72a63 100644 --- a/packages/core/src/confirmation-bus/message-bus.ts +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -50,7 +50,10 @@ export class MessageBus extends EventEmitter { } if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { - const decision = this.policyEngine.check(message.toolCall); + const decision = this.policyEngine.check( + message.toolCall, + message.serverName, + ); switch (decision) { case PolicyDecision.ALLOW: diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index 2b4bcf56854..52d7bd2e9fe 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -19,6 +19,7 @@ export interface ToolConfirmationRequest { type: MessageBusType.TOOL_CONFIRMATION_REQUEST; toolCall: FunctionCall; correlationId: string; + serverName?: string; } export interface ToolConfirmationResponse { diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 7dbf8021b84..9b7aefa8bd1 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -1553,7 +1553,7 @@ describe('CoreToolScheduler request queueing', () => { expect(statusUpdates).toContain('awaiting_approval'); expect(executeFn).not.toHaveBeenCalled(); expect(onAllToolCallsComplete).not.toHaveBeenCalled(); - }); + }, 20000); it('should handle two synchronous calls to schedule', async () => { const executeFn = vi.fn().mockResolvedValue({ diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 513fae847dc..a867354c647 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -6,6 +6,7 @@ // Export config export * from './config/config.js'; +export * from './config/defaultModelConfigs.js'; export * from './output/types.js'; export * from './output/json-formatter.js'; export * from './output/stream-json-formatter.js'; diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts index 03caf2524e3..fd3a4b62b2e 100644 --- a/packages/core/src/policy/policy-engine.test.ts +++ b/packages/core/src/policy/policy-engine.test.ts @@ -22,13 +22,13 @@ describe('PolicyEngine', () => { describe('constructor', () => { it('should use default config when none provided', () => { - const decision = engine.check({ name: 'test' }); + const decision = engine.check({ name: 'test' }, undefined); expect(decision).toBe(PolicyDecision.ASK_USER); }); it('should respect custom default decision', () => { engine = new PolicyEngine({ defaultDecision: PolicyDecision.DENY }); - const decision = engine.check({ name: 'test' }); + const decision = engine.check({ name: 'test' }, undefined); expect(decision).toBe(PolicyDecision.DENY); }); @@ -57,9 +57,15 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); - expect(engine.check({ name: 'shell' })).toBe(PolicyDecision.ALLOW); - expect(engine.check({ name: 'edit' })).toBe(PolicyDecision.DENY); - expect(engine.check({ name: 'other' })).toBe(PolicyDecision.ASK_USER); + expect(engine.check({ name: 'shell' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + expect(engine.check({ name: 'edit' }, undefined)).toBe( + PolicyDecision.DENY, + ); + expect(engine.check({ name: 'other' }, undefined)).toBe( + PolicyDecision.ASK_USER, + ); }); it('should match by args pattern', () => { @@ -87,8 +93,8 @@ describe('PolicyEngine', () => { args: { command: 'ls -la' }, }; - expect(engine.check(dangerousCall)).toBe(PolicyDecision.DENY); - expect(engine.check(safeCall)).toBe(PolicyDecision.ALLOW); + expect(engine.check(dangerousCall, undefined)).toBe(PolicyDecision.DENY); + expect(engine.check(safeCall, undefined)).toBe(PolicyDecision.ALLOW); }); it('should apply rules by priority', () => { @@ -100,7 +106,9 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Higher priority rule (ALLOW) should win - expect(engine.check({ name: 'shell' })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'shell' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); }); it('should apply wildcard rules (no toolName)', () => { @@ -111,8 +119,10 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); - expect(engine.check({ name: 'safe-tool' })).toBe(PolicyDecision.ALLOW); - expect(engine.check({ name: 'any-other-tool' })).toBe( + expect(engine.check({ name: 'safe-tool' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + expect(engine.check({ name: 'any-other-tool' }, undefined)).toBe( PolicyDecision.DENY, ); }); @@ -129,13 +139,17 @@ describe('PolicyEngine', () => { engine = new PolicyEngine(config); // ASK_USER should become DENY in non-interactive mode - expect(engine.check({ name: 'interactive-tool' })).toBe( + expect(engine.check({ name: 'interactive-tool' }, undefined)).toBe( PolicyDecision.DENY, ); // ALLOW should remain ALLOW - expect(engine.check({ name: 'allowed-tool' })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'allowed-tool' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); // Default ASK_USER should also become DENY - expect(engine.check({ name: 'unknown-tool' })).toBe(PolicyDecision.DENY); + expect(engine.check({ name: 'unknown-tool' }, undefined)).toBe( + PolicyDecision.DENY, + ); }); }); @@ -165,11 +179,15 @@ describe('PolicyEngine', () => { }); it('should apply newly added rules', () => { - expect(engine.check({ name: 'new-tool' })).toBe(PolicyDecision.ASK_USER); + expect(engine.check({ name: 'new-tool' }, undefined)).toBe( + PolicyDecision.ASK_USER, + ); engine.addRule({ toolName: 'new-tool', decision: PolicyDecision.ALLOW }); - expect(engine.check({ name: 'new-tool' })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'new-tool' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); }); }); @@ -235,29 +253,31 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Should match my-server tools - expect(engine.check({ name: 'my-server__tool1' })).toBe( + expect(engine.check({ name: 'my-server__tool1' }, undefined)).toBe( PolicyDecision.ALLOW, ); - expect(engine.check({ name: 'my-server__another_tool' })).toBe( + expect(engine.check({ name: 'my-server__another_tool' }, undefined)).toBe( PolicyDecision.ALLOW, ); // Should match blocked-server tools - expect(engine.check({ name: 'blocked-server__tool1' })).toBe( - PolicyDecision.DENY, - ); - expect(engine.check({ name: 'blocked-server__dangerous' })).toBe( + expect(engine.check({ name: 'blocked-server__tool1' }, undefined)).toBe( PolicyDecision.DENY, ); + expect( + engine.check({ name: 'blocked-server__dangerous' }, undefined), + ).toBe(PolicyDecision.DENY); // Should not match other patterns - expect(engine.check({ name: 'other-server__tool' })).toBe( + expect(engine.check({ name: 'other-server__tool' }, undefined)).toBe( PolicyDecision.ASK_USER, ); - expect(engine.check({ name: 'my-server-tool' })).toBe( + expect(engine.check({ name: 'my-server-tool' }, undefined)).toBe( PolicyDecision.ASK_USER, ); // No __ separator - expect(engine.check({ name: 'my-server' })).toBe(PolicyDecision.ASK_USER); // No tool name + expect(engine.check({ name: 'my-server' }, undefined)).toBe( + PolicyDecision.ASK_USER, + ); // No tool name }); it('should prioritize specific tool rules over server wildcards', () => { @@ -277,10 +297,62 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Specific tool deny should override server allow - expect(engine.check({ name: 'my-server__dangerous-tool' })).toBe( - PolicyDecision.DENY, + expect( + engine.check({ name: 'my-server__dangerous-tool' }, undefined), + ).toBe(PolicyDecision.DENY); + expect(engine.check({ name: 'my-server__safe-tool' }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + }); + + it('should NOT match spoofed server names when using wildcards', () => { + // Vulnerability: A rule for 'prefix__*' matches 'prefix__suffix__tool' + // effectively allowing a server named 'prefix__suffix' to spoof 'prefix'. + const rules: PolicyRule[] = [ + { + toolName: 'safe_server__*', + decision: PolicyDecision.ALLOW, + }, + ]; + engine = new PolicyEngine({ rules }); + + // A tool from a different server 'safe_server__malicious' + const spoofedToolCall = { name: 'safe_server__malicious__tool' }; + + // CURRENT BEHAVIOR (FIXED): Matches because it starts with 'safe_server__' BUT serverName doesn't match 'safe_server' + // We expect this to FAIL matching the ALLOW rule, thus falling back to default (ASK_USER) + expect(engine.check(spoofedToolCall, 'safe_server__malicious')).toBe( + PolicyDecision.ASK_USER, + ); + }); + + it('should verify tool name prefix even if serverName matches', () => { + const rules: PolicyRule[] = [ + { + toolName: 'safe_server__*', + decision: PolicyDecision.ALLOW, + }, + ]; + engine = new PolicyEngine({ rules }); + + // serverName matches, but tool name does not start with prefix + const invalidToolCall = { name: 'other_server__tool' }; + expect(engine.check(invalidToolCall, 'safe_server')).toBe( + PolicyDecision.ASK_USER, ); - expect(engine.check({ name: 'my-server__safe-tool' })).toBe( + }); + + it('should allow when both serverName and tool name prefix match', () => { + const rules: PolicyRule[] = [ + { + toolName: 'safe_server__*', + decision: PolicyDecision.ALLOW, + }, + ]; + engine = new PolicyEngine({ rules }); + + const validToolCall = { name: 'safe_server__tool' }; + expect(engine.check(validToolCall, 'safe_server')).toBe( PolicyDecision.ALLOW, ); }); @@ -302,17 +374,19 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Matches highest priority rule (ls command) - expect(engine.check({ name: 'shell', args: { command: 'ls -la' } })).toBe( - PolicyDecision.ALLOW, - ); + expect( + engine.check({ name: 'shell', args: { command: 'ls -la' } }, undefined), + ).toBe(PolicyDecision.ALLOW); // Matches middle priority rule (shell without ls) - expect(engine.check({ name: 'shell', args: { command: 'pwd' } })).toBe( - PolicyDecision.ASK_USER, - ); + expect( + engine.check({ name: 'shell', args: { command: 'pwd' } }, undefined), + ).toBe(PolicyDecision.ASK_USER); // Matches lowest priority rule (not shell) - expect(engine.check({ name: 'edit' })).toBe(PolicyDecision.DENY); + expect(engine.check({ name: 'edit' }, undefined)).toBe( + PolicyDecision.DENY, + ); }); it('should handle tools with no args', () => { @@ -327,17 +401,19 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Tool call without args should not match pattern - expect(engine.check({ name: 'read' })).toBe(PolicyDecision.ASK_USER); - - // Tool call with args not matching pattern - expect(engine.check({ name: 'read', args: { file: 'public.txt' } })).toBe( + expect(engine.check({ name: 'read' }, undefined)).toBe( PolicyDecision.ASK_USER, ); + // Tool call with args not matching pattern + expect( + engine.check({ name: 'read', args: { file: 'public.txt' } }, undefined), + ).toBe(PolicyDecision.ASK_USER); + // Tool call with args matching pattern - expect(engine.check({ name: 'read', args: { file: 'secret.txt' } })).toBe( - PolicyDecision.DENY, - ); + expect( + engine.check({ name: 'read', args: { file: 'secret.txt' } }, undefined), + ).toBe(PolicyDecision.DENY); }); it('should match args pattern regardless of property order', () => { @@ -356,16 +432,16 @@ describe('PolicyEngine', () => { const args1 = { command: 'rm -rf /', path: '/home' }; const args2 = { path: '/home', command: 'rm -rf /' }; - expect(engine.check({ name: 'shell', args: args1 })).toBe( + expect(engine.check({ name: 'shell', args: args1 }, undefined)).toBe( PolicyDecision.DENY, ); - expect(engine.check({ name: 'shell', args: args2 })).toBe( + expect(engine.check({ name: 'shell', args: args2 }, undefined)).toBe( PolicyDecision.DENY, ); // Verify safe command doesn't match const safeArgs = { command: 'ls -la', path: '/home' }; - expect(engine.check({ name: 'shell', args: safeArgs })).toBe( + expect(engine.check({ name: 'shell', args: safeArgs }, undefined)).toBe( PolicyDecision.ASK_USER, ); }); @@ -391,10 +467,10 @@ describe('PolicyEngine', () => { data: { value: 'secret', sensitive: true }, }; - expect(engine.check({ name: 'api', args: args1 })).toBe( + expect(engine.check({ name: 'api', args: args1 }, undefined)).toBe( PolicyDecision.DENY, ); - expect(engine.check({ name: 'api', args: args2 })).toBe( + expect(engine.check({ name: 'api', args: args2 }, undefined)).toBe( PolicyDecision.DENY, ); }); @@ -424,17 +500,17 @@ describe('PolicyEngine', () => { // Should not throw stack overflow error expect(() => - engine.check({ name: 'test', args: circularArgs }), + engine.check({ name: 'test', args: circularArgs }, undefined), ).not.toThrow(); // Should detect the circular reference pattern - expect(engine.check({ name: 'test', args: circularArgs })).toBe( - PolicyDecision.DENY, - ); + expect( + engine.check({ name: 'test', args: circularArgs }, undefined), + ).toBe(PolicyDecision.DENY); // Non-circular object should not match const normalArgs = { name: 'test', data: { value: 'normal' } }; - expect(engine.check({ name: 'test', args: normalArgs })).toBe( + expect(engine.check({ name: 'test', args: normalArgs }, undefined)).toBe( PolicyDecision.ASK_USER, ); }); @@ -471,13 +547,13 @@ describe('PolicyEngine', () => { // Should handle without stack overflow expect(() => - engine.check({ name: 'deep', args: deepCircular }), + engine.check({ name: 'deep', args: deepCircular }, undefined), ).not.toThrow(); // Should detect the circular reference - expect(engine.check({ name: 'deep', args: deepCircular })).toBe( - PolicyDecision.DENY, - ); + expect( + engine.check({ name: 'deep', args: deepCircular }, undefined), + ).toBe(PolicyDecision.DENY); }); it('should handle repeated non-circular objects correctly', () => { @@ -506,7 +582,9 @@ describe('PolicyEngine', () => { }; // Should NOT mark repeated objects as circular, and should match the shared value pattern - expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'test', args }, undefined)).toBe( + PolicyDecision.ALLOW, + ); }); it('should omit undefined and function values from objects', () => { @@ -528,7 +606,9 @@ describe('PolicyEngine', () => { }; // Should match pattern with defined value, undefined and functions omitted - expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'test', args }, undefined)).toBe( + PolicyDecision.ALLOW, + ); // Check that the pattern would NOT match if undefined was included const rulesWithUndefined: PolicyRule[] = [ @@ -539,7 +619,7 @@ describe('PolicyEngine', () => { }, ]; engine = new PolicyEngine({ rules: rulesWithUndefined }); - expect(engine.check({ name: 'test', args })).toBe( + expect(engine.check({ name: 'test', args }, undefined)).toBe( PolicyDecision.ASK_USER, ); @@ -552,7 +632,7 @@ describe('PolicyEngine', () => { }, ]; engine = new PolicyEngine({ rules: rulesWithFunction }); - expect(engine.check({ name: 'test', args })).toBe( + expect(engine.check({ name: 'test', args }, undefined)).toBe( PolicyDecision.ASK_USER, ); }); @@ -573,7 +653,9 @@ describe('PolicyEngine', () => { }; // Should match pattern with undefined and functions converted to null - expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'test', args }, undefined)).toBe( + PolicyDecision.ALLOW, + ); }); it('should produce valid JSON for all inputs', () => { @@ -607,10 +689,12 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Should not throw when checking (which internally uses stableStringify) - expect(() => engine.check({ name: 'test', args: input })).not.toThrow(); + expect(() => + engine.check({ name: 'test', args: input }, undefined), + ).not.toThrow(); // The check should succeed - expect(engine.check({ name: 'test', args: input })).toBe( + expect(engine.check({ name: 'test', args: input }, undefined)).toBe( PolicyDecision.ALLOW, ); } @@ -641,7 +725,9 @@ describe('PolicyEngine', () => { }; // Should match the sanitized pattern, not the dangerous one - expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'test', args }, undefined)).toBe( + PolicyDecision.ALLOW, + ); }); it('should handle toJSON that returns primitives', () => { @@ -663,7 +749,9 @@ describe('PolicyEngine', () => { }; // toJSON returns a string, which should be properly stringified - expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'test', args }, undefined)).toBe( + PolicyDecision.ALLOW, + ); }); it('should handle toJSON that throws an error', () => { @@ -687,7 +775,23 @@ describe('PolicyEngine', () => { }; // Should fall back to regular object serialization when toJSON throws - expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'test', args }, undefined)).toBe( + PolicyDecision.ALLOW, + ); + }); + }); + + describe('serverName requirement', () => { + it('should require serverName for checks', () => { + // @ts-expect-error - intentionally testing missing serverName + expect(engine.check({ name: 'test' })).toBe(PolicyDecision.ASK_USER); + // When serverName is provided (even undefined), it should work + expect(engine.check({ name: 'test' }, undefined)).toBe( + PolicyDecision.ASK_USER, + ); + expect(engine.check({ name: 'test' }, 'some-server')).toBe( + PolicyDecision.ASK_USER, + ); }); }); }); diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts index cc98cc59660..034cce2c8b4 100644 --- a/packages/core/src/policy/policy-engine.ts +++ b/packages/core/src/policy/policy-engine.ts @@ -17,12 +17,21 @@ function ruleMatches( rule: PolicyRule, toolCall: FunctionCall, stringifiedArgs: string | undefined, + serverName: string | undefined, ): boolean { // Check tool name if specified if (rule.toolName) { // Support wildcard patterns: "serverName__*" matches "serverName__anyTool" if (rule.toolName.endsWith('__*')) { const prefix = rule.toolName.slice(0, -3); // Remove "__*" + if (serverName !== undefined) { + // Robust check: if serverName is provided, it MUST match the prefix exactly. + // This prevents "malicious-server" from spoofing "trusted-server" by naming itself "trusted-server__malicious". + if (serverName !== prefix) { + return false; + } + } + // Always verify the prefix, even if serverName matched if (!toolCall.name || !toolCall.name.startsWith(prefix + '__')) { return false; } @@ -65,7 +74,10 @@ export class PolicyEngine { /** * Check if a tool call is allowed based on the configured policies. */ - check(toolCall: FunctionCall): PolicyDecision { + check( + toolCall: FunctionCall, + serverName: string | undefined, + ): PolicyDecision { let stringifiedArgs: string | undefined; // Compute stringified args once before the loop if (toolCall.args && this.rules.some((rule) => rule.argsPattern)) { @@ -78,7 +90,7 @@ export class PolicyEngine { // Find the first matching rule (already sorted by priority) for (const rule of this.rules) { - if (ruleMatches(rule, toolCall, stringifiedArgs)) { + if (ruleMatches(rule, toolCall, stringifiedArgs, serverName)) { debugLogger.debug( `[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`, ); diff --git a/packages/core/src/policy/toml-loader.test.ts b/packages/core/src/policy/toml-loader.test.ts index 1785faba719..0161bf6af16 100644 --- a/packages/core/src/policy/toml-loader.test.ts +++ b/packages/core/src/policy/toml-loader.test.ts @@ -8,6 +8,57 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { ApprovalMode, PolicyDecision } from './types.js'; import type { Dirent } from 'node:fs'; import nodePath from 'node:path'; +import type { PolicyLoadResult } from './toml-loader.js'; + +async function runLoadPoliciesFromToml( + tomlContent: string, + fileName = 'test.toml', +): Promise { + const actualFs = + await vi.importActual( + 'node:fs/promises', + ); + + const mockReaddir = vi.fn( + async ( + path: string, + _options?: { withFileTypes: boolean }, + ): Promise => { + if (nodePath.normalize(path) === nodePath.normalize('/policies')) { + return [ + { + name: fileName, + isFile: () => true, + isDirectory: () => false, + } as Dirent, + ]; + } + return []; + }, + ); + + const mockReadFile = vi.fn(async (path: string): Promise => { + if ( + nodePath.normalize(path) === + nodePath.normalize(nodePath.join('/policies', fileName)) + ) { + return tomlContent; + } + throw new Error('File not found'); + }); + + vi.doMock('node:fs/promises', () => ({ + ...actualFs, + default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, + readFile: mockReadFile, + readdir: mockReaddir, + })); + + const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); + + const getPolicyTier = (_dir: string) => 1; + return load(ApprovalMode.DEFAULT, ['/policies'], getPolicyTier); +} describe('policy-toml-loader', () => { beforeEach(() => { @@ -21,59 +72,12 @@ describe('policy-toml-loader', () => { describe('loadPoliciesFromToml', () => { it('should load and parse a simple policy file', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'test.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'test.toml')) - ) { - return ` + const result = await runLoadPoliciesFromToml(` [[rule]] toolName = "glob" decision = "allow" priority = 100 -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - - const getPolicyTier = (_dir: string) => 1; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); +`); expect(result.rules).toHaveLength(1); expect(result.rules[0]).toEqual({ @@ -85,60 +89,13 @@ priority = 100 }); it('should expand commandPrefix array to multiple rules', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'shell.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'shell.toml')) - ) { - return ` + const result = await runLoadPoliciesFromToml(` [[rule]] toolName = "run_shell_command" commandPrefix = ["git status", "git log"] decision = "allow" priority = 100 -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - - const getPolicyTier = (_dir: string) => 2; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); +`); expect(result.rules).toHaveLength(2); expect(result.rules[0].toolName).toBe('run_shell_command'); @@ -153,60 +110,13 @@ priority = 100 }); it('should transform commandRegex to argsPattern', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'shell.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'shell.toml')) - ) { - return ` + const result = await runLoadPoliciesFromToml(` [[rule]] toolName = "run_shell_command" commandRegex = "git (status|log).*" decision = "allow" priority = 100 -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - - const getPolicyTier = (_dir: string) => 2; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); +`); expect(result.rules).toHaveLength(1); expect( @@ -222,59 +132,12 @@ priority = 100 }); it('should expand toolName array', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'tools.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'tools.toml')) - ) { - return ` + const result = await runLoadPoliciesFromToml(` [[rule]] toolName = ["glob", "grep", "read"] decision = "allow" priority = 100 -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - - const getPolicyTier = (_dir: string) => 1; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); +`); expect(result.rules).toHaveLength(3); expect(result.rules.map((r) => r.toolName)).toEqual([ @@ -286,60 +149,13 @@ priority = 100 }); it('should transform mcpName to composite toolName', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'mcp.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'mcp.toml')) - ) { - return ` + const result = await runLoadPoliciesFromToml(` [[rule]] mcpName = "google-workspace" toolName = ["calendar.list", "calendar.get"] decision = "allow" priority = 100 -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - - const getPolicyTier = (_dir: string) => 2; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); +`); expect(result.rules).toHaveLength(2); expect(result.rules[0].toolName).toBe('google-workspace__calendar.list'); @@ -348,35 +164,7 @@ priority = 100 }); it('should filter rules by mode', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'modes.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'modes.toml')) - ) { - return ` + const result = await runLoadPoliciesFromToml(` [[rule]] toolName = "glob" decision = "allow" @@ -388,26 +176,7 @@ toolName = "grep" decision = "allow" priority = 100 modes = ["yolo"] -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - - const getPolicyTier = (_dir: string) => 1; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); +`); // Only the first rule should be included (modes includes "default") expect(result.rules).toHaveLength(1); @@ -416,119 +185,25 @@ modes = ["yolo"] }); it('should handle TOML parse errors', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'invalid.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'invalid.toml')) - ) { - return ` + const result = await runLoadPoliciesFromToml(` [[rule] toolName = "glob" decision = "allow" priority = 100 -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - - const getPolicyTier = (_dir: string) => 1; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); +`); expect(result.rules).toHaveLength(0); expect(result.errors).toHaveLength(1); expect(result.errors[0].errorType).toBe('toml_parse'); - expect(result.errors[0].fileName).toBe('invalid.toml'); - }); - - it('should handle schema validation errors', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'invalid.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'invalid.toml')) - ) { - return ` -[[rule]] -toolName = "glob" -priority = 100 -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); + expect(result.errors[0].fileName).toBe('test.toml'); + }); - const getPolicyTier = (_dir: string) => 1; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); + it('should handle schema validation errors', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "glob" +priority = 100 +`); expect(result.rules).toHaveLength(0); expect(result.errors).toHaveLength(1); @@ -537,60 +212,13 @@ priority = 100 }); it('should reject commandPrefix without run_shell_command', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'invalid.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'invalid.toml')) - ) { - return ` + const result = await runLoadPoliciesFromToml(` [[rule]] toolName = "glob" commandPrefix = "git status" decision = "allow" priority = 100 -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - - const getPolicyTier = (_dir: string) => 1; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); +`); expect(result.errors).toHaveLength(1); expect(result.errors[0].errorType).toBe('rule_validation'); @@ -598,61 +226,14 @@ priority = 100 }); it('should reject commandPrefix + argsPattern combination', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); - - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'invalid.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); - - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'invalid.toml')) - ) { - return ` + const result = await runLoadPoliciesFromToml(` [[rule]] toolName = "run_shell_command" commandPrefix = "git status" argsPattern = "test" decision = "allow" priority = 100 -`; - } - throw new Error('File not found'); - }); - - vi.doMock('node:fs/promises', () => ({ - ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, - readdir: mockReaddir, - })); - - const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - - const getPolicyTier = (_dir: string) => 1; - const result = await load( - ApprovalMode.DEFAULT, - ['/policies'], - getPolicyTier, - ); +`); expect(result.errors).toHaveLength(1); expect(result.errors[0].errorType).toBe('rule_validation'); @@ -660,6 +241,41 @@ priority = 100 }); it('should handle invalid regex patterns', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "run_shell_command" +commandRegex = "git (status|branch" +decision = "allow" +priority = 100 +`); + + expect(result.rules).toHaveLength(0); + expect(result.errors).toHaveLength(1); + expect(result.errors[0].errorType).toBe('regex_compilation'); + expect(result.errors[0].details).toContain('git (status|branch'); + }); + + it('should escape regex special characters in commandPrefix', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "run_shell_command" +commandPrefix = "git log *.txt" +decision = "allow" +priority = 100 +`); + + expect(result.rules).toHaveLength(1); + // The regex should have escaped the * and . + expect( + result.rules[0].argsPattern?.test('{"command":"git log file.txt"}'), + ).toBe(false); + expect( + result.rules[0].argsPattern?.test('{"command":"git log *.txt"}'), + ).toBe(true); + expect(result.errors).toHaveLength(0); + }); + + it('should handle a mix of valid and invalid policy files', async () => { const actualFs = await vi.importActual( 'node:fs/promises', @@ -672,6 +288,11 @@ priority = 100 ): Promise => { if (nodePath.normalize(path) === nodePath.normalize('/policies')) { return [ + { + name: 'valid.toml', + isFile: () => true, + isDirectory: () => false, + } as Dirent, { name: 'invalid.toml', isFile: () => true, @@ -686,14 +307,24 @@ priority = 100 const mockReadFile = vi.fn(async (path: string): Promise => { if ( nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'invalid.toml')) + nodePath.normalize(nodePath.join('/policies', 'valid.toml')) ) { return ` [[rule]] -toolName = "run_shell_command" -commandRegex = "git (status|branch" +toolName = "glob" decision = "allow" priority = 100 +`; + } + if ( + nodePath.normalize(path) === + nodePath.normalize(nodePath.join('/policies', 'invalid.toml')) + ) { + return ` +[[rule]] +toolName = "grep" +decision = "allow" +priority = -1 `; } throw new Error('File not found'); @@ -715,61 +346,154 @@ priority = 100 getPolicyTier, ); - expect(result.rules).toHaveLength(0); + expect(result.rules).toHaveLength(1); + expect(result.rules[0].toolName).toBe('glob'); expect(result.errors).toHaveLength(1); - expect(result.errors[0].errorType).toBe('regex_compilation'); - expect(result.errors[0].details).toContain('git (status|branch'); + expect(result.errors[0].fileName).toBe('invalid.toml'); + expect(result.errors[0].errorType).toBe('schema_validation'); + }); + }); + describe('Negative Tests', () => { + it('should return a schema_validation error if priority is missing', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "test" +decision = "allow" +`); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('schema_validation'); + expect(error.details).toContain('priority'); }); - it('should escape regex special characters in commandPrefix', async () => { - const actualFs = - await vi.importActual( - 'node:fs/promises', - ); + it('should return a schema_validation error if priority is a float', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "test" +decision = "allow" +priority = 1.5 +`); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('schema_validation'); + expect(error.details).toContain('priority'); + expect(error.details).toContain('integer'); + }); - const mockReaddir = vi.fn( - async ( - path: string, - _options?: { withFileTypes: boolean }, - ): Promise => { - if (nodePath.normalize(path) === nodePath.normalize('/policies')) { - return [ - { - name: 'shell.toml', - isFile: () => true, - isDirectory: () => false, - } as Dirent, - ]; - } - return []; - }, - ); + it('should return a schema_validation error if priority is negative', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "test" +decision = "allow" +priority = -1 +`); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('schema_validation'); + expect(error.details).toContain('priority'); + expect(error.details).toContain('>= 0'); + }); - const mockReadFile = vi.fn(async (path: string): Promise => { - if ( - nodePath.normalize(path) === - nodePath.normalize(nodePath.join('/policies', 'shell.toml')) - ) { - return ` + it('should return a schema_validation error if priority is >= 1000', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "test" +decision = "allow" +priority = 1000 +`); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('schema_validation'); + expect(error.details).toContain('priority'); + expect(error.details).toContain('<= 999'); + }); + + it('should return a schema_validation error if decision is invalid', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "test" +decision = "maybe" +priority = 100 +`); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('schema_validation'); + expect(error.details).toContain('decision'); + }); + + it('should return a schema_validation error if toolName is not a string or array', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = 123 +decision = "allow" +priority = 100 +`); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('schema_validation'); + expect(error.details).toContain('toolName'); + }); + + it('should return a rule_validation error if commandRegex is used with wrong toolName', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "not_shell" +commandRegex = ".*" +decision = "allow" +priority = 100 +`); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('rule_validation'); + expect(error.details).toContain('run_shell_command'); + }); + + it('should return a rule_validation error if commandPrefix and commandRegex are combined', async () => { + const result = await runLoadPoliciesFromToml(` [[rule]] toolName = "run_shell_command" -commandPrefix = "git log *.txt" +commandPrefix = "git" +commandRegex = ".*" decision = "allow" priority = 100 -`; - } - throw new Error('File not found'); +`); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('rule_validation'); + expect(error.details).toContain('mutually exclusive'); + }); + + it('should return a regex_compilation error for invalid argsPattern', async () => { + const result = await runLoadPoliciesFromToml(` +[[rule]] +toolName = "test" +argsPattern = "([a-z)" +decision = "allow" +priority = 100 +`); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('regex_compilation'); + expect(error.message).toBe('Invalid regex pattern'); + }); + + it('should return a file_read error if readdir fails', async () => { + const actualFs = + await vi.importActual( + 'node:fs/promises', + ); + + const mockReaddir = vi.fn(async () => { + throw new Error('Permission denied'); }); vi.doMock('node:fs/promises', () => ({ ...actualFs, - default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, - readFile: mockReadFile, + default: { ...actualFs, readdir: mockReaddir }, readdir: mockReaddir, })); const { loadPoliciesFromToml: load } = await import('./toml-loader.js'); - const getPolicyTier = (_dir: string) => 1; const result = await load( ApprovalMode.DEFAULT, @@ -777,15 +501,10 @@ priority = 100 getPolicyTier, ); - expect(result.rules).toHaveLength(1); - // The regex should have escaped the * and . - expect( - result.rules[0].argsPattern?.test('{"command":"git log file.txt"}'), - ).toBe(false); - expect( - result.rules[0].argsPattern?.test('{"command":"git log *.txt"}'), - ).toBe(true); - expect(result.errors).toHaveLength(0); + expect(result.errors).toHaveLength(1); + const error = result.errors[0]; + expect(error.errorType).toBe('file_read'); + expect(error.message).toContain('Failed to read policy directory'); }); }); }); diff --git a/packages/core/src/services/modelConfig.golden.test.ts b/packages/core/src/services/modelConfig.golden.test.ts new file mode 100644 index 00000000000..c11f763306e --- /dev/null +++ b/packages/core/src/services/modelConfig.golden.test.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import * as fs from 'node:fs/promises'; +import * as path from 'node:path'; +import { ModelConfigService } from './modelConfigService.js'; +import { DEFAULT_MODEL_CONFIGS } from '../config/defaultModelConfigs.js'; + +const GOLDEN_FILE_PATH = path.resolve( + process.cwd(), + 'src', + 'services', + 'test-data', + 'resolved-aliases.golden.json', +); + +describe('ModelConfigService Golden Test', () => { + it('should match the golden file for resolved default aliases', async () => { + const service = new ModelConfigService(DEFAULT_MODEL_CONFIGS); + const aliases = Object.keys(DEFAULT_MODEL_CONFIGS.aliases ?? {}); + + const resolvedAliases: Record = {}; + for (const alias of aliases) { + resolvedAliases[alias] = + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (service as any).internalGetResolvedConfig({ model: alias }); + } + + if (process.env['UPDATE_GOLDENS']) { + await fs.mkdir(path.dirname(GOLDEN_FILE_PATH), { recursive: true }); + await fs.writeFile( + GOLDEN_FILE_PATH, + JSON.stringify(resolvedAliases, null, 2), + 'utf-8', + ); + // In update mode, we pass the test after writing the file. + return; + } + + let goldenContent: string; + try { + goldenContent = await fs.readFile(GOLDEN_FILE_PATH, 'utf-8'); + } catch (e) { + if ((e as NodeJS.ErrnoException).code === 'ENOENT') { + throw new Error( + 'Golden file not found. Run with `UPDATE_GOLDENS=true` to create it.', + ); + } + throw e; + } + + const goldenData = JSON.parse(goldenContent); + + expect( + resolvedAliases, + 'Golden file mismatch. If the new resolved aliases are correct, run the test with `UPDATE_GOLDENS=true` to regenerate the golden file.', + ).toEqual(goldenData); + }); +}); diff --git a/packages/core/src/services/modelConfig.integration.test.ts b/packages/core/src/services/modelConfig.integration.test.ts new file mode 100644 index 00000000000..fd478557660 --- /dev/null +++ b/packages/core/src/services/modelConfig.integration.test.ts @@ -0,0 +1,234 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { ModelConfigService } from './modelConfigService.js'; +import type { ModelConfigServiceConfig } from './modelConfigService.js'; + +// This test suite is designed to validate the end-to-end logic of the +// ModelConfigService with a complex, realistic configuration. +// It tests the interplay of global settings, alias inheritance, and overrides +// of varying specificities. +describe('ModelConfigService Integration', () => { + const complexConfig: ModelConfigServiceConfig = { + aliases: { + // Abstract base with no model + base: { + modelConfig: { + generateContentConfig: { + topP: 0.95, + topK: 64, + }, + }, + }, + 'default-text-model': { + extends: 'base', + modelConfig: { + model: 'gemini-1.5-pro-latest', + generateContentConfig: { + topK: 40, // Override base + }, + }, + }, + 'creative-writer': { + extends: 'default-text-model', + modelConfig: { + generateContentConfig: { + temperature: 0.9, // Override global + topK: 50, // Override parent + }, + }, + }, + 'fast-classifier': { + extends: 'base', + modelConfig: { + model: 'gemini-1.5-flash-latest', + generateContentConfig: { + temperature: 0.1, + candidateCount: 4, + }, + }, + }, + }, + overrides: [ + // Broad override for all flash models + { + match: { model: 'gemini-1.5-flash-latest' }, + modelConfig: { + generateContentConfig: { + maxOutputTokens: 2048, + }, + }, + }, + // Specific override for the 'core' agent + { + match: { overrideScope: 'core' }, + modelConfig: { + generateContentConfig: { + temperature: 0.5, + stopSequences: ['AGENT_STOP'], + }, + }, + }, + // Highly specific override for the 'fast-classifier' when used by the 'core' agent + { + match: { model: 'fast-classifier', overrideScope: 'core' }, + modelConfig: { + generateContentConfig: { + temperature: 0.0, + maxOutputTokens: 4096, + }, + }, + }, + // Override to provide a model for the abstract alias + { + match: { model: 'base', overrideScope: 'core' }, + modelConfig: { + model: 'gemini-1.5-pro-latest', + }, + }, + ], + }; + + const service = new ModelConfigService(complexConfig); + + it('should resolve a simple model, applying core agent defaults', () => { + const resolved = service.getResolvedConfig({ + model: 'gemini-test-model', + }); + + expect(resolved.model).toBe('gemini-test-model'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.5, // from agent override + stopSequences: ['AGENT_STOP'], // from agent override + }); + }); + + it('should correctly apply a simple inherited alias and merge with global defaults', () => { + const resolved = service.getResolvedConfig({ + model: 'default-text-model', + }); + + expect(resolved.model).toBe('gemini-1.5-pro-latest'); // from alias + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.5, // from agent override + topP: 0.95, // from base + topK: 40, // from alias + stopSequences: ['AGENT_STOP'], // from agent override + }); + }); + + it('should resolve a multi-level inherited alias', () => { + const resolved = service.getResolvedConfig({ + model: 'creative-writer', + }); + + expect(resolved.model).toBe('gemini-1.5-pro-latest'); // from default-text-model + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.5, // from agent override + topP: 0.95, // from base + topK: 50, // from alias + stopSequences: ['AGENT_STOP'], // from agent override + }); + }); + + it('should apply an inherited alias and a broad model-based override', () => { + const resolved = service.getResolvedConfig({ + model: 'fast-classifier', + // No agent specified, so it should match core agent-specific rules + }); + + expect(resolved.model).toBe('gemini-1.5-flash-latest'); // from alias + expect(resolved.generateContentConfig).toEqual({ + topP: 0.95, // from base + topK: 64, // from base + candidateCount: 4, // from alias + stopSequences: ['AGENT_STOP'], // from agent override + maxOutputTokens: 4096, // from most specific override + temperature: 0.0, // from most specific override + }); + }); + + it('should apply settings for an unknown model but a known agent', () => { + const resolved = service.getResolvedConfig({ + model: 'gemini-test-model', + overrideScope: 'core', + }); + + expect(resolved.model).toBe('gemini-test-model'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.5, // from agent override + stopSequences: ['AGENT_STOP'], // from agent override + }); + }); + + it('should apply the most specific override for a known inherited alias and agent', () => { + const resolved = service.getResolvedConfig({ + model: 'fast-classifier', + overrideScope: 'core', + }); + + expect(resolved.model).toBe('gemini-1.5-flash-latest'); + expect(resolved.generateContentConfig).toEqual({ + // Inherited from 'base' + topP: 0.95, + topK: 64, + // From 'fast-classifier' alias + candidateCount: 4, + // From 'core' agent override + stopSequences: ['AGENT_STOP'], + // From most specific override (model+agent) + temperature: 0.0, + maxOutputTokens: 4096, + }); + }); + + it('should correctly apply agent override on top of a multi-level inherited alias', () => { + const resolved = service.getResolvedConfig({ + model: 'creative-writer', + overrideScope: 'core', + }); + + expect(resolved.model).toBe('gemini-1.5-pro-latest'); // from default-text-model + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.5, // from agent override (wins over alias) + topP: 0.95, // from base + topK: 50, // from creative-writer alias + stopSequences: ['AGENT_STOP'], // from agent override + }); + }); + + it('should resolve an abstract alias if a specific override provides the model', () => { + const resolved = service.getResolvedConfig({ + model: 'base', + overrideScope: 'core', + }); + + expect(resolved.model).toBe('gemini-1.5-pro-latest'); // from override + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.5, // from agent override + topP: 0.95, // from base alias + topK: 64, // from base alias + stopSequences: ['AGENT_STOP'], // from agent override + }); + }); + + it('should not apply core agent overrides when a different agent is specified', () => { + const resolved = service.getResolvedConfig({ + model: 'fast-classifier', + overrideScope: 'non-core-agent', + }); + + expect(resolved.model).toBe('gemini-1.5-flash-latest'); + expect(resolved.generateContentConfig).toEqual({ + candidateCount: 4, // from alias + maxOutputTokens: 2048, // from override of model + temperature: 0.1, // from alias + topK: 64, // from base + topP: 0.95, // from base + }); + }); +}); diff --git a/packages/core/src/services/modelConfigService.test.ts b/packages/core/src/services/modelConfigService.test.ts new file mode 100644 index 00000000000..998abe75b13 --- /dev/null +++ b/packages/core/src/services/modelConfigService.test.ts @@ -0,0 +1,553 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import type { ModelConfigServiceConfig } from './modelConfigService.js'; +import { ModelConfigService } from './modelConfigService.js'; + +describe('ModelConfigService', () => { + it('should resolve a basic alias to its model and settings', () => { + const config: ModelConfigServiceConfig = { + aliases: { + classifier: { + modelConfig: { + model: 'gemini-1.5-flash-latest', + generateContentConfig: { + temperature: 0, + topP: 0.9, + }, + }, + }, + }, + overrides: [], + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'classifier' }); + + expect(resolved.model).toBe('gemini-1.5-flash-latest'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0, + topP: 0.9, + }); + }); + + it('should apply a simple override on top of an alias', () => { + const config: ModelConfigServiceConfig = { + aliases: { + classifier: { + modelConfig: { + model: 'gemini-1.5-flash-latest', + generateContentConfig: { + temperature: 0, + topP: 0.9, + }, + }, + }, + }, + overrides: [ + { + match: { model: 'classifier' }, + modelConfig: { + generateContentConfig: { + temperature: 0.5, + maxOutputTokens: 1000, + }, + }, + }, + ], + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'classifier' }); + + expect(resolved.model).toBe('gemini-1.5-flash-latest'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.5, + topP: 0.9, + maxOutputTokens: 1000, + }); + }); + + it('should apply the most specific override rule', () => { + const config: ModelConfigServiceConfig = { + aliases: {}, + overrides: [ + { + match: { model: 'gemini-pro' }, + modelConfig: { generateContentConfig: { temperature: 0.5 } }, + }, + { + match: { model: 'gemini-pro', overrideScope: 'my-agent' }, + modelConfig: { generateContentConfig: { temperature: 0.1 } }, + }, + ], + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ + model: 'gemini-pro', + overrideScope: 'my-agent', + }); + + expect(resolved.model).toBe('gemini-pro'); + expect(resolved.generateContentConfig).toEqual({ temperature: 0.1 }); + }); + + it('should use the last override in case of a tie in specificity', () => { + const config: ModelConfigServiceConfig = { + aliases: {}, + overrides: [ + { + match: { model: 'gemini-pro' }, + modelConfig: { + generateContentConfig: { temperature: 0.5, topP: 0.8 }, + }, + }, + { + match: { model: 'gemini-pro' }, + modelConfig: { generateContentConfig: { temperature: 0.1 } }, + }, + ], + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'gemini-pro' }); + + expect(resolved.model).toBe('gemini-pro'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.1, + topP: 0.8, + }); + }); + + it('should correctly pass through generation config from an alias', () => { + const config: ModelConfigServiceConfig = { + aliases: { + 'thinking-alias': { + modelConfig: { + model: 'gemini-pro', + generateContentConfig: { + candidateCount: 500, + }, + }, + }, + }, + overrides: [], + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'thinking-alias' }); + + expect(resolved.generateContentConfig).toEqual({ candidateCount: 500 }); + }); + + it('should let an override generation config win over an alias config', () => { + const config: ModelConfigServiceConfig = { + aliases: { + 'thinking-alias': { + modelConfig: { + model: 'gemini-pro', + generateContentConfig: { + candidateCount: 500, + }, + }, + }, + }, + overrides: [ + { + match: { model: 'thinking-alias' }, + modelConfig: { + generateContentConfig: { + candidateCount: 1000, + }, + }, + }, + ], + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'thinking-alias' }); + + expect(resolved.generateContentConfig).toEqual({ + candidateCount: 1000, + }); + }); + + it('should merge settings from global, alias, and multiple matching overrides', () => { + const config: ModelConfigServiceConfig = { + aliases: { + 'test-alias': { + modelConfig: { + model: 'gemini-test-model', + generateContentConfig: { + topP: 0.9, + topK: 50, + }, + }, + }, + }, + overrides: [ + { + match: { model: 'gemini-test-model' }, + modelConfig: { + generateContentConfig: { + topK: 40, + maxOutputTokens: 2048, + }, + }, + }, + { + match: { overrideScope: 'test-agent' }, + modelConfig: { + generateContentConfig: { + maxOutputTokens: 4096, + }, + }, + }, + { + match: { model: 'gemini-test-model', overrideScope: 'test-agent' }, + modelConfig: { + generateContentConfig: { + temperature: 0.2, + }, + }, + }, + ], + }; + + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ + model: 'test-alias', + overrideScope: 'test-agent', + }); + + expect(resolved.model).toBe('gemini-test-model'); + expect(resolved.generateContentConfig).toEqual({ + // From global, overridden by most specific override + temperature: 0.2, + // From alias, not overridden + topP: 0.9, + // From alias, overridden by less specific override + topK: 40, + // From first matching override, overridden by second matching override + maxOutputTokens: 4096, + }); + }); + + it('should match an agent:core override when agent is undefined', () => { + const config: ModelConfigServiceConfig = { + aliases: {}, + overrides: [ + { + match: { overrideScope: 'core' }, + modelConfig: { + generateContentConfig: { + temperature: 0.1, + }, + }, + }, + ], + }; + + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ + model: 'gemini-pro', + overrideScope: undefined, // Explicitly undefined + }); + + expect(resolved.model).toBe('gemini-pro'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.1, + }); + }); + + describe('alias inheritance', () => { + it('should resolve a simple "extends" chain', () => { + const config: ModelConfigServiceConfig = { + aliases: { + base: { + modelConfig: { + model: 'gemini-1.5-pro-latest', + generateContentConfig: { + temperature: 0.7, + topP: 0.9, + }, + }, + }, + 'flash-variant': { + extends: 'base', + modelConfig: { + model: 'gemini-1.5-flash-latest', + }, + }, + }, + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'flash-variant' }); + + expect(resolved.model).toBe('gemini-1.5-flash-latest'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.7, + topP: 0.9, + }); + }); + + it('should override parent properties from child alias', () => { + const config: ModelConfigServiceConfig = { + aliases: { + base: { + modelConfig: { + model: 'gemini-1.5-pro-latest', + generateContentConfig: { + temperature: 0.7, + topP: 0.9, + }, + }, + }, + 'flash-variant': { + extends: 'base', + modelConfig: { + model: 'gemini-1.5-flash-latest', + generateContentConfig: { + temperature: 0.2, + }, + }, + }, + }, + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'flash-variant' }); + + expect(resolved.model).toBe('gemini-1.5-flash-latest'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.2, + topP: 0.9, + }); + }); + + it('should resolve a multi-level "extends" chain', () => { + const config: ModelConfigServiceConfig = { + aliases: { + base: { + modelConfig: { + model: 'gemini-1.5-pro-latest', + generateContentConfig: { + temperature: 0.7, + topP: 0.9, + }, + }, + }, + 'base-flash': { + extends: 'base', + modelConfig: { + model: 'gemini-1.5-flash-latest', + }, + }, + 'classifier-flash': { + extends: 'base-flash', + modelConfig: { + generateContentConfig: { + temperature: 0, + }, + }, + }, + }, + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ + model: 'classifier-flash', + }); + + expect(resolved.model).toBe('gemini-1.5-flash-latest'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0, + topP: 0.9, + }); + }); + + it('should throw an error for circular dependencies', () => { + const config: ModelConfigServiceConfig = { + aliases: { + a: { extends: 'b', modelConfig: {} }, + b: { extends: 'a', modelConfig: {} }, + }, + }; + const service = new ModelConfigService(config); + expect(() => service.getResolvedConfig({ model: 'a' })).toThrow( + 'Circular alias dependency: a -> b -> a', + ); + }); + + describe('abstract aliases', () => { + it('should allow an alias to extend an abstract alias without a model', () => { + const config: ModelConfigServiceConfig = { + aliases: { + 'abstract-base': { + modelConfig: { + generateContentConfig: { + temperature: 0.1, + }, + }, + }, + 'concrete-child': { + extends: 'abstract-base', + modelConfig: { + model: 'gemini-1.5-pro-latest', + generateContentConfig: { + topP: 0.9, + }, + }, + }, + }, + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'concrete-child' }); + + expect(resolved.model).toBe('gemini-1.5-pro-latest'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.1, + topP: 0.9, + }); + }); + + it('should throw an error if a resolved alias chain has no model', () => { + const config: ModelConfigServiceConfig = { + aliases: { + 'abstract-base': { + modelConfig: { + generateContentConfig: { temperature: 0.7 }, + }, + }, + }, + }; + const service = new ModelConfigService(config); + expect(() => + service.getResolvedConfig({ model: 'abstract-base' }), + ).toThrow( + 'Could not resolve a model name for alias "abstract-base". Please ensure the alias chain or a matching override specifies a model.', + ); + }); + + it('should resolve an abstract alias if an override provides the model', () => { + const config: ModelConfigServiceConfig = { + aliases: { + 'abstract-base': { + modelConfig: { + generateContentConfig: { + temperature: 0.1, + }, + }, + }, + }, + overrides: [ + { + match: { model: 'abstract-base' }, + modelConfig: { + model: 'gemini-1.5-flash-latest', + }, + }, + ], + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'abstract-base' }); + + expect(resolved.model).toBe('gemini-1.5-flash-latest'); + expect(resolved.generateContentConfig).toEqual({ + temperature: 0.1, + }); + }); + }); + + it('should throw an error if an extended alias does not exist', () => { + const config: ModelConfigServiceConfig = { + aliases: { + 'bad-alias': { + extends: 'non-existent', + modelConfig: {}, + }, + }, + }; + const service = new ModelConfigService(config); + expect(() => service.getResolvedConfig({ model: 'bad-alias' })).toThrow( + 'Alias "non-existent" not found.', + ); + }); + }); + + describe('deep merging', () => { + it('should deep merge nested config objects from aliases and overrides', () => { + const config: ModelConfigServiceConfig = { + aliases: { + 'base-safe': { + modelConfig: { + model: 'gemini-pro', + generateContentConfig: { + safetySettings: { + HARM_CATEGORY_HARASSMENT: 'BLOCK_ONLY_HIGH', + HARM_CATEGORY_HATE_SPEECH: 'BLOCK_ONLY_HIGH', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any, + }, + }, + }, + }, + overrides: [ + { + match: { model: 'base-safe' }, + modelConfig: { + generateContentConfig: { + safetySettings: { + HARM_CATEGORY_HATE_SPEECH: 'BLOCK_NONE', + HARM_CATEGORY_SEXUALLY_EXPLICIT: 'BLOCK_MEDIUM_AND_ABOVE', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any, + }, + }, + }, + ], + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'base-safe' }); + + expect(resolved.model).toBe('gemini-pro'); + expect(resolved.generateContentConfig.safetySettings).toEqual({ + // From alias + HARM_CATEGORY_HARASSMENT: 'BLOCK_ONLY_HIGH', + // From alias, overridden by override + HARM_CATEGORY_HATE_SPEECH: 'BLOCK_NONE', + // From override + HARM_CATEGORY_SEXUALLY_EXPLICIT: 'BLOCK_MEDIUM_AND_ABOVE', + }); + }); + + it('should not deeply merge merge arrays from aliases and overrides', () => { + const config: ModelConfigServiceConfig = { + aliases: { + base: { + modelConfig: { + model: 'gemini-pro', + generateContentConfig: { + stopSequences: ['foo'], + }, + }, + }, + }, + overrides: [ + { + match: { model: 'base' }, + modelConfig: { + generateContentConfig: { + stopSequences: ['overrideFoo'], + }, + }, + }, + ], + }; + const service = new ModelConfigService(config); + const resolved = service.getResolvedConfig({ model: 'base' }); + + expect(resolved.model).toBe('gemini-pro'); + expect(resolved.generateContentConfig.stopSequences).toEqual([ + 'overrideFoo', + ]); + }); + }); +}); diff --git a/packages/core/src/services/modelConfigService.ts b/packages/core/src/services/modelConfigService.ts new file mode 100644 index 00000000000..14b5e5bddb2 --- /dev/null +++ b/packages/core/src/services/modelConfigService.ts @@ -0,0 +1,248 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { GenerateContentConfig } from '@google/genai'; + +// The primary key for the ModelConfig is the model string. However, we also +// support a secondary key to limit the override scope, typically an agent name. +export interface ModelConfigKey { + model: string; + + // In many cases the model (or model config alias) is sufficient to fully + // scope an override. However, in some cases, we want additional scoping of + // an override. Consider the case of developing a new subagent, perhaps we + // want to override the temperature for all model calls made by this subagent. + // However, we most certainly do not want to change the temperature for other + // subagents, nor do we want to introduce a whole new set of aliases just for + // the new subagent. Using the `overrideScope` we can limit our overrides to + // model calls made by this specific subagent, and no others, while still + // ensuring model configs are fully orthogonal to the agents who use them. + overrideScope?: string; +} + +export interface ModelConfig { + model?: string; + generateContentConfig?: GenerateContentConfig; +} + +export interface ModelConfigOverride { + match: { + model?: string; // Can be a model name or an alias + overrideScope?: string; + }; + modelConfig: ModelConfig; +} + +export interface ModelConfigAlias { + extends?: string; + modelConfig: ModelConfig; +} + +export interface ModelConfigServiceConfig { + aliases?: Record; + overrides?: ModelConfigOverride[]; +} + +export type ResolvedModelConfig = _ResolvedModelConfig & { + readonly _brand: unique symbol; +}; + +export interface _ResolvedModelConfig { + model: string; // The actual, resolved model name + generateContentConfig: GenerateContentConfig; +} + +export class ModelConfigService { + // TODO(12597): Process config to build a typed alias hierarchy. + constructor(private readonly config: ModelConfigServiceConfig) {} + + private resolveAlias( + aliasName: string, + aliases: Record, + visited = new Set(), + ): ModelConfigAlias { + if (visited.has(aliasName)) { + throw new Error( + `Circular alias dependency: ${[...visited, aliasName].join(' -> ')}`, + ); + } + visited.add(aliasName); + + const alias = aliases[aliasName]; + if (!alias) { + throw new Error(`Alias "${aliasName}" not found.`); + } + + if (!alias.extends) { + return alias; + } + + const baseAlias = this.resolveAlias(alias.extends, aliases, visited); + + return { + modelConfig: { + model: alias.modelConfig.model ?? baseAlias.modelConfig.model, + generateContentConfig: this.deepMerge( + baseAlias.modelConfig.generateContentConfig, + alias.modelConfig.generateContentConfig, + ), + }, + }; + } + + private internalGetResolvedConfig(context: ModelConfigKey): { + model: string | undefined; + generateContentConfig: GenerateContentConfig; + } { + const config = this.config || {}; + const { aliases = {}, overrides = [] } = config; + let baseModel: string | undefined = context.model; + let resolvedConfig: GenerateContentConfig = {}; + + // Step 1: Alias Resolution + if (aliases[context.model]) { + const resolvedAlias = this.resolveAlias(context.model, aliases); + baseModel = resolvedAlias.modelConfig.model; // This can now be undefined + resolvedConfig = this.deepMerge( + resolvedConfig, + resolvedAlias.modelConfig.generateContentConfig, + ); + } + + // If an alias was used but didn't resolve to a model, `baseModel` is undefined. + // We still need a model for matching overrides. We'll use the original alias name + // for matching if no model is resolved yet. + const modelForMatching = baseModel ?? context.model; + + const finalContext = { + ...context, + model: modelForMatching, + }; + + // Step 2: Override Application + const matches = overrides + .map((override, index) => { + const matchEntries = Object.entries(override.match); + if (matchEntries.length === 0) { + return null; + } + + const isMatch = matchEntries.every(([key, value]) => { + if (key === 'model') { + return value === context.model || value === finalContext.model; + } + if (key === 'overrideScope' && value === 'core') { + // The 'core' overrideScope is special. It should match if the + // overrideScope is explicitly 'core' or if the overrideScope + // is not specified. + return context.overrideScope === 'core' || !context.overrideScope; + } + return finalContext[key as keyof ModelConfigKey] === value; + }); + + if (isMatch) { + return { + specificity: matchEntries.length, + modelConfig: override.modelConfig, + index, + }; + } + return null; + }) + .filter((match): match is NonNullable => match !== null); + + // The override application logic is designed to be both simple and powerful. + // By first sorting all matching overrides by specificity (and then by their + // original order as a tie-breaker), we ensure that as we merge the `config` + // objects, the settings from the most specific rules are applied last, + // correctly overwriting any values from broader, less-specific rules. + // This achieves a per-property override effect without complex per-property logic. + matches.sort((a, b) => { + if (a.specificity !== b.specificity) { + return a.specificity - b.specificity; + } + return a.index - b.index; + }); + + // Apply matching overrides + for (const match of matches) { + if (match.modelConfig.model) { + baseModel = match.modelConfig.model; + } + if (match.modelConfig.generateContentConfig) { + resolvedConfig = this.deepMerge( + resolvedConfig, + match.modelConfig.generateContentConfig, + ); + } + } + + return { + model: baseModel, + generateContentConfig: resolvedConfig, + }; + } + + getResolvedConfig(context: ModelConfigKey): ResolvedModelConfig { + const resolved = this.internalGetResolvedConfig(context); + + if (!resolved.model) { + throw new Error( + `Could not resolve a model name for alias "${context.model}". Please ensure the alias chain or a matching override specifies a model.`, + ); + } + + return { + model: resolved.model, + generateContentConfig: resolved.generateContentConfig, + } as ResolvedModelConfig; + } + + private isObject(item: unknown): item is Record { + return !!item && typeof item === 'object' && !Array.isArray(item); + } + + private deepMerge( + config1: GenerateContentConfig | undefined, + config2: GenerateContentConfig | undefined, + ): Record { + return this.genericDeepMerge( + config1 as Record | undefined, + config2 as Record | undefined, + ); + } + + private genericDeepMerge( + ...objects: Array | undefined> + ): Record { + return objects.reduce((acc: Record, obj) => { + if (!obj) { + return acc; + } + + Object.keys(obj).forEach((key) => { + const accValue = acc[key]; + const objValue = obj[key]; + + // For now, we only deep merge objects, and not arrays. This is because + // If we deep merge arrays, there is no way for the user to completely + // override the base array. + // TODO(joshualitt): Consider knobs here, i.e. opt-in to deep merging + // arrays on a case-by-case basis. + if (this.isObject(accValue) && this.isObject(objValue)) { + acc[key] = this.deepMerge( + accValue as Record, + objValue as Record, + ); + } else { + acc[key] = objValue; + } + }); + + return acc; + }, {}); + } +} diff --git a/packages/core/src/services/shellExecutionService.test.ts b/packages/core/src/services/shellExecutionService.test.ts index 3e2fdc889e6..1532e863253 100644 --- a/packages/core/src/services/shellExecutionService.test.ts +++ b/packages/core/src/services/shellExecutionService.test.ts @@ -351,6 +351,23 @@ describe('ShellExecutionService', () => { expect(mockHeadlessTerminal.scrollLines).toHaveBeenCalledWith(10); }); + + it('should not throw when resizing a pty that has already exited (Windows)', () => { + const resizeError = new Error( + 'Cannot resize a pty that has already exited', + ); + mockPtyProcess.resize.mockImplementation(() => { + throw resizeError; + }); + + // This should catch the specific error and not re-throw it. + expect(() => { + ShellExecutionService.resizePty(mockPtyProcess.pid, 100, 40); + }).not.toThrow(); + + expect(mockPtyProcess.resize).toHaveBeenCalledWith(100, 40); + expect(mockHeadlessTerminal.resize).not.toHaveBeenCalled(); + }); }); describe('Failed Execution', () => { @@ -753,7 +770,7 @@ describe('ShellExecutionService child_process fallback', () => { expect(onOutputEventMock).not.toHaveBeenCalled(); }); - it('should truncate stdout using a sliding window and show a warning', async () => { + it.skip('should truncate stdout using a sliding window and show a warning', async () => { const MAX_SIZE = 16 * 1024 * 1024; const chunk1 = 'a'.repeat(MAX_SIZE / 2 - 5); const chunk2 = 'b'.repeat(MAX_SIZE / 2 - 5); @@ -781,7 +798,7 @@ describe('ShellExecutionService child_process fallback', () => { outputWithoutMessage.startsWith(expectedStart.substring(0, 10)), ).toBe(true); expect(outputWithoutMessage.endsWith('c'.repeat(20))).toBe(true); - }, 20000); + }, 120000); }); describe('Failed Execution', () => { diff --git a/packages/core/src/services/shellExecutionService.ts b/packages/core/src/services/shellExecutionService.ts index 66952afc036..d797dd83b07 100644 --- a/packages/core/src/services/shellExecutionService.ts +++ b/packages/core/src/services/shellExecutionService.ts @@ -771,9 +771,11 @@ export class ShellExecutionService { if ( e instanceof Error && (('code' in e && e.code === 'ESRCH') || - e.message === 'Cannot resize a pty that has already exited') + e.message.includes('Cannot resize a pty that has already exited')) ) { - // ignore + // On Unix, we get an ESRCH error. + // On Windows, we get a message-based error. + // In both cases, it's safe to ignore. } else { throw e; } diff --git a/packages/core/src/services/test-data/resolved-aliases.golden.json b/packages/core/src/services/test-data/resolved-aliases.golden.json new file mode 100644 index 00000000000..199c36ce3aa --- /dev/null +++ b/packages/core/src/services/test-data/resolved-aliases.golden.json @@ -0,0 +1,123 @@ +{ + "base": { + "generateContentConfig": { + "temperature": 0, + "topP": 1 + } + }, + "chat-base": { + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "thinkingConfig": { + "includeThoughts": true, + "thinkingBudget": -1 + } + } + }, + "gemini-2.5-pro": { + "model": "gemini-2.5-pro", + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "thinkingConfig": { + "includeThoughts": true, + "thinkingBudget": -1 + } + } + }, + "gemini-2.5-flash": { + "model": "gemini-2.5-flash", + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "thinkingConfig": { + "includeThoughts": true, + "thinkingBudget": -1 + } + } + }, + "gemini-2.5-flash-lite": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "thinkingConfig": { + "includeThoughts": true, + "thinkingBudget": -1 + } + } + }, + "classifier": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "maxOutputTokens": 1024, + "thinkingConfig": { + "thinkingBudget": 512 + } + } + }, + "prompt-completion": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "temperature": 0.3, + "topP": 1, + "maxOutputTokens": 16000, + "thinkingConfig": { + "thinkingBudget": 0 + } + } + }, + "edit-corrector": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "thinkingConfig": { + "thinkingBudget": 0 + } + } + }, + "summarizer-default": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "maxOutputTokens": 2000 + } + }, + "summarizer-shell": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "maxOutputTokens": 2000 + } + }, + "web-search-tool": { + "model": "gemini-2.5-flash", + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "tools": [ + { + "googleSearch": {} + } + ] + } + }, + "web-fetch-tool": { + "model": "gemini-2.5-flash", + "generateContentConfig": { + "temperature": 0, + "topP": 1, + "tools": [ + { + "urlContext": {} + } + ] + } + } +} diff --git a/packages/core/src/tools/base-tool-invocation.test.ts b/packages/core/src/tools/base-tool-invocation.test.ts new file mode 100644 index 00000000000..38d651f0760 --- /dev/null +++ b/packages/core/src/tools/base-tool-invocation.test.ts @@ -0,0 +1,131 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { BaseToolInvocation, type ToolResult } from './tools.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { + type Message, + MessageBusType, + type ToolConfirmationRequest, + type ToolConfirmationResponse, +} from '../confirmation-bus/types.js'; + +class TestBaseToolInvocation extends BaseToolInvocation { + getDescription(): string { + return 'test description'; + } + async execute(): Promise { + return { llmContent: [], returnDisplay: '' }; + } +} + +describe('BaseToolInvocation', () => { + let messageBus: MessageBus; + let abortController: AbortController; + + beforeEach(() => { + messageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as MessageBus; + abortController = new AbortController(); + }); + + it('should propagate serverName to ToolConfirmationRequest', async () => { + const serverName = 'test-server'; + const tool = new TestBaseToolInvocation( + {}, + messageBus, + 'test-tool', + 'Test Tool', + serverName, + ); + + let capturedRequest: ToolConfirmationRequest | undefined; + vi.mocked(messageBus.publish).mockImplementation((request: Message) => { + if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { + capturedRequest = request; + } + }); + + let responseHandler: + | ((response: ToolConfirmationResponse) => void) + | undefined; + vi.mocked(messageBus.subscribe).mockImplementation( + (type: MessageBusType, handler: (message: Message) => void) => { + if (type === MessageBusType.TOOL_CONFIRMATION_RESPONSE) { + responseHandler = handler as ( + response: ToolConfirmationResponse, + ) => void; + } + }, + ); + + const confirmationPromise = tool.shouldConfirmExecute( + abortController.signal, + ); + + // Wait for microtasks to ensure publish is called + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(messageBus.publish).toHaveBeenCalledTimes(1); + expect(capturedRequest).toBeDefined(); + expect(capturedRequest?.type).toBe( + MessageBusType.TOOL_CONFIRMATION_REQUEST, + ); + expect(capturedRequest?.serverName).toBe(serverName); + + // Simulate response to finish the promise cleanly + if (responseHandler && capturedRequest) { + responseHandler({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: capturedRequest.correlationId, + confirmed: true, + }); + } + + await confirmationPromise; + }); + + it('should NOT propagate serverName if not provided', async () => { + const tool = new TestBaseToolInvocation( + {}, + messageBus, + 'test-tool', + 'Test Tool', + // no serverName + ); + + let capturedRequest: ToolConfirmationRequest | undefined; + vi.mocked(messageBus.publish).mockImplementation((request: Message) => { + if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { + capturedRequest = request; + } + }); + + // We need to mock subscribe to avoid hanging if we want to await the promise, + // but for this test we just need to check publish. + // We'll abort to clean up. + const confirmationPromise = tool.shouldConfirmExecute( + abortController.signal, + ); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(messageBus.publish).toHaveBeenCalledTimes(1); + expect(capturedRequest).toBeDefined(); + expect(capturedRequest?.serverName).toBeUndefined(); + + abortController.abort(); + try { + await confirmationPromise; + } catch { + // ignore abort error + } + }); +}); diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 7811888eecc..14c134be0a2 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -89,6 +89,7 @@ describe('mcp-client', () => { } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), + sortTools: vi.fn(), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const client = new McpClient( @@ -153,6 +154,7 @@ describe('mcp-client', () => { } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), + sortTools: vi.fn(), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const client = new McpClient( @@ -237,6 +239,7 @@ describe('mcp-client', () => { const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool); const mockedToolRegistry = { registerTool: vi.fn(), + sortTools: vi.fn(), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const client = new McpClient( @@ -286,6 +289,7 @@ describe('mcp-client', () => { } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), + sortTools: vi.fn(), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const client = new McpClient( @@ -340,6 +344,7 @@ describe('mcp-client', () => { unregisterTool: vi.fn(), getMessageBus: vi.fn().mockReturnValue(undefined), removeMcpToolsByServer: vi.fn(), + sortTools: vi.fn(), } as unknown as ToolRegistry; const mockedPromptRegistry = { registerPrompt: vi.fn(), diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index c5b1dc6caa1..45e481390eb 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -151,6 +151,7 @@ export class McpClient { for (const tool of tools) { this.toolRegistry.registerTool(tool); } + this.toolRegistry.sortTools(); } /** @@ -568,6 +569,7 @@ export async function connectAndDiscover( for (const tool of tools) { toolRegistry.registerTool(tool); } + toolRegistry.sortTools(); } catch (error) { if (mcpClient) { mcpClient.close(); diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 5949bdf5b20..f5c9205e9d5 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -77,7 +77,14 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< // Use composite format for policy checks: serverName__toolName // This enables server wildcards (e.g., "google-workspace__*") // while still allowing specific tool rules - super(params, messageBus, `${serverName}__${serverToolName}`, displayName); + + super( + params, + messageBus, + `${serverName}__${serverToolName}`, + displayName, + serverName, + ); } protected override async getConfirmationDetails( diff --git a/packages/core/src/tools/shell.test.ts b/packages/core/src/tools/shell.test.ts index 4ba6dd83535..7143521eaeb 100644 --- a/packages/core/src/tools/shell.test.ts +++ b/packages/core/src/tools/shell.test.ts @@ -232,30 +232,34 @@ describe('ShellTool', () => { ); }); - itWindowsOnly('should not wrap command on windows', async () => { - vi.mocked(os.platform).mockReturnValue('win32'); - const invocation = shellTool.build({ command: 'dir' }); - const promise = invocation.execute(mockAbortSignal); - resolveShellExecution({ - rawOutput: Buffer.from(''), - output: '', - exitCode: 0, - signal: null, - error: null, - aborted: false, - pid: 12345, - executionMethod: 'child_process', - }); - await promise; - expect(mockShellExecutionService).toHaveBeenCalledWith( - 'dir', - '/test/dir', - expect.any(Function), - mockAbortSignal, - false, - {}, - ); - }); + itWindowsOnly( + 'should not wrap command on windows', + async () => { + vi.mocked(os.platform).mockReturnValue('win32'); + const invocation = shellTool.build({ command: 'dir' }); + const promise = invocation.execute(mockAbortSignal); + resolveShellExecution({ + rawOutput: Buffer.from(''), + output: '', + exitCode: 0, + signal: null, + error: null, + aborted: false, + pid: 12345, + executionMethod: 'child_process', + }); + await promise; + expect(mockShellExecutionService).toHaveBeenCalledWith( + 'dir', + '/test/dir', + expect.any(Function), + mockAbortSignal, + false, + {}, + ); + }, + 20000, + ); it('should format error messages correctly', async () => { const error = new Error('wrapped command failed'); diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index f002250910f..80e9390cce5 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -250,6 +250,51 @@ describe('ToolRegistry', () => { }); }); + describe('sortTools', () => { + it('should sort tools by priority: built-in, discovered, then MCP (by server name)', () => { + const builtIn1 = new MockTool({ name: 'builtin-1' }); + const builtIn2 = new MockTool({ name: 'builtin-2' }); + const discovered1 = new DiscoveredTool( + config, + 'discovered-1', + 'desc', + {}, + ); + const mockCallable = {} as CallableTool; + const mcpZebra = new DiscoveredMCPTool( + mockCallable, + 'zebra-server', + 'mcp-zebra', + 'desc', + {}, + ); + const mcpApple = new DiscoveredMCPTool( + mockCallable, + 'apple-server', + 'mcp-apple', + 'desc', + {}, + ); + + // Register in mixed order + toolRegistry.registerTool(mcpZebra); + toolRegistry.registerTool(discovered1); + toolRegistry.registerTool(builtIn1); + toolRegistry.registerTool(mcpApple); + toolRegistry.registerTool(builtIn2); + + toolRegistry.sortTools(); + + expect(toolRegistry.getAllToolNames()).toEqual([ + 'builtin-1', + 'builtin-2', + 'discovered-1', + 'mcp-apple', + 'mcp-zebra', + ]); + }); + }); + describe('discoverTools', () => { it('should will preserve tool parametersJsonSchema during discovery from command', async () => { const discoveryCommand = 'my-discovery-command'; diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index abb03d53295..c350abfbd27 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -205,6 +205,43 @@ export class ToolRegistry { this.tools.set(tool.name, tool); } + /** + * Sorts tools as: + * 1. Built in tools. + * 2. Discovered tools. + * 3. MCP tools ordered by server name. + * + * This is a stable sort in that ties preseve existing order. + */ + sortTools(): void { + const getPriority = (tool: AnyDeclarativeTool): number => { + if (tool instanceof DiscoveredMCPTool) return 2; + if (tool instanceof DiscoveredTool) return 1; + return 0; // Built-in + }; + + this.tools = new Map( + Array.from(this.tools.entries()).sort((a, b) => { + const toolA = a[1]; + const toolB = b[1]; + const priorityA = getPriority(toolA); + const priorityB = getPriority(toolB); + + if (priorityA !== priorityB) { + return priorityA - priorityB; + } + + if (priorityA === 2) { + const serverA = (toolA as DiscoveredMCPTool).serverName; + const serverB = (toolB as DiscoveredMCPTool).serverName; + return serverA.localeCompare(serverB); + } + + return 0; + }), + ); + } + private removeDiscoveredTools(): void { for (const tool of this.tools.values()) { if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) { diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 98605aea9b1..59d1ef7bafb 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -78,6 +78,7 @@ export abstract class BaseToolInvocation< protected readonly messageBus?: MessageBus, readonly _toolName?: string, readonly _toolDisplayName?: string, + readonly _serverName?: string, ) {} abstract getDescription(): string; @@ -215,6 +216,7 @@ export abstract class BaseToolInvocation< type: MessageBusType.TOOL_CONFIRMATION_REQUEST, toolCall, correlationId, + serverName: this._serverName, }; try { diff --git a/packages/core/src/utils/extensionLoader.ts b/packages/core/src/utils/extensionLoader.ts index b65f2271434..f47ff950150 100644 --- a/packages/core/src/utils/extensionLoader.ts +++ b/packages/core/src/utils/extensionLoader.ts @@ -64,9 +64,12 @@ export abstract class ExtensionLoader { }); try { await this.config.getMcpClientManager()!.startExtension(extension); - // TODO: Move all extension features here, including at least: + // TODO: Update custom command updating away from the event based system + // and call directly into a custom command manager here. See the + // useSlashCommandProcessor hook which responds to events fired here today. + + // TODO: Move all enablement of extension features here, including at least: // - context file loading - // - custom command loading // - excluded tool configuration } finally { this.startCompletedCount++; @@ -116,9 +119,12 @@ export abstract class ExtensionLoader { try { await this.config.getMcpClientManager()!.stopExtension(extension); + // TODO: Update custom command updating away from the event based system + // and call directly into a custom command manager here. See the + // useSlashCommandProcessor hook which responds to events fired here today. + // TODO: Remove all extension features here, including at least: // - context files - // - custom commands // - excluded tools } finally { this.stopCompletedCount++; diff --git a/packages/core/src/utils/googleQuotaErrors.test.ts b/packages/core/src/utils/googleQuotaErrors.test.ts index cc5e5de43a0..34836119a8a 100644 --- a/packages/core/src/utils/googleQuotaErrors.test.ts +++ b/packages/core/src/utils/googleQuotaErrors.test.ts @@ -25,6 +25,44 @@ describe('classifyGoogleError', () => { expect(result).toBe(regularError); }); + it('should return RetryableQuotaError when message contains "Please retry in Xs"', () => { + const complexError = { + error: { + message: + '{"error": {"code": 429, "status": 429, "message": "You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/usage?tab=rate-limit. \n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 2\nPlease retry in 44.097740004s.", "details": [{"detail": "??? to (unknown) : APP_ERROR(8) You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/usage?tab=rate-limit. \n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 2\nPlease retry in 44.097740004s."}]}}', + code: 429, + status: 'Too Many Requests', + }, + }; + const rawError = new Error(JSON.stringify(complexError)); + vi.spyOn(errorParser, 'parseGoogleApiError').mockReturnValue(null); + + const result = classifyGoogleError(rawError); + + expect(result).toBeInstanceOf(RetryableQuotaError); + expect((result as RetryableQuotaError).retryDelayMs).toBe(44097.740004); + expect((result as RetryableQuotaError).message).toBe(rawError.message); + }); + + it('should return RetryableQuotaError when error is a string and message contains "Please retry in Xms"', () => { + const complexErrorString = JSON.stringify({ + error: { + message: + '{"error": {"code": 429, "status": 429, "message": "You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/usage?tab=rate-limit. \n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 2\nPlease retry in 900.2ms.", "details": [{"detail": "??? to (unknown) : APP_ERROR(8) You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/usage?tab=rate-limit. \n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 2\nPlease retry in 900.2ms."}]}}', + code: 429, + status: 'Too Many Requests', + }, + }); + const rawError = new Error(complexErrorString); + vi.spyOn(errorParser, 'parseGoogleApiError').mockReturnValue(null); + + const result = classifyGoogleError(rawError); + + expect(result).toBeInstanceOf(RetryableQuotaError); + expect((result as RetryableQuotaError).retryDelayMs).toBeCloseTo(900.2); + expect((result as RetryableQuotaError).message).toBe(rawError.message); + }); + it('should return original error if code is not 429', () => { const apiError: GoogleApiError = { code: 500, diff --git a/packages/core/src/utils/googleQuotaErrors.ts b/packages/core/src/utils/googleQuotaErrors.ts index 55c5d29a8a2..f09d8b24744 100644 --- a/packages/core/src/utils/googleQuotaErrors.ts +++ b/packages/core/src/utils/googleQuotaErrors.ts @@ -43,16 +43,20 @@ export class RetryableQuotaError extends Error { } /** - * Parses a duration string (e.g., "34.074824224s", "60s") and returns the time in seconds. + * Parses a duration string (e.g., "34.074824224s", "60s", "900ms") and returns the time in seconds. * @param duration The duration string to parse. * @returns The duration in seconds, or null if parsing fails. */ function parseDurationInSeconds(duration: string): number | null { - if (!duration.endsWith('s')) { - return null; + if (duration.endsWith('ms')) { + const milliseconds = parseFloat(duration.slice(0, -2)); + return isNaN(milliseconds) ? null : milliseconds / 1000; } - const seconds = parseFloat(duration.slice(0, -1)); - return isNaN(seconds) ? null : seconds; + if (duration.endsWith('s')) { + const seconds = parseFloat(duration.slice(0, -1)); + return isNaN(seconds) ? null : seconds; + } + return null; } /** @@ -64,6 +68,7 @@ function parseDurationInSeconds(duration: string): number | null { * - If the error suggests a retry delay of more than 2 minutes, it's a `TerminalQuotaError`. * - If the error suggests a retry delay of 2 minutes or less, it's a `RetryableQuotaError`. * - If the error indicates a per-minute limit, it's a `RetryableQuotaError`. + * - If the error message contains the phrase "Please retry in X[s|ms]", it's a `RetryableQuotaError`. * * @param error The error to classify. * @returns A `TerminalQuotaError`, `RetryableQuotaError`, or the original `unknown` error. @@ -72,6 +77,24 @@ export function classifyGoogleError(error: unknown): unknown { const googleApiError = parseGoogleApiError(error); if (!googleApiError || googleApiError.code !== 429) { + // Fallback: try to parse the error message for a retry delay + const errorMessage = error instanceof Error ? error.message : String(error); + const match = errorMessage.match(/Please retry in ([0-9.]+(?:ms|s))/); + if (match?.[1]) { + const retryDelaySeconds = parseDurationInSeconds(match[1]); + if (retryDelaySeconds !== null) { + return new RetryableQuotaError( + errorMessage, + googleApiError ?? { + code: 429, + message: errorMessage, + details: [], + }, + retryDelaySeconds, + ); + } + } + return error; // Not a 429 error we can handle. } diff --git a/packages/core/src/utils/pathReader.test.ts b/packages/core/src/utils/pathReader.test.ts index 45229a678b2..0aa4e308e14 100644 --- a/packages/core/src/utils/pathReader.test.ts +++ b/packages/core/src/utils/pathReader.test.ts @@ -20,7 +20,12 @@ const createMockConfig = ( cwd: string, otherDirs: string[] = [], mockFileService?: FileDiscoveryService, + fileFiltering: { + respectGitIgnore?: boolean; + respectGeminiIgnore?: boolean; + } = {}, ): Config => { + const { respectGitIgnore = true, respectGeminiIgnore = true } = fileFiltering; const workspace = new WorkspaceContext(cwd, otherDirs); const fileSystemService = new StandardFileSystemService(); return { @@ -29,6 +34,8 @@ const createMockConfig = ( getTargetDir: () => cwd, getFileSystemService: () => fileSystemService, getFileService: () => mockFileService, + getFileFilteringRespectGitIgnore: () => respectGitIgnore, + getFileFilteringRespectGeminiIgnore: () => respectGeminiIgnore, } as unknown as Config; }; @@ -333,6 +340,51 @@ describe('readPathFromWorkspace', () => { expect(resultText).not.toContain('invisible'); expect(mockFileService.filterFiles).toHaveBeenCalled(); }); + + it('should pass correct ignore flags to file service for a single file', async () => { + mock({ + [CWD]: { + 'file.txt': 'content', + }, + }); + const mockFileService = { + filterFiles: vi.fn(() => []), + } as unknown as FileDiscoveryService; + const config = createMockConfig(CWD, [], mockFileService, { + respectGitIgnore: false, + respectGeminiIgnore: true, + }); + await readPathFromWorkspace('file.txt', config); + expect(mockFileService.filterFiles).toHaveBeenCalledWith(['file.txt'], { + respectGitIgnore: false, + respectGeminiIgnore: true, + }); + }); + + it('should pass correct ignore flags to file service for a directory', async () => { + mock({ + [CWD]: { + 'my-dir': { + 'file.txt': 'content', + }, + }, + }); + const mockFileService = { + filterFiles: vi.fn((files) => files), + } as unknown as FileDiscoveryService; + const config = createMockConfig(CWD, [], mockFileService, { + respectGitIgnore: true, + respectGeminiIgnore: false, + }); + await readPathFromWorkspace('my-dir', config); + expect(mockFileService.filterFiles).toHaveBeenCalledWith( + [path.join('my-dir', 'file.txt')], + { + respectGitIgnore: true, + respectGeminiIgnore: false, + }, + ); + }); }); it('should throw an error for an absolute path outside the workspace', async () => { diff --git a/packages/core/src/utils/pathReader.ts b/packages/core/src/utils/pathReader.ts index 1b177848dbf..486ce2a8217 100644 --- a/packages/core/src/utils/pathReader.ts +++ b/packages/core/src/utils/pathReader.ts @@ -73,8 +73,8 @@ export async function readPathFromWorkspace( path.relative(config.getTargetDir(), p), ); const filteredFiles = fileService.filterFiles(relativeFiles, { - respectGitIgnore: true, - respectGeminiIgnore: true, + respectGitIgnore: config.getFileFilteringRespectGitIgnore(), + respectGeminiIgnore: config.getFileFilteringRespectGeminiIgnore(), }); const finalFiles = filteredFiles.map((p) => path.resolve(config.getTargetDir(), p), @@ -98,8 +98,8 @@ export async function readPathFromWorkspace( // It's a single file, check if it's ignored. const relativePath = path.relative(config.getTargetDir(), absolutePath); const filtered = fileService.filterFiles([relativePath], { - respectGitIgnore: true, - respectGeminiIgnore: true, + respectGitIgnore: config.getFileFilteringRespectGitIgnore(), + respectGeminiIgnore: config.getFileFilteringRespectGeminiIgnore(), }); if (filtered.length === 0) { diff --git a/packages/core/src/utils/workspaceContext.test.ts b/packages/core/src/utils/workspaceContext.test.ts index c93dffe47f2..01fd6da4982 100644 --- a/packages/core/src/utils/workspaceContext.test.ts +++ b/packages/core/src/utils/workspaceContext.test.ts @@ -83,18 +83,21 @@ describe('WorkspaceContext with real filesystem', () => { expect(directories).toHaveLength(2); }); - it('should handle symbolic links correctly', () => { - const realDir = path.join(tempDir, 'real'); - fs.mkdirSync(realDir, { recursive: true }); - const symlinkDir = path.join(tempDir, 'symlink-to-real'); - fs.symlinkSync(realDir, symlinkDir, 'dir'); - const workspaceContext = new WorkspaceContext(cwd); - workspaceContext.addDirectory(symlinkDir); + it.skipIf(os.platform() === 'win32')( + 'should handle symbolic links correctly', + () => { + const realDir = path.join(tempDir, 'real'); + fs.mkdirSync(realDir, { recursive: true }); + const symlinkDir = path.join(tempDir, 'symlink-to-real'); + fs.symlinkSync(realDir, symlinkDir, 'dir'); + const workspaceContext = new WorkspaceContext(cwd); + workspaceContext.addDirectory(symlinkDir); - const directories = workspaceContext.getDirectories(); + const directories = workspaceContext.getDirectories(); - expect(directories).toEqual([cwd, realDir]); - }); + expect(directories).toEqual([cwd, realDir]); + }, + ); }); describe('path validation', () => { @@ -158,7 +161,7 @@ describe('WorkspaceContext with real filesystem', () => { ); }); - describe('with symbolic link', () => { + describe.skipIf(os.platform() === 'win32')('with symbolic link', () => { describe('in the workspace', () => { let realDir: string; let symlinkDir: string; diff --git a/packages/core/vitest.config.ts b/packages/core/vitest.config.ts index b983891257f..b8027f65126 100644 --- a/packages/core/vitest.config.ts +++ b/packages/core/vitest.config.ts @@ -9,6 +9,7 @@ import { defineConfig } from 'vitest/config'; export default defineConfig({ test: { reporters: ['default', 'junit'], + timeout: 30000, silent: true, setupFiles: ['./test-setup.ts'], outputFile: { diff --git a/packages/test-utils/package.json b/packages/test-utils/package.json index f60db67da71..667c12f6df7 100644 --- a/packages/test-utils/package.json +++ b/packages/test-utils/package.json @@ -1,6 +1,6 @@ { "name": "@google/gemini-cli-test-utils", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "private": true, "main": "src/index.ts", "license": "Apache-2.0", diff --git a/packages/vscode-ide-companion/package.json b/packages/vscode-ide-companion/package.json index 47942f7c47e..bec835965a9 100644 --- a/packages/vscode-ide-companion/package.json +++ b/packages/vscode-ide-companion/package.json @@ -2,7 +2,7 @@ "name": "gemini-cli-vscode-ide-companion", "displayName": "Gemini CLI Companion", "description": "Enable Gemini CLI with direct access to your IDE workspace.", - "version": "0.13.0-nightly.20251031.c89bc30d", + "version": "0.14.0-nightly.20251104.da3da198", "publisher": "google", "icon": "assets/icon.png", "repository": { diff --git a/schemas/settings.schema.json b/schemas/settings.schema.json index 59992ce53f4..055164cd4e1 100644 --- a/schemas/settings.schema.json +++ b/schemas/settings.schema.json @@ -412,6 +412,270 @@ }, "additionalProperties": false }, + "modelConfigs": { + "title": "Model Configs", + "description": "Model configurations.", + "markdownDescription": "Model configurations.\n\n- Category: `Model`\n- Requires restart: `no`\n- Default: `{\"aliases\":{\"base\":{\"modelConfig\":{\"generateContentConfig\":{\"temperature\":0,\"topP\":1}}},\"chat-base\":{\"extends\":\"base\",\"modelConfig\":{\"generateContentConfig\":{\"thinkingConfig\":{\"includeThoughts\":true,\"thinkingBudget\":-1}}}},\"gemini-2.5-pro\":{\"extends\":\"chat-base\",\"modelConfig\":{\"model\":\"gemini-2.5-pro\"}},\"gemini-2.5-flash\":{\"extends\":\"chat-base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash\"}},\"gemini-2.5-flash-lite\":{\"extends\":\"chat-base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\"}},\"classifier\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"maxOutputTokens\":1024,\"thinkingConfig\":{\"thinkingBudget\":512}}}},\"prompt-completion\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"temperature\":0.3,\"maxOutputTokens\":16000,\"thinkingConfig\":{\"thinkingBudget\":0}}}},\"edit-corrector\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"thinkingConfig\":{\"thinkingBudget\":0}}}},\"summarizer-default\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"maxOutputTokens\":2000}}},\"summarizer-shell\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"maxOutputTokens\":2000}}},\"web-search-tool\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash\",\"generateContentConfig\":{\"tools\":[{\"googleSearch\":{}}]}}},\"web-fetch-tool\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash\",\"generateContentConfig\":{\"tools\":[{\"urlContext\":{}}]}}}}}`", + "default": { + "aliases": { + "base": { + "modelConfig": { + "generateContentConfig": { + "temperature": 0, + "topP": 1 + } + } + }, + "chat-base": { + "extends": "base", + "modelConfig": { + "generateContentConfig": { + "thinkingConfig": { + "includeThoughts": true, + "thinkingBudget": -1 + } + } + } + }, + "gemini-2.5-pro": { + "extends": "chat-base", + "modelConfig": { + "model": "gemini-2.5-pro" + } + }, + "gemini-2.5-flash": { + "extends": "chat-base", + "modelConfig": { + "model": "gemini-2.5-flash" + } + }, + "gemini-2.5-flash-lite": { + "extends": "chat-base", + "modelConfig": { + "model": "gemini-2.5-flash-lite" + } + }, + "classifier": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "maxOutputTokens": 1024, + "thinkingConfig": { + "thinkingBudget": 512 + } + } + } + }, + "prompt-completion": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "temperature": 0.3, + "maxOutputTokens": 16000, + "thinkingConfig": { + "thinkingBudget": 0 + } + } + } + }, + "edit-corrector": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "thinkingConfig": { + "thinkingBudget": 0 + } + } + } + }, + "summarizer-default": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "maxOutputTokens": 2000 + } + } + }, + "summarizer-shell": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "maxOutputTokens": 2000 + } + } + }, + "web-search-tool": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash", + "generateContentConfig": { + "tools": [ + { + "googleSearch": {} + } + ] + } + } + }, + "web-fetch-tool": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash", + "generateContentConfig": { + "tools": [ + { + "urlContext": {} + } + ] + } + } + } + } + }, + "type": "object", + "properties": { + "aliases": { + "title": "Model Config Aliases", + "description": "Named presets for model configs. Can be used in place of a model name and can inherit from other aliases using an `extends` property.", + "markdownDescription": "Named presets for model configs. Can be used in place of a model name and can inherit from other aliases using an `extends` property.\n\n- Category: `Model`\n- Requires restart: `no`\n- Default: `{\"base\":{\"modelConfig\":{\"generateContentConfig\":{\"temperature\":0,\"topP\":1}}},\"chat-base\":{\"extends\":\"base\",\"modelConfig\":{\"generateContentConfig\":{\"thinkingConfig\":{\"includeThoughts\":true,\"thinkingBudget\":-1}}}},\"gemini-2.5-pro\":{\"extends\":\"chat-base\",\"modelConfig\":{\"model\":\"gemini-2.5-pro\"}},\"gemini-2.5-flash\":{\"extends\":\"chat-base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash\"}},\"gemini-2.5-flash-lite\":{\"extends\":\"chat-base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\"}},\"classifier\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"maxOutputTokens\":1024,\"thinkingConfig\":{\"thinkingBudget\":512}}}},\"prompt-completion\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"temperature\":0.3,\"maxOutputTokens\":16000,\"thinkingConfig\":{\"thinkingBudget\":0}}}},\"edit-corrector\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"thinkingConfig\":{\"thinkingBudget\":0}}}},\"summarizer-default\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"maxOutputTokens\":2000}}},\"summarizer-shell\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash-lite\",\"generateContentConfig\":{\"maxOutputTokens\":2000}}},\"web-search-tool\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash\",\"generateContentConfig\":{\"tools\":[{\"googleSearch\":{}}]}}},\"web-fetch-tool\":{\"extends\":\"base\",\"modelConfig\":{\"model\":\"gemini-2.5-flash\",\"generateContentConfig\":{\"tools\":[{\"urlContext\":{}}]}}}}`", + "default": { + "base": { + "modelConfig": { + "generateContentConfig": { + "temperature": 0, + "topP": 1 + } + } + }, + "chat-base": { + "extends": "base", + "modelConfig": { + "generateContentConfig": { + "thinkingConfig": { + "includeThoughts": true, + "thinkingBudget": -1 + } + } + } + }, + "gemini-2.5-pro": { + "extends": "chat-base", + "modelConfig": { + "model": "gemini-2.5-pro" + } + }, + "gemini-2.5-flash": { + "extends": "chat-base", + "modelConfig": { + "model": "gemini-2.5-flash" + } + }, + "gemini-2.5-flash-lite": { + "extends": "chat-base", + "modelConfig": { + "model": "gemini-2.5-flash-lite" + } + }, + "classifier": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "maxOutputTokens": 1024, + "thinkingConfig": { + "thinkingBudget": 512 + } + } + } + }, + "prompt-completion": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "temperature": 0.3, + "maxOutputTokens": 16000, + "thinkingConfig": { + "thinkingBudget": 0 + } + } + } + }, + "edit-corrector": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "thinkingConfig": { + "thinkingBudget": 0 + } + } + } + }, + "summarizer-default": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "maxOutputTokens": 2000 + } + } + }, + "summarizer-shell": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash-lite", + "generateContentConfig": { + "maxOutputTokens": 2000 + } + } + }, + "web-search-tool": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash", + "generateContentConfig": { + "tools": [ + { + "googleSearch": {} + } + ] + } + } + }, + "web-fetch-tool": { + "extends": "base", + "modelConfig": { + "model": "gemini-2.5-flash", + "generateContentConfig": { + "tools": [ + { + "urlContext": {} + } + ] + } + } + } + }, + "type": "object", + "additionalProperties": true + }, + "overrides": { + "title": "Model Config Overrides", + "description": "Apply specific configuration overrides based on matches, with a primary key of model (or alias). The most specific match will be used.", + "markdownDescription": "Apply specific configuration overrides based on matches, with a primary key of model (or alias). The most specific match will be used.\n\n- Category: `Model`\n- Requires restart: `no`\n- Default: `[]`", + "default": [], + "type": "array", + "items": {} + } + }, + "additionalProperties": false + }, "context": { "title": "Context", "description": "Settings for managing context provided to the model.",