From 864b8d038a65094827280e37e3e312bd36820ff5 Mon Sep 17 00:00:00 2001 From: connor-mccorm Date: Wed, 22 Jun 2022 14:11:45 -0700 Subject: [PATCH] Using encoder/decoder registries --- ludwig/schema/features/audio_feature.py | 4 +++- ludwig/schema/features/bag_feature.py | 4 +++- ludwig/schema/features/binary_feature.py | 7 +++++-- ludwig/schema/features/category_feature.py | 7 +++++-- ludwig/schema/features/date_feature.py | 4 +++- ludwig/schema/features/h3_feature.py | 4 +++- ludwig/schema/features/image_feature.py | 4 +++- ludwig/schema/features/number_feature.py | 7 +++++-- ludwig/schema/features/sequence_feature.py | 7 +++++-- ludwig/schema/features/set_feature.py | 7 +++++-- ludwig/schema/features/text_feature.py | 9 +++++---- ludwig/schema/features/timeseries_feature.py | 4 +++- ludwig/schema/features/vector_feature.py | 7 +++++-- 13 files changed, 53 insertions(+), 22 deletions(-) diff --git a/ludwig/schema/features/audio_feature.py b/ludwig/schema/features/audio_feature.py index cffaaacdf0a..d69eb3d15a4 100644 --- a/ludwig/schema/features/audio_feature.py +++ b/ludwig/schema/features/audio_feature.py @@ -2,6 +2,8 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -15,7 +17,7 @@ class AudioInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["passthrough", "parallel_cnn", "stacked_cnn", "stacked_parallel_cnn", "rnn", "cnnrnn"], + list(get_encoder_classes('audio').keys()), default="parallel_cnn", description="Encoder to use for this audio feature.", ) diff --git a/ludwig/schema/features/bag_feature.py b/ludwig/schema/features/bag_feature.py index ef020c57d0c..39c33d2c018 100644 --- a/ludwig/schema/features/bag_feature.py +++ b/ludwig/schema/features/bag_feature.py @@ -2,6 +2,8 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -17,7 +19,7 @@ class BagInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["embed"], + list(get_encoder_classes('bag').keys()), default="embed", description="Encoder to use for this bag feature.", ) diff --git a/ludwig/schema/features/binary_feature.py b/ludwig/schema/features/binary_feature.py index 29a57c0cbf4..4bc38881592 100644 --- a/ludwig/schema/features/binary_feature.py +++ b/ludwig/schema/features/binary_feature.py @@ -2,6 +2,9 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes +from ludwig.decoders.registry import get_decoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -15,7 +18,7 @@ class BinaryInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["passthrough", "dense"], + list(get_encoder_classes('binary').keys()), default="passthrough", description="Encoder to use for this binary feature.", ) @@ -34,7 +37,7 @@ class BinaryOutputFeatureConfig(schema_utils.BaseMarshmallowConfig): """BinaryOutputFeature is a dataclass that configures the parameters used for a binary output feature.""" decoder: Optional[str] = schema_utils.StringOptions( - ["regressor"], + list(get_decoder_classes('binary').keys()), default="regressor", allow_none=True, description="Decoder to use for this binary feature.", diff --git a/ludwig/schema/features/category_feature.py b/ludwig/schema/features/category_feature.py index 5c868f85888..753059db286 100644 --- a/ludwig/schema/features/category_feature.py +++ b/ludwig/schema/features/category_feature.py @@ -2,6 +2,9 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes +from ludwig.decoders.registry import get_decoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -15,7 +18,7 @@ class CategoryInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["passthrough", "dense", "sparse"], + list(get_encoder_classes('category').keys()), default="dense", description="Encoder to use for this category feature.", ) @@ -34,7 +37,7 @@ class CategoryOutputFeatureConfig(schema_utils.BaseMarshmallowConfig): """CategoryOutputFeature is a dataclass that configures the parameters used for a category output feature.""" decoder: Optional[str] = schema_utils.StringOptions( - ["classifier"], + list(get_decoder_classes('category').keys()), default="classifier", allow_none=True, description="Decoder to use for this category feature.", diff --git a/ludwig/schema/features/date_feature.py b/ludwig/schema/features/date_feature.py index 04b29e783df..eb54ded11eb 100644 --- a/ludwig/schema/features/date_feature.py +++ b/ludwig/schema/features/date_feature.py @@ -2,6 +2,8 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -17,7 +19,7 @@ class DateInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["embed", "wave"], + list(get_encoder_classes('date').keys()), default="embed", description="Encoder to use for this date feature.", ) diff --git a/ludwig/schema/features/h3_feature.py b/ludwig/schema/features/h3_feature.py index 4caa3db8a17..5203e041871 100644 --- a/ludwig/schema/features/h3_feature.py +++ b/ludwig/schema/features/h3_feature.py @@ -2,6 +2,8 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -17,7 +19,7 @@ class H3InputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["embed", "weighted_sum", "rnn"], + list(get_encoder_classes('h3').keys()), default="embed", description="Encoder to use for this h3 feature.", ) diff --git a/ludwig/schema/features/image_feature.py b/ludwig/schema/features/image_feature.py index c67cab7638a..a87df62dc7f 100644 --- a/ludwig/schema/features/image_feature.py +++ b/ludwig/schema/features/image_feature.py @@ -2,6 +2,8 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -17,7 +19,7 @@ class ImageInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["stacked_cnn", "resnet", "mlp_mixer", "vit"], + list(get_encoder_classes('image').keys()), default="stacked_cnn", description="Encoder to use for this image feature.", ) diff --git a/ludwig/schema/features/number_feature.py b/ludwig/schema/features/number_feature.py index 37cc2df0136..fbeb6735c22 100644 --- a/ludwig/schema/features/number_feature.py +++ b/ludwig/schema/features/number_feature.py @@ -2,6 +2,9 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes +from ludwig.decoders.registry import get_decoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -15,7 +18,7 @@ class NumberInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["passthrough", "dense"], + list(get_encoder_classes('number').keys()), default="passthrough", description="Encoder to use for this number feature.", ) @@ -33,7 +36,7 @@ class NumberInputFeatureConfig(schema_utils.BaseMarshmallowConfig): class NumberOutputFeatureConfig(schema_utils.BaseMarshmallowConfig): decoder: Optional[str] = schema_utils.StringOptions( - ["regressor"], + list(get_decoder_classes('number').keys()), default="regressor", allow_none=True, description="Decoder to use for this number feature.", diff --git a/ludwig/schema/features/sequence_feature.py b/ludwig/schema/features/sequence_feature.py index c626465f504..ebf95d5a4ed 100644 --- a/ludwig/schema/features/sequence_feature.py +++ b/ludwig/schema/features/sequence_feature.py @@ -2,6 +2,9 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes +from ludwig.decoders.registry import get_decoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -17,7 +20,7 @@ class SequenceInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["passthrough", "embed", "parallel_cnn", "stacked_cnn", "stacked_parallel_cnn", "rnn", "cnnrnn", "transformer"], + list(get_encoder_classes('sequence').keys()), default="embed", description="Encoder to use for this sequence feature.", ) @@ -38,7 +41,7 @@ class SequenceOutputFeatureConfig(schema_utils.BaseMarshmallowConfig): """ decoder: Optional[str] = schema_utils.StringOptions( - ["generator", "tagger"], + list(get_decoder_classes('sequence').keys()), default="generator", allow_none=True, description="Decoder to use for this sequence feature.", diff --git a/ludwig/schema/features/set_feature.py b/ludwig/schema/features/set_feature.py index a56f5326e43..8eb89ecc954 100644 --- a/ludwig/schema/features/set_feature.py +++ b/ludwig/schema/features/set_feature.py @@ -2,6 +2,9 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes +from ludwig.decoders.registry import get_decoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -17,7 +20,7 @@ class SetInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["embed"], + list(get_encoder_classes('set').keys()), default="embed", description="Encoder to use for this set feature.", ) @@ -38,7 +41,7 @@ class SetOutputFeatureConfig(schema_utils.BaseMarshmallowConfig): """ decoder: Optional[str] = schema_utils.StringOptions( - ["classifier"], + list(get_decoder_classes('set').keys()), default="classifier", allow_none=True, description="Decoder to use for this set feature.", diff --git a/ludwig/schema/features/text_feature.py b/ludwig/schema/features/text_feature.py index a380786fdf6..7c501f1d85d 100644 --- a/ludwig/schema/features/text_feature.py +++ b/ludwig/schema/features/text_feature.py @@ -2,6 +2,9 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes +from ludwig.decoders.registry import get_decoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -17,9 +20,7 @@ class TextInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["passthrough", "embed", "parallel_cnn", "stacked_cnn", "stacked_parallel_cnn", "rnn", "cnnrnn", "transformer", - "albert", "mt5", "xlmroberta", "bert", "xlm", "gpt", "gpt2", "roberta", "transformer_xl", "xlnet", - "distilbert", "ctrl", "camembert", "t5", "flaubert", "electra", "longformer", "auto_transformer"], + list(get_encoder_classes('text').keys()), default="parallel_cnn", description="Encoder to use for this text feature.", ) @@ -40,7 +41,7 @@ class TextOutputFeatureConfig(schema_utils.BaseMarshmallowConfig): """ decoder: Optional[str] = schema_utils.StringOptions( - ["tagger", "generator"], + list(get_decoder_classes('text').keys()), default="generator", description="Decoder to use for this text output feature.", ) diff --git a/ludwig/schema/features/timeseries_feature.py b/ludwig/schema/features/timeseries_feature.py index 2b7dbf366ff..5f990f7e619 100644 --- a/ludwig/schema/features/timeseries_feature.py +++ b/ludwig/schema/features/timeseries_feature.py @@ -2,6 +2,8 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -17,7 +19,7 @@ class TimeseriesInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["passthrough", "parallel_cnn", "stacked_cnn", "stacked_parallel_cnn", "rnn", "cnnrnn", "transformer"], + list(get_encoder_classes('timeseries').keys()), default="parallel_cnn", description="Encoder to use for this timeseries feature.", ) diff --git a/ludwig/schema/features/vector_feature.py b/ludwig/schema/features/vector_feature.py index 2d26194f563..136532da0f4 100644 --- a/ludwig/schema/features/vector_feature.py +++ b/ludwig/schema/features/vector_feature.py @@ -2,6 +2,9 @@ from marshmallow_dataclass import dataclass +from ludwig.encoders.registry import get_encoder_classes +from ludwig.decoders.registry import get_decoder_classes + from ludwig.schema import utils as schema_utils from ludwig.schema.preprocessing import BasePreprocessingConfig, PreprocessingDataclassField @@ -17,7 +20,7 @@ class VectorInputFeatureConfig(schema_utils.BaseMarshmallowConfig): ) encoder: Optional[str] = schema_utils.StringOptions( - ["passthrough", "dense"], + list(get_encoder_classes('vector').keys()), default="dense", description="Encoder to use for this vector feature.", ) @@ -38,7 +41,7 @@ class VectorOutputFeatureConfig(schema_utils.BaseMarshmallowConfig): """ decoder: Optional[str] = schema_utils.StringOptions( - ["projector"], + list(get_decoder_classes('vector').keys()), default="projector", description="Decoder to use for this vector feature.", )