-
Notifications
You must be signed in to change notification settings - Fork 834
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: Wrap OpenAIDefaults for Python (#2327)
* Add PySpark version of OpenAIDefaults - WIP * Add getters and resetters to OpenAIDefaults, and add Python version too! * Fix python OpenAIDefaults and add tests! * Adding tests and fixing style * Add python tests * Add URL to OpenAIDefaults and add new tests * Fix style
- Loading branch information
Showing
7 changed files
with
257 additions
and
11 deletions.
There are no files selected for viewing
60 changes: 60 additions & 0 deletions
60
cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (C) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
import sys | ||
|
||
if sys.version >= "3": | ||
basestring = str | ||
|
||
import pyspark | ||
from pyspark import SparkContext | ||
|
||
|
||
def getOption(opt): | ||
if opt.isDefined(): | ||
return opt.get() | ||
else: | ||
return None | ||
|
||
|
||
class OpenAIDefaults: | ||
def __init__(self): | ||
self.defaults = ( | ||
SparkContext.getOrCreate()._jvm.com.microsoft.azure.synapse.ml.services.openai.OpenAIDefaults | ||
) | ||
|
||
def set_deployment_name(self, name): | ||
self.defaults.setDeploymentName(name) | ||
|
||
def get_deployment_name(self): | ||
return getOption(self.defaults.getDeploymentName()) | ||
|
||
def reset_deployment_name(self): | ||
self.defaults.resetDeploymentName() | ||
|
||
def set_subscription_key(self, key): | ||
self.defaults.setSubscriptionKey(key) | ||
|
||
def get_subscription_key(self): | ||
return getOption(self.defaults.getSubscriptionKey()) | ||
|
||
def reset_subscription_key(self): | ||
self.defaults.resetSubscriptionKey() | ||
|
||
def set_temperature(self, temp): | ||
self.defaults.setTemperature(float(temp)) | ||
|
||
def get_temperature(self): | ||
return getOption(self.defaults.getTemperature()) | ||
|
||
def reset_temperature(self): | ||
self.defaults.resetTemperature() | ||
|
||
def set_URL(self, URL): | ||
self.defaults.setURL(URL) | ||
|
||
def get_URL(self): | ||
return getOption(self.defaults.getURL()) | ||
|
||
def reset_URL(self): | ||
self.defaults.resetURL() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (C) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
from synapse.ml.services.openai.OpenAIDefaults import OpenAIDefaults | ||
from synapse.ml.services.openai.OpenAIPrompt import OpenAIPrompt | ||
import unittest, os, json, subprocess | ||
from pyspark.sql import SQLContext | ||
from pyspark.sql.functions import col | ||
|
||
|
||
from synapse.ml.core.init_spark import * | ||
|
||
spark = init_spark() | ||
sc = SQLContext(spark.sparkContext) | ||
|
||
|
||
class TestOpenAIDefaults(unittest.TestCase): | ||
def test_setters_and_getters(self): | ||
defaults = OpenAIDefaults() | ||
|
||
defaults.set_deployment_name("Bing Bong") | ||
defaults.set_subscription_key("SubKey") | ||
defaults.set_temperature(0.05) | ||
defaults.set_URL("Test URL") | ||
|
||
self.assertEqual(defaults.get_deployment_name(), "Bing Bong") | ||
self.assertEqual(defaults.get_subscription_key(), "SubKey") | ||
self.assertEqual(defaults.get_temperature(), 0.05) | ||
self.assertEqual(defaults.get_URL(), "Test URL") | ||
|
||
def test_resetters(self): | ||
defaults = OpenAIDefaults() | ||
|
||
defaults.set_deployment_name("Bing Bong") | ||
defaults.set_subscription_key("SubKey") | ||
defaults.set_temperature(0.05) | ||
defaults.set_URL("Test URL") | ||
|
||
self.assertEqual(defaults.get_deployment_name(), "Bing Bong") | ||
self.assertEqual(defaults.get_subscription_key(), "SubKey") | ||
self.assertEqual(defaults.get_temperature(), 0.05) | ||
self.assertEqual(defaults.get_URL(), "Test URL") | ||
|
||
defaults.reset_deployment_name() | ||
defaults.reset_subscription_key() | ||
defaults.reset_temperature() | ||
defaults.reset_URL() | ||
|
||
self.assertEqual(defaults.get_deployment_name(), None) | ||
self.assertEqual(defaults.get_subscription_key(), None) | ||
self.assertEqual(defaults.get_temperature(), None) | ||
self.assertEqual(defaults.get_URL(), None) | ||
|
||
def test_two_defaults(self): | ||
defaults = OpenAIDefaults() | ||
|
||
defaults.set_deployment_name("Bing Bong") | ||
self.assertEqual(defaults.get_deployment_name(), "Bing Bong") | ||
|
||
defaults2 = OpenAIDefaults() | ||
defaults.set_deployment_name("Bing Bong") | ||
defaults2.set_deployment_name("Vamos") | ||
self.assertEqual(defaults.get_deployment_name(), "Vamos") | ||
|
||
defaults2.set_deployment_name("Test 2") | ||
defaults.set_deployment_name("Test 1") | ||
self.assertEqual(defaults.get_deployment_name(), "Test 1") | ||
|
||
def test_prompt_w_defaults(self): | ||
|
||
secretJson = subprocess.check_output( | ||
"az keyvault secret show --vault-name mmlspark-build-keys --name openai-api-key-2", | ||
shell=True, | ||
) | ||
openai_api_key = json.loads(secretJson)["value"] | ||
|
||
df = spark.createDataFrame( | ||
[ | ||
("apple", "fruits"), | ||
("mercedes", "cars"), | ||
("cake", "dishes"), | ||
], | ||
["text", "category"], | ||
) | ||
|
||
defaults = OpenAIDefaults() | ||
defaults.set_deployment_name("gpt-35-turbo-0125") | ||
defaults.set_subscription_key(openai_api_key) | ||
defaults.set_temperature(0.05) | ||
defaults.set_URL("https://synapseml-openai-2.openai.azure.com/") | ||
|
||
prompt = OpenAIPrompt() | ||
prompt = prompt.setOutputCol("outParsed") | ||
prompt = prompt.setPromptTemplate( | ||
"Complete this comma separated list of 5 {category}: {text}, " | ||
) | ||
results = prompt.transform(df) | ||
results.select("outParsed").show(truncate=False) | ||
nonNullCount = results.filter(col("outParsed").isNotNull()).count() | ||
assert nonNullCount == 3 | ||
|
||
|
||
if __name__ == "__main__": | ||
result = unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters