diff --git a/.gitignore b/.gitignore index 848c1dd..a261919 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,12 @@ hs_err_pid* # IntelliJ Files # .idea/ *.iml + +# Eclipse Files # +.project +.settings +.classpath + # Ignore Gradle project-specific cache directory .gradle diff --git a/api/.gitignore b/api/.gitignore new file mode 100644 index 0000000..b83d222 --- /dev/null +++ b/api/.gitignore @@ -0,0 +1 @@ +/target/ diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/AssistantMessage.java b/api/src/main/java/com/theokanning/openai/completion/chat/AssistantMessage.java index 1685e54..139e9eb 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/AssistantMessage.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/AssistantMessage.java @@ -1,13 +1,15 @@ package com.theokanning.openai.completion.chat; +import java.util.List; + import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; +import com.theokanning.openai.utils.JsonUtil; + import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; -import java.util.List; - /** * @author LiangTao * @date 2024年04月10 10:31 @@ -49,4 +51,18 @@ public AssistantMessage(String content, String name) { public String getTextContent() { return content; } + + /** + * Deserializes the message to an object of the specified target class. + * + * @param targetClass the type of the object + * @return the deserialized object + **/ + public T parsed(Class targetClass) { + try { + return JsonUtil.getInstance().readValue(getTextContent(), targetClass); + } catch (Exception e) { + throw new RuntimeException(e); + } + } } diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java index 31e002d..3284f7c 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java @@ -1,15 +1,18 @@ package com.theokanning.openai.completion.chat; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.theokanning.openai.assistants.run.ToolChoice; + import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; -import java.util.List; -import java.util.Map; - @Data @Builder @AllArgsConstructor @@ -29,10 +32,11 @@ public class ChatCompletionRequest { List messages; /** - * Must be either 'text' or 'json_object'.
+ * Must be either 'text', 'json_object' or 'json_schema'.
* When specifying 'json_object' as the request format it's still necessary to instruct the model to return JSON. */ @JsonProperty("response_format") + @JsonSerialize(using = ChatResponseFormat.ChatResponseFormatSerializer.class) ChatResponseFormat responseFormat; /** diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatResponseFormat.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatResponseFormat.java index 8922802..b7d2f48 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatResponseFormat.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatResponseFormat.java @@ -1,29 +1,48 @@ package com.theokanning.openai.completion.chat; +import java.io.IOException; + import com.fasterxml.jackson.core.JacksonException; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.exc.InvalidFormatException; +import com.kjetland.jackson.jsonSchema.JsonSchemaConfig; +import com.kjetland.jackson.jsonSchema.JsonSchemaGenerator; +import com.theokanning.openai.utils.JsonUtil; + import lombok.Data; import lombok.NoArgsConstructor; -import java.io.IOException; - /** * see {@link ChatCompletionRequest} documentation. */ @Data @NoArgsConstructor public class ChatResponseFormat { + private static final ObjectMapper MAPPER = JsonUtil.getInstance(); + private static final JsonSchemaConfig CONFIG = JsonSchemaConfig.vanillaJsonSchemaDraft4(); + private static final JsonSchemaGenerator JSON_SCHEMA_GENERATOR = new JsonSchemaGenerator(MAPPER, CONFIG); + /** * auto/text/json_object */ private String type; + + /** + * This is used together with type field set to "json_schema" + * to enable structured outputs. + * + * @see https://openai.com/index/introducing-structured-outputs-in-the-api/ + * + */ + private JsonNode json_schema; /** * 构造私有,只允许从静态变量获取 @@ -37,7 +56,13 @@ private ChatResponseFormat(String type) { public static final ChatResponseFormat TEXT = new ChatResponseFormat("text"); public static final ChatResponseFormat JSON_OBJECT = new ChatResponseFormat("json_object"); - + + public static ChatResponseFormat jsonSchema(Class rootClass) { + JsonNode jsonSchema = JSON_SCHEMA_GENERATOR.generateJsonSchema(rootClass); + ChatResponseFormat jsonSchemaFormat = new ChatResponseFormat("json_schema"); + jsonSchemaFormat.setJson_schema(jsonSchema); + return jsonSchemaFormat; + } @NoArgsConstructor public static class ChatResponseFormatSerializer extends JsonSerializer { @@ -48,6 +73,18 @@ public void serialize(ChatResponseFormat value, JsonGenerator gen, SerializerPro } else { gen.writeStartObject(); gen.writeObjectField("type", (value).getType()); + + if (value.getType().equals("json_schema")) { + JsonNode jsonSchema = value.getJson_schema(); + + gen.writeObjectFieldStart("json_schema"); + gen.writeStringField("name", "ChatResponseFormat"); + gen.writeBooleanField("strict", true); + gen.writeFieldName("schema"); + gen.writeTree(jsonSchema); + gen.writeEndObject(); + } + gen.writeEndObject(); } } diff --git a/client/.gitignore b/client/.gitignore new file mode 100644 index 0000000..b83d222 --- /dev/null +++ b/client/.gitignore @@ -0,0 +1 @@ +/target/ diff --git a/service/.gitignore b/service/.gitignore new file mode 100644 index 0000000..b83d222 --- /dev/null +++ b/service/.gitignore @@ -0,0 +1 @@ +/target/ diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index 11aaaa3..62bf988 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -1,22 +1,52 @@ package com.theokanning.openai.service; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Duration; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +import javax.validation.constraints.NotNull; + +import org.junit.jupiter.api.Test; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.theokanning.openai.assistants.run.ToolChoice; -import com.theokanning.openai.completion.chat.*; +import com.theokanning.openai.completion.chat.AssistantMessage; +import com.theokanning.openai.completion.chat.ChatCompletionChoice; +import com.theokanning.openai.completion.chat.ChatCompletionChunk; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatFunctionCall; +import com.theokanning.openai.completion.chat.ChatFunctionDynamic; +import com.theokanning.openai.completion.chat.ChatFunctionProperty; +import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.completion.chat.ChatResponseFormat; +import com.theokanning.openai.completion.chat.ChatTool; +import com.theokanning.openai.completion.chat.ChatToolCall; +import com.theokanning.openai.completion.chat.StreamOption; +import com.theokanning.openai.completion.chat.SystemMessage; +import com.theokanning.openai.completion.chat.ToolMessage; +import com.theokanning.openai.completion.chat.UserMessage; import com.theokanning.openai.function.FunctionDefinition; import com.theokanning.openai.function.FunctionExecutorManager; import com.theokanning.openai.service.util.ToolUtil; -import org.junit.jupiter.api.Test; - -import java.time.Duration; -import java.time.LocalDate; -import java.util.*; -import static org.junit.jupiter.api.Assertions.*; +import lombok.Data; +import lombok.NoArgsConstructor; class ChatCompletionTest { @@ -106,7 +136,7 @@ void createChatCompletionWithJsonMode() { ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0); assertTrue(isValidJson(choice.getMessage().getContent()), "Response is not valid JSON"); } - + private boolean isValidJson(String jsonString) { ObjectMapper objectMapper = new ObjectMapper(); try { @@ -116,7 +146,48 @@ private boolean isValidJson(String jsonString) { return false; } } + + @Test + void createChatCompletionWithJsonSchema() throws JsonProcessingException { + final List messages = new ArrayList<>(); + final ChatMessage systemMessage = new SystemMessage("You are a helpful math tutor. Guide the user through the solution step by step."); + final ChatMessage userMessage = new UserMessage("how can I solve 8x + 7 = -23"); + messages.add(systemMessage); + messages.add(userMessage); + + Class rootClass = MathReasoning.class; + ChatResponseFormat responseFormat = ChatResponseFormat.jsonSchema(rootClass); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest + .builder() + .model("gpt-4o-2024-08-06") + .messages(messages) + .responseFormat(responseFormat) + .maxTokens(1000) + .build(); + ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0); + MathReasoning mathReasoning = choice.getMessage().parsed(rootClass); + + String finalAnswer = mathReasoning.getFinal_answer(); + assertTrue(finalAnswer.contains("x")); + assertTrue(finalAnswer.contains("=")); + } + + @Data + @NoArgsConstructor + private static class MathReasoning { + @NotNull private List steps; + @NotNull private String final_answer; + } + + @Data + @NoArgsConstructor + private static class Step { + @NotNull private String explanation; + @NotNull private String output; + } + @Test void createChatCompletionWithFunctions() { final List functions = Collections.singletonList(ToolUtil.weatherFunction()); @@ -131,7 +202,7 @@ void createChatCompletionWithFunctions() { ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest .builder() - .model("gpt-3.5-turbo-0613") + .model("gpt-4o-2024-08-06") .messages(messages) .functions(functions) .n(1) @@ -163,7 +234,7 @@ void createChatCompletionWithFunctions() { ChatCompletionRequest chatCompletionRequest2 = ChatCompletionRequest .builder() - .model("gpt-3.5-turbo-0613") + .model("gpt-4o-2024-08-06") .messages(messages) .functions(functions) .n(1) @@ -205,7 +276,7 @@ void createChatCompletionWithDynamicFunctions() { ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest .builder() - .model("gpt-3.5-turbo-0613") + .model("gpt-4o-2024-08-06") .messages(messages) .functions(Collections.singletonList(function)) .n(1) @@ -290,7 +361,7 @@ void streamChatCompletionWithFunctions() { ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest .builder() - .model("gpt-3.5-turbo-0613") + .model("gpt-4o-2024-08-06") .messages(messages) .functions(functions) .n(1) @@ -324,7 +395,7 @@ void streamChatCompletionWithFunctions() { ChatCompletionRequest chatCompletionRequest2 = ChatCompletionRequest .builder() - .model("gpt-3.5-turbo-0613") + .model("gpt-4o-2024-08-06") .messages(messages) .functions(functions) .n(1)