diff --git a/api/src/main/java/com/theokanning/openai/function/FunctionDefinition.java b/api/src/main/java/com/theokanning/openai/function/FunctionDefinition.java index 150594b..84aec5b 100644 --- a/api/src/main/java/com/theokanning/openai/function/FunctionDefinition.java +++ b/api/src/main/java/com/theokanning/openai/function/FunctionDefinition.java @@ -113,6 +113,7 @@ public FunctionDefinition build() { } FunctionDefinition functionDefinition = new FunctionDefinition(); functionDefinition.name = name; + functionDefinition.strict = strict; functionDefinition.description = description; functionDefinition.parametersDefinitionClass = parametersDefinitionClass; functionDefinition.parametersDefinition = parametersDefinition; diff --git a/api/src/main/java/com/theokanning/openai/function/FunctionParametersSerializer.java b/api/src/main/java/com/theokanning/openai/function/FunctionParametersSerializer.java index 87e9776..9924be3 100644 --- a/api/src/main/java/com/theokanning/openai/function/FunctionParametersSerializer.java +++ b/api/src/main/java/com/theokanning/openai/function/FunctionParametersSerializer.java @@ -29,6 +29,9 @@ public void serialize(FunctionDefinition value, JsonGenerator gen, SerializerPro parameterSchema.remove("$schema"); parameterSchema.remove("title"); parameterSchema.remove("additionalProperties"); + if (Boolean.TRUE == value.getStrict()) { + parameterSchema.put("additionalProperties", Boolean.FALSE); + } gen.writeRawValue(JsonUtil.writeValueAsString(parameterSchema)); } else { gen.writeFieldName("parameters"); diff --git a/api/src/test/java/com/theokanning/openai/FunctionDefinitionTest.java b/api/src/test/java/com/theokanning/openai/FunctionDefinitionTest.java new file mode 100644 index 0000000..2cd328f --- /dev/null +++ b/api/src/test/java/com/theokanning/openai/FunctionDefinitionTest.java @@ -0,0 +1,135 @@ +package com.theokanning.openai; + + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.theokanning.openai.function.FunctionDefinition; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class FunctionDefinitionTest { + + private final ObjectMapper objectMapper = new ObjectMapper(); + + class TestParameters { + public String name; + } + + @Test + void shouldSerializeWithStrictTrue() throws Exception { + // given + FunctionDefinition function = + FunctionDefinition.builder() + .name("test_function") + .description("Test function") + .strict(true) + .build(); + + // when + String json = objectMapper.writeValueAsString(function); + JsonNode jsonNode = objectMapper.readTree(json); + + // then + assertEquals(jsonNode.get("name").asText(), "test_function"); + assertEquals(jsonNode.get("description").asText(),"Test function"); + assertTrue(jsonNode.get("strict").asBoolean()); + } + + @Test + void shouldSerializeWithStrictFalse() throws Exception { + // given + FunctionDefinition function = + FunctionDefinition.builder() + .name("test_function") + .description("Test function") + .strict(false) + .build(); + + // when + String json = objectMapper.writeValueAsString(function); + JsonNode jsonNode = objectMapper.readTree(json); + + // then + assertEquals("test_function", jsonNode.get("name").asText()); + assertEquals("Test function", jsonNode.get("description").asText()); + assertFalse(jsonNode.get("strict").asBoolean()); + } + + @Test + void shouldSerializeWithoutStrict() throws Exception { + // given + FunctionDefinition function = + FunctionDefinition.builder().name("test_function").description("Test function").build(); + + // when + String json = objectMapper.writeValueAsString(function); + JsonNode jsonNode = objectMapper.readTree(json); + + // then + assertEquals("test_function", jsonNode.get("name").asText()); + assertEquals("Test function", jsonNode.get("description").asText()); + assertFalse(jsonNode.has("strict")); + } + + @Test + void shouldSetAdditionalPropertiesFalseWhenStrictIsTrue() throws Exception { + // given + FunctionDefinition function = + FunctionDefinition.builder() + .name("test_function") + .description("Test function") + .strict(true) + .parametersDefinitionByClass(TestParameters.class) + .build(); + + // when + String json = objectMapper.writeValueAsString(function); + JsonNode jsonNode = objectMapper.readTree(json); + JsonNode parametersNode = jsonNode.get("parameters"); + + // then + assertFalse(parametersNode.get("additionalProperties").asBoolean()); + } + + @Test + void whenStrictIsFalse() throws Exception { + // given + FunctionDefinition function = + FunctionDefinition.builder() + .name("test_function") + .description("Test function") + .strict(false) + .parametersDefinitionByClass(TestParameters.class) + .build(); + + // when + String json = objectMapper.writeValueAsString(function); + JsonNode jsonNode = objectMapper.readTree(json); + JsonNode parametersNode = jsonNode.get("parameters"); + + // then + assertFalse(parametersNode.has("additionalProperties")); + } + + @Test + void shouldNotSetAdditionalPropertiesWhenStrictIsNull() throws Exception { + // given + FunctionDefinition function = + FunctionDefinition.builder() + .name("test_function") + .description("Test function") + .parametersDefinitionByClass(TestParameters.class) + .build(); + + // when + String json = objectMapper.writeValueAsString(function); + JsonNode jsonNode = objectMapper.readTree(json); + JsonNode parametersNode = jsonNode.get("parameters"); + + // then + assertFalse(parametersNode.has("additionalProperties")); + } +}