diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index d6ed2cc1b60..a9724219671 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -718,3 +718,11 @@ def check_prompt_requirements(config: "ModelConfig") -> None: # noqa: F821 "A template must contain at least one reference to a column or the sample keyword {__sample__} for " "a JSON-serialized representation of non-output feature columns." ) + + +@register_config_check +def check_sample_ratio_and_size_compatible(config: "ModelConfig") -> None: + sample_ratio = config.preprocessing.sample_ratio + sample_size = config.preprocessing.sample_size + if sample_size is not None and sample_ratio < 1.0: + raise ConfigValidationError("sample_size cannot be used when sample_ratio < 1.0") diff --git a/ludwig/data/preprocessing.py b/ludwig/data/preprocessing.py index ba7ba4dfed9..708edea346d 100644 --- a/ludwig/data/preprocessing.py +++ b/ludwig/data/preprocessing.py @@ -1201,15 +1201,8 @@ def build_dataset( if mode == "training": sample_ratio = global_preprocessing_parameters["sample_ratio"] - if sample_ratio < 1.0: - if not df_engine.partitioned and len(dataset_df) * sample_ratio < 1: - raise ValueError( - f"sample_ratio {sample_ratio} is too small for dataset of length {len(dataset_df)}. " - f"Please increase sample_ratio or use a larger dataset." - ) - - logger.debug(f"sample {sample_ratio} of data") - dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed) + sample_size = global_preprocessing_parameters["sample_size"] + dataset_df = _get_sampled_dataset_df(dataset_df, df_engine, sample_ratio, sample_size, random_seed) # If persisting DataFrames in memory is enabled, we want to do this after # each batch of parallel ops in order to avoid redundant computation @@ -1396,6 +1389,29 @@ def embed_fixed_features( return results +def _get_sampled_dataset_df(dataset_df, df_engine, sample_ratio, sample_size, random_seed): + df_len = len(dataset_df) + if sample_ratio < 1.0: + if not df_engine.partitioned and df_len * sample_ratio < 1: + raise ValueError( + f"sample_ratio {sample_ratio} is too small for dataset of length {df_len}. " + f"Please increase sample_ratio or use a larger dataset." + ) + + logger.debug(f"sample {sample_ratio} of data") + dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed) + + if sample_size: + if sample_size < df_len: + # Cannot use 'n' parameter when using dask DataFrames -- only 'frac' is supported + sample_ratio = sample_size / df_len + dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed) + else: + logger.warning("sample_size is larger than dataset size, ignoring sample_size") + + return dataset_df + + def get_features_with_cacheable_fixed_embeddings( feature_configs: List[FeatureConfigDict], metadata: TrainingSetMetadataDict ) -> List[FeatureConfigDict]: diff --git a/ludwig/explain/captum.py b/ludwig/explain/captum.py index 6194defd88f..081568e18f7 100644 --- a/ludwig/explain/captum.py +++ b/ludwig/explain/captum.py @@ -279,9 +279,11 @@ def get_input_tensors( :return: A list of variables, one for each input feature. Shape of each variable is [batch size, embedding size]. """ - # Ignore sample_ratio from the model config, since we want to explain all the data. + # Ignore sample_ratio and sample_size from the model config, since we want to explain all the data. sample_ratio_bak = model.config_obj.preprocessing.sample_ratio + sample_size_bak = model.config_obj.preprocessing.sample_size model.config_obj.preprocessing.sample_ratio = 1.0 + model.config_obj.preprocessing.sample_size = None config = model.config_obj.to_dict() training_set_metadata = copy.deepcopy(model.training_set_metadata) @@ -302,8 +304,9 @@ def get_input_tensors( callbacks=model.callbacks, ) - # Restore sample_ratio + # Restore sample_ratio and sample_size model.config_obj.preprocessing.sample_ratio = sample_ratio_bak + model.config_obj.preprocessing.sample_size = sample_size_bak # Make sure the number of rows in the preprocessed dataset matches the number of rows in the input data assert ( diff --git a/ludwig/schema/metadata/configs/preprocessing.yaml b/ludwig/schema/metadata/configs/preprocessing.yaml index 688f2084732..a29d2ece63a 100644 --- a/ludwig/schema/metadata/configs/preprocessing.yaml +++ b/ludwig/schema/metadata/configs/preprocessing.yaml @@ -44,6 +44,22 @@ sample_ratio: expected_impact: 2 suggested_values: Depends on data size ui_display_name: Sample Ratio +sample_size: + default_value_reasoning: + The default value is None because we do not want to shrink + the dataset by default, and we do not know the size of an arbitrary dataset. + By setting the default to None, we fall back on the sample_ratio to determine + the size of the dataset. + description_implications: + Decreases the amount of data you are inputting into + the model. Could be useful if you have more data than you need and you are + concerned with computational costs. More useful than sample_ratio if you + know the exact number of samples you want to train on instead of knowing the proportion. + example_value: + - 1000 + expected_impact: 2 + suggested_values: Depends on data size + ui_display_name: Sample Size column: expected_impact: 3 ui_display_name: Split Column diff --git a/ludwig/schema/preprocessing.py b/ludwig/schema/preprocessing.py index 075c63fe4a3..4963f3783ba 100644 --- a/ludwig/schema/preprocessing.py +++ b/ludwig/schema/preprocessing.py @@ -18,6 +18,14 @@ class PreprocessingConfig(schema_utils.BaseMarshmallowConfig): parameter_metadata=PREPROCESSING_METADATA["sample_ratio"], ) + sample_size: float = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="The maximum number of samples from the dataset to use. Cannot be set if sample_ratio is set to be " + "< 1.0. If sample_ratio is set to 1.0, this will override the number of samples to used.", + parameter_metadata=PREPROCESSING_METADATA["sample_size"], + ) + oversample_minority: float = schema_utils.NonNegativeFloat( default=None, allow_none=True, diff --git a/tests/integration_tests/test_preprocessing.py b/tests/integration_tests/test_preprocessing.py index 9f9812c6d85..671327e8e12 100644 --- a/tests/integration_tests/test_preprocessing.py +++ b/tests/integration_tests/test_preprocessing.py @@ -162,6 +162,108 @@ def test_sample_ratio_deterministic(backend, tmpdir, ray_cluster_2cpu): assert test_set_1.to_df().compute().equals(test_set_2.to_df().compute()) +@pytest.mark.parametrize( + "backend", + [ + pytest.param("local", id="local"), + pytest.param("ray", id="ray", marks=pytest.mark.distributed), + ], +) +def test_sample_size(backend, tmpdir, ray_cluster_2cpu): + num_examples = 100 + sample_size = 25 + + input_features = [sequence_feature(encoder={"reduce_output": "sum"}), audio_feature(folder=tmpdir)] + output_features = [category_feature(decoder={"vocab_size": 5}, reduce_input="sum")] + data_csv = generate_data( + input_features, output_features, os.path.join(tmpdir, "dataset.csv"), num_examples=num_examples + ) + config = { + INPUT_FEATURES: input_features, + OUTPUT_FEATURES: output_features, + TRAINER: { + EPOCHS: 2, + }, + PREPROCESSING: {"sample_size": sample_size}, + } + + model = LudwigModel(config, backend=backend) + train_set, val_set, test_set, training_set_metadata = model.preprocess( + data_csv, + skip_save_processed_input=True, + ) + + count = len(train_set) + len(val_set) + len(test_set) + assert sample_size == count + + # Check that sample size is disabled when doing preprocessing for prediction + dataset, _ = preprocess_for_prediction( + model.config_obj.to_dict(), + dataset=data_csv, + training_set_metadata=training_set_metadata, + split=FULL, + include_outputs=True, + backend=model.backend, + ) + assert "sample_size" in model.config_obj.preprocessing.to_dict() + assert len(dataset) == num_examples + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("local", id="local"), + pytest.param("ray", id="ray", marks=pytest.mark.distributed), + ], +) +def test_sample_size_deterministic(backend, tmpdir, ray_cluster_2cpu): + """Ensures that the sampled dataset is the same when using a random seed. + + model.preprocess returns a PandasPandasDataset object when using local backend, and returns a RayDataset object when + using the Ray backend. + """ + num_examples = 100 + sample_size = 30 + + input_features = [binary_feature()] + output_features = [category_feature()] + data_csv = generate_data( + input_features, output_features, os.path.join(tmpdir, "dataset.csv"), num_examples=num_examples + ) + + config = { + INPUT_FEATURES: input_features, + OUTPUT_FEATURES: output_features, + PREPROCESSING: {"sample_size": sample_size}, + } + + model1 = LudwigModel(config, backend=backend) + train_set_1, val_set_1, test_set_1, _ = model1.preprocess( + data_csv, + skip_save_processed_input=True, + ) + + model2 = LudwigModel(config, backend=backend) + train_set_2, val_set_2, test_set_2, _ = model2.preprocess( + data_csv, + skip_save_processed_input=True, + ) + + # Ensure sizes are the same + assert sample_size == len(train_set_1) + len(val_set_1) + len(test_set_1) + assert sample_size == len(train_set_2) + len(val_set_2) + len(test_set_2) + + # Ensure actual rows are the same + if backend == "local": + assert train_set_1.to_df().equals(train_set_2.to_df()) + assert val_set_1.to_df().equals(val_set_2.to_df()) + assert test_set_1.to_df().equals(test_set_2.to_df()) + else: + assert train_set_1.to_df().compute().equals(train_set_2.to_df().compute()) + assert val_set_1.to_df().compute().equals(val_set_2.to_df().compute()) + assert test_set_1.to_df().compute().equals(test_set_2.to_df().compute()) + + def test_strip_whitespace_category(csv_filename, tmpdir): data_csv_path = os.path.join(tmpdir, csv_filename) diff --git a/tests/ludwig/config_validation/test_checks.py b/tests/ludwig/config_validation/test_checks.py index 44630167fda..613a29a3fd3 100644 --- a/tests/ludwig/config_validation/test_checks.py +++ b/tests/ludwig/config_validation/test_checks.py @@ -489,3 +489,35 @@ def test_check_prompt_requirements(): config["prompt"] = {"task": "Some task", "template": "{__task__}"} ModelConfig.from_dict(config) + + +def test_check_sample_ratio_and_size_compatible(): + config = { + "input_features": [binary_feature()], + "output_features": [binary_feature()], + "model_type": "ecd", + } + ModelConfig.from_dict( + { + "input_features": [binary_feature()], + "output_features": [binary_feature()], + "model_type": "ecd", + } + ) + + config["preprocessing"] = {"sample_size": 10} + ModelConfig.from_dict(config) + + config["preprocessing"]["sample_ratio"] = 1 + ModelConfig.from_dict(config) + + config["preprocessing"]["sample_ratio"] = 0.1 + with pytest.raises(ConfigValidationError): + ModelConfig.from_dict(config) + + config["preprocessing"]["sample_size"] = 0 + with pytest.raises(ConfigValidationError): + ModelConfig.from_dict(config) + + del config["preprocessing"]["sample_size"] + ModelConfig.from_dict(config)