-
Notifications
You must be signed in to change notification settings - Fork 397
/
Copy pathdatatypes.py
73 lines (52 loc) · 1.51 KB
/
datatypes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Data model for the riffusion API.
"""
from __future__ import annotations
import typing as T
from dataclasses import dataclass
@dataclass(frozen=True)
class PromptInput:
"""
Parameters for one end of interpolation.
"""
# Text prompt fed into a CLIP model
prompt: str
# Random seed for denoising
seed: int
# Negative prompt to avoid (optional)
negative_prompt: T.Optional[str] = None
# Denoising strength
denoising: float = 0.75
# Classifier-free guidance strength
guidance: float = 7.0
@dataclass(frozen=True)
class InferenceInput:
"""
Parameters for a single run of the riffusion model, interpolating between
a start and end set of PromptInputs. This is the API required for a request
to the model server.
"""
# Start point of interpolation
start: PromptInput
# End point of interpolation
end: PromptInput
# Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1
# uses end fully.
alpha: float
# Number of inner loops of the diffusion model
num_inference_steps: int = 50
# Which seed image to use
seed_image_id: str = "og_beat"
# ID of mask image to use
mask_image_id: T.Optional[str] = None
@dataclass(frozen=True)
class InferenceOutput:
"""
Response from the model inference server.
"""
# base64 encoded spectrogram image as a JPEG
image: str
# base64 encoded audio clip as an MP3
audio: str
# The duration of the audio clip
duration_s: float