diff --git a/api/src/main/java/com/theokanning/openai/assistants/run/CreateThreadAndRunRequest.java b/api/src/main/java/com/theokanning/openai/assistants/run/CreateThreadAndRunRequest.java index 883f111..1d8b267 100644 --- a/api/src/main/java/com/theokanning/openai/assistants/run/CreateThreadAndRunRequest.java +++ b/api/src/main/java/com/theokanning/openai/assistants/run/CreateThreadAndRunRequest.java @@ -95,6 +95,13 @@ public class CreateThreadAndRunRequest { @JsonProperty("tool_choice") ToolChoice toolChoice; + /** + * Whether to enable parallel function calling during tool use. + */ + @JsonProperty("parallel_tool_calls") + Boolean parallelToolCalls; + + /** * Specifies the format that the model must output. Compatible with GPT-4 Turbo and all GPT-3.5 Turbo models since gpt-3.5-turbo-1106. * Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java index 81318eb..31e002d 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java @@ -157,5 +157,12 @@ public class ChatCompletionRequest { @JsonProperty("tool_choice") ToolChoice toolChoice; + /** + * Whether to enable parallel function calling during tool use. + * {@see https://platform.openai.com/docs/guides/function-calling/parallel-function-calling} + */ + @JsonProperty("parallel_tool_calls") + Boolean parallelToolCalls; + } diff --git a/api/src/test/resources/assistants/CreateThreadAndRunRequest.json b/api/src/test/resources/assistants/CreateThreadAndRunRequest.json index 5b3abcb..10d40ce 100644 --- a/api/src/test/resources/assistants/CreateThreadAndRunRequest.json +++ b/api/src/test/resources/assistants/CreateThreadAndRunRequest.json @@ -36,5 +36,6 @@ } } ], - "stream": true + "stream": true, + "parallel_tool_calls": true } diff --git a/api/src/test/resources/fixtures/ChatCompletionRequest.json b/api/src/test/resources/fixtures/ChatCompletionRequest.json index a8786e7..bfb2c84 100644 --- a/api/src/test/resources/fixtures/ChatCompletionRequest.json +++ b/api/src/test/resources/fixtures/ChatCompletionRequest.json @@ -105,5 +105,6 @@ ], "stream_options": { "include_usage": true - } + }, + "parallel_tool_calls": true } diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index 63f8054..11aaaa3 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -715,4 +715,42 @@ void streamChatMultipleToolCalls() { } + @Test + public void parallelToolCallTest() { + final List messages = new ArrayList<>(); + final ChatMessage systemMessage = new SystemMessage("You are a helpful assistant."); + final ChatMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + messages.add(systemMessage); + messages.add(userMessage); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest + .builder() + .model("gpt-3.5-turbo") + .messages(messages) + .tools(Arrays.asList(new ChatTool(ToolUtil.weatherFunction()))) + .toolChoice(ToolChoice.AUTO) + .parallelToolCalls(false) + .n(1) + .maxTokens(200) + .build(); + + AssistantMessage accumulatedMessage = service.mapStreamToAccumulator(service.streamChatCompletion(chatCompletionRequest)) + .blockingLast() + .getAccumulatedMessage(); + + List toolCalls = accumulatedMessage.getToolCalls(); + assertEquals(1, toolCalls.size()); + assertEquals("get_weather", toolCalls.get(0).getFunction().getName()); + assertInstanceOf(ObjectNode.class, toolCalls.get(0).getFunction().getArguments()); + + + chatCompletionRequest.setParallelToolCalls(true); + AssistantMessage accumulatedMessage2 = service.mapStreamToAccumulator(service.streamChatCompletion(chatCompletionRequest)) + .blockingLast() + .getAccumulatedMessage(); + List toolCalls2 = accumulatedMessage2.getToolCalls(); + assertEquals(3, toolCalls2.size()); + assertEquals("get_weather", toolCalls2.get(0).getFunction().getName()); + } + }