diff --git a/plugins/file_search.py b/plugins/file_search.py index 5d91a87..febe946 100644 --- a/plugins/file_search.py +++ b/plugins/file_search.py @@ -1,28 +1,18 @@ -import google.generativeai as genai -import yaml -import os +from .gemini_utils import gemini_generate_content import re -def get_gemini_model(): - config_path = os.path.expanduser("config.yaml") - with open(config_path) as f: - config = yaml.safe_load(f) - genai.configure(api_key=config["gemini_api_key"]) - return genai.GenerativeModel('models/gemini-1.5-flash') - -model = get_gemini_model() def handle_find(query: str) -> str: """ Generate precise, executable find commands matching the exact request Returns sanitized, shell-ready commands """ - clean_query = ' '.join(query.strip().split()).lower() + clean_query = " ".join(query.strip().split()).lower() simple_cases = { - 'count all .txt files': 'find . -type f -name "*.txt" | wc -l', - 'sh files': 'find . -type f -name "*.sh"', - 'python or javascript files': r'find . -type f \( -name "*.py" -o -name "*.js" \)', + "count all .txt files": 'find . -type f -name "*.txt" | wc -l', + "sh files": 'find . -type f -name "*.sh"', + "python or javascript files": r'find . -type f \( -name "*.py" -o -name "*.js" \)', } if clean_query in simple_cases: @@ -44,19 +34,17 @@ def handle_find(query: str) -> str: Command: find .""" try: - response = model.generate_content(prompt) - command = response.text.strip() - - command = re.sub(r'^\s*find\s*\.', 'find .', command) - command = re.sub(r'\s+', ' ', command).split('\n')[0].split('#')[0].strip() + command = gemini_generate_content(prompt) + command = re.sub(r"^\s*find\s*\.", "find .", command) + command = re.sub(r"\s+", " ", command).split("\n")[0].split("#")[0].strip() - if not command.startswith('find .'): + if not command.startswith("find ."): command = f"find . {command}" - if any(char in command for char in [';', '&&', '||', '`', '$(']): + if any(char in command for char in [";", "&&", "||", "`", "$("]): raise ValueError("Potential command injection") - if not command.startswith('find . '): + if not command.startswith("find . "): raise ValueError("Command must start with 'find .'") return command @@ -64,6 +52,7 @@ def handle_find(query: str) -> str: print(f"Gemini error: {str(e)}") return _basic_find_handler(query) + def _basic_find_handler(query: str) -> str: """Fallback handler with basic patterns""" patterns = { @@ -75,26 +64,27 @@ def _basic_find_handler(query: str) -> str: "large": "-size +10M", "empty": "-empty", "dir": "-type d", - "file": "-type f" + "file": "-type f", } terms = [] - for word in re.findall(r'[\w\.]+', query.lower()): + for word in re.findall(r"[\w\.]+", query.lower()): if word in patterns: terms.append(patterns[word]) - elif re.match(r'^\..+$', word): # Handle extensions like .sh + elif re.match(r"^\..+$", word): # Handle extensions like .sh terms.append(f"-type f -name '*{word}'") elif word.isnumeric(): terms.append(f"-size +{word}M") return f"find . {' '.join(terms)}" if terms else "find . -type f" + if __name__ == "__main__": test_queries = [ "!find .sh files", "find python files", "locate large javascript files", - "find empty directories" + "find empty directories", ] for query in test_queries: diff --git a/plugins/gemini_utils.py b/plugins/gemini_utils.py new file mode 100644 index 0000000..1f03de4 --- /dev/null +++ b/plugins/gemini_utils.py @@ -0,0 +1,27 @@ +import google.generativeai as genai +import yaml +import os + +models = ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-1.5-flash"] + + +def get_gemini_model(model_name=models[0]): + model_name = f"models/{model_name}" + config_path = os.path.expanduser("config.yaml") + with open(config_path) as f: + config = yaml.safe_load(f) + genai.configure(api_key=config["gemini_api_key"]) + return genai.GenerativeModel(model_name) + + +def gemini_generate_content(prompt, model_names=models): + last_exception = None + for model_name in model_names: + try: + model = get_gemini_model(model_name) + response = model.generate_content(prompt) + return response.text.strip() + except Exception as e: + last_exception = e + continue + raise RuntimeError(f"All Gemini models failed. Last error: {last_exception}") diff --git a/plugins/git_helper.py b/plugins/git_helper.py index fb5487f..20e84e0 100644 --- a/plugins/git_helper.py +++ b/plugins/git_helper.py @@ -1,16 +1,6 @@ -import google.generativeai as genai from typing import List -import yaml -import os +from .gemini_utils import gemini_generate_content -def load_gemini_config(): - config_path = os.path.expanduser("config.yaml") - with open(config_path) as f: - config = yaml.safe_load(f) - genai.configure(api_key=config["gemini_api_key"]) - return genai.GenerativeModel('models/gemini-1.5-flash') - -model = load_gemini_config() def handle_git(args: List[str]) -> str: """ @@ -25,7 +15,7 @@ def handle_git(args: List[str]) -> str: "log": "git log --oneline -n 10", "branch": "git branch -vv", "stash": "git stash list", - "diff": "git diff --cached" + "diff": "git diff --cached", } if args[0] in simple_commands: @@ -40,8 +30,7 @@ def handle_git(args: List[str]) -> str: Command: git """ try: - response = model.generate_content(prompt) - full_command = response.text.strip() + full_command = gemini_generate_content(prompt) if full_command.startswith("git "): return full_command return f"git {full_command}"