From a8a6f45e8caf18ee2479ce628e1fc288949f6ed2 Mon Sep 17 00:00:00 2001 From: liangtao <547670718@qq.com> Date: Thu, 19 Sep 2024 10:30:29 +0800 Subject: [PATCH] =?UTF-8?q?bug=EF=BC=9A=20StreamOption.includeUsage=20=3D?= =?UTF-8?q?=20true=20causes=20OpenAiService.mapStreamToAccumulator=20to=20?= =?UTF-8?q?throw=20java.lang.IndexOutOfBoundsException=20#60?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service/ChatMessageAccumulator.java | 14 ++++++++++++- .../openai/service/OpenAiService.java | 20 +++++++++++-------- .../openai/service/ChatCompletionTest.java | 8 +++++--- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/service/src/main/java/com/theokanning/openai/service/ChatMessageAccumulator.java b/service/src/main/java/com/theokanning/openai/service/ChatMessageAccumulator.java index cdb465e..30210a6 100644 --- a/service/src/main/java/com/theokanning/openai/service/ChatMessageAccumulator.java +++ b/service/src/main/java/com/theokanning/openai/service/ChatMessageAccumulator.java @@ -1,5 +1,6 @@ package com.theokanning.openai.service; +import com.theokanning.openai.Usage; import com.theokanning.openai.completion.chat.AssistantMessage; import com.theokanning.openai.completion.chat.ChatFunctionCall; @@ -15,15 +16,18 @@ public class ChatMessageAccumulator { private final AssistantMessage messageChunk; private final AssistantMessage accumulatedMessage; + private final Usage usage; + /** * Constructor that initializes the message chunk and accumulated message. * * @param messageChunk The message chunk. * @param accumulatedMessage The accumulated message. */ - public ChatMessageAccumulator(AssistantMessage messageChunk, AssistantMessage accumulatedMessage) { + public ChatMessageAccumulator(AssistantMessage messageChunk, AssistantMessage accumulatedMessage,Usage usage) { this.messageChunk = messageChunk; this.accumulatedMessage = accumulatedMessage; + this.usage=usage; } /** @@ -64,6 +68,14 @@ public AssistantMessage getAccumulatedMessage() { return accumulatedMessage; } + + /** + * 只有{@link com.theokanning.openai.completion.chat.StreamOption#INCLUDE} 时,usage才不为null + */ + public Usage getUsage() { + return usage; + } + /** * Retrieves the function call from the message chunk. * This is equivalent to getMessageChunk().getFunctionCall(). diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index 8cd693a..5a3856f 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -742,15 +742,19 @@ public Flowable mapStreamToAccumulator(Flowable { - ChatCompletionChoice firstChoice = chunk.getChoices().get(0); - AssistantMessage messageChunk = firstChoice.getMessage(); - appendContent(messageChunk, accumulatedMessage); - processFunctionCall(messageChunk, functionCall, accumulatedMessage); - processToolCalls(messageChunk, accumulatedMessage); - if (firstChoice.getFinishReason() != null) { - handleFinishReason(firstChoice.getFinishReason(), functionCall, accumulatedMessage); + List choices = chunk.getChoices(); + AssistantMessage messageChunk=new AssistantMessage(); + if (choices!=null && !choices.isEmpty()){ + ChatCompletionChoice firstChoice = choices.get(0); + messageChunk = firstChoice.getMessage(); + appendContent(messageChunk, accumulatedMessage); + processFunctionCall(messageChunk, functionCall, accumulatedMessage); + processToolCalls(messageChunk, accumulatedMessage); + if (firstChoice.getFinishReason() != null) { + handleFinishReason(firstChoice.getFinishReason(), functionCall, accumulatedMessage); + } } - return new ChatMessageAccumulator(messageChunk, accumulatedMessage); + return new ChatMessageAccumulator(messageChunk, accumulatedMessage,chunk.getUsage()); }); } diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index d50efdb..ecd8707 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -316,12 +316,15 @@ void zeroArgStreamToolTest() { .n(1) .maxTokens(100) .logitBias(new HashMap<>()) + .streamOptions(StreamOption.INCLUDE) .build(); - AssistantMessage accumulatedMessage = service.mapStreamToAccumulator(service.streamChatCompletion(chatCompletionRequest)) - .blockingLast().getAccumulatedMessage(); + ChatMessageAccumulator chatMessageAccumulator = service.mapStreamToAccumulator(service.streamChatCompletion(chatCompletionRequest)) + .blockingLast(); + AssistantMessage accumulatedMessage = chatMessageAccumulator.getAccumulatedMessage(); List toolCalls = accumulatedMessage.getToolCalls(); assertNotNull(toolCalls); assertEquals(1, toolCalls.size()); + assertNotNull(chatMessageAccumulator.getUsage()); ChatToolCall chatToolCall = toolCalls.get(0); ChatFunctionCall functionCall = chatToolCall.getFunction(); assertEquals("get_today", functionCall.getName()); @@ -954,5 +957,4 @@ void toolCallingStrictTest(){ assertEquals("asc",arguments.get("order_by").asText()); } - }