这是indexloc提供的服务,不要输入任何密码
Skip to content

Basic support for structured outputs #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ hs_err_pid*
# IntelliJ Files #
.idea/
*.iml

# Eclipse Files #
.project
.settings
.classpath

# Ignore Gradle project-specific cache directory
.gradle

Expand Down
1 change: 1 addition & 0 deletions api/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/target/
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package com.theokanning.openai.completion.chat;

import java.util.List;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.theokanning.openai.utils.JsonUtil;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;

/**
* @author LiangTao
* @date 2024年04月10 10:31
Expand Down Expand Up @@ -49,4 +51,18 @@ public AssistantMessage(String content, String name) {
public String getTextContent() {
return content;
}

/**
* Deserializes the message to an object of the specified target class.
*
* @param targetClass the type of the object
* @return the deserialized object
**/
public <T> T parsed(Class<T> targetClass) {
try {
return JsonUtil.getInstance().readValue(getTextContent(), targetClass);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package com.theokanning.openai.completion.chat;

import java.util.List;
import java.util.Map;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.theokanning.openai.assistants.run.ToolChoice;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;
import java.util.Map;

@Data
@Builder
@AllArgsConstructor
Expand All @@ -29,10 +32,11 @@ public class ChatCompletionRequest {
List<ChatMessage> messages;

/**
* Must be either 'text' or 'json_object'. <br>
* Must be either 'text', 'json_object' or 'json_schema'. <br>
* When specifying 'json_object' as the request format it's still necessary to instruct the model to return JSON.
*/
@JsonProperty("response_format")
@JsonSerialize(using = ChatResponseFormat.ChatResponseFormatSerializer.class)
ChatResponseFormat responseFormat;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,48 @@
package com.theokanning.openai.completion.chat;

import java.io.IOException;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.exc.InvalidFormatException;
import com.kjetland.jackson.jsonSchema.JsonSchemaConfig;
import com.kjetland.jackson.jsonSchema.JsonSchemaGenerator;
import com.theokanning.openai.utils.JsonUtil;

import lombok.Data;
import lombok.NoArgsConstructor;

import java.io.IOException;

/**
* see {@link ChatCompletionRequest} documentation.
*/
@Data
@NoArgsConstructor
public class ChatResponseFormat {
private static final ObjectMapper MAPPER = JsonUtil.getInstance();
private static final JsonSchemaConfig CONFIG = JsonSchemaConfig.vanillaJsonSchemaDraft4();
private static final JsonSchemaGenerator JSON_SCHEMA_GENERATOR = new JsonSchemaGenerator(MAPPER, CONFIG);

/**
* auto/text/json_object
*/
private String type;

/**
* This is used together with type field set to "json_schema"
* to enable structured outputs.
*
* @see https://openai.com/index/introducing-structured-outputs-in-the-api/
*
*/
private JsonNode json_schema;

/**
* 构造私有,只允许从静态变量获取
Expand All @@ -37,7 +56,13 @@ private ChatResponseFormat(String type) {
public static final ChatResponseFormat TEXT = new ChatResponseFormat("text");

public static final ChatResponseFormat JSON_OBJECT = new ChatResponseFormat("json_object");


public static ChatResponseFormat jsonSchema(Class<?> rootClass) {
JsonNode jsonSchema = JSON_SCHEMA_GENERATOR.generateJsonSchema(rootClass);
ChatResponseFormat jsonSchemaFormat = new ChatResponseFormat("json_schema");
jsonSchemaFormat.setJson_schema(jsonSchema);
return jsonSchemaFormat;
}

@NoArgsConstructor
public static class ChatResponseFormatSerializer extends JsonSerializer<ChatResponseFormat> {
Expand All @@ -48,6 +73,18 @@ public void serialize(ChatResponseFormat value, JsonGenerator gen, SerializerPro
} else {
gen.writeStartObject();
gen.writeObjectField("type", (value).getType());

if (value.getType().equals("json_schema")) {
JsonNode jsonSchema = value.getJson_schema();

gen.writeObjectFieldStart("json_schema");
gen.writeStringField("name", "ChatResponseFormat");
gen.writeBooleanField("strict", true);
gen.writeFieldName("schema");
gen.writeTree(jsonSchema);
gen.writeEndObject();
}

gen.writeEndObject();
}
}
Expand Down
1 change: 1 addition & 0 deletions client/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/target/
1 change: 1 addition & 0 deletions service/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/target/
Original file line number Diff line number Diff line change
@@ -1,22 +1,52 @@
package com.theokanning.openai.service;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.time.Duration;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

import javax.validation.constraints.NotNull;

import org.junit.jupiter.api.Test;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.theokanning.openai.assistants.run.ToolChoice;
import com.theokanning.openai.completion.chat.*;
import com.theokanning.openai.completion.chat.AssistantMessage;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatFunctionCall;
import com.theokanning.openai.completion.chat.ChatFunctionDynamic;
import com.theokanning.openai.completion.chat.ChatFunctionProperty;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatResponseFormat;
import com.theokanning.openai.completion.chat.ChatTool;
import com.theokanning.openai.completion.chat.ChatToolCall;
import com.theokanning.openai.completion.chat.StreamOption;
import com.theokanning.openai.completion.chat.SystemMessage;
import com.theokanning.openai.completion.chat.ToolMessage;
import com.theokanning.openai.completion.chat.UserMessage;
import com.theokanning.openai.function.FunctionDefinition;
import com.theokanning.openai.function.FunctionExecutorManager;
import com.theokanning.openai.service.util.ToolUtil;
import org.junit.jupiter.api.Test;

import java.time.Duration;
import java.time.LocalDate;
import java.util.*;

import static org.junit.jupiter.api.Assertions.*;
import lombok.Data;
import lombok.NoArgsConstructor;

class ChatCompletionTest {

Expand Down Expand Up @@ -106,7 +136,7 @@ void createChatCompletionWithJsonMode() {
ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0);
assertTrue(isValidJson(choice.getMessage().getContent()), "Response is not valid JSON");
}

private boolean isValidJson(String jsonString) {
ObjectMapper objectMapper = new ObjectMapper();
try {
Expand All @@ -116,7 +146,48 @@ private boolean isValidJson(String jsonString) {
return false;
}
}

@Test
void createChatCompletionWithJsonSchema() throws JsonProcessingException {
final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = new SystemMessage("You are a helpful math tutor. Guide the user through the solution step by step.");
final ChatMessage userMessage = new UserMessage("how can I solve 8x + 7 = -23");
messages.add(systemMessage);
messages.add(userMessage);

Class<MathReasoning> rootClass = MathReasoning.class;
ChatResponseFormat responseFormat = ChatResponseFormat.jsonSchema(rootClass);

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-4o-2024-08-06")
.messages(messages)
.responseFormat(responseFormat)
.maxTokens(1000)
.build();

ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0);
MathReasoning mathReasoning = choice.getMessage().parsed(rootClass);

String finalAnswer = mathReasoning.getFinal_answer();
assertTrue(finalAnswer.contains("x"));
assertTrue(finalAnswer.contains("="));
}

@Data
@NoArgsConstructor
private static class MathReasoning {
@NotNull private List<Step> steps;
@NotNull private String final_answer;
}

@Data
@NoArgsConstructor
private static class Step {
@NotNull private String explanation;
@NotNull private String output;
}

@Test
void createChatCompletionWithFunctions() {
final List<FunctionDefinition> functions = Collections.singletonList(ToolUtil.weatherFunction());
Expand All @@ -131,7 +202,7 @@ void createChatCompletionWithFunctions() {

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.model("gpt-4o-2024-08-06")
.messages(messages)
.functions(functions)
.n(1)
Expand Down Expand Up @@ -163,7 +234,7 @@ void createChatCompletionWithFunctions() {

ChatCompletionRequest chatCompletionRequest2 = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.model("gpt-4o-2024-08-06")
.messages(messages)
.functions(functions)
.n(1)
Expand Down Expand Up @@ -205,7 +276,7 @@ void createChatCompletionWithDynamicFunctions() {

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.model("gpt-4o-2024-08-06")
.messages(messages)
.functions(Collections.singletonList(function))
.n(1)
Expand Down Expand Up @@ -290,7 +361,7 @@ void streamChatCompletionWithFunctions() {

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.model("gpt-4o-2024-08-06")
.messages(messages)
.functions(functions)
.n(1)
Expand Down Expand Up @@ -324,7 +395,7 @@ void streamChatCompletionWithFunctions() {

ChatCompletionRequest chatCompletionRequest2 = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.model("gpt-4o-2024-08-06")
.messages(messages)
.functions(functions)
.n(1)
Expand Down