这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 60 additions & 59 deletions lib/providers/chat_providers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import '../models/chat_session.dart';
import '../models/message.dart';
import '../services/llm_service.dart';
import '../services/llm_service_factory.dart';
import '../services/vector_store.dart';

// Provider for the LLM service factory
final llmServiceFactoryProvider = Provider<LlmServiceFactory>((ref) {
Expand All @@ -22,33 +23,35 @@ final selectedLlmServiceProvider = Provider<LlmService>((ref) {
});

// Provider for the list of chat sessions
final chatSessionsProvider = StateNotifierProvider<ChatSessionsNotifier, List<ChatSession>>((ref) {
return ChatSessionsNotifier();
});
final chatSessionsProvider =
StateNotifierProvider<ChatSessionsNotifier, List<ChatSession>>((ref) {
return ChatSessionsNotifier();
});

// Provider for the currently active chat session
final activeChatSessionProvider = StateNotifierProvider<ActiveChatSessionNotifier, ChatSession?>((ref) {
return ActiveChatSessionNotifier(ref);
});
final activeChatSessionProvider =
StateNotifierProvider<ActiveChatSessionNotifier, ChatSession?>((ref) {
return ActiveChatSessionNotifier(ref);
});

// Provider for the loading state
final isLoadingProvider = StateProvider<bool>((ref) => false);

// Notifier for managing the list of chat sessions
class ChatSessionsNotifier extends StateNotifier<List<ChatSession>> {
ChatSessionsNotifier() : super([]);

void addSession(ChatSession session) {
state = [...state, session];
}

void updateSession(ChatSession updatedSession) {
state = [
for (final session in state)
if (session.id == updatedSession.id) updatedSession else session,
];
}

void deleteSession(String sessionId) {
state = state.where((session) => session.id != sessionId).toList();
}
Expand All @@ -62,7 +65,7 @@ class ChatSessionsNotifier extends StateNotifier<List<ChatSession>> {
session,
];
}

ChatSession? getSessionById(String id) {
try {
return state.firstWhere((session) => session.id == id);
Expand All @@ -75,50 +78,46 @@ class ChatSessionsNotifier extends StateNotifier<List<ChatSession>> {
// Notifier for managing the active chat session
class ActiveChatSessionNotifier extends StateNotifier<ChatSession?> {
final Ref _ref;

ActiveChatSessionNotifier(this._ref) : super(null);

void setActiveSession(ChatSession session) {
state = session;
}

void createNewSession(LlmServiceType serviceType) {
final newSession = ChatSession(
serviceType: serviceType,
title: 'New Chat',
);

final newSession = ChatSession(serviceType: serviceType, title: 'New Chat');

// Add to the list of sessions
_ref.read(chatSessionsProvider.notifier).addSession(newSession);

// Set as active session
state = newSession;
}

Future<void> sendMessage(String content) async {
if (state == null) return;

// Create user message
final userMessage = Message(
role: MessageRole.user,
content: content,
);

final userMessage = Message(role: MessageRole.user, content: content);

// Add user message to the session
final updatedSession = state!.addMessage(userMessage);
state = updatedSession;
_ref.read(chatSessionsProvider.notifier).updateSession(updatedSession);

// Set loading state
_ref.read(isLoadingProvider.notifier).state = true;

try {
// Get the selected LLM service
final llmService = _ref.read(selectedLlmServiceProvider);

// Send the message to the LLM service
final assistantMessage = await llmService.sendMessage(updatedSession.messages);

final assistantMessage = await llmService.sendMessage(
updatedSession.messages,
);

// Add the assistant's response to the session
final finalSession = updatedSession.addMessage(assistantMessage);
state = finalSession;
Expand All @@ -129,7 +128,7 @@ class ActiveChatSessionNotifier extends StateNotifier<ChatSession?> {
role: MessageRole.assistant,
content: 'Error: ${e.toString()}',
);

final finalSession = updatedSession.addMessage(errorMessage);
state = finalSession;
_ref.read(chatSessionsProvider.notifier).updateSession(finalSession);
Expand All @@ -139,10 +138,7 @@ class ActiveChatSessionNotifier extends StateNotifier<ChatSession?> {
}
}

Future<void> sendMessageWithImage(
String content,
String base64Image,
) async {
Future<void> sendMessageWithImage(String content, String base64Image) async {
if (state == null) return;

final userMessage = Message(
Expand Down Expand Up @@ -182,42 +178,47 @@ class ActiveChatSessionNotifier extends StateNotifier<ChatSession?> {
}
}

Future<void> sendMessageWithPdf(
String content,
String pdfText,
) async {
Future<void> sendMessageWithPdf(String content, String pdfText) async {
if (state == null) return;

ChatSession? updatedSession;

_ref.read(isLoadingProvider.notifier).state = true;

try {
final llmService = _ref.read(selectedLlmServiceProvider);

final embedding = await llmService.embedText(pdfText);
// Index the PDF text into the shared vector store
await VectorStore.instance.upsertDocument(pdfText, llmService);

final userMessage = Message(
role: MessageRole.user,
content: content,
embedding: embedding,
// Retrieve relevant chunks based on the user's question
final relevantChunks = await VectorStore.instance.search(
content,
llmService,
);

updatedSession = state!.addMessage(userMessage);
state = updatedSession;
_ref.read(chatSessionsProvider.notifier).updateSession(updatedSession!);

final messagesWithPdf = [
...updatedSession.messages,
Message(role: MessageRole.user, content: pdfText),
final contextMessage = relevantChunks.isEmpty
? null
: Message(
role: MessageRole.user,
content: relevantChunks.join('\n\n'),
);

// Build messages with the retrieved context before the question
final messagesForLlm = [
...state!.messages,
if (contextMessage != null) contextMessage,
Message(role: MessageRole.user, content: content),
];

final assistantMessage =
await llmService.sendMessage(messagesWithPdf);
final assistantMessage = await llmService.sendMessage(messagesForLlm);

final finalSession = updatedSession.addMessage(assistantMessage);
state = finalSession;
_ref.read(chatSessionsProvider.notifier).updateSession(finalSession);
// Update session with user question and assistant response
updatedSession = state!
.addMessage(Message(role: MessageRole.user, content: content))
.addMessage(assistantMessage);

state = updatedSession;
_ref.read(chatSessionsProvider.notifier).updateSession(updatedSession);
} catch (e) {
final errorMessage = Message(
role: MessageRole.assistant,
Expand All @@ -232,4 +233,4 @@ class ActiveChatSessionNotifier extends StateNotifier<ChatSession?> {
_ref.read(isLoadingProvider.notifier).state = false;
}
}
}
}
123 changes: 123 additions & 0 deletions lib/services/vector_store.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import 'dart:convert';
import 'dart:math';
import 'dart:io';

import 'package:crypto/crypto.dart';
import 'package:path_provider/path_provider.dart';

import 'llm_service.dart';

class DocumentChunk {
final String id;
final String text;
final List<double> embedding;

DocumentChunk({
required this.id,
required this.text,
required this.embedding,
});

Map<String, dynamic> toJson() => {
'id': id,
'text': text,
'embedding': embedding,
};

factory DocumentChunk.fromJson(Map<String, dynamic> json) {
return DocumentChunk(
id: json['id'],
text: json['text'],
embedding: (json['embedding'] as List)
.map((e) => (e as num).toDouble())
.toList(),
);
}
}

class VectorStore {
VectorStore._internal();
static final VectorStore instance = VectorStore._internal();

final Map<String, DocumentChunk> _chunks = {};
bool _loaded = false;

Future<void> _load() async {
if (_loaded) return;
final dir = await getApplicationSupportDirectory();
final file = File('${dir.path}/doc_index.json');
if (await file.exists()) {
final data = jsonDecode(await file.readAsString());
for (final item in data) {
final chunk = DocumentChunk.fromJson(item);
_chunks[chunk.id] = chunk;
}
}
_loaded = true;
}

Future<void> _save() async {
final dir = await getApplicationSupportDirectory();
final file = File('${dir.path}/doc_index.json');
await file.writeAsString(
jsonEncode(_chunks.values.map((c) => c.toJson()).toList()),
);
}

Future<void> upsertDocument(
String text,
LlmService llmService, {
int chunkSize = 1000,
}) async {
await _load();
final chunks = _splitIntoChunks(text, chunkSize);
for (final chunk in chunks) {
final id = _fingerprint(chunk);
if (!_chunks.containsKey(id)) {
final embedding = await llmService.embedText(chunk);
_chunks[id] = DocumentChunk(id: id, text: chunk, embedding: embedding);
}
}
await _save();
}

Future<List<String>> search(
String query,
LlmService llmService, {
int topK = 3,
}) async {
await _load();
if (_chunks.isEmpty) return [];
final queryEmbedding = await llmService.embedText(query);
final scores = <DocumentChunk, double>{};
for (final chunk in _chunks.values) {
final score = _cosineSimilarity(queryEmbedding, chunk.embedding);
scores[chunk] = score;
}
final sorted = scores.entries.toList()
..sort((a, b) => b.value.compareTo(a.value));
return sorted.take(topK).map((e) => e.key.text).toList();
}

List<String> _splitIntoChunks(String text, int size) {
final regex = RegExp('.{1,$size}', dotAll: true);
return regex.allMatches(text).map((m) => m.group(0)!).toList();
}

String _fingerprint(String text) {
return sha256.convert(utf8.encode(text)).toString();
}

double _cosineSimilarity(List<double> a, List<double> b) {
double dot = 0;
double normA = 0;
double normB = 0;
for (var i = 0; i < a.length; i++) {
dot += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
if (normA == 0 || normB == 0) return 0.0;
return dot / (sqrt(normA) * sqrt(normB));
}
}
1 change: 1 addition & 0 deletions pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies:
image_picker: ^1.0.4
file_picker: ^10.2.0
doc_text_extractor: ^1.0.0
crypto: ^3.0.3

dev_dependencies:
flutter_test:
Expand Down