Skip to content

Commit

Permalink
Add enums for role, purpose and voice
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed May 22, 2024
1 parent b8aba2e commit 7854dc6
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ public sealed interface ChatMessage
record SystemMessage(String content, Optional<String> name) implements ChatMessage {
@Override
public String role() {
return Constants.SYSTEM_MESSAGE_ROLE;
return Role.SYSTEM.getId();
}
}

sealed interface UserMessage<T> extends ChatMessage
permits UserMessageWithTextContent, UserMessageWithContentParts {
@Override
default String role() {
return Constants.USER_MESSAGE_ROLE;
return Role.USER.getId();
}

T content();
Expand All @@ -43,14 +43,14 @@ record AssistantMessage(String content, Optional<String> name, Optional<List<Too
implements ChatMessage {
@Override
public String role() {
return Constants.ASSISTANT_MESSAGE_ROLE;
return Role.ASSISTANT.getId();
}
}

record ToolMessage(String content, String toolCallId) implements ChatMessage {
@Override
public String role() {
return Constants.TOOL_MESSAGE_ROLE;
return Role.TOOL.getId();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ private Constants() {}
static final String OPENAI_PROJECT_HEADER = "OpenAI-Project";
static final String OPENAI_BETA_HEADER = "OpenAI-Beta";

static final String SYSTEM_MESSAGE_ROLE = "system";
static final String USER_MESSAGE_ROLE = "user";
static final String ASSISTANT_MESSAGE_ROLE = "assistant";
static final String TOOL_MESSAGE_ROLE = "tool";

static final String IMAGE_FILE_MESSAGE_CONTENT_TYPE = "image_file";
static final String IMAGE_URL_MESSAGE_CONTENT_TYPE = "image_url";
static final String TEXT_MESSAGE_CONTENT_TYPE = "text";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,28 @@ public static Builder newBuilder() {

public static class Builder {

private static final String DEFAULT_ROLE = "user";

private String role = DEFAULT_ROLE;
private String role = Role.USER.getId();

private Object content;
private Optional<List<Attachment>> attachments = Optional.empty();
private Optional<Map<String, String>> metadata = Optional.empty();

/**
* @param role The role of the entity that is creating the message. Allowed values include:
* <ul>
* <li>`user`: Indicates the message is sent by an actual user and should be used in most
* cases to represent user-generated messages.
* <li>`assistant`: Indicates the message is generated by the assistant. Use this value to
* insert messages from the assistant into the conversation.
* </ul>
* @param role The role of the entity that is creating the message.
*/
public Builder role(String role) {
this.role = role;
return this;
}

/**
* @param role The role of the entity that is creating the message.
*/
public Builder role(Role role) {
this.role = role.getId();
return this;
}

/**
* @param content The content of the message.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,28 @@ public static Builder newBuilder() {

public static class Builder {

private String role = Constants.USER_MESSAGE_ROLE;
private String role = Role.USER.getId();

private String content;
private Optional<List<Attachment>> attachments = Optional.empty();
private Optional<Map<String, String>> metadata = Optional.empty();

/**
* @param role The role of the entity that is creating the message. Currently only user is
* supported.
* @param role The role of the entity that is creating the message.
*/
public Builder role(String role) {
this.role = role;
return this;
}

/**
* @param role The role of the entity that is creating the message.
*/
public Builder role(Role role) {
this.role = role.getId();
return this;
}

/**
* @param content The content of the message.
*/
Expand Down
19 changes: 19 additions & 0 deletions src/main/java/io/github/stefanbratanov/jvm/openai/Purpose.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.github.stefanbratanov.jvm.openai;

/** The intended purpose of a file */
public enum Purpose {
ASSISTANTS("assistants"),
BATCH("batch"),
FINE_TUNE("fine-tune"),
VISION("vision");

private final String id;

Purpose(String id) {
this.id = id;
}

public String getId() {
return id;
}
}
19 changes: 19 additions & 0 deletions src/main/java/io/github/stefanbratanov/jvm/openai/Role.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.github.stefanbratanov.jvm.openai;

public enum Role {
SYSTEM("system"),
USER("user"),
ASSISTANT("assistant"),
TOOL("tool"),
FUNCTION("function");

private final String id;

Role(String id) {
this.id = id;
}

public String getId() {
return id;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static Builder newBuilder() {
public static class Builder {

private static final String DEFAULT_MODEL = OpenAIModel.TTS_1.getId();
private static final String DEFAULT_VOICE = "alloy";
private static final String DEFAULT_VOICE = Voice.ALLOY.getId();

private String model = DEFAULT_MODEL;
private String input;
Expand Down Expand Up @@ -61,6 +61,17 @@ public Builder voice(String voice) {
return this;
}

/**
* @param voice The voice to use when generating the audio. Previews of the voices are available
* in the <a
* href="https://platform.openai.com/docs/guides/text-to-speech/voice-options">Text to
* speech guide</a>.
*/
public Builder voice(Voice voice) {
this.voice = voice.getId();
return this;
}

/**
* @param responseFormat The format to audio in
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ public Builder purpose(String purpose) {
return this;
}

/**
* @param purpose The intended purpose of the uploaded file.
*/
public Builder purpose(Purpose purpose) {
this.purpose = purpose.getId();
return this;
}

public UploadFileRequest build() {
return new UploadFileRequest(file, purpose);
}
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/io/github/stefanbratanov/jvm/openai/Voice.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package io.github.stefanbratanov.jvm.openai;

/** The voice to use when generating an audio */
public enum Voice {
ALLOY("alloy"),
ECHO("echo"),
FABLE("fable"),
ONYX("onyx"),
NOVA("nova"),
SHIMMER("shimmer");

private final String id;

Voice(String id) {
this.id = id;
}

public String getId() {
return id;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void testMessagesClient() {
UploadFileRequest uploadFileRequest =
UploadFileRequest.newBuilder()
.file(getTestResource("/mydata.jsonl"))
.purpose("assistants")
.purpose(Purpose.ASSISTANTS)
.build();
File file = openAI.filesClient().uploadFile(uploadFileRequest);
// create thread
Expand Down Expand Up @@ -503,7 +503,7 @@ private File uploadRealEstateAgentAssistantFile() {
UploadFileRequest uploadFileRequest =
UploadFileRequest.newBuilder()
.file(getTestResource("/real-estate-agent-assistant.txt"))
.purpose("assistants")
.purpose(Purpose.ASSISTANTS)
.build();
return openAI.filesClient().uploadFile(uploadFileRequest);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ void testBatchClient() {
UploadFileRequest uploadInputFileRequest =
UploadFileRequest.newBuilder()
.file(getTestResource("/batch-input-file.jsonl"))
.purpose("batch")
.purpose(Purpose.BATCH)
.build();

File inputFile = filesClient.uploadFile(uploadInputFileRequest);
Expand Down Expand Up @@ -358,7 +358,7 @@ void testFilesClient() {
Path jsonlFile = getTestResource("/mydata.jsonl");

UploadFileRequest uploadFileRequest =
UploadFileRequest.newBuilder().file(jsonlFile).purpose("fine-tune").build();
UploadFileRequest.newBuilder().file(jsonlFile).purpose(Purpose.FINE_TUNE).build();

File uploadedFile = filesClient.uploadFile(uploadFileRequest);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public SpeechRequest randomSpeechRequest() {
return SpeechRequest.newBuilder()
.model(randomTtsModel())
.input(randomString(15))
.voice(oneOf("alloy", "echo", "fable", "onyx", "nova", "shimmer"))
.voice(oneOf(Voice.ALLOY, Voice.ECHO, Voice.FABLE, Voice.ONYX, Voice.NOVA, Voice.SHIMMER))
.responseFormat(oneOf("mp3", "opus", "aac", "flac"))
.speed(randomDouble(0.25, 4.0))
.build();
Expand Down Expand Up @@ -855,7 +855,7 @@ private ChatCompletion.Choice randomChatCompletionChoice() {
new ChatCompletion.Choice.Message(
randomString(10),
listOf(randomInt(0, 3), () -> randomFunctionToolCall(false)),
Constants.ASSISTANT_MESSAGE_ROLE),
Role.ASSISTANT.getId()),
randomLogprobs(),
randomFinishReason());
}
Expand Down

0 comments on commit 7854dc6

Please sign in to comment.