diff --git a/sunbeam/ai/prompts.py b/sunbeam/ai/prompts.py index d49f799c..65154c17 100644 --- a/sunbeam/ai/prompts.py +++ b/sunbeam/ai/prompts.py @@ -1,39 +1,46 @@ -RULE_GEN_SYSTEM_PROMPT = ( - "You are an expert Snakemake engineer.\n" - "Generate valid Snakemake `.smk` rules only.\n" - "\nConstraints:\n" - "- Output ONLY rules and required Python blocks for Snakemake.\n" - "- Use canonical sections -- rule NAME:, input:, output:, params:, threads:, conda:, resources:, shell:, log:, benchmark:.\n" - "- Do not include prose, markdown, or triple backticks. However, each rule should include a docstring with the rule's purpose and any additional context.\n" - "- Prefer stable, portable shell commands and reference existing Sunbeam conventions if mentioned.\n" - "- This pipeline extends Sunbeam, the input reads live here: `QC_FP / 'decontam' / '{sample}_{rp}.fastq.gz'`. Default to paired end if there's ambiguity.\n" - "- Name conda envs according to the tool they install or their purpose if there are multiple tools. Always use the `.yaml` extension.\n" - "\nSome common Sunbeam conventions:\n" - "- Other extensions may use similar rules and rule names; avoid collisions by prefixing each rule name with the extension name (e.g., `myext_rule_name`).\n" - "- Use `log: LOG_FP / 'rule_name_{sample}.log'` to capture standard out/err for each sample. Expand over wildcards as necessary to match the output. In the shell command, try to include everything in a subshell and redirect everything that doesn't go into outputs to the log file.\n" - "- Use `benchmark: BENCHMARK_FP / 'rule_name_{sample}.tsv'` to capture resource usage for each sample.\n" - "- You should create a target rule named `myext_all` that depends on all final outputs of the extension.\n" - "\nSome important Sunbeam variables available in rules:\n" - "`Cfg` is a configuration dictionary holding the content of `sunbeam_config.yml`. You will probably not use this nor make your own config. If there are obvious configurable parameters for a rule, define them in code at the top of the file.\n" - "`Samples` is a dictionary of sample metadata, where keys are sample names and values are dictionaries with keys `1` and (optionally) `2` for read file paths.\n" - "`Pairs` is a list. If the run is paired end, it is ['1', '2']. If single end, it is ['1'].\n" - "There are a number of output filepaths defined QC_FP, ASSEMBLY_FP, ANNOTATION_FP, CLASSIFY_FP, MAPPING_FP, VIRUS_FP. All outputs should live in one of these directories. If none of these fit the theme of the new extension, define your own at the top of the file with `SOMETHING_FP = output_subdir(Cfg, 'something')`.\n" -) - - -CONDA_ENV_GEN_SYSTEM_PROMPT = """ +def RULE_GEN_SYSTEM_PROMPT(ext_name: str) -> str: + return f""" + You are an expert Snakemake engineer. + Generate valid Snakemake `.smk` rules only. + + Constraints: + - Output ONLY rules and required Python blocks for Snakemake. + - Use canonical sections -- rule NAME:, input:, output:, params:, threads:, conda:, resources:, shell:, log:, benchmark:. Threads should be defined in the `threads:` field and not in the `params:` field. + - Only use `shell:` for shell commands. Do not use `script:`, `run:`, or `wrapper:` sections. + - Do not include prose, markdown, or triple backticks. However, each rule should include a docstring with the rule's purpose and any additional context. + - Prefer stable, portable shell commands and reference existing Sunbeam conventions if mentioned. + - This pipeline extends Sunbeam, the input reads live here: `QC_FP / 'decontam' / '{{sample}}_{{rp}}.fastq.gz'`. Default to paired end if there's ambiguity. + - Name conda envs according to the tool they install or their purpose if there are multiple tools. Always use the `.yaml` extension. + + Some common Sunbeam conventions: + - Other extensions may use similar rules and rule names; avoid collisions by prefixing each rule name with the extension name (e.g., `{ext_name}_rule_name`). + - Use `log: LOG_FP / 'rule_name_{{sample}}.log'` to capture standard out/err for each sample. Expand over wildcards as necessary to match the output. In the shell command, try to include everything in a subshell and redirect everything that doesn't go into outputs to the log file. + - Use `benchmark: BENCHMARK_FP / 'rule_name_{{sample}}.tsv'` to capture resource usage for each sample. + - You should create a target rule named `{ext_name}_all` that depends on all final outputs of the extension. + + Some important Sunbeam variables available in rules: + `Cfg` is a configuration dictionary holding the content of `sunbeam_config.yml`. You will probably not use this nor make your own config. If there are obvious configurable parameters for a rule, define them in code at the top of the file. + `Samples` is a dictionary of sample metadata, where keys are sample names and values are dictionaries with keys `1` and (optionally) `2` for read file paths. + `Pairs` is a list. If the run is paired end, it is ['1', '2']. If single end, it is ['1']. + At the top of the file, create a new output directory with `{ext_name.upper()}_FP = output_subdir(Cfg, '{ext_name}')` and then put all outputs into subdirectories of this. + """ + + +def CONDA_ENV_GEN_SYSTEM_PROMPT() -> str: + return """ You are an expert Snakemake engineer and bioinformatician. Generate a valid conda environment YAML file to satisfy the dependencies of the following Snakemake rule. Constraints: - Output ONLY a valid conda environment YAML file. +- Do not include prose, markdown, or triple backticks. - Include a name field matching the conda environment name used in the rule's `conda:` section. - Include the `defaults`, `conda-forge`, and `bioconda` channels. - Use bioconda packages where possible. -Examples: +Examples (backticks are for clarity; do not include them in your output): -```yaml +``` name: blast channels: - defaults @@ -43,7 +50,7 @@ - blast ``` -```yaml +``` name: assembly channels: - defaults diff --git a/sunbeam/ai/rule_creator.py b/sunbeam/ai/rule_creator.py index 3a74d96f..5eedde3f 100644 --- a/sunbeam/ai/rule_creator.py +++ b/sunbeam/ai/rule_creator.py @@ -45,6 +45,7 @@ def _default_model_name() -> str: def create_rules_from_prompt( + ext_name: str, prompt: str, *, context_files: Optional[Sequence[Path]] = None, @@ -58,6 +59,7 @@ def create_rules_from_prompt( lazily and raises a clear error if they are not installed. Args: + ext_name: Name of the extension (used for namespacing rules). prompt: User's high-level request describing desired workflow behavior. context_files: Optional paths whose contents provide additional context to the LLM (e.g., existing rules or configs). Read as text. @@ -76,27 +78,30 @@ def create_rules_from_prompt( from langchain_openai import ChatOpenAI from langchain_core.messages import SystemMessage, HumanMessage - system = RULE_GEN_SYSTEM_PROMPT + system = RULE_GEN_SYSTEM_PROMPT(ext_name) # Assemble context payloads. context_blobs: list[str] = [] if context_files: for p in context_files: try: - blob = Path(p).read_text() + blob = p.read_text() except Exception: continue - context_blobs.append(f"CONTEXT FILE {Path(p).name}:\n{blob}") + context_blobs.append(f"CONTEXT FILE {p.name}:\n{blob}") context_text = ("\n\n".join(context_blobs)).strip() + context_text = f"Extension name: {ext_name}\n\n{context_text}" user_content = ( prompt if not context_text else (f"Context:\n{context_text}\n\nTask:\n{prompt}") ) model_name = model or _default_model_name() - llm = ChatOpenAI(model=model_name, api_key=api_key) + llm = ChatOpenAI(model=model_name, api_key=api_key, cache=False) # Run rule generation + print("System prompt: ", system) + print("User prompt: ", user_content) resp = llm.invoke( [ SystemMessage(content=system), @@ -116,7 +121,7 @@ def create_rules_from_prompt( context = f"Generate a conda environment YAML file for the following Snakemake rule:\n\n{rule}" resp = llm.invoke( [ - SystemMessage(content=CONDA_ENV_GEN_SYSTEM_PROMPT), + SystemMessage(content=CONDA_ENV_GEN_SYSTEM_PROMPT()), HumanMessage(content=context), ] ) @@ -132,9 +137,13 @@ def create_rules_from_prompt( out_path.write_text(rules_text) # Write envs envs_path = out_path.parent / "envs" + print("Out path: ", out_path) + print("Envs path: ", envs_path) envs_path.mkdir(exist_ok=True) for env_name, env_text in env_texts.items(): - (envs_path / f"{env_name}.yaml").write_text(env_text) + env_fp = envs_path / f"{env_name}.yaml" + print("Writing env: ", env_fp) + env_fp.write_text(env_text) written_path = out_path return RuleCreationResult( @@ -170,13 +179,19 @@ def get_envs_from_rules( # Extract env names envs = {} for rule in rules: - for line in rule.splitlines(): + lines = rule.splitlines() + for i, line in enumerate(lines): if line.strip().startswith("conda:"): - parts = line.split(":", 1) - if len(parts) == 2: - env_name = parts[1].strip().strip('"').strip("'").strip(".yaml") - envs[env_name] = rule - else: - print("Warning: Malformed conda line:", line) + env_line = lines[i + 1] + env_name = ( + env_line.strip() + .replace('"', "") + .replace("'", "") + .replace(".yaml", "") + .replace(",", "") + .split("/")[-1] + ) + print(env_name) + envs[env_name] = rule return envs diff --git a/sunbeam/scripts/generate.py b/sunbeam/scripts/generate.py index 675e00e9..5ce16303 100644 --- a/sunbeam/scripts/generate.py +++ b/sunbeam/scripts/generate.py @@ -21,7 +21,17 @@ def main(argv=sys.argv): raise SystemExit(f"Extension directory {ext_dir} already exists") rules_path = ext_dir / f"sbx_{ruleset_name}.smk" - result = create_rules_from_prompt(prompt, write_to=rules_path) + + # Default context files to include + context_files = [ + Path(__file__).parent.parent / "workflow" / "Snakefile", + Path(__file__).parent.parent / "workflow" / "rules" / "qc.smk", + Path(__file__).parent.parent / "workflow" / "rules" / "decontam.smk", + ] + + result = create_rules_from_prompt( + ruleset_name, prompt, context_files=context_files, write_to=rules_path + ) logger.info(f"Created extension scaffold at {ext_dir}")