Skip to content

Commit

Permalink
Add support for structured outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Aug 21, 2024
1 parent e7f7bd5 commit ed6c392
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ public sealed interface AssistantsResponseFormat
*/
record StringResponseFormat(String format) implements AssistantsResponseFormat {}

static AssistantsResponseFormat none() {
return new StringResponseFormat("none");
}

static AssistantsResponseFormat auto() {
return new StringResponseFormat("auto");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import com.fasterxml.jackson.databind.exc.InvalidFormatException;
import java.io.IOException;
import java.util.Optional;

class AssistantsResponseFormatDeserializer extends StdDeserializer<AssistantsResponseFormat> {

Expand All @@ -21,7 +22,12 @@ public AssistantsResponseFormat deserialize(JsonParser p, DeserializationContext
return new AssistantsResponseFormat.StringResponseFormat(node.asText());
} else if (node.isObject()) {
String type = node.get("type").asText();
return new ResponseFormat(type);
if (node.has("json_schema")) {
JsonSchema jsonSchema = p.getCodec().treeToValue(node.get("json_schema"), JsonSchema.class);
return new ResponseFormat(type, Optional.of(jsonSchema));
} else {
return new ResponseFormat(type, Optional.empty());
}
}
throw InvalidFormatException.from(
p, "Expected String or Object", node, AssistantsResponseFormat.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
import java.io.IOException;
import java.util.Optional;

class AssistantsResponseFormatSerializer extends StdSerializer<AssistantsResponseFormat> {

Expand All @@ -20,6 +21,10 @@ public void serialize(
} else if (value instanceof ResponseFormat responseFormat) {
gen.writeStartObject();
gen.writeStringField("type", responseFormat.type());
Optional<JsonSchema> jsonSchema = responseFormat.jsonSchema();
if (jsonSchema.isPresent()) {
gen.writeObjectField("json_schema", jsonSchema.get());
}
gen.writeEndObject();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ public record ChatCompletion(
public record Choice(int index, Message message, Logprobs logprobs, String finishReason) {

/** A chat completion message generated by the model. */
public record Message(String content, List<ToolCall> toolCalls, String role) {}
public record Message(String content, String refusal, List<ToolCall> toolCalls, String role) {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ public record ChatCompletionChunk(
public record Choice(Delta delta, int index, Logprobs logprobs, String finishReason) {

/** A chat completion delta generated by streamed model responses. */
public record Delta(String role, String content, List<ToolCall> toolCalls) {}
public record Delta(String role, String content, String refusal, List<ToolCall> toolCalls) {}
}
}
17 changes: 14 additions & 3 deletions src/main/java/io/github/stefanbratanov/jvm/openai/ChatMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ record UserMessageWithContentParts(List<ContentPart> content, Optional<String> n
implements UserMessage<List<ContentPart>> {}
}

record AssistantMessage(String content, Optional<String> name, Optional<List<ToolCall>> toolCalls)
record AssistantMessage(
String content,
Optional<String> refusal,
Optional<String> name,
Optional<List<ToolCall>> toolCalls)
implements ChatMessage {
@Override
public String role() {
Expand Down Expand Up @@ -67,11 +71,18 @@ static UserMessageWithContentParts userMessage(ContentPart... content) {
}

static AssistantMessage assistantMessage(String content) {
return new AssistantMessage(content, Optional.empty(), Optional.empty());
return new AssistantMessage(content, Optional.empty(), Optional.empty(), Optional.empty());
}

static AssistantMessage assistantMessage(String content, List<ToolCall> toolCalls) {
return new AssistantMessage(content, Optional.empty(), Optional.of(toolCalls));
return new AssistantMessage(
content, Optional.empty(), Optional.empty(), Optional.of(toolCalls));
}

static AssistantMessage assistantMessage(
String content, String refusal, List<ToolCall> toolCalls) {
return new AssistantMessage(
content, Optional.of(refusal), Optional.empty(), Optional.of(toolCalls));
}

static ToolMessage toolMessage(String content, String toolCallId) {
Expand Down
43 changes: 17 additions & 26 deletions src/main/java/io/github/stefanbratanov/jvm/openai/Function.java
Original file line number Diff line number Diff line change
@@ -1,37 +1,17 @@
package io.github.stefanbratanov.jvm.openai;

import com.fasterxml.jackson.databind.JsonNode;
import java.io.IOException;
import java.util.AbstractMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

/** Function that the model may generate JSON inputs for. */
public record Function(
String name, Optional<String> description, Optional<Map<String, Object>> parameters) {
String name,
Optional<String> description,
Optional<Map<String, Object>> parameters,
Optional<Boolean> strict) {

public Function {
parameters = parameters.map(this::parametersWithoutJsonEscaping);
}

private Map<String, Object> parametersWithoutJsonEscaping(Map<String, Object> parameters) {
return parameters.entrySet().stream()
.map(
entry -> {
if (entry.getValue() instanceof String value) {
try {
JsonNode node = ObjectMapperSingleton.getInstance().readTree(value);
if (node != null && !node.isNull()) {
return new AbstractMap.SimpleEntry<>(entry.getKey(), node);
}
} catch (IOException ex) {
return entry;
}
}
return entry;
})
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
parameters = parameters.map(Utils::mapWithoutJsonEscaping);
}

public static Builder newBuilder() {
Expand All @@ -43,6 +23,7 @@ public static class Builder {
private String name;
private Optional<String> description = Optional.empty();
private Optional<Map<String, Object>> parameters = Optional.empty();
private Optional<Boolean> strict = Optional.empty();

/**
* @param name The name of the function to be called. Must be a-z, A-Z, 0-9, or contain
Expand Down Expand Up @@ -72,8 +53,18 @@ public Builder parameters(Map<String, Object> parameters) {
return this;
}

/**
* @param strict Whether to enable strict schema adherence when generating the function call. If
* set to true, the model will follow the exact schema defined in the parameters field. Only
* a subset of JSON Schema is supported when strict is true.
*/
public Builder strict(boolean strict) {
this.strict = Optional.of(strict);
return this;
}

public Function build() {
return new Function(name, description, parameters);
return new Function(name, description, parameters, strict);
}
}
}
68 changes: 68 additions & 0 deletions src/main/java/io/github/stefanbratanov/jvm/openai/JsonSchema.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package io.github.stefanbratanov.jvm.openai;

import java.util.Map;
import java.util.Optional;

public record JsonSchema(
String name,
Optional<String> description,
Optional<Map<String, Object>> schema,
Optional<Boolean> strict) {

public JsonSchema {
schema = schema.map(Utils::mapWithoutJsonEscaping);
}

public static Builder newBuilder() {
return new Builder();
}

public static class Builder {

private String name;
private Optional<String> description = Optional.empty();
private Optional<Map<String, Object>> schema = Optional.empty();
private Optional<Boolean> strict = Optional.empty();

/**
* @param name The name of the response format.
*/
public Builder name(String name) {
this.name = name;
return this;
}

/**
* @param description A description of what the response format is for, used by the model to
* determine how to respond in the format.
*/
public Builder description(String description) {
this.description = Optional.of(description);
return this;
}

/**
* @param schema The schema for the response format, described as a JSON Schema object. The JSON
* schema should be defined as {@link Map} where a value could be a raw escaped JSON {@link
* String} and it will be serialized without escaping.
*/
public Builder schema(Map<String, Object> schema) {
this.schema = Optional.of(schema);
return this;
}

/**
* @param strict Whether to enable strict schema adherence when generating the output. If set to
* true, the model will always follow the exact schema defined in the schema field. Only a
* subset of JSON Schema is supported when strict is true.
*/
public Builder strict(boolean strict) {
this.strict = Optional.of(strict);
return this;
}

public JsonSchema build() {
return new JsonSchema(name, description, schema, strict);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import java.util.List;

/** Log probability information */
public record Logprobs(List<Content> content) {
public record Logprobs(List<Content> content, List<Refusal> refusal) {

public record Content(
String token, double logprob, List<Byte> bytes, List<TopLogprob> topLogprobs) {}

public record Refusal(
String token, double logprob, List<Byte> bytes, List<TopLogprob> topLogprobs) {}

public record TopLogprob(String token, double logprob, List<Byte> bytes) {}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package io.github.stefanbratanov.jvm.openai;

import java.util.Optional;

/** An object specifying the format that the model must output. */
public record ResponseFormat(String type) implements AssistantsResponseFormat {
public record ResponseFormat(String type, Optional<JsonSchema> jsonSchema)
implements AssistantsResponseFormat {
public static ResponseFormat text() {
return new ResponseFormat("text");
return new ResponseFormat("text", Optional.empty());
}

public static ResponseFormat json() {
return new ResponseFormat("json_object");
return new ResponseFormat("json_object", Optional.empty());
}

public static ResponseFormat jsonSchema(JsonSchema jsonSchema) {
return new ResponseFormat("json_schema", Optional.of(jsonSchema));
}
}
31 changes: 31 additions & 0 deletions src/main/java/io/github/stefanbratanov/jvm/openai/Utils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package io.github.stefanbratanov.jvm.openai;

import com.fasterxml.jackson.databind.JsonNode;
import java.io.IOException;
import java.util.AbstractMap;
import java.util.Map;
import java.util.stream.Collectors;

class Utils {

private Utils() {}

static Map<String, Object> mapWithoutJsonEscaping(Map<String, Object> map) {
return map.entrySet().stream()
.map(
entry -> {
if (entry.getValue() instanceof String value) {
try {
JsonNode node = ObjectMapperSingleton.getInstance().readTree(value);
if (node != null && !node.isNull()) {
return new AbstractMap.SimpleEntry<>(entry.getKey(), node);
}
} catch (IOException ex) {
return entry;
}
}
return entry;
})
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ void deserializesAssistantsResponseFormat() throws JsonProcessingException {
objectMapper.readValue(
getStringResource("/assistants-response-formats.json"), new TypeReference<>() {});

assertThat(choices).hasSize(4);
assertThat(choices).hasSize(5);

assertThat(choices.get(0))
.isInstanceOfSatisfying(
Expand All @@ -150,6 +150,28 @@ void deserializesAssistantsResponseFormat() throws JsonProcessingException {
assertThat(choices.get(3))
.isInstanceOfSatisfying(
ResponseFormat.class, format -> assertThat(format.type()).isEqualTo("json_object"));
assertThat(choices.get(4))
.isInstanceOfSatisfying(
ResponseFormat.class,
format -> {
assertThat(format.type()).isEqualTo("json_schema");
assertThat(format.jsonSchema())
.hasValueSatisfying(
jsonSchema -> {
assertThat(jsonSchema.name()).isEqualTo("math_response");
assertThat(jsonSchema.strict()).hasValue(true);
assertThat(jsonSchema.schema())
.hasValueSatisfying(
schema ->
assertThat(schema)
.isNotEmpty()
.containsKeys(
"type",
"properties",
"required",
"additionalProperties"));
});
});
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ public class TestConstants {

private TestConstants() {}

// change after https://github.com/openai/openai-openapi/pull/313 and
// https://github.com/openai/openai-openapi/pull/314 are merged
public static final String OPEN_AI_SPECIFICATION_URL =
"https://github.com/openai/openai-openapi/raw/master/openapi.yaml";
"https://raw.githubusercontent.com/StefanBratanov/openai-openapi/temp_fixes/openapi.yaml";
}
Loading

0 comments on commit ed6c392

Please sign in to comment.