Skip to content

Commit

Permalink
Using encoder/decoder registries
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccorm committed Jun 22, 2022
1 parent 7f299e9 commit 864b8d0
Show file tree
Hide file tree
Showing 13 changed files with 53 additions and 22 deletions.
4 changes: 3 additions & 1 deletion ludwig/schema/features/audio_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -15,7 +17,7 @@ class AudioInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["passthrough", "parallel_cnn", "stacked_cnn", "stacked_parallel_cnn", "rnn", "cnnrnn"],
list(get_encoder_classes('audio').keys()),
default="parallel_cnn",
description="Encoder to use for this audio feature.",
)
Expand Down
4 changes: 3 additions & 1 deletion ludwig/schema/features/bag_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -17,7 +19,7 @@ class BagInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["embed"],
list(get_encoder_classes('bag').keys()),
default="embed",
description="Encoder to use for this bag feature.",
)
Expand Down
7 changes: 5 additions & 2 deletions ludwig/schema/features/binary_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes
from ludwig.decoders.registry import get_decoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -15,7 +18,7 @@ class BinaryInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["passthrough", "dense"],
list(get_encoder_classes('binary').keys()),
default="passthrough",
description="Encoder to use for this binary feature.",
)
Expand All @@ -34,7 +37,7 @@ class BinaryOutputFeatureConfig(schema_utils.BaseMarshmallowConfig):
"""BinaryOutputFeature is a dataclass that configures the parameters used for a binary output feature."""

decoder: Optional[str] = schema_utils.StringOptions(
["regressor"],
list(get_decoder_classes('binary').keys()),
default="regressor",
allow_none=True,
description="Decoder to use for this binary feature.",
Expand Down
7 changes: 5 additions & 2 deletions ludwig/schema/features/category_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes
from ludwig.decoders.registry import get_decoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -15,7 +18,7 @@ class CategoryInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["passthrough", "dense", "sparse"],
list(get_encoder_classes('category').keys()),
default="dense",
description="Encoder to use for this category feature.",
)
Expand All @@ -34,7 +37,7 @@ class CategoryOutputFeatureConfig(schema_utils.BaseMarshmallowConfig):
"""CategoryOutputFeature is a dataclass that configures the parameters used for a category output feature."""

decoder: Optional[str] = schema_utils.StringOptions(
["classifier"],
list(get_decoder_classes('category').keys()),
default="classifier",
allow_none=True,
description="Decoder to use for this category feature.",
Expand Down
4 changes: 3 additions & 1 deletion ludwig/schema/features/date_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -17,7 +19,7 @@ class DateInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["embed", "wave"],
list(get_encoder_classes('date').keys()),
default="embed",
description="Encoder to use for this date feature.",
)
Expand Down
4 changes: 3 additions & 1 deletion ludwig/schema/features/h3_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -17,7 +19,7 @@ class H3InputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["embed", "weighted_sum", "rnn"],
list(get_encoder_classes('h3').keys()),
default="embed",
description="Encoder to use for this h3 feature.",
)
Expand Down
4 changes: 3 additions & 1 deletion ludwig/schema/features/image_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -17,7 +19,7 @@ class ImageInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["stacked_cnn", "resnet", "mlp_mixer", "vit"],
list(get_encoder_classes('image').keys()),
default="stacked_cnn",
description="Encoder to use for this image feature.",
)
Expand Down
7 changes: 5 additions & 2 deletions ludwig/schema/features/number_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes
from ludwig.decoders.registry import get_decoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -15,7 +18,7 @@ class NumberInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["passthrough", "dense"],
list(get_encoder_classes('number').keys()),
default="passthrough",
description="Encoder to use for this number feature.",
)
Expand All @@ -33,7 +36,7 @@ class NumberInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
class NumberOutputFeatureConfig(schema_utils.BaseMarshmallowConfig):

decoder: Optional[str] = schema_utils.StringOptions(
["regressor"],
list(get_decoder_classes('number').keys()),
default="regressor",
allow_none=True,
description="Decoder to use for this number feature.",
Expand Down
7 changes: 5 additions & 2 deletions ludwig/schema/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes
from ludwig.decoders.registry import get_decoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -17,7 +20,7 @@ class SequenceInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["passthrough", "embed", "parallel_cnn", "stacked_cnn", "stacked_parallel_cnn", "rnn", "cnnrnn", "transformer"],
list(get_encoder_classes('sequence').keys()),
default="embed",
description="Encoder to use for this sequence feature.",
)
Expand All @@ -38,7 +41,7 @@ class SequenceOutputFeatureConfig(schema_utils.BaseMarshmallowConfig):
"""

decoder: Optional[str] = schema_utils.StringOptions(
["generator", "tagger"],
list(get_decoder_classes('sequence').keys()),
default="generator",
allow_none=True,
description="Decoder to use for this sequence feature.",
Expand Down
7 changes: 5 additions & 2 deletions ludwig/schema/features/set_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes
from ludwig.decoders.registry import get_decoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -17,7 +20,7 @@ class SetInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["embed"],
list(get_encoder_classes('set').keys()),
default="embed",
description="Encoder to use for this set feature.",
)
Expand All @@ -38,7 +41,7 @@ class SetOutputFeatureConfig(schema_utils.BaseMarshmallowConfig):
"""

decoder: Optional[str] = schema_utils.StringOptions(
["classifier"],
list(get_decoder_classes('set').keys()),
default="classifier",
allow_none=True,
description="Decoder to use for this set feature.",
Expand Down
9 changes: 5 additions & 4 deletions ludwig/schema/features/text_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes
from ludwig.decoders.registry import get_decoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -17,9 +20,7 @@ class TextInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["passthrough", "embed", "parallel_cnn", "stacked_cnn", "stacked_parallel_cnn", "rnn", "cnnrnn", "transformer",
"albert", "mt5", "xlmroberta", "bert", "xlm", "gpt", "gpt2", "roberta", "transformer_xl", "xlnet",
"distilbert", "ctrl", "camembert", "t5", "flaubert", "electra", "longformer", "auto_transformer"],
list(get_encoder_classes('text').keys()),
default="parallel_cnn",
description="Encoder to use for this text feature.",
)
Expand All @@ -40,7 +41,7 @@ class TextOutputFeatureConfig(schema_utils.BaseMarshmallowConfig):
"""

decoder: Optional[str] = schema_utils.StringOptions(
["tagger", "generator"],
list(get_decoder_classes('text').keys()),
default="generator",
description="Decoder to use for this text output feature.",
)
4 changes: 3 additions & 1 deletion ludwig/schema/features/timeseries_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -17,7 +19,7 @@ class TimeseriesInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["passthrough", "parallel_cnn", "stacked_cnn", "stacked_parallel_cnn", "rnn", "cnnrnn", "transformer"],
list(get_encoder_classes('timeseries').keys()),
default="parallel_cnn",
description="Encoder to use for this timeseries feature.",
)
Expand Down
7 changes: 5 additions & 2 deletions ludwig/schema/features/vector_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from marshmallow_dataclass import dataclass

from ludwig.encoders.registry import get_encoder_classes
from ludwig.decoders.registry import get_decoder_classes

from ludwig.schema import utils as schema_utils
from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField

Expand All @@ -17,7 +20,7 @@ class VectorInputFeatureConfig(schema_utils.BaseMarshmallowConfig):
)

encoder: Optional[str] = schema_utils.StringOptions(
["passthrough", "dense"],
list(get_encoder_classes('vector').keys()),
default="dense",
description="Encoder to use for this vector feature.",
)
Expand All @@ -38,7 +41,7 @@ class VectorOutputFeatureConfig(schema_utils.BaseMarshmallowConfig):
"""

decoder: Optional[str] = schema_utils.StringOptions(
["projector"],
list(get_decoder_classes('vector').keys()),
default="projector",
description="Decoder to use for this vector feature.",
)
Expand Down

0 comments on commit 864b8d0

Please sign in to comment.