+
= {};
+const pending: Record | undefined> = {};
+
+export const cached = async (
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ func: (...args: any) => Promise,
+ key: string,
+): Promise => {
+ // If result is already cached, return it
+ if (cache[key]) {
+ return cache[key] as T;
+ }
+
+ // If there's already a pending request for this key, return its promise
+ if (pending[key]) {
+ return pending[key] as Promise;
+ }
+
+ // Create new promise for this request and store it
+ const promise = (async () => {
+ try {
+ const value = await func();
+ cache[key] = value;
+ return value;
+ } finally {
+ // Clean up pending promise after completion (success or failure)
+ delete pending[key];
+ }
+ })();
+
+ pending[key] = promise;
+ return promise;
+};
diff --git a/clients/search-component/src/utils/estimation.ts b/clients/search-component/src/utils/estimation.ts
new file mode 100644
index 0000000000..d1c2b9874e
--- /dev/null
+++ b/clients/search-component/src/utils/estimation.ts
@@ -0,0 +1,115 @@
+import { ChunkWithHighlights } from "./types";
+
+export const guessTitleAndDesc = (
+ item: ChunkWithHighlights,
+): {
+ title: string;
+ descriptionHtml: string;
+} => {
+ let descriptionHtml = item.highlights
+ ? item.highlights.join("...")
+ : item.chunk.chunk_html || "";
+ const $descriptionHtml = document.createElement("div");
+ $descriptionHtml.innerHTML = descriptionHtml;
+ $descriptionHtml.querySelectorAll("b").forEach((b) => {
+ return b.replaceWith(b.textContent || "");
+ });
+ descriptionHtml = $descriptionHtml.innerHTML;
+
+ const chunkHtmlHeadingsDiv = document.createElement("div");
+ chunkHtmlHeadingsDiv.innerHTML = item.chunk.chunk_html || "";
+ const chunkHtmlHeadings = chunkHtmlHeadingsDiv.querySelectorAll(
+ "h1, h2, h3, h4, h5, h6",
+ );
+ const $firstHeading = chunkHtmlHeadings[0] ?? document.createElement("h1");
+ const cleanFirstHeading = $firstHeading?.innerHTML;
+ const title = `${
+ cleanFirstHeading ||
+ item.chunk.metadata?.title ||
+ item.chunk.metadata?.page_title ||
+ item.chunk.metadata?.name
+ }`;
+
+ descriptionHtml = descriptionHtml
+ .replace(" ", " ")
+ .replace(cleanFirstHeading || "", "");
+
+ return {
+ title,
+ descriptionHtml,
+ };
+};
+
+export const findCommonName = (names: string[]) => {
+ // Return null if array is empty
+ if (!names || names.length === 0) return null;
+
+ // Get the first string as reference
+ const firstString = names[0];
+
+ let commonPrefix = "";
+
+ // Iterate through each character of the first string
+ for (let i = 0; i < firstString.length; i++) {
+ const currentChar = firstString[i];
+
+ // Check if this character exists in the same position for all names
+ // Compare case-insensitively but keep original case
+ const allMatch = names.every(
+ (str) => str[i]?.toLowerCase() === currentChar.toLowerCase(),
+ );
+
+ if (allMatch) {
+ commonPrefix += firstString[i]; // Use original case from first string
+ } else {
+ break;
+ }
+ }
+
+ // Strip non-alphabetic characters from the end
+ commonPrefix = commonPrefix.replace(/[^a-zA-Z]+$/, "");
+
+ if (commonPrefix.endsWith(" /X")) {
+ commonPrefix = commonPrefix.slice(0, -3);
+ }
+
+ // Strip html
+ commonPrefix = commonPrefix.replace(/<[^>]*>/g, "");
+
+ // Return null if no common prefix was found
+ return commonPrefix.length > 0 ? commonPrefix : null;
+};
+
+interface HasTitle {
+ title: string;
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ [key: string]: any;
+}
+
+export function uniquifyVariants(array: T[]): T[] {
+ // Find the common prefix from titles
+ const findCommonPrefix = (strings: string[]): string => {
+ if (strings.length === 0) return "";
+ let prefix = strings[0];
+ for (const str of strings) {
+ while (str.indexOf(prefix) !== 0) {
+ prefix = prefix.slice(0, -1);
+ }
+ }
+ return prefix;
+ };
+
+ if (!array || array.length === 0) {
+ return [];
+ }
+
+ // Get array of titles
+ const titles = array.map((item) => item.title);
+ const commonPrefix = findCommonPrefix(titles);
+
+ // Return new array with transformed titles
+ return array.map((item) => ({
+ ...item,
+ title: item.title.replace(commonPrefix, "").trim(),
+ }));
+}
diff --git a/clients/search-component/src/utils/hooks/chat-context.tsx b/clients/search-component/src/utils/hooks/chat-context.tsx
index 8f520a7f54..56c325b543 100644
--- a/clients/search-component/src/utils/hooks/chat-context.tsx
+++ b/clients/search-component/src/utils/hooks/chat-context.tsx
@@ -3,6 +3,9 @@ import { useModalState } from "./modal-context";
import { Chunk } from "../types";
import { getFingerprint } from "@thumbmarkjs/thumbmarkjs";
import { useEffect } from "react";
+import { cached } from "../cache";
+import { getAllChunksForGroup } from "../trieve";
+import { ChatMessageProxy, ChunkGroup, RoleProxy } from "trieve-ts-sdk";
type Messages = {
queryId: string | null;
@@ -11,7 +14,24 @@ type Messages = {
additional: Chunk[] | null;
}[][];
-const ModalContext = createContext<{
+const mapMessageType = (message: Messages[0][0]): ChatMessageProxy => {
+ return {
+ content: message.text,
+ role: message.type as RoleProxy,
+ } satisfies ChatMessageProxy;
+};
+
+function removeBrackets(str: string) {
+ let result = str.replace(/\[.*?\]/g, "");
+
+ // Handle unclosed brackets: remove from [ to end
+ result = result.replace(/\[.*$/, "");
+
+ // Replace multiple spaces with single space and trim, but preserve period at end
+ return result.replace(/\s+/g, " ").trim().replace(/\s+\./g, ".");
+}
+
+const ChatContext = createContext<{
askQuestion: (question?: string) => Promise;
isLoading: boolean;
messages: Messages;
@@ -20,6 +40,8 @@ const ModalContext = createContext<{
stopGeneratingMessage: () => void;
clearConversation: () => void;
switchToChatAndAskQuestion: (query: string) => Promise;
+ cancelGroupChat: () => void;
+ chatWithGroup: (group: ChunkGroup, betterGroupName?: string) => void;
isDoneReading?: React.MutableRefObject;
rateChatCompletion: (isPositive: boolean, queryId: string | null) => void;
}>({
@@ -28,14 +50,17 @@ const ModalContext = createContext<{
isLoading: false,
messages: [],
setCurrentQuestion: () => {},
+ cancelGroupChat: () => {},
clearConversation: () => {},
+ chatWithGroup: () => {},
switchToChatAndAskQuestion: async () => {},
stopGeneratingMessage: () => {},
rateChatCompletion: () => {},
});
function ChatProvider({ children }: { children: React.ReactNode }) {
- const { query, trieveSDK, modalRef, setMode } = useModalState();
+ const { query, trieveSDK, modalRef, setMode, setCurrentGroup } =
+ useModalState();
const [currentQuestion, setCurrentQuestion] = useState(query);
const [currentTopic, setCurrentTopic] = useState("");
const called = useRef(false);
@@ -65,7 +90,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
setMessages([]);
};
- const { currentTag } = useModalState();
+ const { currentTag, currentGroup } = useModalState();
useEffect(() => {
if (currentTag) {
@@ -91,7 +116,25 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
const decoder = new TextDecoder();
const newText = decoder.decode(value);
textInStream += newText;
- const [jsonData, text] = textInStream.split("||");
+
+ let text: string = "";
+ let jsonData: string = "";
+
+ if (textInStream.includes("||")) {
+ // The RAG over chunks endpoint returns references last
+ if (currentGroup) {
+ [text, jsonData] = textInStream.split("||");
+ } else {
+ [jsonData, text] = textInStream.split("||");
+ }
+ } else {
+ text = textInStream;
+ }
+
+ if (currentGroup) {
+ text = removeBrackets(text);
+ }
+
let json;
try {
json = JSON.parse(jsonData);
@@ -110,6 +153,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
},
],
]);
+
setTimeout(() => {
modalRef.current?.scroll({
top: modalRef.current.scrollHeight + 200,
@@ -128,24 +172,65 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
question?: string;
}) => {
setIsLoading(true);
- const { reader, queryId } = await trieveSDK.createMessageReaderWithQueryId(
- {
- topic_id: id || currentTopic,
- new_message_content: question || currentQuestion,
- llm_options: {
- completion_first: false,
+
+ // Use group search
+ if (currentGroup) {
+ // Should already be preloaded when group selected to chat with
+ const groupChunks = await cached(() => {
+ return getAllChunksForGroup(currentGroup.id, trieveSDK);
+ }, `chunk-ids-${currentGroup.id}`);
+
+ const { reader, queryId } = await trieveSDK.ragOnChunkReaderWithQueryId(
+ {
+ chunk_ids: groupChunks.map((c) => c.id),
+ prev_messages: [
+ ...messages.slice(0, -1).map((m) => mapMessageType(m[0])),
+ {
+ content: question || currentQuestion,
+ role: "user",
+ },
+ ],
+ stream_response: true,
},
- page_size: 5,
- filters:
- currentTag !== "all"
- ? {
- must: [{ field: "tag_set", match_any: [currentTag] }], // Apply tag filter
- }
- : null,
- },
- chatMessageAbortController.current.signal,
- );
- handleReader(reader, queryId);
+ chatMessageAbortController.current.signal,
+ );
+ handleReader(reader, queryId);
+ } else {
+ const { reader, queryId } =
+ await trieveSDK.createMessageReaderWithQueryId(
+ {
+ topic_id: id || currentTopic,
+ new_message_content: question || currentQuestion,
+ llm_options: {
+ completion_first: false,
+ },
+ page_size: 5,
+ filters:
+ currentTag !== "all"
+ ? {
+ must: [{ field: "tag_set", match_any: [currentTag] }], // Apply tag filter
+ }
+ : null,
+ },
+ chatMessageAbortController.current.signal,
+ );
+ handleReader(reader, queryId);
+ }
+ };
+
+ const chatWithGroup = async (group: ChunkGroup, betterGroupName?: string) => {
+ if (betterGroupName) {
+ group.name = betterGroupName;
+ }
+ clearConversation();
+ setCurrentGroup(group);
+ setMode("chat");
+ // preload the chunk ids
+ cached(() => {
+ return getAllChunksForGroup(group.id, trieveSDK);
+ }, `chunk-ids-${group.id}`).catch((e) => {
+ console.error(e);
+ });
};
const stopGeneratingMessage = () => {
@@ -164,6 +249,11 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
}
};
+ const cancelGroupChat = () => {
+ setCurrentGroup(null);
+ clearConversation();
+ };
+
const askQuestion = async (question?: string) => {
isDoneReading.current = false;
setMessages((m) => [
@@ -213,12 +303,14 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
};
return (
-
{children}
-
+
);
}
function useChatState() {
- const context = useContext(ModalContext);
+ const context = useContext(ChatContext);
if (!context) {
throw new Error("useChatState must be used within a ChatProvider");
}
diff --git a/clients/search-component/src/utils/hooks/modal-context.tsx b/clients/search-component/src/utils/hooks/modal-context.tsx
index 8c96807e50..ded7651385 100644
--- a/clients/search-component/src/utils/hooks/modal-context.tsx
+++ b/clients/search-component/src/utils/hooks/modal-context.tsx
@@ -7,6 +7,7 @@ import React, {
} from "react";
import { Chunk, ChunkWithHighlights, GroupChunk } from "../types";
import {
+ ChunkGroup,
CountChunkQueryResponseBody,
SearchChunksReqPayload,
TrieveSDK,
@@ -31,6 +32,7 @@ export type currencyPosition = "before" | "after";
export type ModalTypes = "ecommerce" | "docs";
export type SearchModes = "chat" | "search";
export type searchOptions = simpleSearchReqPayload & customAutoCompleteAddOn;
+
export type ModalProps = {
datasetId: string;
apiKey: string;
@@ -118,6 +120,8 @@ const ModalContext = createContext<{
setContextProps: (props: ModalProps) => void;
currentTag: string;
setCurrentTag: React.Dispatch>;
+ currentGroup: ChunkGroup | null;
+ setCurrentGroup: React.Dispatch>;
tagCounts: CountChunkQueryResponseBody[];
}>({
props: defaultProps,
@@ -138,6 +142,8 @@ const ModalContext = createContext<{
setLoadingResults: () => {},
setCurrentTag: () => {},
currentTag: "all",
+ currentGroup: null,
+ setCurrentGroup: () => {},
tagCounts: [],
setContextProps: () => {},
});
@@ -165,9 +171,11 @@ const ModalProvider = ({
const modalRef = useRef(null);
const [tagCounts, setTagCounts] = useState([]);
const [currentTag, setCurrentTag] = useState(
- props.tags?.find((t) => t.selected)?.tag || "all"
+ props.tags?.find((t) => t.selected)?.tag || "all",
);
+ const [currentGroup, setCurrentGroup] = useState(null);
+
const trieve = new TrieveSDK({
baseUrl: props.baseUrl,
apiKey: props.apiKey,
@@ -242,8 +250,8 @@ const ModalProvider = ({
trieve: trieve,
abortController,
...(tag.tag !== "all" && { tag: tag.tag }),
- })
- )
+ }),
+ ),
);
setTagCounts(numberOfRecords);
} catch (e) {
@@ -340,6 +348,8 @@ const ModalProvider = ({
modalRef,
currentTag,
setCurrentTag,
+ currentGroup,
+ setCurrentGroup,
tagCounts,
}}
>
diff --git a/clients/search-component/src/utils/trieve.ts b/clients/search-component/src/utils/trieve.ts
index 901788656a..c6081b15a4 100644
--- a/clients/search-component/src/utils/trieve.ts
+++ b/clients/search-component/src/utils/trieve.ts
@@ -1,4 +1,9 @@
-import { SearchResponseBody, TrieveSDK } from "trieve-ts-sdk";
+import {
+ ChunkMetadata,
+ ChunkMetadataStringTagSet,
+ SearchResponseBody,
+ TrieveSDK,
+} from "trieve-ts-sdk";
import { Chunk, GroupSearchResults, Props, SearchResults } from "./types";
import { defaultHighlightOptions, highlightText } from "./highlight";
import { ModalTypes } from "./hooks/modal-context";
@@ -7,7 +12,7 @@ export const omit = (obj: object | null | undefined, keys: string[]) => {
if (!obj) return obj;
return Object.fromEntries(
- Object.entries(obj).filter(([key]) => !keys.includes(key))
+ Object.entries(obj).filter(([key]) => !keys.includes(key)),
);
};
@@ -56,7 +61,7 @@ export const searchWithTrieve = async ({
search_type: searchOptions.search_type ?? "fulltext",
...omit(searchOptions, ["use_autocomplete"]),
},
- abortController?.signal
+ abortController?.signal,
)) as SearchResponseBody;
} else {
results = (await trieve.search(
@@ -84,7 +89,7 @@ export const searchWithTrieve = async ({
search_type: searchOptions.search_type ?? "fulltext",
...omit(searchOptions, ["use_autocomplete"]),
},
- abortController?.signal
+ abortController?.signal,
)) as SearchResponseBody;
}
@@ -145,7 +150,7 @@ export const groupSearchWithTrieve = async ({
search_type: searchOptions.search_type ?? "fulltext",
...omit(searchOptions, ["use_autocomplete"]),
},
- abortController?.signal
+ abortController?.signal,
);
const resultsWithHighlight = results.results.map((group) => {
@@ -198,7 +203,7 @@ export const countChunks = async ({
search_type: "fulltext",
...omit(searchOptions, ["search_type"]),
},
- abortController?.signal
+ abortController?.signal,
);
return results;
};
@@ -240,7 +245,7 @@ export const getSuggestedQueries = async ({
search_type: "semantic",
context: "You are a user searching through a docs website",
},
- abortController?.signal
+ abortController?.signal,
);
};
@@ -257,10 +262,41 @@ export const getSuggestedQuestions = async ({
search_type: "semantic",
context: "You are a user searching through a docs website",
},
- abortController?.signal
+ abortController?.signal,
);
};
export const sendFeedback = async ({ trieve }: { trieve: TrieveSDK }) => {
return trieve;
};
+
+export type SimpleChunk = ChunkMetadata | ChunkMetadataStringTagSet;
+
+export const getAllChunksForGroup = async (
+ groupId: string,
+ trieve: TrieveSDK,
+): Promise => {
+ let moreToFind = true;
+ let page = 1;
+ const chunks = [];
+ while (moreToFind) {
+ const results = await trieve.trieve.fetch(
+ "/api/chunk_group/{group_id}/{page}",
+ "get",
+ {
+ datasetId: trieve.datasetId,
+ groupId,
+ page,
+ },
+ );
+ if (results.chunks.length === 0) {
+ moreToFind = false;
+ break;
+ }
+ for (const chunk of results.chunks) {
+ chunks.push(chunk);
+ }
+ page += 1;
+ }
+ return chunks;
+};
diff --git a/clients/search-component/src/utils/types.ts b/clients/search-component/src/utils/types.ts
index 2bf325b769..0d61f02af7 100644
--- a/clients/search-component/src/utils/types.ts
+++ b/clients/search-component/src/utils/types.ts
@@ -31,13 +31,13 @@ export type GroupSearchResults = {
};
export function isChunksWithHighlights(
- result: ChunkWithHighlights | GroupChunk[]
+ result: ChunkWithHighlights | GroupChunk[],
): result is ChunkWithHighlights {
- return (result as ChunkWithHighlights).highlights !== undefined;
+ return !Array.isArray(result);
}
export function isGroupChunk(
- result: ChunkWithHighlights | GroupChunk
+ result: ChunkWithHighlights | GroupChunk,
): result is GroupChunk {
return (result as GroupChunk).group !== undefined;
}