Skip to content

Commit

Permalink
Merge branch 'master' into serena/mvad
Browse files Browse the repository at this point in the history
  • Loading branch information
serena-ruan authored Apr 18, 2023
2 parents ea495ff + 87d5bc5 commit 685f0e6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ trait HasAADToken extends HasServiceParams {
def setAADTokenCol(v: String): this.type = setVectorParam(AADToken, v)

def getAADTokenCol: String = getVectorParam(AADToken)

def setDefaultAADToken(v: String): this.type = {
setDefault(AADToken -> Left(v))
}
}

trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
Expand All @@ -170,9 +174,10 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
setUrl(v + urlPath.stripPrefix("/"))
}

def setInternalEndpoint(v: String): this.type = {
setUrl(v + s"/cognitive/${this.internalServiceType}/" + urlPath.stripPrefix("/"))
}
override def getUrl: String = this.getOrDefault(url)

def setDefaultInternalEndpoint(v: String): this.type = setDefault(
url, v + s"/cognitive/${this.internalServiceType}/" + urlPath.stripPrefix("/"))

private[ml] def internalServiceType: String = ""

Expand All @@ -185,16 +190,16 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
| self._java_obj = self._java_obj.setEndpoint(value)
| return self
|
|def setInternalEndpoint(self, value):
| self._java_obj = self._java_obj.setInternalEndpoint(value)
|def setDefaultInternalEndpoint(self, value):
| self._java_obj = self._java_obj.setDefaultInternalEndpoint(value)
| return self
|
|def _transform(self, dataset: DataFrame) -> DataFrame:
| if running_on_synapse_internal():
| from synapse.ml.mlflow import get_mlflow_env_config
| mlflow_env_configs = get_mlflow_env_config()
| self.setAADToken(mlflow_env_configs.driver_aad_token)
| self.setInternalEndpoint(mlflow_env_configs.workload_endpoint)
| self._java_obj.setDefaultAADToken(mlflow_env_configs.driver_aad_token)
| self.setDefaultInternalEndpoint(mlflow_env_configs.workload_endpoint)
| return super()._transform(dataset)
|""".stripMargin
}
Expand All @@ -219,16 +224,6 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
|/// <returns> New $dotnetClassName object </returns>
|public $dotnetClassName SetEndpoint(string value) =>
| $dotnetClassWrapperName(Reference.Invoke(\"setEndpoint\", value));
|
|/// <summary>
|/// Sets value for internal endpoint
|/// </summary>
|/// <param name=\"value\">
|/// Endpoint of the cognitive service
|/// </param>
|/// <returns> New $dotnetClassName object </returns>
|public $dotnetClassName setInternalEndpoint(string value) =>
| $dotnetClassWrapperName(Reference.Invoke(\"setInternalEndpoint\", value));
|""".stripMargin
}
}
Expand Down Expand Up @@ -269,6 +264,8 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
case _ => p.name
}

override def getUrl: String = this.getOrDefault(url)

protected def prepareUrlRoot: Row => String = {
_ => getUrl
}
Expand Down
3 changes: 1 addition & 2 deletions website/docs/getting_started/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ your `build.sbt`:
```scala
resolvers += "SynapseML" at "https://mmlspark.azureedge.net/maven"
// Use 0.11.0 version for Spark3.2 and 0.9.5-13-d1b51517-SNAPSHOT version for Spark3.1
libraryDependencies += "com.microsoft.azure" %% "synapseml_2.12" % "0.11.0"

libraryDependencies += "com.microsoft.azure" % "synapseml_2.12" % "0.11.0"
```

## Spark package
Expand Down

0 comments on commit 685f0e6

Please sign in to comment.