+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions lib/sycamore/sycamore/query/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sycamore.llms.openai import OpenAI, OpenAIModels
from sycamore.query.execution.sycamore_executor import SycamoreExecutor
from sycamore.query.logical_plan import LogicalPlan
from sycamore.query.planner import LlmPlanner, PlannerExample
from sycamore.query.planner import LlmPlanner, PlannerExample, Planner
from sycamore.query.result import SycamoreQueryResult
from sycamore.query.schema import OpenSearchSchema, OpenSearchSchemaFetcher
from sycamore.query.strategy import DefaultQueryPlanStrategy, QueryPlanStrategy
Expand Down Expand Up @@ -131,6 +131,7 @@ def __init__(
sycamore_exec_mode: ExecMode = ExecMode.RAY,
llm: Optional[Union[LLM, str]] = None,
query_plan_strategy: Optional[QueryPlanStrategy] = None,
query_planner: Optional[Planner] = None,
):
from sycamore.connectors.opensearch.utils import OpenSearchClientWithLogging

Expand All @@ -139,6 +140,7 @@ def __init__(
self.cache_dir = cache_dir
self.sycamore_exec_mode = sycamore_exec_mode
self.query_plan_strategy = query_plan_strategy
self.query_planner = query_planner

# TODO: remove these assertions and simplify the code to get all customization via the
# context.
Expand Down Expand Up @@ -194,19 +196,23 @@ def generate_plan(
natural_language_response: Whether to generate a natural language response. If False,
raw data will be returned.
"""
llm_client = self.context.params.get("default", {}).get("llm")
if not llm_client:
llm_client = OpenAI(OpenAIModels.GPT_4O.value, cache=cache_from_path(self.llm_cache_dir))
planner = LlmPlanner(
index,
data_schema=schema,
os_config=self.os_config,
os_client=self._os_client,
llm_client=llm_client,
strategy=self.query_plan_strategy or DefaultQueryPlanStrategy(),
examples=examples,
natural_language_response=natural_language_response,
)
planner: Planner
if self.query_planner is None:
llm_client = self.context.params.get("default", {}).get("llm")
if not llm_client:
llm_client = OpenAI(OpenAIModels.GPT_4O.value, cache=cache_from_path(self.llm_cache_dir))
planner = LlmPlanner(
index,
data_schema=schema,
os_config=self.os_config,
os_client=self._os_client,
llm_client=llm_client,
strategy=self.query_plan_strategy or DefaultQueryPlanStrategy(),
examples=examples,
natural_language_response=natural_language_response,
)
else:
planner = self.query_planner
plan = planner.plan(query)
return plan

Expand Down
9 changes: 8 additions & 1 deletion lib/sycamore/sycamore/query/planner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import typing
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -368,7 +369,13 @@ def __init__(self, schema: Union[OpenSearchSchema, Schema], plan: LogicalPlan) -
]


class LlmPlanner:
class Planner:
@abstractmethod
def plan(self, question: str) -> LogicalPlan:
raise NotImplementedError


class LlmPlanner(Planner):
"""The top-level query planner for SycamoreQuery. This class is responsible for generating
a logical query plan from a user query using the OpenAI LLM.

Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载