+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions apps/query-eval/queryeval/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
QueryEvalResult,
QueryEvalResultsFile,
)
from sycamore.data import Document
from sycamore.docset import DocSet
from sycamore.query.client import SycamoreQueryClient, configure_logging

Expand Down Expand Up @@ -88,18 +89,24 @@ def __init__(

# Configure logging.
if self.config.config.log_file:
os.makedirs(os.path.dirname(self.config.config.log_file), exist_ok=True)
os.makedirs(os.path.dirname(os.path.abspath(self.config.config.log_file)), exist_ok=True)
configure_logging(logfile=self.config.config.log_file, log_level=logging.INFO)

if not self.config.config.index:
raise ValueError("Index must be specified")
if not self.config.config.results_file:
raise ValueError("Results file must be specified")

console.print(f"Writing results to: {self.config.config.results_file}")
os.makedirs(os.path.dirname(self.config.config.results_file), exist_ok=True)
if self.config.config.results_file:
console.print(f"Writing results to: {self.config.config.results_file}")
os.makedirs(os.path.dirname(os.path.abspath(self.config.config.results_file)), exist_ok=True)

# Read results file if it exists.
if not self.config.config.overwrite and os.path.exists(self.config.config.results_file):
if (
not self.config.config.overwrite
and self.config.config.results_file
and os.path.exists(self.config.config.results_file)
):
results = self.read_results_file(self.config.config.results_file)
console.print(
f":white_check_mark: Read {len(results.results or [])} "
Expand Down Expand Up @@ -170,11 +177,15 @@ def write_results_file(self):
results_file.write(to_yaml_str(results_file_obj))
console.print(f":white_check_mark: Wrote {len(self.results_map)} results to {self.config.config.results_file}")

def format_docset(self, docset: DocSet) -> List[Dict[str, Any]]:
"""Convert a DocSet query result to a list of dicts."""
def format_doclist(self, doclist: List[Document]) -> List[Dict[str, Any]]:
"""Convert a document list query result to a list of dicts."""
results = []
for doc in docset.take_all():
results.append(doc.data)
for doc in doclist:
if hasattr(doc, "data"):
if hasattr(doc.data, "model_dump"):
results.append(doc.data.model_dump())
else:
results.append(doc.data)
return results

def get_result(self, query: QueryEvalQuery) -> Optional[QueryEvalResult]:
Expand Down Expand Up @@ -267,12 +278,16 @@ def do_query(self, query: QueryEvalQuery, result: QueryEvalResult) -> QueryEvalR
else:
query_result.result = query_result.result.take_all()
t2 = time.time()
result.result = self.format_docset(query_result.result)
result.result = self.format_doclist(query_result.result)
else:
result.result = str(query_result.result)
t2 = time.time()
assert result.metrics
result.metrics.query_time = t2 - t1
try:
result.retrieved_docs = query_result.retrieved_docs()
except Exception:
result.retrieved_docs = None

console.print(f"[green]:clock9: Executed query in {result.metrics.query_time:.2f} seconds")
console.print(f":white_check_mark: Result: {result.result}")
Expand Down
6 changes: 6 additions & 0 deletions apps/query-eval/queryeval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# data/ntsb-queries.yaml \
# run

import tempfile
from typing import Optional, Tuple

import click
Expand Down Expand Up @@ -56,6 +57,11 @@ def cli(
raw_output: bool,
):
ctx.ensure_object(dict)

if not query_cache_path:
query_cache_path = tempfile.mkdtemp()
console.print(f"[yellow]Using query cache path: {query_cache_path}")

driver = QueryEvalDriver(
input_file_path=config_file,
index=index,
Expand Down
89 changes: 89 additions & 0 deletions apps/query-eval/queryeval/test_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
from unittest.mock import MagicMock, patch

from queryeval.driver import QueryEvalDriver
from queryeval.types import (
QueryEvalQuery,
QueryEvalResult,
QueryEvalMetrics,
)
from sycamore.query.result import SycamoreQueryResult
from sycamore.query.logical_plan import LogicalPlan, Node


@pytest.fixture
def mock_client():
client = MagicMock()
client.get_opensearch_schema.return_value = {"field1": "text", "field2": "keyword"}
return client


@pytest.fixture
def mock_plan():
return LogicalPlan(query="test query", nodes={0: Node(node_id=0, node_type="test_op")}, result_node=0)


@pytest.fixture
def test_input_file(tmp_path):
input_file = tmp_path / "test_input.yaml"
input_file.write_text(
"""
config:
index: test-index
results_file: test-results.yaml
queries:
- query: "test query 1"
tags: ["test"]
- query: "test query 2"
"""
)
return str(input_file)


def test_driver_init(test_input_file, mock_client):
with patch("queryeval.driver.SycamoreQueryClient", return_value=mock_client):
driver = QueryEvalDriver(input_file_path=test_input_file, index="test-index", doc_limit=10, tags=["test"])

assert driver.config.config.index == "test-index"
assert driver.config.config.doc_limit == 10
assert driver.config.config.tags == ["test"]
assert len(driver.config.queries) == 2


def test_driver_do_plan(test_input_file, mock_client, mock_plan):
with patch("queryeval.driver.SycamoreQueryClient", return_value=mock_client):
driver = QueryEvalDriver(input_file_path=test_input_file)

query = QueryEvalQuery(query="test query")
result = QueryEvalResult(query=query, metrics=QueryEvalMetrics())

mock_client.generate_plan.return_value = mock_plan

result = driver.do_plan(query, result)
assert result.plan is not None
mock_client.generate_plan.assert_called_once()


def test_driver_do_query(test_input_file, mock_client, mock_plan):
with patch("queryeval.driver.SycamoreQueryClient", return_value=mock_client):
driver = QueryEvalDriver(input_file_path=test_input_file)

query = QueryEvalQuery(query="test query")
result = QueryEvalResult(query=query, plan=mock_plan, metrics=QueryEvalMetrics())

mock_query_result = SycamoreQueryResult(query_id="test", plan=result.plan, result="test result")
mock_client.run_plan.return_value = mock_query_result

result = driver.do_query(query, result)
assert result.result == "test result"


def test_driver_do_eval(test_input_file, mock_client, mock_plan):
with patch("queryeval.driver.SycamoreQueryClient", return_value=mock_client):
driver = QueryEvalDriver(input_file_path=test_input_file)

query = QueryEvalQuery(query="test query", expected_plan=mock_plan)
result = QueryEvalResult(query=query, plan=mock_plan)

result = driver.do_eval(query, result)
assert result.metrics.plan_similarity == 1.0
3 changes: 2 additions & 1 deletion apps/query-eval/queryeval/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This module defines types used for the config, input, and output
# files for the Sycamore Query evaluator.

from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Set

from pydantic import BaseModel

Expand Down Expand Up @@ -68,6 +68,7 @@ class QueryEvalResult(BaseModel):
error: Optional[str] = None
metrics: Optional[QueryEvalMetrics] = None
notes: Optional[str] = None
retrieved_docs: Optional[Set[str]] = None


class QueryEvalResultsFile(BaseModel):
Expand Down
20 changes: 12 additions & 8 deletions apps/query-server/queryserver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
import tempfile
from typing import Annotated, List, Optional
from typing import Annotated, Any, List, Optional, Set

from fastapi import FastAPI, Path
from pydantic import BaseModel
Expand Down Expand Up @@ -41,9 +41,12 @@ class Query(BaseModel):
index: str


def get_index_schema(index: str) -> OpenSearchSchema:
"""Get the schema for the given index."""
return sqclient.get_opensearch_schema(index)
class QueryResult(BaseModel):
"""Result of a query."""

plan: LogicalPlan
result: Any
retrieved_docs: Set[str]


@app.get("/v1/indices")
Expand All @@ -53,7 +56,7 @@ async def list_indices() -> List[Index]:
retval = []
indices = util.get_opensearch_indices()
for index in indices:
index_schema = sqclient.get_opensearch_schema(index)
index_schema = util.get_schema(sqclient, index)
retval.append(Index(index=index, index_schema=index_schema))
return retval

Expand All @@ -64,7 +67,7 @@ async def get_index(
) -> Index:
"""Return details on the given index."""

schema = get_index_schema(index)
schema = util.get_schema(sqclient, index)
return Index(index=index, index_schema=schema)


Expand All @@ -84,8 +87,9 @@ async def run_plan(plan: LogicalPlan) -> SycamoreQueryResult:


@app.post("/v1/query")
async def run_query(query: Query) -> SycamoreQueryResult:
async def run_query(query: Query) -> QueryResult:
"""Generate a plan for the given query, run it, and return the result."""

plan = sqclient.generate_plan(query.query, query.index, util.get_schema(sqclient, query.index))
return sqclient.run_plan(plan)
sqresult = sqclient.run_plan(plan)
return QueryResult(plan=sqresult.plan, result=sqresult.result, retrieved_docs=sqresult.retrieved_docs())
2 changes: 1 addition & 1 deletion lib/sycamore/sycamore/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def execute_iter(self, plan: Node, **kwargs) -> Iterable[Document]:
for d in self.recursive_execute(plan):
yield d
else:
assert False
raise ValueError(f"Unknown execution mode {self._exec_mode}")

plan.traverse(visit=lambda n: n.finalize())

Expand Down
4 changes: 2 additions & 2 deletions lib/sycamore/sycamore/query/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import logging
import os
import uuid
from typing import Any, List, Optional, Union
from typing import List, Optional, Union

import structlog
import yaml
Expand Down Expand Up @@ -225,7 +225,7 @@ def query(
index: str,
dry_run: bool = False,
codegen_mode: bool = False,
) -> Any:
) -> SycamoreQueryResult:
"""Run a query against the given index."""
schema = self.get_opensearch_schema(index)
plan = self.generate_plan(query, index, schema)
Expand Down
40 changes: 39 additions & 1 deletion lib/sycamore/sycamore/query/result.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import io
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Set

from pydantic import BaseModel

import sycamore
from sycamore.query.logical_plan import LogicalPlan
from sycamore import DocSet

Expand Down Expand Up @@ -52,3 +53,40 @@ def to_str(self, limit: int = 100) -> str:
return out.getvalue()
else:
return str(self.result)

def retrieved_docs(self) -> Set[str]:
"""Return a set of Document paths for the documents retrieved by the query."""

context = sycamore.init()

if self.execution is None:
raise ValueError("No execution data available.")

# We want to return the set of documents from the deepest node in the query plan
# that yields "true" documents from the data source. To do this, we recurse up the query
# plan tree, collecting the set of documents from each node that has a trace directory
# and which contain documents with "path" properties.

def get_source_docs(context: sycamore.Context, node_id: int) -> Set[str]:
"""Helper function to recursively retrieve the source document paths for a given node."""
if self.execution is not None and node_id in self.execution:
node_trace_dir = self.execution[node_id].trace_dir
if node_trace_dir:
try:
mds = context.read.materialize(node_trace_dir)
keep = mds.filter(lambda doc: doc.properties.get("path") is not None)
if keep.count() > 0:
return {doc.properties.get("path") for doc in keep.take_all()}
except ValueError:
# This can happen if the materialize directory is empty.
# Ignore and move onto the next node.
pass

# Walk up the tree.
node = self.plan.nodes[node_id]
retval: Set[str] = set()
for input_node_id in node.inputs:
retval = retval.union(get_source_docs(context, input_node_id))
return retval

return get_source_docs(context, self.plan.result_node)
Loading
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载