diff --git a/lib/providers/chat_providers.dart b/lib/providers/chat_providers.dart index 6e024b5..a9bdaed 100644 --- a/lib/providers/chat_providers.dart +++ b/lib/providers/chat_providers.dart @@ -3,6 +3,7 @@ import '../models/chat_session.dart'; import '../models/message.dart'; import '../services/llm_service.dart'; import '../services/llm_service_factory.dart'; +import '../services/vector_store.dart'; // Provider for the LLM service factory final llmServiceFactoryProvider = Provider((ref) { @@ -22,14 +23,16 @@ final selectedLlmServiceProvider = Provider((ref) { }); // Provider for the list of chat sessions -final chatSessionsProvider = StateNotifierProvider>((ref) { - return ChatSessionsNotifier(); -}); +final chatSessionsProvider = + StateNotifierProvider>((ref) { + return ChatSessionsNotifier(); + }); // Provider for the currently active chat session -final activeChatSessionProvider = StateNotifierProvider((ref) { - return ActiveChatSessionNotifier(ref); -}); +final activeChatSessionProvider = + StateNotifierProvider((ref) { + return ActiveChatSessionNotifier(ref); + }); // Provider for the loading state final isLoadingProvider = StateProvider((ref) => false); @@ -37,18 +40,18 @@ final isLoadingProvider = StateProvider((ref) => false); // Notifier for managing the list of chat sessions class ChatSessionsNotifier extends StateNotifier> { ChatSessionsNotifier() : super([]); - + void addSession(ChatSession session) { state = [...state, session]; } - + void updateSession(ChatSession updatedSession) { state = [ for (final session in state) if (session.id == updatedSession.id) updatedSession else session, ]; } - + void deleteSession(String sessionId) { state = state.where((session) => session.id != sessionId).toList(); } @@ -62,7 +65,7 @@ class ChatSessionsNotifier extends StateNotifier> { session, ]; } - + ChatSession? getSessionById(String id) { try { return state.firstWhere((session) => session.id == id); @@ -75,50 +78,46 @@ class ChatSessionsNotifier extends StateNotifier> { // Notifier for managing the active chat session class ActiveChatSessionNotifier extends StateNotifier { final Ref _ref; - + ActiveChatSessionNotifier(this._ref) : super(null); - + void setActiveSession(ChatSession session) { state = session; } - + void createNewSession(LlmServiceType serviceType) { - final newSession = ChatSession( - serviceType: serviceType, - title: 'New Chat', - ); - + final newSession = ChatSession(serviceType: serviceType, title: 'New Chat'); + // Add to the list of sessions _ref.read(chatSessionsProvider.notifier).addSession(newSession); - + // Set as active session state = newSession; } - + Future sendMessage(String content) async { if (state == null) return; - + // Create user message - final userMessage = Message( - role: MessageRole.user, - content: content, - ); - + final userMessage = Message(role: MessageRole.user, content: content); + // Add user message to the session final updatedSession = state!.addMessage(userMessage); state = updatedSession; _ref.read(chatSessionsProvider.notifier).updateSession(updatedSession); - + // Set loading state _ref.read(isLoadingProvider.notifier).state = true; - + try { // Get the selected LLM service final llmService = _ref.read(selectedLlmServiceProvider); - + // Send the message to the LLM service - final assistantMessage = await llmService.sendMessage(updatedSession.messages); - + final assistantMessage = await llmService.sendMessage( + updatedSession.messages, + ); + // Add the assistant's response to the session final finalSession = updatedSession.addMessage(assistantMessage); state = finalSession; @@ -129,7 +128,7 @@ class ActiveChatSessionNotifier extends StateNotifier { role: MessageRole.assistant, content: 'Error: ${e.toString()}', ); - + final finalSession = updatedSession.addMessage(errorMessage); state = finalSession; _ref.read(chatSessionsProvider.notifier).updateSession(finalSession); @@ -139,10 +138,7 @@ class ActiveChatSessionNotifier extends StateNotifier { } } - Future sendMessageWithImage( - String content, - String base64Image, - ) async { + Future sendMessageWithImage(String content, String base64Image) async { if (state == null) return; final userMessage = Message( @@ -182,12 +178,8 @@ class ActiveChatSessionNotifier extends StateNotifier { } } - Future sendMessageWithPdf( - String content, - String pdfText, - ) async { + Future sendMessageWithPdf(String content, String pdfText) async { if (state == null) return; - ChatSession? updatedSession; _ref.read(isLoadingProvider.notifier).state = true; @@ -195,29 +187,38 @@ class ActiveChatSessionNotifier extends StateNotifier { try { final llmService = _ref.read(selectedLlmServiceProvider); - final embedding = await llmService.embedText(pdfText); + // Index the PDF text into the shared vector store + await VectorStore.instance.upsertDocument(pdfText, llmService); - final userMessage = Message( - role: MessageRole.user, - content: content, - embedding: embedding, + // Retrieve relevant chunks based on the user's question + final relevantChunks = await VectorStore.instance.search( + content, + llmService, ); - updatedSession = state!.addMessage(userMessage); - state = updatedSession; - _ref.read(chatSessionsProvider.notifier).updateSession(updatedSession!); - - final messagesWithPdf = [ - ...updatedSession.messages, - Message(role: MessageRole.user, content: pdfText), + final contextMessage = relevantChunks.isEmpty + ? null + : Message( + role: MessageRole.user, + content: relevantChunks.join('\n\n'), + ); + + // Build messages with the retrieved context before the question + final messagesForLlm = [ + ...state!.messages, + if (contextMessage != null) contextMessage, + Message(role: MessageRole.user, content: content), ]; - final assistantMessage = - await llmService.sendMessage(messagesWithPdf); + final assistantMessage = await llmService.sendMessage(messagesForLlm); - final finalSession = updatedSession.addMessage(assistantMessage); - state = finalSession; - _ref.read(chatSessionsProvider.notifier).updateSession(finalSession); + // Update session with user question and assistant response + updatedSession = state! + .addMessage(Message(role: MessageRole.user, content: content)) + .addMessage(assistantMessage); + + state = updatedSession; + _ref.read(chatSessionsProvider.notifier).updateSession(updatedSession); } catch (e) { final errorMessage = Message( role: MessageRole.assistant, @@ -232,4 +233,4 @@ class ActiveChatSessionNotifier extends StateNotifier { _ref.read(isLoadingProvider.notifier).state = false; } } -} \ No newline at end of file +} diff --git a/lib/services/vector_store.dart b/lib/services/vector_store.dart new file mode 100644 index 0000000..4d4ac8a --- /dev/null +++ b/lib/services/vector_store.dart @@ -0,0 +1,123 @@ +import 'dart:convert'; +import 'dart:math'; +import 'dart:io'; + +import 'package:crypto/crypto.dart'; +import 'package:path_provider/path_provider.dart'; + +import 'llm_service.dart'; + +class DocumentChunk { + final String id; + final String text; + final List embedding; + + DocumentChunk({ + required this.id, + required this.text, + required this.embedding, + }); + + Map toJson() => { + 'id': id, + 'text': text, + 'embedding': embedding, + }; + + factory DocumentChunk.fromJson(Map json) { + return DocumentChunk( + id: json['id'], + text: json['text'], + embedding: (json['embedding'] as List) + .map((e) => (e as num).toDouble()) + .toList(), + ); + } +} + +class VectorStore { + VectorStore._internal(); + static final VectorStore instance = VectorStore._internal(); + + final Map _chunks = {}; + bool _loaded = false; + + Future _load() async { + if (_loaded) return; + final dir = await getApplicationSupportDirectory(); + final file = File('${dir.path}/doc_index.json'); + if (await file.exists()) { + final data = jsonDecode(await file.readAsString()); + for (final item in data) { + final chunk = DocumentChunk.fromJson(item); + _chunks[chunk.id] = chunk; + } + } + _loaded = true; + } + + Future _save() async { + final dir = await getApplicationSupportDirectory(); + final file = File('${dir.path}/doc_index.json'); + await file.writeAsString( + jsonEncode(_chunks.values.map((c) => c.toJson()).toList()), + ); + } + + Future upsertDocument( + String text, + LlmService llmService, { + int chunkSize = 1000, + }) async { + await _load(); + final chunks = _splitIntoChunks(text, chunkSize); + for (final chunk in chunks) { + final id = _fingerprint(chunk); + if (!_chunks.containsKey(id)) { + final embedding = await llmService.embedText(chunk); + _chunks[id] = DocumentChunk(id: id, text: chunk, embedding: embedding); + } + } + await _save(); + } + + Future> search( + String query, + LlmService llmService, { + int topK = 3, + }) async { + await _load(); + if (_chunks.isEmpty) return []; + final queryEmbedding = await llmService.embedText(query); + final scores = {}; + for (final chunk in _chunks.values) { + final score = _cosineSimilarity(queryEmbedding, chunk.embedding); + scores[chunk] = score; + } + final sorted = scores.entries.toList() + ..sort((a, b) => b.value.compareTo(a.value)); + return sorted.take(topK).map((e) => e.key.text).toList(); + } + + List _splitIntoChunks(String text, int size) { + final regex = RegExp('.{1,$size}', dotAll: true); + return regex.allMatches(text).map((m) => m.group(0)!).toList(); + } + + String _fingerprint(String text) { + return sha256.convert(utf8.encode(text)).toString(); + } + + double _cosineSimilarity(List a, List b) { + double dot = 0; + double normA = 0; + double normB = 0; + for (var i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + if (normA == 0 || normB == 0) return 0.0; + return dot / (sqrt(normA) * sqrt(normB)); + } +} diff --git a/pubspec.yaml b/pubspec.yaml index 55535e9..e75106e 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -54,6 +54,7 @@ dependencies: image_picker: ^1.0.4 file_picker: ^10.2.0 doc_text_extractor: ^1.0.0 + crypto: ^3.0.3 dev_dependencies: flutter_test: