-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmanualHuggingFace.py
86 lines (67 loc) · 3.31 KB
/
manualHuggingFace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from huggingface_hub import InferenceClient
from pydantic import BaseModel, Field
from typing import List
from dotenv import load_dotenv
from smells import getAllSmellsInASingleString
class DetectedSmells(BaseModel):
smellNames: List[str] = Field(description="List of detected smells")
load_dotenv()
huggingface_token = os.getenv("HUGGING_FACE_API_KEY")
client = InferenceClient(api_key=huggingface_token)
test_code = """
@Test
public void realCase() {
Point p34 = new Point("34", 556506.667, 172513.91, 620.34, true);
Point p45 = new Point("45", 556495.16, 172493.912, 623.37, true);
Point p47 = new Point("47", 556612.21, 172489.274, 0.0, true);
Abriss a = new Abriss(p34, false);
a.removeDAO(CalculationsDataSource.getInstance());
a.getMeasures().add(new Measure(p45, 0.0, 91.6892, 23.277, 1.63));
a.getMeasures().add(new Measure(p47, 281.3521, 100.0471, 108.384, 1.63));
try {
a.compute();
} catch (CalculationException e) {
Assert.fail(e.getMessage());
}
// test intermediate values with point 45
Assert.assertEquals("233.2405",
this.df4.format(a.getResults().get(0).getUnknownOrientation()));
Assert.assertEquals("233.2435",
this.df4.format(a.getResults().get(0).getOrientedDirection()));
Assert.assertEquals("-0.1", this.df1.format(
a.getResults().get(0).getErrTrans()));
// test intermediate values with point 47
Assert.assertEquals("233.2466",
this.df4.format(a.getResults().get(1).getUnknownOrientation()));
Assert.assertEquals("114.5956",
this.df4.format(a.getResults().get(1).getOrientedDirection()));
Assert.assertEquals("0.5", this.df1.format(
a.getResults().get(1).getErrTrans()));
// test final results
Assert.assertEquals("233.2435", this.df4.format(a.getMean()));
Assert.assertEquals("43", this.df0.format(a.getMSE()));
Assert.assertEquals("30", this.df0.format(a.getMeanErrComp()));
}
"""
all_smells = getAllSmellsInASingleString()
def get_smells(test_code: str) -> DetectedSmells:
messages = [
# {
# "role": "system",
# "content": "You are an expert in detecting test smells. You will be given a test code block and a production code block. You have to identify which smells the test code contains. The test block may contain multiple smells, single smell or no smells at all. Here is the list of test smells: " + all_smells
# },
{
"role": "user",
"content": "You are an expert in detecting test smells. You will be given a test code block and a production code block. You have to identify which smells the test code contains. The test block may contain multiple smells, single smell or no smells at all. Here is the list of test smells: " + all_smells + "Here is the test code block:" + test_code + "Here is the production code block:" + "Tell the name of the test smell(s) the test code block contains"
}
]
# Remove response_model argument
completion = client.chat.completions.create(
model="google/gemma-7b-it",
messages=messages,
max_tokens=500
)
response = completion.choices[0].message.content
return []
get_smells(test_code)