Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
sss04 committed Dec 20, 2024
1 parent e703bd2 commit ad42fcd
Showing 1 changed file with 14 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from synapse.ml.services.openai.OpenAIDefaults import OpenAIDefaults
from synapse.ml.services.openai.OpenAIPrompt import OpenAIPrompt
import unittest,os, json, subprocess
import unittest, os, json, subprocess
from pyspark.sql import SQLContext
from pyspark.sql.functions import col

Expand Down Expand Up @@ -74,11 +74,14 @@ def test_prompt_w_defaults(self):
)
openai_api_key = json.loads(secretJson)["value"]

df = spark.createDataFrame([
("apple", "fruits"),
("mercedes", "cars"),
("cake", "dishes"),
], ["text", "category"])
df = spark.createDataFrame(
[
("apple", "fruits"),
("mercedes", "cars"),
("cake", "dishes"),
],
["text", "category"],
)

defaults = OpenAIDefaults()
defaults.set_deployment_name("gpt-35-turbo-0125")
Expand All @@ -88,11 +91,13 @@ def test_prompt_w_defaults(self):

prompt = OpenAIPrompt()
prompt = prompt.setOutputCol("outParsed")
prompt = prompt.setPromptTemplate("Complete this comma separated list of 5 {category}: {text}, ")
prompt = prompt.setPromptTemplate(
"Complete this comma separated list of 5 {category}: {text}, "
)
results = prompt.transform(df)
results.select("outParsed").show(truncate = False)
results.select("outParsed").show(truncate=False)
nonNullCount = results.filter(col("outParsed").isNotNull()).count()
assert (nonNullCount == 3)
assert nonNullCount == 3


if __name__ == "__main__":
Expand Down

0 comments on commit ad42fcd

Please sign in to comment.