这是indexloc提供的服务,不要输入任何密码
Skip to content

Function call 参数类型优化 #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ public class CreateThreadAndRunRequest {
* Specifying a particular tool like {"type": "file_search"} or {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool.
*/
@JsonProperty("tool_choice")
@JsonSerialize(using = ToolChoice.Serializer.class)
@JsonDeserialize(using = ToolChoice.Deserializer.class)
ToolChoice toolChoice;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ public class Run {
* Specifying a particular tool like {"type": "file_search"} or {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool.
*/
@JsonProperty("tool_choice")
@JsonSerialize(using = ToolChoice.Serializer.class)
@JsonDeserialize(using = ToolChoice.Deserializer.class)
ToolChoice toolChoice;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ public class RunCreateRequest {
* Specifying a particular tool like {"type": "file_search"} or {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool.
*/
@JsonProperty("tool_choice")
@JsonSerialize(using = ToolChoice.Serializer.class)
@JsonDeserialize(using = ToolChoice.Deserializer.class)
ToolChoice toolChoice;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import lombok.Data;

import java.io.IOException;
Expand All @@ -17,25 +19,26 @@
* @date 2024年04月18 17:18
**/
@Data
@JsonSerialize(using = ToolChoice.Serializer.class)
@JsonDeserialize(using = ToolChoice.Deserializer.class)
public class ToolChoice {
public static final ToolChoice REQUIRED = new ToolChoice("required");

public static final ToolChoice NONE = new ToolChoice("none");

public static final ToolChoice AUTO = new ToolChoice("auto");

/**
* The name of the function to call.
*/
Function function;

public static final ToolChoice NONE = new ToolChoice("none");

public static final ToolChoice AUTO = new ToolChoice("auto");
/**
* The type of the tool. If type is function, the function name must be set
* enum: none/auto/function/required
*/
String type;



private ToolChoice(String type) {
this.type = type;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package com.theokanning.openai.completion.chat;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.theokanning.openai.assistants.run.ToolChoice;
import lombok.AllArgsConstructor;
import lombok.Builder;
Expand All @@ -19,7 +17,6 @@
public class ChatCompletionRequest {



/**
* ID of the model to use.
*/
Expand Down Expand Up @@ -114,9 +111,9 @@ public class ChatCompletionRequest {
String user;

/**
* @since 0.20.5 {@link com.theokanning.openai.completion.chat.ChatFunction} {@link com.theokanning.openai.completion.chat.ChatFunctionDynamic}will be deprecated
* @deprecated Replaced by {@link #tools}
* recommend to use {@link com.theokanning.openai.function.FunctionDefinition} or custom class
* @since 0.20.5 {@link com.theokanning.openai.completion.chat.ChatFunction} {@link com.theokanning.openai.completion.chat.ChatFunctionDynamic}will be deprecated
*/
@Deprecated
List<?> functions;
Expand All @@ -126,9 +123,7 @@ public class ChatCompletionRequest {
*/
@JsonProperty("function_call")
@Deprecated
@JsonSerialize(using = ChatCompletionRequestFunctionCall.Serializer.class)
@JsonDeserialize(using = ChatCompletionRequestFunctionCall.Deserializer.class)
Object functionCall;
ChatCompletionRequestFunctionCall functionCall;


/**
Expand All @@ -148,7 +143,6 @@ public class ChatCompletionRequest {
Integer topLogprobs;



/**
* Function definition, only used if type is "function"
* recommend to use {@link com.theokanning.openai.function.FunctionDefinition} or custom class
Expand All @@ -161,39 +155,7 @@ public class ChatCompletionRequest {
* Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function.
*/
@JsonProperty("tool_choice")
@JsonSerialize(using = ToolChoice.Serializer.class)
@JsonDeserialize(using = ToolChoice.Deserializer.class)
ToolChoice toolChoice;



public static ChatCompletionRequestBuilder builder() {
return new InternalBuilder();
}
private static class InternalBuilder extends ChatCompletionRequestBuilder {
public InternalBuilder() {
super();
}

@Override
public ChatCompletionRequest build() {
ChatCompletionRequest request = super.build();
request.functionCallParamCheck();
return request;
}
}

private void functionCallParamCheck() {
if (functionCall==null){
return;
}
if (!(functionCall instanceof ChatCompletionRequestFunctionCall || functionCall instanceof String)) {
throw new IllegalArgumentException("functionCall must be a ChatCompletionRequestFunctionCall or a String type");
}
}

public void setFunctionCall(Object functionCall) {
this.functionCall = functionCall;
functionCallParamCheck();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
Expand All @@ -20,38 +22,59 @@
@Data
@AllArgsConstructor
@NoArgsConstructor
@JsonSerialize(using = ChatCompletionRequestFunctionCall.Serializer.class)
@JsonDeserialize(using = ChatCompletionRequestFunctionCall.Deserializer.class)
public class ChatCompletionRequestFunctionCall {
/**
* Controls which (if any) function is called by the model.
* none means the model will not call a function and instead generates a message.
* auto means the model can pick between generating a message or calling a function.
* Specifying a particular function via {"name": "my_function"} forces the model to call that function.
*/

public static final ChatCompletionRequestFunctionCall NONE = new ChatCompletionRequestFunctionCall("none");

public static final ChatCompletionRequestFunctionCall AUTO = new ChatCompletionRequestFunctionCall("auto");

String name;

public static class Serializer extends JsonSerializer<Object> {
public static ChatCompletionRequestFunctionCall of(String name) {
return new ChatCompletionRequestFunctionCall(name);
}


public static class Serializer extends JsonSerializer<ChatCompletionRequestFunctionCall> {
@Override
public void serialize(Object value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
if (value instanceof String) {
gen.writeString((String) value);
public void serialize(ChatCompletionRequestFunctionCall value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
String name = value.getName();
if ("none".equals(name) || "auto".equals(name)) {
gen.writeString(name);
return;
}
if (value instanceof ChatCompletionRequestFunctionCall) {
ChatCompletionRequestFunctionCall functionCall = (ChatCompletionRequestFunctionCall) value;
if (functionCall.getName() == null) {
gen.writeNull();
} else {
gen.writeStartObject();
gen.writeObjectField("name", functionCall.getName());
gen.writeEndObject();
}
return;
if (name == null) {
gen.writeNull();
} else {
gen.writeStartObject();
gen.writeObjectField("name", name);
gen.writeEndObject();
}
// This should never happen
throw new IllegalArgumentException("Unexpected value to function call: " + value);
}
}

public static class Deserializer extends JsonDeserializer<Object> {
public static class Deserializer extends JsonDeserializer<ChatCompletionRequestFunctionCall> {
@Override
public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
public ChatCompletionRequestFunctionCall deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
//如果是字符串,则读取字符串
if (p.getCurrentToken() == JsonToken.VALUE_STRING) {
return p.getText();
String text = p.getText();
switch (text) {
case "none":
return ChatCompletionRequestFunctionCall.NONE;
case "auto":
return ChatCompletionRequestFunctionCall.AUTO;
default:
return ChatCompletionRequestFunctionCall.of(text);
}
}
//如果是对象,则读取对象
if (p.getCurrentToken() == JsonToken.START_OBJECT) {
Expand Down
12 changes: 6 additions & 6 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
</developer>
</developers>
<scm>
<connection>https://github.com/Lambdua/openai-java.git</connection>
<developerConnection>scm:git:ssh://git@ssh.github.com:443/Lambdua/openai-java.git </developerConnection>
<url>https://github.com/Lambdua/openai-java</url>
<connection>scm:git:https://github.com/Lambdua/openai4j.git</connection>
<developerConnection>scm:git:https://github.com/Lambdua/openai4j.git</developerConnection>
<url>https://github.com/Lambdua/openai4j</url>
</scm>

<properties>
Expand Down Expand Up @@ -213,12 +213,12 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.2</version> <!-- 确保使用支持JUnit 5的版本 -->
<version>2.22.2</version>
<dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>5.10.1</version> <!-- 使用与你的测试一致的JUnit Jupiter版本 -->
<version>5.10.1</version>
</dependency>
</dependencies>
<configuration>
Expand All @@ -228,7 +228,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-clean-plugin</artifactId>
<version>3.1.0</version> <!-- 使用最新版本 -->
<version>3.1.0</version>
<configuration>
<filesets>
<fileset>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ void createChatCompletionWithFunctions() {
assertEquals("get_weather", functionCall.getName());
assertInstanceOf(ObjectNode.class, functionCall.getArguments());

ChatMessage callResponse = functionExecutor.executeAndConvertToChatMessage(functionCall.getName(),functionCall.getArguments());
ChatMessage callResponse = functionExecutor.executeAndConvertToChatMessage(functionCall.getName(), functionCall.getArguments());
assertNotEquals("error", callResponse.getName());

// this performs an unchecked cast
ToolUtil.WeatherResponse functionExecutionResponse = functionExecutor.execute(functionCall.getName(),functionCall.getArguments());
ToolUtil.WeatherResponse functionExecutionResponse = functionExecutor.execute(functionCall.getName(), functionCall.getArguments());
assertInstanceOf(ToolUtil.WeatherResponse.class, functionExecutionResponse);
assertEquals(25, functionExecutionResponse.temperature);

JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(functionCall.getName(),functionCall.getArguments());
JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(functionCall.getName(), functionCall.getArguments());
assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse);
assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText());

Expand Down Expand Up @@ -250,15 +250,15 @@ void streamChatCompletionWithFunctions() {
assertEquals("get_weather", functionCall.getName());
assertInstanceOf(ObjectNode.class, functionCall.getArguments());

ChatMessage callResponse = functionExecutor.executeAndConvertToChatMessage(functionCall.getName(),functionCall.getArguments());
ChatMessage callResponse = functionExecutor.executeAndConvertToChatMessage(functionCall.getName(), functionCall.getArguments());
assertNotEquals("error", callResponse.getName());

// this performs an unchecked cast
ToolUtil.WeatherResponse functionExecutionResponse = functionExecutor.execute(functionCall.getName(),functionCall.getArguments());
ToolUtil.WeatherResponse functionExecutionResponse = functionExecutor.execute(functionCall.getName(), functionCall.getArguments());
assertInstanceOf(ToolUtil.WeatherResponse.class, functionExecutionResponse);
assertEquals(25, functionExecutionResponse.temperature);

JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(functionCall.getName(),functionCall.getArguments());
JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(functionCall.getName(), functionCall.getArguments());
assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse);
assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText());

Expand Down Expand Up @@ -367,7 +367,7 @@ void createChatCompletionWithToolFunctions() {
assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse);
assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText());

ToolMessage chatMessageTool = toolExecutor.executeAndConvertToChatMessage(function.getName(),function.getArguments(), toolCall.getId());
ToolMessage chatMessageTool = toolExecutor.executeAndConvertToChatMessage(function.getName(), function.getArguments(), toolCall.getId());
//确保不是异常的返回
assertNotEquals("error", chatMessageTool.getName());

Expand Down Expand Up @@ -397,7 +397,7 @@ void createChatCompletionWithMultipleToolCalls() {
.name("get_weather")
.description("Get the current weather in a given location")
.parametersDefinitionByClass(ToolUtil.Weather.class)
.executor( w -> {
.executor(w -> {
switch (w.location) {
case "tokyo":
return new ToolUtil.WeatherResponse(w.location, w.unit, 10, "cloudy");
Expand All @@ -412,9 +412,9 @@ void createChatCompletionWithMultipleToolCalls() {
FunctionDefinition.<ToolUtil.City>builder().name("getCities").description("Get a list of cities by time")
.parametersDefinitionByClass(ToolUtil.City.class)
.executor(v -> {
assertEquals("2022-12-01", v.time);
return Arrays.asList("tokyo", "paris");
}).build()
assertEquals("2022-12-01", v.time);
return Arrays.asList("tokyo", "paris");
}).build()
);
final FunctionExecutorManager toolExecutor = new FunctionExecutorManager(functions);

Expand Down Expand Up @@ -452,11 +452,11 @@ void createChatCompletionWithMultipleToolCalls() {
assertInstanceOf(List.class, execute);


JsonNode jsonNode = toolExecutor.executeAndConvertToJson(function.getName(),function.getArguments());
JsonNode jsonNode = toolExecutor.executeAndConvertToJson(function.getName(), function.getArguments());
assertInstanceOf(ArrayNode.class, jsonNode);


ToolMessage toolMessage = toolExecutor.executeAndConvertToChatMessage(function.getName(),function.getArguments(), toolCall.getId());
ToolMessage toolMessage = toolExecutor.executeAndConvertToChatMessage(function.getName(), function.getArguments(), toolCall.getId());
assertNotEquals("error", toolMessage.getName());

messages.add(choice.getMessage());
Expand Down Expand Up @@ -488,7 +488,7 @@ void createChatCompletionWithMultipleToolCalls() {
for (ChatToolCall weatherToolCall : choice2.getMessage().getToolCalls()) {
Object itemResult = toolExecutor.execute(weatherToolCall.getFunction().getName(), weatherToolCall.getFunction().getArguments());
assertInstanceOf(ToolUtil.WeatherResponse.class, itemResult);
messages.add(toolExecutor.executeAndConvertToChatMessage(weatherToolCall.getFunction().getName(),weatherToolCall.getFunction().getArguments(), weatherToolCall.getId()));
messages.add(toolExecutor.executeAndConvertToChatMessage(weatherToolCall.getFunction().getName(), weatherToolCall.getFunction().getArguments(), weatherToolCall.getId()));
}

ChatCompletionRequest chatCompletionRequest3 = ChatCompletionRequest
Expand Down Expand Up @@ -546,7 +546,7 @@ void streamChatMultipleToolCalls() {
.name("get_weather")
.description("Get the current weather in a given location")
.parametersDefinitionByClass(ToolUtil.Weather.class)
.executor( w -> {
.executor(w -> {
switch (w.location) {
case "tokyo":
return new ToolUtil.WeatherResponse(w.location, w.unit, 10, "cloudy");
Expand Down Expand Up @@ -603,7 +603,7 @@ void streamChatMultipleToolCalls() {
assertInstanceOf(ArrayNode.class, jsonNode);


ToolMessage toolMessage = toolExecutor.executeAndConvertToChatMessage(function.getName(),function.getArguments(), toolCall.getId());
ToolMessage toolMessage = toolExecutor.executeAndConvertToChatMessage(function.getName(), function.getArguments(), toolCall.getId());
assertNotEquals("error", toolMessage.getName());

messages.add(accumulatedMessage);
Expand Down Expand Up @@ -637,7 +637,7 @@ void streamChatMultipleToolCalls() {
ChatFunctionCall call2 = weatherToolCall.getFunction();
Object itemResult = toolExecutor.execute(call2.getName(), call2.getArguments());
assertInstanceOf(ToolUtil.WeatherResponse.class, itemResult);
messages.add(toolExecutor.executeAndConvertToChatMessage(call2.getName(),call2.getArguments(), weatherToolCall.getId()));
messages.add(toolExecutor.executeAndConvertToChatMessage(call2.getName(), call2.getArguments(), weatherToolCall.getId()));
}

ChatCompletionRequest chatCompletionRequest3 = ChatCompletionRequest
Expand Down