diff --git a/frontends/search/src/components/ResultsPage.tsx b/frontends/search/src/components/ResultsPage.tsx index ac5f6311a4..d055044e3c 100644 --- a/frontends/search/src/components/ResultsPage.tsx +++ b/frontends/search/src/components/ResultsPage.tsx @@ -251,6 +251,7 @@ const ResultsPage = (props: ResultsPageProps) => { if (!dataset) return; let sort_by; + let mmr; if (isSortBySearchType(props.search.debounced.sort_by)) { props.search.debounced.sort_by.rerank_type != "" @@ -262,6 +263,12 @@ const ResultsPage = (props: ResultsPageProps) => { : (sort_by = undefined); } + if (!props.search.debounced.mmr.use_mmr) { + mmr = undefined; + } else { + mmr = props.search.debounced.mmr; + } + const query = props.search.debounced.multiQueries.length > 0 ? props.search.debounced.multiQueries @@ -280,6 +287,7 @@ const ResultsPage = (props: ResultsPageProps) => { score_threshold: props.search.debounced.scoreThreshold, sort_options: { sort_by: sort_by, + mmr: mmr, }, slim_chunks: props.search.debounced.slimChunks ?? false, page_size: props.search.debounced.pageSize ?? 10, diff --git a/frontends/search/src/components/SearchForm.tsx b/frontends/search/src/components/SearchForm.tsx index b27c2c153e..f911348179 100644 --- a/frontends/search/src/components/SearchForm.tsx +++ b/frontends/search/src/components/SearchForm.tsx @@ -1193,6 +1193,46 @@ const SearchForm = (props: { }} /> +
+ + { + setTempSearchValues((prev) => { + return { + ...prev, + mmr: { + ...prev.mmr, + use_mmr: e.target.checked, + }, + }; + }); + }} + /> +
+
+ + { + setTempSearchValues((prev) => { + return { + ...prev, + mmr: { + ...prev.mmr, + mmr_lambda: parseFloat( + e.currentTarget.value, + ), + }, + }; + }); + }} + /> +
Search Refinement
diff --git a/frontends/search/src/hooks/useSearch.ts b/frontends/search/src/hooks/useSearch.ts index 00e37d8665..0bd12fab87 100644 --- a/frontends/search/src/hooks/useSearch.ts +++ b/frontends/search/src/hooks/useSearch.ts @@ -68,6 +68,10 @@ export interface SearchOptions { prioritize_domain_specifc_words: boolean | null; disableOnWords: string[]; sort_by: SortByField | SortBySearchType; + mmr: { + use_mmr: boolean; + mmr_lambda?: number; + }; pageSize: number; getTotalPages: boolean; highlightResults: boolean; @@ -98,6 +102,9 @@ const initalState: SearchOptions = { sort_by: { field: "", }, + mmr: { + use_mmr: false, + }, pageSize: 10, getTotalPages: true, correctTypos: false, @@ -146,6 +153,7 @@ const fromStateToParams = (state: SearchOptions): Params => { oneTypoWordRangeMax: state.oneTypoWordRangeMax?.toString() ?? "6", twoTypoWordRangeMin: state.twoTypoWordRangeMin.toString(), twoTypoWordRangeMax: state.twoTypoWordRangeMax?.toString() ?? "", + mmr: JSON.stringify(state.mmr), prioritize_domain_specifc_words: state.prioritize_domain_specifc_words?.toString() ?? "", disableOnWords: state.disableOnWords.join(","), @@ -189,6 +197,11 @@ const fromParamsToState = ( initalState.sort_by, pageSize: parseInt(params.pageSize ?? "10"), getTotalPages: (params.getTotalPages ?? "true") === "true", + mmr: + (JSON.parse(params.mmr ?? "{}") as { + use_mmr: boolean; + mmr_lambda?: number; + }) ?? initalState.mmr, correctTypos: (params.correctTypos ?? "false") === "true", oneTypoWordRangeMin: parseInt(params.oneTypoWordRangeMin ?? "4"), oneTypoWordRangeMax: parseIntOrNull(params.oneTypoWordRangeMax), diff --git a/server/src/data/models.rs b/server/src/data/models.rs index 79ffdb8a84..839b587c61 100644 --- a/server/src/data/models.rs +++ b/server/src/data/models.rs @@ -3517,6 +3517,7 @@ impl ApiKeyRequestParams { new_message_content: payload.new_message_content, topic_id: payload.topic_id, user_id: payload.user_id, + sort_options: payload.sort_options, highlight_options: self.highlight_options.or(payload.highlight_options), search_type: self.search_type.or(payload.search_type), use_group_search: payload.use_group_search, @@ -6667,6 +6668,17 @@ pub struct SortOptions { pub use_weights: Option, /// Tag weights is a JSON object which can be used to boost the ranking of chunks with certain tags. This is useful for when you want to be able to bias towards chunks with a certain tag on the fly. The keys are the tag names and the values are the weights. pub tag_weights: Option>, + /// Set use_mmr to true to use the Maximal Marginal Relevance algorithm to rerank the results. If not specified, this defaults to false. + pub mmr: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, ToSchema, Default)] +/// MMR Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks. +pub struct MmrOptions { + /// Set use_mmr to true to use the Maximal Marginal Relevance algorithm to rerank the results. + pub use_mmr: bool, + /// Set mmr_lambda to a value between 0.0 and 1.0 to control the tradeoff between relevance and diversity. Closer to 1.0 will give more diverse results, closer to 0.0 will give more relevant results. If not specified, this defaults to 0.5. + pub mmr_lambda: Option, } #[derive(Serialize, Deserialize, Debug, Clone, ToSchema, Default)] @@ -6787,6 +6799,9 @@ fn extract_sort_highlight_options( if let Some(value) = other.remove("tag_weights") { sort_options.tag_weights = serde_json::from_value(value).ok(); } + if let Some(value) = other.remove("mmr") { + sort_options.mmr = serde_json::from_value(value).ok(); + } // Extract highlight options if let Some(value) = other.remove("highlight_results") { @@ -6815,6 +6830,7 @@ fn extract_sort_highlight_options( && sort_options.location_bias.is_none() && sort_options.use_weights.is_none() && sort_options.tag_weights.is_none() + && sort_options.mmr.is_none() { None } else { @@ -7140,6 +7156,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload { pub search_type: Option, pub concat_user_messages_query: Option, pub search_query: Option, + pub sort_options: Option, pub page_size: Option, pub filters: Option, pub score_threshold: Option, @@ -7169,6 +7186,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload { new_message_content: helper.new_message_content, topic_id: helper.topic_id, highlight_options, + sort_options: helper.sort_options, search_type: helper.search_type, use_group_search: helper.use_group_search, concat_user_messages_query: helper.concat_user_messages_query, @@ -7195,6 +7213,8 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload { pub highlight_options: Option, pub search_type: Option, pub concat_user_messages_query: Option, + pub sort_options: Option, + pub search_query: Option, pub page_size: Option, pub filters: Option, @@ -7224,6 +7244,7 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload { Ok(RegenerateMessageReqPayload { topic_id: helper.topic_id, highlight_options, + sort_options: helper.sort_options, search_type: helper.search_type, concat_user_messages_query: helper.concat_user_messages_query, search_query: helper.search_query, @@ -7251,6 +7272,8 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload { pub new_message_content: String, pub highlight_options: Option, pub search_type: Option, + pub sort_options: Option, + pub use_group_search: Option, pub concat_user_messages_query: Option, pub search_query: Option, @@ -7281,6 +7304,7 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload { Ok(EditMessageReqPayload { topic_id: helper.topic_id, message_sort_order: helper.message_sort_order, + sort_options: helper.sort_options, new_message_content: helper.new_message_content, highlight_options, search_type: helper.search_type, diff --git a/server/src/handlers/message_handler.rs b/server/src/handlers/message_handler.rs index c7758eb1ee..24985e8706 100644 --- a/server/src/handlers/message_handler.rs +++ b/server/src/handlers/message_handler.rs @@ -6,7 +6,7 @@ use crate::{ data::models::{ self, ChunkMetadata, ChunkMetadataStringTagSet, ChunkMetadataTypes, ContextOptions, DatasetAndOrgWithSubAndPlan, DatasetConfiguration, HighlightOptions, LLMOptions, Pool, - QdrantChunkMetadata, RedisPool, SearchMethod, SuggestType, + QdrantChunkMetadata, RedisPool, SearchMethod, SortOptions, SuggestType, }, errors::ServiceError, get_env, @@ -98,6 +98,8 @@ pub struct CreateMessageReqPayload { pub search_query: Option, /// Page size is the number of chunks to fetch during RAG. If 0, then no search will be performed. If specified, this will override the N retrievals to include in the dataset configuration. Default is None. pub page_size: Option, + /// Sort Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks. + pub sort_options: Option, /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. pub filters: Option, /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. @@ -349,6 +351,8 @@ pub struct RegenerateMessageReqPayload { pub search_query: Option, /// Page size is the number of chunks to fetch during RAG. If 0, then no search will be performed. If specified, this will override the N retrievals to include in the dataset configuration. Default is None. pub page_size: Option, + /// Sort Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks. + pub sort_options: Option, /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. pub filters: Option, /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. @@ -381,6 +385,8 @@ pub struct EditMessageReqPayload { pub concat_user_messages_query: Option, /// Query is the search query. This can be any string. The search_query will be used to create a dense embedding vector and/or sparse vector which will be used to find the result set. If not specified, will default to the last user message or HyDE if HyDE is enabled in the dataset configuration. Default is None. pub search_query: Option, + /// Sort Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks. + pub sort_options: Option, /// Page size is the number of chunks to fetch during RAG. If 0, then no search will be performed. If specified, this will override the N retrievals to include in the dataset configuration. Default is None. pub page_size: Option, /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. @@ -404,6 +410,7 @@ impl From for CreateMessageReqPayload { topic_id: data.topic_id, highlight_options: data.highlight_options, search_type: data.search_type, + sort_options: data.sort_options, use_group_search: data.use_group_search, concat_user_messages_query: data.concat_user_messages_query, search_query: data.search_query, @@ -426,6 +433,7 @@ impl From for CreateMessageReqPayload { highlight_options: data.highlight_options, search_type: data.search_type, use_group_search: data.use_group_search, + sort_options: data.sort_options, concat_user_messages_query: data.concat_user_messages_query, search_query: data.search_query, page_size: data.page_size, diff --git a/server/src/lib.rs b/server/src/lib.rs index 847987adbb..b36d398198 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -494,6 +494,7 @@ impl Modify for SecurityAddon { data::models::OrganizationUsageCount, data::models::Dataset, data::models::DatasetAndUsage, + data::models::MmrOptions, data::models::DatasetUsageCount, data::models::DatasetDTO, data::models::DatasetUsageCount, diff --git a/server/src/operators/chunk_operator.rs b/server/src/operators/chunk_operator.rs index a6aff80f80..71997ec200 100644 --- a/server/src/operators/chunk_operator.rs +++ b/server/src/operators/chunk_operator.rs @@ -138,7 +138,7 @@ pub struct ChunkMetadataWithQdrantId { pub qdrant_id: uuid::Uuid, } -pub async fn get_chunk_metadatas_and_collided_chunks_from_point_ids_query( +pub async fn get_chunk_metadatas_from_point_ids_query( point_ids: Vec, pool: web::Data, ) -> Result, ServiceError> { diff --git a/server/src/operators/message_operator.rs b/server/src/operators/message_operator.rs index 162d0574d5..985d6d803e 100644 --- a/server/src/operators/message_operator.rs +++ b/server/src/operators/message_operator.rs @@ -348,6 +348,7 @@ pub async fn get_rag_chunks_query( .page_size .unwrap_or(n_retrievals_to_include.try_into().unwrap_or(8)), ), + sort_options: create_message_req_payload.sort_options, highlight_options: create_message_req_payload.highlight_options, filters: create_message_req_payload.filters, group_size: Some(1), @@ -453,6 +454,7 @@ pub async fn get_rag_chunks_query( search_type: search_type.clone(), query: QueryTypes::Single(query.clone()), score_threshold: create_message_req_payload.score_threshold, + sort_options: create_message_req_payload.sort_options, page_size: Some( create_message_req_payload .page_size diff --git a/server/src/operators/qdrant_operator.rs b/server/src/operators/qdrant_operator.rs index 15362a7952..281230c39a 100644 --- a/server/src/operators/qdrant_operator.rs +++ b/server/src/operators/qdrant_operator.rs @@ -1,6 +1,6 @@ use super::{ group_operator::get_groups_from_group_ids_query, - search_operator::{assemble_qdrant_filter, SearchResult}, + search_operator::{assemble_qdrant_filter, SearchResult, SearchResultTrait}, }; use crate::{ data::models::{ @@ -17,11 +17,11 @@ use itertools::Itertools; use qdrant_client::{ qdrant::{ group_id::Kind, point_id::PointIdOptions, quantization_config::Quantization, query, - BinaryQuantization, CreateCollectionBuilder, CreateFieldIndexCollectionBuilder, - DeleteFieldIndexCollectionBuilder, DeletePointsBuilder, Distance, FieldType, Filter, - GetPointsBuilder, HnswConfigDiff, OrderBy, PointId, PointStruct, PrefetchQuery, - QuantizationConfig, Query, QueryBatchPoints, QueryPointGroups, QueryPoints, - RecommendPointGroups, RecommendPoints, RecommendStrategy, RetrievedPoint, + vectors::VectorsOptions, BinaryQuantization, CreateCollectionBuilder, + CreateFieldIndexCollectionBuilder, DeleteFieldIndexCollectionBuilder, DeletePointsBuilder, + Distance, FieldType, Filter, GetPointsBuilder, HnswConfigDiff, OrderBy, PointId, + PointStruct, PrefetchQuery, QuantizationConfig, Query, QueryBatchPoints, QueryPointGroups, + QueryPoints, RecommendPointGroups, RecommendPoints, RecommendStrategy, RetrievedPoint, ScrollPointsBuilder, SearchBatchPoints, SearchParams, SearchPointGroups, SearchPoints, SetPayloadPointsBuilder, SparseIndexConfig, SparseVectorConfig, SparseVectorParams, TextIndexParamsBuilder, TokenizerType, UpsertPointsBuilder, UuidIndexParamsBuilder, Value, @@ -766,6 +766,34 @@ pub struct GroupSearchResults { pub hits: Vec, } +impl SearchResultTrait for GroupSearchResults { + fn score(&self) -> f32 { + self.hits.get(0).map_or(0.0, |hit| hit.score) + } + + fn point_id(&self) -> uuid::Uuid { + self.hits + .get(0) + .map_or(uuid::Uuid::default(), |hit| hit.point_id) + } + + fn payload(&self) -> HashMap { + self.hits + .get(0) + .map_or(HashMap::new(), |hit| hit.payload.clone()) + } + + fn set_score(&mut self, score: f32) { + if let Some(hit) = self.hits.get_mut(0) { + hit.score = score; + } + } + + fn embedding(&self) -> Option> { + self.hits.get(0).and_then(|hit| hit.embedding.clone()) + } +} + #[derive(Debug, Clone)] pub enum VectorType { SpladeSparse(Vec<(u32, f32)>), @@ -790,6 +818,7 @@ pub async fn search_over_groups_qdrant_query( queries: Vec, dataset_config: DatasetConfiguration, get_total_pages: bool, + use_mmr: bool, ) -> Result<(Vec, u64), ServiceError> { if queries.is_empty() || queries.iter().all(|query| query.limit == 0) { return Ok((vec![], 0)); @@ -801,6 +830,8 @@ pub async fn search_over_groups_qdrant_query( .max() .unwrap_or(3); + let get_payload = dataset_config.QDRANT_ONLY; + let limit = queries.iter().map(|query| query.limit).max().unwrap_or(10); let qdrant_collection = get_qdrant_collection_from_dataset_config(&dataset_config); @@ -847,8 +878,8 @@ pub async fn search_over_groups_qdrant_query( using: vector_name, query: Some(qdrant_query), score_threshold, - with_payload: Some(WithPayloadSelector::from(false)), - with_vectors: Some(WithVectorsSelector::from(false)), + with_payload: Some(WithPayloadSelector::from(get_payload)), + with_vectors: Some(WithVectorsSelector::from(use_mmr)), timeout: Some(60), filter: Some(query.filter.clone()), params: Some(SearchParams { @@ -901,6 +932,18 @@ pub async fn search_over_groups_qdrant_query( score: hit.score, point_id: uuid::Uuid::parse_str(&id).ok()?, payload: hit.payload.clone(), + embedding: hit.vectors.clone().map(|v| match v.vectors_options { + Some(VectorsOptions::Vectors(named_v)) => named_v + .vectors + .into_iter() + .filter(|v| v.1.indices.is_none()) + .map(|v| v.1.data) + .collect::>() + .get(0) + .unwrap_or(&vec![]) + .clone(), + _ => vec![], + }), }), PointIdOptions::Num(_) => None, }) @@ -1014,12 +1057,13 @@ pub async fn search_qdrant_query( queries: Vec, dataset_config: DatasetConfiguration, get_total_pages: bool, + use_mmr: bool, ) -> Result<(Vec, u64, Vec), ServiceError> { if queries.is_empty() || queries.iter().all(|query| query.limit == 0) { return Ok((vec![], 0, vec![])); } - let qdrant_only = dataset_config.QDRANT_ONLY; + let get_payload = dataset_config.QDRANT_ONLY; let qdrant_collection = get_qdrant_collection_from_dataset_config(&dataset_config); @@ -1055,46 +1099,24 @@ pub async fn search_qdrant_query( _ => query.score_threshold, }; - if qdrant_only { - QueryPoints { - collection_name: qdrant_collection.to_string(), - limit: Some(query.limit), - offset: Some(offset), - prefetch, - using: vector_name, - query: Some(qdrant_query), - score_threshold, - with_payload: Some(WithPayloadSelector::from(true)), - with_vectors: Some(WithVectorsSelector::from(false)), - timeout: Some(60), - filter: Some(query.filter.clone()), - params: Some(SearchParams { - exact: Some(false), - indexed_only: Some(dataset_config.INDEXED_ONLY), - ..Default::default() - }), - ..Default::default() - } - } else { - QueryPoints { - collection_name: qdrant_collection.to_string(), - limit: Some(query.limit), - offset: Some(offset), - prefetch, - using: vector_name, - query: Some(qdrant_query), - score_threshold, - with_payload: Some(WithPayloadSelector::from(false)), - with_vectors: Some(WithVectorsSelector::from(false)), - timeout: Some(60), - filter: Some(query.filter.clone()), - params: Some(SearchParams { - exact: Some(false), - indexed_only: Some(dataset_config.INDEXED_ONLY), - ..Default::default() - }), + QueryPoints { + collection_name: qdrant_collection.to_string(), + limit: Some(query.limit), + offset: Some(offset), + prefetch, + using: vector_name, + query: Some(qdrant_query), + score_threshold, + with_payload: Some(WithPayloadSelector::from(get_payload)), + with_vectors: Some(WithVectorsSelector::from(use_mmr)), + timeout: Some(60), + filter: Some(query.filter.clone()), + params: Some(SearchParams { + exact: Some(false), + indexed_only: Some(dataset_config.INDEXED_ONLY), ..Default::default() - } + }), + ..Default::default() } }) .collect::>(); @@ -1135,6 +1157,20 @@ pub async fn search_qdrant_query( score: scored_point.score, point_id: uuid::Uuid::parse_str(&id).ok()?, payload: scored_point.payload.clone(), + embedding: scored_point.vectors.clone().map(|v| { + match v.vectors_options { + Some(VectorsOptions::Vectors(named_v)) => named_v + .vectors + .into_iter() + .filter(|v| v.1.indices.is_none()) + .map(|v| v.1.data) + .collect::>() + .get(0) + .unwrap_or(&vec![]) + .clone(), + _ => vec![], + } + }), }), PointIdOptions::Num(_) => None, }, @@ -1165,7 +1201,7 @@ pub async fn recommend_qdrant_query( dataset_id: uuid::Uuid, dataset_config: DatasetConfiguration, pool: web::Data, -) -> Result, ServiceError> { +) -> Result, ServiceError> { let qdrant_collection = get_qdrant_collection_from_dataset_config(&dataset_config); let recommend_strategy = match strategy { @@ -1255,12 +1291,14 @@ pub async fn recommend_qdrant_query( } }; - Some(QdrantRecommendResult { + Some(SearchResult { point_id, score: point.score, + payload: point.payload.clone(), + embedding: None, }) }) - .collect::>(); + .collect::>(); Ok(recommended_point_ids) } @@ -1381,6 +1419,7 @@ pub async fn recommend_qdrant_groups_query( score: hit.score, point_id: uuid::Uuid::parse_str(&id).ok()?, payload: hit.payload.clone(), + embedding: None, }), PointIdOptions::Num(_) => None, }) @@ -1969,6 +2008,7 @@ pub async fn scroll_dataset_points( score: 0 as f32, point_id, payload, + embedding: None, }) }) .collect::>(); diff --git a/server/src/operators/search_operator.rs b/server/src/operators/search_operator.rs index 3a27a3dd40..c0cbb7e6f3 100644 --- a/server/src/operators/search_operator.rs +++ b/server/src/operators/search_operator.rs @@ -1,7 +1,7 @@ use super::chunk_operator::{ - get_chunk_metadatas_and_collided_chunks_from_point_ids_query, - get_content_chunk_from_point_ids_query, get_highlights, get_highlights_with_exact_match, - get_qdrant_ids_from_chunk_ids_query, get_slim_chunks_from_point_ids_query, HighlightStrategy, + get_chunk_metadatas_from_point_ids_query, get_content_chunk_from_point_ids_query, + get_highlights, get_highlights_with_exact_match, get_qdrant_ids_from_chunk_ids_query, + get_slim_chunks_from_point_ids_query, HighlightStrategy, }; use super::group_operator::{ get_group_ids_from_tracking_ids_query, get_groups_from_group_ids_query, @@ -17,9 +17,9 @@ use super::typo_operator::correct_query; use crate::data::models::{ convert_to_date_time, ChunkGroup, ChunkGroupAndFileId, ChunkMetadata, ChunkMetadataStringTagSet, ChunkMetadataTypes, ConditionType, ContentChunkMetadata, Dataset, - DatasetConfiguration, HasIDCondition, QdrantChunkMetadata, QdrantSortBy, QueryTypes, - ReRankOptions, RedisPool, ScoreChunk, ScoreChunkDTO, SearchMethod, SlimChunkMetadata, - SortByField, SortBySearchType, SortOptions, UnifiedId, + DatasetConfiguration, HasIDCondition, MmrOptions, QdrantChunkMetadata, QdrantSortBy, + QueryTypes, ReRankOptions, RedisPool, ScoreChunk, ScoreChunkDTO, SearchMethod, + SlimChunkMetadata, SortByField, SortBySearchType, SortOptions, UnifiedId, }; use crate::handlers::chunk_handler::{ AutocompleteReqPayload, ChunkFilter, CountChunkQueryResponseBody, CountChunksReqPayload, @@ -48,11 +48,42 @@ use simple_server_timing_header::Timer; use std::collections::{HashMap, HashSet}; use utoipa::ToSchema; +pub trait SearchResultTrait { + fn score(&self) -> f32; + fn set_score(&mut self, score: f32); + fn point_id(&self) -> uuid::Uuid; + fn payload(&self) -> HashMap; + fn embedding(&self) -> Option>; +} + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct SearchResult { pub score: f32, pub point_id: uuid::Uuid, pub payload: HashMap, + pub embedding: Option>, +} + +impl SearchResultTrait for SearchResult { + fn score(&self) -> f32 { + self.score + } + + fn set_score(&mut self, score: f32) { + self.score = score; + } + + fn point_id(&self) -> uuid::Uuid { + self.point_id + } + + fn payload(&self) -> HashMap { + self.payload.clone() + } + + fn embedding(&self) -> Option> { + self.embedding.clone() + } } #[derive(Serialize, Deserialize, Clone, Debug)] @@ -439,21 +470,96 @@ impl RetrievePointQuery { } } -#[allow(clippy::too_many_arguments)] +fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let norm_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let norm_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + + if norm_a == 0.0 || norm_b == 0.0 { + 0.0 + } else { + dot_product / (norm_a * norm_b) + } +} + +pub fn apply_mmr( + mut docs: Vec, + lambda: f32, + max_results: usize, +) -> Vec { + if docs.is_empty() || docs.iter().any(|doc| doc.embedding().is_none()) { + return vec![]; + } + + let mut selected_indices = Vec::with_capacity(max_results); + let mut remaining_indices: Vec = (0..docs.len()).collect(); + + let (first_idx_pos, &first_idx) = remaining_indices + .iter() + .enumerate() + .max_by(|(_, &a), (_, &b)| docs[a].score().partial_cmp(&docs[b].score()).unwrap()) + .unwrap(); + + selected_indices.push(first_idx); + remaining_indices.remove(first_idx_pos); + + // Iteratively select documents + while selected_indices.len() < max_results && !remaining_indices.is_empty() { + let mut best_score = f32::NEG_INFINITY; + let mut best_idx_pos = 0; + + // Calculate MMR score for each remaining document + for (idx_pos, &idx) in remaining_indices.iter().enumerate() { + // Calculate similarity to already selected documents + let max_similarity = selected_indices + .iter() + .map(|&sel_idx| { + cosine_similarity( + docs[idx].embedding().as_ref().unwrap().as_slice(), + docs[sel_idx].embedding().as_ref().unwrap().as_slice(), + ) + }) + .fold(f32::NEG_INFINITY, |a, b| a.max(b)); + + // Calculate MMR score + let mmr_score = lambda * docs[idx].score() * (1.0 - (1.0 - lambda) * max_similarity); + + docs[idx].set_score(mmr_score); + + if mmr_score > best_score { + best_score = mmr_score; + best_idx_pos = idx_pos; + } + } + + selected_indices.push(remaining_indices[best_idx_pos]); + remaining_indices.remove(best_idx_pos); + } + log::info!("Selected indices: {:?}", selected_indices); + // Return document IDs in selection order + selected_indices + .iter() + .map(|&idx| docs[idx].clone()) + .collect() +} pub async fn retrieve_qdrant_points_query( qdrant_searches: Vec, page: u64, + mmr_options: Option, get_total_pages: bool, config: &DatasetConfiguration, ) -> Result { let page = if page == 0 { 1 } else { page }; + let use_mmr = mmr_options.is_some() && mmr_options.as_ref().unwrap().use_mmr; + let (point_ids, count, batch_lengths) = search_qdrant_query( page, qdrant_searches.clone(), config.clone(), get_total_pages, + use_mmr, ) .await?; @@ -991,15 +1097,18 @@ pub struct SearchOverGroupsQueryResult { pub async fn retrieve_group_qdrant_points_query( qdrant_searches: Vec, page: u64, + mmr_options: Option, get_total_pages: bool, config: &DatasetConfiguration, ) -> Result { let page = if page == 0 { 1 } else { page }; + let use_mmr = mmr_options.is_some() && mmr_options.as_ref().unwrap().use_mmr; let (point_ids, count) = search_over_groups_qdrant_query( page, qdrant_searches.clone(), config.clone(), get_total_pages, + use_mmr, ) .await?; @@ -1125,15 +1234,11 @@ pub async fn retrieve_chunks_for_groups( .flat_map(|hit| hit.hits.iter().map(|point| point.point_id).collect_vec()) .collect_vec(); - let metadata_chunks = match data.slim_chunks.unwrap_or(false) - && data.search_type != SearchMethod::Hybrid - { - true => get_slim_chunks_from_point_ids_query(point_ids, pool.clone()).await?, - _ => { - get_chunk_metadatas_and_collided_chunks_from_point_ids_query(point_ids, pool.clone()) - .await? - } - }; + let metadata_chunks = + match data.slim_chunks.unwrap_or(false) && data.search_type != SearchMethod::Hybrid { + true => get_slim_chunks_from_point_ids_query(point_ids, pool.clone()).await?, + _ => get_chunk_metadatas_from_point_ids_query(point_ids, pool.clone()).await?, + }; let groups = get_groups_from_group_ids_query( search_over_groups_query_result @@ -1280,10 +1385,7 @@ pub async fn get_metadata_from_groups( let chunk_metadatas = match slim_chunks { Some(true) => get_slim_chunks_from_point_ids_query(point_ids, pool.clone()).await?, - _ => { - get_chunk_metadatas_and_collided_chunks_from_point_ids_query(point_ids, pool.clone()) - .await? - } + _ => get_chunk_metadatas_from_point_ids_query(point_ids, pool.clone()).await?, }; let groups = get_groups_from_group_ids_query( @@ -1377,8 +1479,7 @@ pub async fn retrieve_chunks_from_point_ids( } else if data.content_only.unwrap_or(false) { get_content_chunk_from_point_ids_query(point_ids, pool.clone()).await? } else { - get_chunk_metadatas_and_collided_chunks_from_point_ids_query(point_ids, pool.clone()) - .await? + get_chunk_metadatas_from_point_ids_query(point_ids, pool.clone()).await? }; let timer = if let Some(timer) = timer { @@ -1494,6 +1595,7 @@ pub async fn retrieve_chunks_from_point_ids( pub fn rerank_chunks( chunks: Vec, + search_results: Vec, sort_options: Option, ) -> Vec { let mut reranked_chunks = Vec::new(); @@ -1619,6 +1721,31 @@ pub fn rerank_chunks( }) .collect::>(); } + + if sort_options.mmr.is_some() + && sort_options + .mmr + .as_ref() + .map(|m| m.use_mmr) + .unwrap_or(false) + { + let lambda = sort_options.mmr.unwrap().mmr_lambda.unwrap_or(0.3); + let max_result = search_results.len(); + let reranked_results = apply_mmr(search_results, lambda, max_result); + + reranked_chunks = reranked_chunks + .iter_mut() + .map(|chunk| { + let search_result = reranked_results + .iter() + .find(|result| result.point_id == chunk.metadata[0].qdrant_point_id()) + .unwrap(); + chunk.score = search_result.score.into(); + chunk.clone() + }) + .collect::>(); + } + reranked_chunks.sort_by(|a, b| { b.score .partial_cmp(&a.score) @@ -1630,6 +1757,7 @@ pub fn rerank_chunks( pub fn rerank_groups( groups: Vec, + search_results: Vec, sort_options: Option, ) -> Vec { let mut reranked_groups = Vec::new(); @@ -1758,6 +1886,33 @@ pub fn rerank_groups( }) .collect::>(); } + + if sort_options.mmr.is_some() + && sort_options + .mmr + .as_ref() + .map(|m| m.use_mmr) + .unwrap_or(false) + { + let lambda = sort_options.mmr.unwrap().mmr_lambda.unwrap_or(0.3); + let max_result = search_results.len(); + let reranked_results = apply_mmr(search_results, lambda, max_result); + + reranked_groups = reranked_groups + .iter_mut() + .map(|group| { + let search_result = reranked_results + .iter() + .find(|result| { + result.point_id() == group.metadata[0].metadata[0].qdrant_point_id() + }) + .unwrap(); + let first_chunk = group.metadata.get_mut(0).unwrap(); + first_chunk.score = search_result.score().into(); + group.clone() + }) + .collect::>(); + } reranked_groups.sort_by(|a, b| { let a_first_chunk = a.metadata.get(0).unwrap(); let b_first_chunk = b.metadata.get(0).unwrap(); @@ -1978,6 +2133,7 @@ pub async fn search_chunks_query( let search_chunk_query_results = retrieve_qdrant_points_query( vec![qdrant_query], data.page.unwrap_or(1), + data.sort_options.as_ref().and_then(|d| d.mmr.clone()), data.get_total_pages.unwrap_or(false), config, ) @@ -1986,7 +2142,7 @@ pub async fn search_chunks_query( timer.add("fetched from qdrant"); let mut result_chunks = retrieve_chunks_from_point_ids( - search_chunk_query_results, + search_chunk_query_results.clone(), Some(timer), &data, config.QDRANT_ONLY, @@ -2017,7 +2173,11 @@ pub async fn search_chunks_query( result_chunks.score_chunks }; - result_chunks.score_chunks = rerank_chunks(rerank_chunks_input, data.sort_options); + result_chunks.score_chunks = rerank_chunks( + rerank_chunks_input, + search_chunk_query_results.search_results, + data.sort_options, + ); timer.add("reranking"); @@ -2131,13 +2291,14 @@ pub async fn search_hybrid_chunks( let search_chunk_query_results = retrieve_qdrant_points_query( qdrant_queries, data.page.unwrap_or(1), + data.sort_options.as_ref().and_then(|d| d.mmr.clone()), data.get_total_pages.unwrap_or(false), config, ) .await?; let result_chunks = retrieve_chunks_from_point_ids( - search_chunk_query_results, + search_chunk_query_results.clone(), Some(timer), &data, config.QDRANT_ONLY, @@ -2161,7 +2322,11 @@ pub async fn search_hybrid_chunks( cross_encoder_results.retain(|chunk| chunk.score >= score_threshold.into()); } - rerank_chunks(cross_encoder_results, data.sort_options) + rerank_chunks( + cross_encoder_results, + search_chunk_query_results.search_results, + data.sort_options, + ) }; reranked_chunks.truncate(data.page_size.unwrap_or(10) as usize); @@ -2276,13 +2441,14 @@ pub async fn search_groups_query( let search_semantic_chunk_query_results = retrieve_qdrant_points_query( vec![qdrant_query], data.page.unwrap_or(1), + data.sort_options.as_ref().and_then(|d| d.mmr.clone()), data.get_total_pages.unwrap_or(false), config, ) .await?; let mut result_chunks = retrieve_chunks_from_point_ids( - search_semantic_chunk_query_results, + search_semantic_chunk_query_results.clone(), None, &web::Json(data.clone().into()), config.QDRANT_ONLY, @@ -2313,7 +2479,11 @@ pub async fn search_groups_query( result_chunks.score_chunks }; - result_chunks.score_chunks = rerank_chunks(rerank_chunks_input, data.sort_options); + result_chunks.score_chunks = rerank_chunks( + rerank_chunks_input, + search_semantic_chunk_query_results.search_results, + data.sort_options, + ); Ok(SearchWithinGroupResults { bookmarks: result_chunks.score_chunks, @@ -2415,6 +2585,7 @@ pub async fn search_hybrid_groups( let mut qdrant_results = retrieve_qdrant_points_query( qdrant_queries, data.page.unwrap_or(1), + data.sort_options.as_ref().and_then(|d| d.mmr.clone()), data.get_total_pages.unwrap_or(false), config, ) @@ -2428,7 +2599,7 @@ pub async fn search_hybrid_groups( .collect(); let result_chunks = retrieve_chunks_from_point_ids( - qdrant_results, + qdrant_results.clone(), None, &web::Json(data.clone().into()), config.QDRANT_ONLY, @@ -2454,7 +2625,11 @@ pub async fn search_hybrid_groups( config, ) .await?; - let score_chunks = rerank_chunks(cross_encoder_results, data.sort_options); + let score_chunks = rerank_chunks( + cross_encoder_results, + qdrant_results.search_results, + data.sort_options, + ); score_chunks .iter() @@ -2470,7 +2645,11 @@ pub async fn search_hybrid_groups( ) .await?; - rerank_chunks(cross_encoder_results, data.sort_options) + rerank_chunks( + cross_encoder_results, + qdrant_results.search_results, + data.sort_options, + ) }; if let Some(score_threshold) = data.score_threshold { @@ -2567,6 +2746,7 @@ pub async fn search_over_groups_query( let search_over_groups_qdrant_result = retrieve_group_qdrant_points_query( vec![qdrant_query], data.page.unwrap_or(1), + data.sort_options.as_ref().and_then(|d| d.mmr.clone()), data.get_total_pages.unwrap_or(false), config, ) @@ -2595,7 +2775,11 @@ pub async fn search_over_groups_query( timer.add("fetched from postgres"); - result_chunks.group_chunks = rerank_groups(result_chunks.group_chunks, data.sort_options); + result_chunks.group_chunks = rerank_groups( + result_chunks.group_chunks, + search_over_groups_qdrant_result.search_results, + data.sort_options, + ); result_chunks.corrected_query = corrected_query.map(|c| c.query); @@ -2745,6 +2929,7 @@ pub async fn hybrid_search_over_groups( let mut qdrant_results = retrieve_group_qdrant_points_query( qdrant_queries, data.page.unwrap_or(1), + data.sort_options.as_ref().and_then(|d| d.mmr.clone()), data.get_total_pages.unwrap_or(false), config, ) @@ -2807,7 +2992,11 @@ pub async fn hybrid_search_over_groups( }); } - reranked_chunks = rerank_groups(reranked_chunks, data.sort_options); + reranked_chunks = rerank_groups( + reranked_chunks, + qdrant_results.search_results, + data.sort_options, + ); let result_chunks = DeprecatedSearchOverGroupsResponseBody { group_chunks: reranked_chunks, @@ -2913,8 +3102,14 @@ pub async fn autocomplete_chunks_query( ); }; - let search_chunk_query_results = - retrieve_qdrant_points_query(qdrant_query, 1, false, config).await?; + let search_chunk_query_results = retrieve_qdrant_points_query( + qdrant_query, + 1, + data.sort_options.as_ref().and_then(|d| d.mmr.clone()), + false, + config, + ) + .await?; timer.add("fetching from qdrant"); @@ -2942,8 +3137,16 @@ pub async fn autocomplete_chunks_query( (result_chunks.score_chunks.as_slice(), empty_vec) }; - let mut reranked_chunks = rerank_chunks(before_increase.to_vec(), data.sort_options.clone()); - reranked_chunks.extend(rerank_chunks(after_increase.to_vec(), data.sort_options)); + let mut reranked_chunks = rerank_chunks( + before_increase.to_vec(), + search_chunk_query_results.search_results.clone(), + data.sort_options.clone(), + ); + reranked_chunks.extend(rerank_chunks( + after_increase.to_vec(), + search_chunk_query_results.search_results, + data.sort_options, + )); result_chunks.score_chunks = reranked_chunks;