Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Adding responseFormat parameter in OpenAI Chat Completion Request #2329

Merged
merged 5 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {

// list of shared text parameters. In method getOptionalParams, we will iterate over these parameters
// to compute the optional parameters. Since this list never changes, we can create it once and reuse it.
private val sharedTextParams = Seq(
private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq(
maxTokens,
temperature,
topP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
Expand All @@ -16,10 +17,91 @@ import spray.json._

import scala.language.existentials

object OpenAIResponseFormat extends Enumeration {
case class ResponseFormat(paylodName: String, prompt: String) extends super.Val(paylodName)

val TEXT: ResponseFormat = ResponseFormat("text", "Output must be in text format")
val JSON: ResponseFormat = ResponseFormat("json_object", "Output must be in JSON format")

def asStringSet: Set[String] =
OpenAIResponseFormat.values.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].paylodName)

def fromResponseFormatString(format: String): OpenAIResponseFormat.ResponseFormat = {
if (TEXT.paylodName== format) {
TEXT
} else if (JSON.paylodName == format) {
JSON
} else {
throw new IllegalArgumentException("Response format must be valid for OpenAI API. " +
"Currently supported formats are " +
asStringSet.mkString(", "))
}
}
}

trait HasOpenAITextParamsExtended extends HasOpenAITextParams {
val responseFormat: ServiceParam[Map[String, String]] = new ServiceParam[Map[String, String]](
this,
"responseFormat",
"Response format for the completion. Can be 'json_object' or 'text'.",
isRequired = false) {
override val payloadName: String = "response_format"
}

def getResponseFormat: Map[String, String] = getScalarParam(responseFormat)

def setResponseFormat(value: Map[String, String]): this.type = {
val allowedFormat = OpenAIResponseFormat.asStringSet

// This test is to validate that value is properly formatted Map('type' -> '<format>')
if (value == null || value.size !=1 || !value.contains("type") || value("type").isEmpty) {
throw new IllegalArgumentException("Response format map must of the form Map('type' -> '<format>')"
+ " where <format> is one of " + allowedFormat.mkString(", "))
}

// This test is to validate that the format is one of the allowed formats
if (!allowedFormat.contains(value("type").toLowerCase)) {
throw new IllegalArgumentException("Response format must be valid for OpenAI API. " +
"Currently supported formats are " +
allowedFormat.mkString(", "))
}
setScalarParam(responseFormat, value)
}

def setResponseFormat(value: String): this.type = {
if (value == null || value.isEmpty) {
this
} else {
setResponseFormat(Map("type" -> value.toLowerCase))
}
}

def setResponseFormat(value: OpenAIResponseFormat.ResponseFormat): this.type = {
setScalarParam(responseFormat, Map("type" -> value.paylodName))
}

// override this field to include the new parameter
override private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf,
logProbs,
responseFormat
)
}

object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasMessagesInput with HasCognitiveServiceInput
with HasOpenAITextParamsExtended with HasMessagesInput with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down Expand Up @@ -54,7 +136,7 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(

override def responseDataType: DataType = ChatCompletionResponse.schema

private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
private[openai] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
val mappedMessages: Seq[Map[String, String]] = messages.map { m =>
Seq("role", "content", "name").map(n =>
n -> Option(m.getAs[String](n))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.{HasGlobalParams, StringStringMapParam}
import com.microsoft.azure.synapse.ml.services._
import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T}

import scala.collection.JavaConverters._

object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt]

class OpenAIPrompt(override val uid: String) extends Transformer
with HasOpenAITextParams with HasMessagesInput
with HasOpenAITextParamsExtended with HasMessagesInput
with HasErrorCol with HasOutputCol
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
Expand Down Expand Up @@ -131,9 +131,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
df.map({ row =>
val originalOutput = Option(row.getAs[Row](outputCol))
.map({ row => openAIResultFromRow(row).choices.head })
val isFiltered = originalOutput
.map(output => Option(output.message.content).isEmpty)
.getOrElse(false)
val isFiltered = originalOutput.exists(output => Option(output.message.content).isEmpty)

if (isFiltered) {
val updatedRowSeq = row.toSeq.updated(
Expand All @@ -152,24 +150,23 @@ class OpenAIPrompt(override val uid: String) extends Transformer
transferGlobalParamsToParamMap()
logTransform[DataFrame]({
val df = dataset.toDF

val completion = openAICompletion
val promptCol = Functions.template(getPromptTemplate)
val createMessagesUDF = udf((userMessage: String) => {
Seq(
OpenAIMessage("system", getSystemPrompt),
OpenAIMessage("user", userMessage)
)
})

completion match {
case chatCompletion: OpenAIChatCompletion =>
if (isSet(responseFormat)) {
chatCompletion.setResponseFormat(getResponseFormat)
}
val messageColName = getMessagesCol
val createMessagesUDF = udf((userMessage: String) => {
getPromptsForMessage(userMessage)
})
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
val completionNamed = chatCompletion.setMessagesCol(messageColName)

val transformed = addRAIErrors(
completionNamed.transform(dfTemplated), chatCompletion.getErrorCol, chatCompletion.getOutputCol)

val results = transformed
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
Expand All @@ -183,10 +180,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

case completion: OpenAICompletion =>
if (isSet(responseFormat)) {
throw new IllegalArgumentException("responseFormat is not supported for completion models")
}
val promptColName = df.withDerivativeCol("prompt")
val dfTemplated = df.withColumn(promptColName, promptCol)
val completionNamed = completion.setPromptCol(promptColName)

// run completion
val results = completionNamed
.transform(dfTemplated)
Expand All @@ -204,6 +203,26 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}, dataset.columns.length)
}

// If the response format is set, add a system prompt to the messages. This is required by the
// OpenAI api. If the reponseFormat is json and the prompt does not contain string 'JSON' then 400 error is returned
// For this reason we add a system prompt to the messages.
// This method is made private[openai] for testing purposes
private[openai] def getPromptsForMessage(userMessage: String) = {
val basePrompts = Seq(
OpenAIMessage("system", getSystemPrompt),
OpenAIMessage("user", userMessage)
)

if (isSet(responseFormat)) {
val responseFormatPrompt = OpenAIResponseFormat
.fromResponseFormatString(getResponseFormat("type"))
.prompt
basePrompts :+ OpenAIMessage("system", responseFormatPrompt)
} else {
basePrompts
}
}

private val legacyModels = Set("ada", "babbage", "curie", "davinci",
"text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003",
"code-cushman-001", "code-davinci-002")
Expand Down Expand Up @@ -240,8 +259,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer

getPostProcessing.toLowerCase match {
case "csv" => new DelimiterParser(opts.getOrElse("delimiter", ","))
case "json" => new JsonParser(opts.get("jsonSchema").get, Map.empty)
case "regex" => new RegexParser(opts.get("regex").get, opts.get("regexGroup").get.toInt)
case "json" => new JsonParser(opts("jsonSchema"), Map.empty)
case "regex" => new RegexParser(opts("regex"), opts("regexGroup").toInt)
case "" => new PassThroughParser()
case _ => throw new IllegalArgumentException(s"Unsupported postProcessing type: '$getPostProcessing'")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.commons.io.IOUtils
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{DataFrame, Row}
import org.scalactic.Equality
Expand Down Expand Up @@ -151,6 +152,113 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
assert(Option(results.apply(2).getAs[Row]("out")).isEmpty)
}

test("getOptionalParam should include responseFormat"){
val completion = new OpenAIChatCompletion()
.setDeploymentName(deploymentNameGpt4)

def validateResponseFormat(params: Map[String, Any], responseFormat: String): Unit = {
val responseFormatPayloadName = this.completion.responseFormat.payloadName
assert(params.contains(responseFormatPayloadName))
val responseFormatMap = params(responseFormatPayloadName).asInstanceOf[Map[String, String]]
assert(responseFormatMap.contains("type"))
assert(responseFormatMap("type") == responseFormat)
}

val messages: Seq[Row] = Seq(
OpenAIMessage("user", "Whats your favorite color")
).toDF("role", "content", "name").collect()

val optionalParams: Map[String, Any] = completion.getOptionalParams(messages.head)
assert(!optionalParams.contains("response_format"))

completion.setResponseFormat("")
val optionalParams0: Map[String, Any] = completion.getOptionalParams(messages.head)
assert(!optionalParams0.contains("response_format"))

completion.setResponseFormat("json_object")
val optionalParams1: Map[String, Any] = completion.getOptionalParams(messages.head)
validateResponseFormat(optionalParams1, "json_object")

completion.setResponseFormat("text")
val optionalParams2: Map[String, Any] = completion.getOptionalParams(messages.head)
validateResponseFormat(optionalParams2, "text")

completion.setResponseFormat(Map("type" -> "json_object"))
val optionalParams3: Map[String, Any] = completion.getOptionalParams(messages.head)
validateResponseFormat(optionalParams3, "json_object")

completion.setResponseFormat(OpenAIResponseFormat.TEXT)
val optionalParams4: Map[String, Any] = completion.getOptionalParams(messages.head)
validateResponseFormat(optionalParams4, "text")
}

test("setResponseFormat should throw exception if invalid format"){
val completion = new OpenAIChatCompletion()
.setDeploymentName(deploymentNameGpt4)

assertThrows[IllegalArgumentException] {
completion.setResponseFormat("invalid_format")
}

assertThrows[IllegalArgumentException] {
completion.setResponseFormat(Map("type" -> "invalid_format"))
}

assertThrows[IllegalArgumentException] {
completion.setResponseFormat(Map("invalid_key" -> "json_object"))
}
}

test("validate that gpt4o accepts json_object response format") {
val goodDf: DataFrame = Seq(
Seq(
OpenAIMessage("system", "You are an AI chatbot with red as your favorite color"),
OpenAIMessage("system", OpenAIResponseFormat.JSON.prompt),
OpenAIMessage("user", "Whats your favorite color")
),
Seq(
OpenAIMessage("system", "You are very excited"),
OpenAIMessage("system", OpenAIResponseFormat.JSON.prompt),
OpenAIMessage("user", "How are you today")
),
Seq(
OpenAIMessage("system", OpenAIResponseFormat.JSON.prompt),
OpenAIMessage("system", "You are very excited"),
OpenAIMessage("user", "How are you today"),
OpenAIMessage("system", "Better than ever"),
OpenAIMessage("user", "Why?")
)
).toDF("messages")

val completion = new OpenAIChatCompletion()
.setDeploymentName(deploymentNameGpt4o)
.setCustomServiceName(openAIServiceName)
.setApiVersion("2023-05-15")
.setMaxTokens(500)
.setOutputCol("out")
.setMessagesCol("messages")
.setTemperature(0)
.setSubscriptionKey(openAIAPIKey)
.setResponseFormat("json_object")

testCompletion(completion, goodDf)
}

test("validate that gpt4 accepts text response format") {
val completion = new OpenAIChatCompletion()
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
.setApiVersion("2023-05-15")
.setMaxTokens(5000)
.setOutputCol("out")
.setMessagesCol("messages")
.setTemperature(0)
.setSubscriptionKey(openAIAPIKey)
.setResponseFormat("text")

testCompletion(completion, goodDf)
}

ignore("Custom EndPoint") {
lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "")
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ trait OpenAIAPIKey {
lazy val deploymentName: String = "gpt-35-turbo"
lazy val modelName: String = "gpt-35-turbo"
lazy val deploymentNameGpt4: String = "gpt-4"
lazy val deploymentNameGpt4o: String = "gpt-4o"
lazy val modelNameGpt4: String = "gpt-4"
}

Expand Down
Loading
Loading