diff --git a/api/src/main/java/com/theokanning/openai/assistants/run/CreateThreadAndRunRequest.java b/api/src/main/java/com/theokanning/openai/assistants/run/CreateThreadAndRunRequest.java index 283fe38..883f111 100644 --- a/api/src/main/java/com/theokanning/openai/assistants/run/CreateThreadAndRunRequest.java +++ b/api/src/main/java/com/theokanning/openai/assistants/run/CreateThreadAndRunRequest.java @@ -93,8 +93,6 @@ public class CreateThreadAndRunRequest { * Specifying a particular tool like {"type": "file_search"} or {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool. */ @JsonProperty("tool_choice") - @JsonSerialize(using = ToolChoice.Serializer.class) - @JsonDeserialize(using = ToolChoice.Deserializer.class) ToolChoice toolChoice; /** diff --git a/api/src/main/java/com/theokanning/openai/assistants/run/Run.java b/api/src/main/java/com/theokanning/openai/assistants/run/Run.java index dbbe270..d730b0c 100644 --- a/api/src/main/java/com/theokanning/openai/assistants/run/Run.java +++ b/api/src/main/java/com/theokanning/openai/assistants/run/Run.java @@ -122,8 +122,6 @@ public class Run { * Specifying a particular tool like {"type": "file_search"} or {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool. */ @JsonProperty("tool_choice") - @JsonSerialize(using = ToolChoice.Serializer.class) - @JsonDeserialize(using = ToolChoice.Deserializer.class) ToolChoice toolChoice; /** diff --git a/api/src/main/java/com/theokanning/openai/assistants/run/RunCreateRequest.java b/api/src/main/java/com/theokanning/openai/assistants/run/RunCreateRequest.java index 7459d54..c6b27c9 100644 --- a/api/src/main/java/com/theokanning/openai/assistants/run/RunCreateRequest.java +++ b/api/src/main/java/com/theokanning/openai/assistants/run/RunCreateRequest.java @@ -101,8 +101,6 @@ public class RunCreateRequest { * Specifying a particular tool like {"type": "file_search"} or {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool. */ @JsonProperty("tool_choice") - @JsonSerialize(using = ToolChoice.Serializer.class) - @JsonDeserialize(using = ToolChoice.Deserializer.class) ToolChoice toolChoice; /** diff --git a/api/src/main/java/com/theokanning/openai/assistants/run/ToolChoice.java b/api/src/main/java/com/theokanning/openai/assistants/run/ToolChoice.java index f6b4196..625503d 100644 --- a/api/src/main/java/com/theokanning/openai/assistants/run/ToolChoice.java +++ b/api/src/main/java/com/theokanning/openai/assistants/run/ToolChoice.java @@ -7,6 +7,8 @@ import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import lombok.Data; import java.io.IOException; @@ -17,25 +19,26 @@ * @date 2024年04月18 17:18 **/ @Data +@JsonSerialize(using = ToolChoice.Serializer.class) +@JsonDeserialize(using = ToolChoice.Deserializer.class) public class ToolChoice { public static final ToolChoice REQUIRED = new ToolChoice("required"); + public static final ToolChoice NONE = new ToolChoice("none"); + + public static final ToolChoice AUTO = new ToolChoice("auto"); + /** * The name of the function to call. */ Function function; - public static final ToolChoice NONE = new ToolChoice("none"); - - public static final ToolChoice AUTO = new ToolChoice("auto"); /** * The type of the tool. If type is function, the function name must be set * enum: none/auto/function/required */ String type; - - private ToolChoice(String type) { this.type = type; } diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java index 743df14..81318eb 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java @@ -1,8 +1,6 @@ package com.theokanning.openai.completion.chat; import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.theokanning.openai.assistants.run.ToolChoice; import lombok.AllArgsConstructor; import lombok.Builder; @@ -19,7 +17,6 @@ public class ChatCompletionRequest { - /** * ID of the model to use. */ @@ -114,9 +111,9 @@ public class ChatCompletionRequest { String user; /** + * @since 0.20.5 {@link com.theokanning.openai.completion.chat.ChatFunction} {@link com.theokanning.openai.completion.chat.ChatFunctionDynamic}will be deprecated * @deprecated Replaced by {@link #tools} * recommend to use {@link com.theokanning.openai.function.FunctionDefinition} or custom class - * @since 0.20.5 {@link com.theokanning.openai.completion.chat.ChatFunction} {@link com.theokanning.openai.completion.chat.ChatFunctionDynamic}will be deprecated */ @Deprecated List functions; @@ -126,9 +123,7 @@ public class ChatCompletionRequest { */ @JsonProperty("function_call") @Deprecated - @JsonSerialize(using = ChatCompletionRequestFunctionCall.Serializer.class) - @JsonDeserialize(using = ChatCompletionRequestFunctionCall.Deserializer.class) - Object functionCall; + ChatCompletionRequestFunctionCall functionCall; /** @@ -148,7 +143,6 @@ public class ChatCompletionRequest { Integer topLogprobs; - /** * Function definition, only used if type is "function" * recommend to use {@link com.theokanning.openai.function.FunctionDefinition} or custom class @@ -161,39 +155,7 @@ public class ChatCompletionRequest { * Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. */ @JsonProperty("tool_choice") - @JsonSerialize(using = ToolChoice.Serializer.class) - @JsonDeserialize(using = ToolChoice.Deserializer.class) ToolChoice toolChoice; - - public static ChatCompletionRequestBuilder builder() { - return new InternalBuilder(); - } - private static class InternalBuilder extends ChatCompletionRequestBuilder { - public InternalBuilder() { - super(); - } - - @Override - public ChatCompletionRequest build() { - ChatCompletionRequest request = super.build(); - request.functionCallParamCheck(); - return request; - } - } - - private void functionCallParamCheck() { - if (functionCall==null){ - return; - } - if (!(functionCall instanceof ChatCompletionRequestFunctionCall || functionCall instanceof String)) { - throw new IllegalArgumentException("functionCall must be a ChatCompletionRequestFunctionCall or a String type"); - } - } - - public void setFunctionCall(Object functionCall) { - this.functionCall = functionCall; - functionCallParamCheck(); - } } diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequestFunctionCall.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequestFunctionCall.java index b7fe32a..10e3daf 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequestFunctionCall.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequestFunctionCall.java @@ -7,6 +7,8 @@ import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -20,38 +22,59 @@ @Data @AllArgsConstructor @NoArgsConstructor +@JsonSerialize(using = ChatCompletionRequestFunctionCall.Serializer.class) +@JsonDeserialize(using = ChatCompletionRequestFunctionCall.Deserializer.class) public class ChatCompletionRequestFunctionCall { + /** + * Controls which (if any) function is called by the model. + * none means the model will not call a function and instead generates a message. + * auto means the model can pick between generating a message or calling a function. + * Specifying a particular function via {"name": "my_function"} forces the model to call that function. + */ + + public static final ChatCompletionRequestFunctionCall NONE = new ChatCompletionRequestFunctionCall("none"); + + public static final ChatCompletionRequestFunctionCall AUTO = new ChatCompletionRequestFunctionCall("auto"); + String name; - public static class Serializer extends JsonSerializer { + public static ChatCompletionRequestFunctionCall of(String name) { + return new ChatCompletionRequestFunctionCall(name); + } + + + public static class Serializer extends JsonSerializer { @Override - public void serialize(Object value, JsonGenerator gen, SerializerProvider serializers) throws IOException { - if (value instanceof String) { - gen.writeString((String) value); + public void serialize(ChatCompletionRequestFunctionCall value, JsonGenerator gen, SerializerProvider serializers) throws IOException { + String name = value.getName(); + if ("none".equals(name) || "auto".equals(name)) { + gen.writeString(name); return; } - if (value instanceof ChatCompletionRequestFunctionCall) { - ChatCompletionRequestFunctionCall functionCall = (ChatCompletionRequestFunctionCall) value; - if (functionCall.getName() == null) { - gen.writeNull(); - } else { - gen.writeStartObject(); - gen.writeObjectField("name", functionCall.getName()); - gen.writeEndObject(); - } - return; + if (name == null) { + gen.writeNull(); + } else { + gen.writeStartObject(); + gen.writeObjectField("name", name); + gen.writeEndObject(); } - // This should never happen - throw new IllegalArgumentException("Unexpected value to function call: " + value); } } - public static class Deserializer extends JsonDeserializer { + public static class Deserializer extends JsonDeserializer { @Override - public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + public ChatCompletionRequestFunctionCall deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { //如果是字符串,则读取字符串 if (p.getCurrentToken() == JsonToken.VALUE_STRING) { - return p.getText(); + String text = p.getText(); + switch (text) { + case "none": + return ChatCompletionRequestFunctionCall.NONE; + case "auto": + return ChatCompletionRequestFunctionCall.AUTO; + default: + return ChatCompletionRequestFunctionCall.of(text); + } } //如果是对象,则读取对象 if (p.getCurrentToken() == JsonToken.START_OBJECT) { diff --git a/pom.xml b/pom.xml index 7e504e5..42daa3e 100644 --- a/pom.xml +++ b/pom.xml @@ -28,9 +28,9 @@ - https://github.com/Lambdua/openai-java.git - scm:git:ssh://git@ssh.github.com:443/Lambdua/openai-java.git - https://github.com/Lambdua/openai-java + scm:git:https://github.com/Lambdua/openai4j.git + scm:git:https://github.com/Lambdua/openai4j.git + https://github.com/Lambdua/openai4j @@ -213,12 +213,12 @@ org.apache.maven.plugins maven-surefire-plugin - 2.22.2 + 2.22.2 org.junit.jupiter junit-jupiter-engine - 5.10.1 + 5.10.1 @@ -228,7 +228,7 @@ org.apache.maven.plugins maven-clean-plugin - 3.1.0 + 3.1.0 diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index ea07b3d..ada5b47 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -145,15 +145,15 @@ void createChatCompletionWithFunctions() { assertEquals("get_weather", functionCall.getName()); assertInstanceOf(ObjectNode.class, functionCall.getArguments()); - ChatMessage callResponse = functionExecutor.executeAndConvertToChatMessage(functionCall.getName(),functionCall.getArguments()); + ChatMessage callResponse = functionExecutor.executeAndConvertToChatMessage(functionCall.getName(), functionCall.getArguments()); assertNotEquals("error", callResponse.getName()); // this performs an unchecked cast - ToolUtil.WeatherResponse functionExecutionResponse = functionExecutor.execute(functionCall.getName(),functionCall.getArguments()); + ToolUtil.WeatherResponse functionExecutionResponse = functionExecutor.execute(functionCall.getName(), functionCall.getArguments()); assertInstanceOf(ToolUtil.WeatherResponse.class, functionExecutionResponse); assertEquals(25, functionExecutionResponse.temperature); - JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(functionCall.getName(),functionCall.getArguments()); + JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(functionCall.getName(), functionCall.getArguments()); assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse); assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText()); @@ -250,15 +250,15 @@ void streamChatCompletionWithFunctions() { assertEquals("get_weather", functionCall.getName()); assertInstanceOf(ObjectNode.class, functionCall.getArguments()); - ChatMessage callResponse = functionExecutor.executeAndConvertToChatMessage(functionCall.getName(),functionCall.getArguments()); + ChatMessage callResponse = functionExecutor.executeAndConvertToChatMessage(functionCall.getName(), functionCall.getArguments()); assertNotEquals("error", callResponse.getName()); // this performs an unchecked cast - ToolUtil.WeatherResponse functionExecutionResponse = functionExecutor.execute(functionCall.getName(),functionCall.getArguments()); + ToolUtil.WeatherResponse functionExecutionResponse = functionExecutor.execute(functionCall.getName(), functionCall.getArguments()); assertInstanceOf(ToolUtil.WeatherResponse.class, functionExecutionResponse); assertEquals(25, functionExecutionResponse.temperature); - JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(functionCall.getName(),functionCall.getArguments()); + JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(functionCall.getName(), functionCall.getArguments()); assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse); assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText()); @@ -367,7 +367,7 @@ void createChatCompletionWithToolFunctions() { assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse); assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText()); - ToolMessage chatMessageTool = toolExecutor.executeAndConvertToChatMessage(function.getName(),function.getArguments(), toolCall.getId()); + ToolMessage chatMessageTool = toolExecutor.executeAndConvertToChatMessage(function.getName(), function.getArguments(), toolCall.getId()); //确保不是异常的返回 assertNotEquals("error", chatMessageTool.getName()); @@ -397,7 +397,7 @@ void createChatCompletionWithMultipleToolCalls() { .name("get_weather") .description("Get the current weather in a given location") .parametersDefinitionByClass(ToolUtil.Weather.class) - .executor( w -> { + .executor(w -> { switch (w.location) { case "tokyo": return new ToolUtil.WeatherResponse(w.location, w.unit, 10, "cloudy"); @@ -412,9 +412,9 @@ void createChatCompletionWithMultipleToolCalls() { FunctionDefinition.builder().name("getCities").description("Get a list of cities by time") .parametersDefinitionByClass(ToolUtil.City.class) .executor(v -> { - assertEquals("2022-12-01", v.time); - return Arrays.asList("tokyo", "paris"); - }).build() + assertEquals("2022-12-01", v.time); + return Arrays.asList("tokyo", "paris"); + }).build() ); final FunctionExecutorManager toolExecutor = new FunctionExecutorManager(functions); @@ -452,11 +452,11 @@ void createChatCompletionWithMultipleToolCalls() { assertInstanceOf(List.class, execute); - JsonNode jsonNode = toolExecutor.executeAndConvertToJson(function.getName(),function.getArguments()); + JsonNode jsonNode = toolExecutor.executeAndConvertToJson(function.getName(), function.getArguments()); assertInstanceOf(ArrayNode.class, jsonNode); - ToolMessage toolMessage = toolExecutor.executeAndConvertToChatMessage(function.getName(),function.getArguments(), toolCall.getId()); + ToolMessage toolMessage = toolExecutor.executeAndConvertToChatMessage(function.getName(), function.getArguments(), toolCall.getId()); assertNotEquals("error", toolMessage.getName()); messages.add(choice.getMessage()); @@ -488,7 +488,7 @@ void createChatCompletionWithMultipleToolCalls() { for (ChatToolCall weatherToolCall : choice2.getMessage().getToolCalls()) { Object itemResult = toolExecutor.execute(weatherToolCall.getFunction().getName(), weatherToolCall.getFunction().getArguments()); assertInstanceOf(ToolUtil.WeatherResponse.class, itemResult); - messages.add(toolExecutor.executeAndConvertToChatMessage(weatherToolCall.getFunction().getName(),weatherToolCall.getFunction().getArguments(), weatherToolCall.getId())); + messages.add(toolExecutor.executeAndConvertToChatMessage(weatherToolCall.getFunction().getName(), weatherToolCall.getFunction().getArguments(), weatherToolCall.getId())); } ChatCompletionRequest chatCompletionRequest3 = ChatCompletionRequest @@ -546,7 +546,7 @@ void streamChatMultipleToolCalls() { .name("get_weather") .description("Get the current weather in a given location") .parametersDefinitionByClass(ToolUtil.Weather.class) - .executor( w -> { + .executor(w -> { switch (w.location) { case "tokyo": return new ToolUtil.WeatherResponse(w.location, w.unit, 10, "cloudy"); @@ -603,7 +603,7 @@ void streamChatMultipleToolCalls() { assertInstanceOf(ArrayNode.class, jsonNode); - ToolMessage toolMessage = toolExecutor.executeAndConvertToChatMessage(function.getName(),function.getArguments(), toolCall.getId()); + ToolMessage toolMessage = toolExecutor.executeAndConvertToChatMessage(function.getName(), function.getArguments(), toolCall.getId()); assertNotEquals("error", toolMessage.getName()); messages.add(accumulatedMessage); @@ -637,7 +637,7 @@ void streamChatMultipleToolCalls() { ChatFunctionCall call2 = weatherToolCall.getFunction(); Object itemResult = toolExecutor.execute(call2.getName(), call2.getArguments()); assertInstanceOf(ToolUtil.WeatherResponse.class, itemResult); - messages.add(toolExecutor.executeAndConvertToChatMessage(call2.getName(),call2.getArguments(), weatherToolCall.getId())); + messages.add(toolExecutor.executeAndConvertToChatMessage(call2.getName(), call2.getArguments(), weatherToolCall.getId())); } ChatCompletionRequest chatCompletionRequest3 = ChatCompletionRequest