Skip to content

Commit

Permalink
feat: Add ChatGPT through the OpenAIChatCompletion transformer (#1887)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 authored Mar 23, 2023
1 parent 7657089 commit 9f634a6
Show file tree
Hide file tree
Showing 16 changed files with 299 additions and 210 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,22 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
}
}

trait HasAPIVersion extends HasServiceParams {
val apiVersion: ServiceParam[String] = new ServiceParam[String](
this, "apiVersion", "version of the api", isRequired = true, isURLParam = true) {
override val payloadName: String = "api-version"
}

def getApiVersion: String = getScalarParam(apiVersion)

def setApiVersion(v: String): this.type = setScalarParam(apiVersion, v)

def getApiVersionCol: String = getVectorParam(apiVersion)

def setApiVersionCol(v: String): this.type = setVectorParam(apiVersion, v)

}

object URLEncodingUtils {

private case class NameValuePairInternal(t: (String, String)) extends NameValuePair {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package com.microsoft.azure.synapse.ml.cognitive.form

import com.microsoft.azure.synapse.ml.cognitive._
import com.microsoft.azure.synapse.ml.cognitive.openai.HasAPIVersion
import com.microsoft.azure.synapse.ml.cognitive.vision.{BasicAsyncReply, HasImageInput}
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.param.ServiceParam
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package com.microsoft.azure.synapse.ml.cognitive.language

import com.microsoft.azure.synapse.ml.cognitive._
import com.microsoft.azure.synapse.ml.cognitive.openai.HasAPIVersion
import com.microsoft.azure.synapse.ml.cognitive.text.{TADocument, TextAnalyticsAutoBatch}
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.param.ServiceParam
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,14 @@
package com.microsoft.azure.synapse.ml.cognitive.openai

import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.cognitive.{
CognitiveServicesBase, HasCognitiveServiceInput,
HasInternalJsonOutputParser, HasServiceParams
}
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.cognitive.{HasAPIVersion, HasServiceParams}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._
import spray.json._

import scala.language.existentials

trait HasPrompt extends HasServiceParams {
trait HasPromptInputs extends HasServiceParams {
val prompt: ServiceParam[String] = new ServiceParam[String](
this, "prompt", "The text to complete", isRequired = false)

Expand All @@ -32,9 +22,7 @@ trait HasPrompt extends HasServiceParams {
def getPromptCol: String = getVectorParam(prompt)

def setPromptCol(v: String): this.type = setVectorParam(prompt, v)
}

trait HasBatchPrompt extends HasServiceParams {
val batchPrompt: ServiceParam[Seq[String]] = new ServiceParam[Seq[String]](
this, "batchPrompt", "Sequence of prompts to complete", isRequired = false)

Expand All @@ -45,65 +33,40 @@ trait HasBatchPrompt extends HasServiceParams {
def getBatchPromptCol: String = getVectorParam(batchPrompt)

def setBatchPromptCol(v: String): this.type = setVectorParam(batchPrompt, v)
}

trait HasIndexPrompt extends HasServiceParams {
val indexPrompt: ServiceParam[Seq[Int]] = new ServiceParam[Seq[Int]](
this, "indexPrompt", "Sequence of indexes to complete", isRequired = false)

def getIndexPrompt: Seq[Int] = getScalarParam(indexPrompt)

def setIndexPrompt(v: Seq[Int]): this.type = setScalarParam(indexPrompt, v)

def getIndexPromptCol: String = getVectorParam(indexPrompt)

def setIndexPromptCol(v: String): this.type = setVectorParam(indexPrompt, v)
}

trait HasBatchIndexPrompt extends HasServiceParams {
val batchIndexPrompt: ServiceParam[Seq[Seq[Int]]] = new ServiceParam[Seq[Seq[Int]]](
this, "batchIndexPrompt", "Sequence of index sequences to complete", isRequired = false)

def getBatchIndexPrompt: Seq[Seq[Int]] = getScalarParam(batchIndexPrompt)
trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

def setBatchIndexPrompt(v: Seq[Seq[Int]]): this.type = setScalarParam(batchIndexPrompt, v)

def getBatchIndexPromptCol: String = getVectorParam(batchIndexPrompt)

def setBatchIndexPromptCol(v: String): this.type = setVectorParam(batchIndexPrompt, v)
}
val deploymentName = new ServiceParam[String](
this, "deploymentName", "The name of the deployment", isRequired = true)

trait HasAPIVersion extends HasServiceParams {
val apiVersion: ServiceParam[String] = new ServiceParam[String](
this, "apiVersion", "version of the api", isRequired = true, isURLParam = true) {
override val payloadName: String = "api-version"
}
def getDeploymentName: String = getScalarParam(deploymentName)

def getApiVersion: String = getScalarParam(apiVersion)
def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v)

def setApiVersion(v: String): this.type = setScalarParam(apiVersion, v)
def getDeploymentNameCol: String = getVectorParam(deploymentName)

def getApiVersionCol: String = getVectorParam(apiVersion)
def setDeploymentNameCol(v: String): this.type = setVectorParam(deploymentName, v)

def setApiVersionCol(v: String): this.type = setVectorParam(apiVersion, v)
val user: ServiceParam[String] = new ServiceParam[String](
this, "user",
"The ID of the end-user, for use in tracking and rate-limiting.",
isRequired = false)

setDefault(apiVersion -> Left("2022-03-01-preview"))
}
def getUser: String = getScalarParam(user)

trait HasDeploymentName extends HasServiceParams {
val deploymentName = new ServiceParam[String](
this, "deploymentName", "The name of the deployment", isRequired = true)
def setUser(v: String): this.type = setScalarParam(user, v)

def getDeploymentName: String = getScalarParam(deploymentName)
def getUserCol: String = getVectorParam(user)

def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v)
def setUserCol(v: String): this.type = setVectorParam(user, v)

def getDeploymentNameCol: String = getVectorParam(deploymentName)
setDefault(apiVersion -> Left("2023-03-15-preview"))

def setDeploymentNameCol(v: String): this.type = setVectorParam(deploymentName, v)
}

trait HasMaxTokens extends HasServiceParams {
trait HasOpenAITextParams extends HasOpenAISharedParams {

val maxTokens: ServiceParam[Int] = new ServiceParam[Int](
this, "maxTokens",
Expand All @@ -118,9 +81,6 @@ trait HasMaxTokens extends HasServiceParams {

def setMaxTokensCol(v: String): this.type = setVectorParam(maxTokens, v)

}

trait HasTemperature extends HasServiceParams {
val temperature: ServiceParam[Double] = new ServiceParam[Double](
this, "temperature",
"What sampling temperature to use. Higher values means the model will take more risks." +
Expand All @@ -135,24 +95,7 @@ trait HasTemperature extends HasServiceParams {
def getTemperatureCol: String = getVectorParam(temperature)

def setTemperatureCol(v: String): this.type = setVectorParam(temperature, v)
}

trait HasModel extends HasServiceParams {
val model: ServiceParam[String] = new ServiceParam[String](
this, "model",
"The name of the model to use",
isRequired = false)

def getModel: String = getScalarParam(model)

def setModel(v: String): this.type = setScalarParam(model, v)

def getModelCol: String = getVectorParam(model)

def setModelCol(v: String): this.type = setVectorParam(model, v)
}

trait HasStop extends HasServiceParams {
val stop: ServiceParam[String] = new ServiceParam[String](
this, "stop",
"A sequence which indicates the end of the current document.",
Expand All @@ -165,12 +108,6 @@ trait HasStop extends HasServiceParams {
def getStopCol: String = getVectorParam(stop)

def setStopCol(v: String): this.type = setVectorParam(stop, v)
}

trait HasOpenAIParams extends HasServiceParams
with HasPrompt with HasBatchPrompt with HasIndexPrompt with HasBatchIndexPrompt
with HasTemperature with HasModel with HasStop
with HasAPIVersion with HasDeploymentName with HasMaxTokens {

val topP: ServiceParam[Double] = new ServiceParam[Double](
this, "topP",
Expand All @@ -189,19 +126,6 @@ trait HasOpenAIParams extends HasServiceParams

def setTopPCol(v: String): this.type = setVectorParam(topP, v)

val user: ServiceParam[String] = new ServiceParam[String](
this, "user",
"The ID of the end-user, for use in tracking and rate-limiting.",
isRequired = false)

def getUser: String = getScalarParam(user)

def setUser(v: String): this.type = setScalarParam(user, v)

def getUserCol: String = getVectorParam(user)

def setUserCol(v: String): this.type = setVectorParam(user, v)

val n: ServiceParam[Int] = new ServiceParam[Int](
this, "n",
"How many snippets to generate for each prompt. Minimum of 1 and maximum of 128 allowed.",
Expand Down Expand Up @@ -299,5 +223,24 @@ trait HasOpenAIParams extends HasServiceParams

def setBestOfCol(v: String): this.type = setVectorParam(bestOf, v)

private[ml] def getOptionalParams(r: Row): Map[String, Any] = {
Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf
).flatMap(param =>
getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v))
).++(Seq(
getValueOpt(r, logProbs).map(v => ("logprobs", v))
).flatten).toMap
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.cognitive.openai

import com.microsoft.azure.synapse.ml.cognitive.{
CognitiveServicesBase, HasCognitiveServiceInput, HasInternalJsonOutputParser
}
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json._
import spray.json.DefaultJsonProtocol._

import scala.language.existentials

object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends CognitiveServicesBase(uid)
with HasOpenAITextParams with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass()

val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
" in the chat format. This column should have type Array(Struct(role: String, content: String)).")

def getMessagesCol: String = $(messagesCol)

def setMessagesCol(v: String): this.type = set(messagesCol, v)

def this() = this(Identifiable.randomUID("OpenAIChatCompletion"))

def urlPath: String = ""

override private[ml] def internalServiceType: String = "openai"

override def setCustomServiceName(v: String): this.type = {
setUrl(s"https://$v.openai.azure.com/" + urlPath.stripPrefix("/"))
}

override protected def prepareUrlRoot: Row => String = { row =>
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions"
}

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
lazy val optionalParams: Map[String, Any] = getOptionalParams(r)
val messages = r.getAs[Seq[Row]](getMessagesCol)
Some(getStringEntity(messages, optionalParams))
}

override val subscriptionKeyHeaderName: String = "api-key"

override def shouldSkip(row: Row): Boolean =
super.shouldSkip(row) || Option(row.getAs[Row](getMessagesCol)).isEmpty

override protected def getVectorParamMap: Map[String, String] = super.getVectorParamMap
.updated("messages", getMessagesCol)

override def responseDataType: DataType = ChatCompletionResponse.schema

private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
val mappedMessages: Seq[Map[String, String]] = messages.map { m =>
Seq("role", "content", "name").map(n =>
n -> Option(m.getAs[String](n))
).toMap.filter(_._2.isDefined).mapValues(_.get)
}
val fullPayload = optionalParams.updated("messages", mappedMessages)
new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON)
}

}



Loading

0 comments on commit 9f634a6

Please sign in to comment.