Skip to content

Commit

Permalink
Add sample_size as a global preprocessing parameter (#3650)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Infernaught and pre-commit-ci[bot] authored Oct 12, 2023
1 parent fd91478 commit df6f5ef
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 11 deletions.
8 changes: 8 additions & 0 deletions ludwig/config_validation/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,11 @@ def check_prompt_requirements(config: "ModelConfig") -> None: # noqa: F821
"A template must contain at least one reference to a column or the sample keyword {__sample__} for "
"a JSON-serialized representation of non-output feature columns."
)


@register_config_check
def check_sample_ratio_and_size_compatible(config: "ModelConfig") -> None:
sample_ratio = config.preprocessing.sample_ratio
sample_size = config.preprocessing.sample_size
if sample_size is not None and sample_ratio < 1.0:
raise ConfigValidationError("sample_size cannot be used when sample_ratio < 1.0")
34 changes: 25 additions & 9 deletions ludwig/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,15 +1201,8 @@ def build_dataset(

if mode == "training":
sample_ratio = global_preprocessing_parameters["sample_ratio"]
if sample_ratio < 1.0:
if not df_engine.partitioned and len(dataset_df) * sample_ratio < 1:
raise ValueError(
f"sample_ratio {sample_ratio} is too small for dataset of length {len(dataset_df)}. "
f"Please increase sample_ratio or use a larger dataset."
)

logger.debug(f"sample {sample_ratio} of data")
dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed)
sample_size = global_preprocessing_parameters["sample_size"]
dataset_df = _get_sampled_dataset_df(dataset_df, df_engine, sample_ratio, sample_size, random_seed)

# If persisting DataFrames in memory is enabled, we want to do this after
# each batch of parallel ops in order to avoid redundant computation
Expand Down Expand Up @@ -1396,6 +1389,29 @@ def embed_fixed_features(
return results


def _get_sampled_dataset_df(dataset_df, df_engine, sample_ratio, sample_size, random_seed):
df_len = len(dataset_df)
if sample_ratio < 1.0:
if not df_engine.partitioned and df_len * sample_ratio < 1:
raise ValueError(
f"sample_ratio {sample_ratio} is too small for dataset of length {df_len}. "
f"Please increase sample_ratio or use a larger dataset."
)

logger.debug(f"sample {sample_ratio} of data")
dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed)

if sample_size:
if sample_size < df_len:
# Cannot use 'n' parameter when using dask DataFrames -- only 'frac' is supported
sample_ratio = sample_size / df_len
dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed)
else:
logger.warning("sample_size is larger than dataset size, ignoring sample_size")

return dataset_df


def get_features_with_cacheable_fixed_embeddings(
feature_configs: List[FeatureConfigDict], metadata: TrainingSetMetadataDict
) -> List[FeatureConfigDict]:
Expand Down
7 changes: 5 additions & 2 deletions ludwig/explain/captum.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,11 @@ def get_input_tensors(
:return: A list of variables, one for each input feature. Shape of each variable is [batch size, embedding size].
"""
# Ignore sample_ratio from the model config, since we want to explain all the data.
# Ignore sample_ratio and sample_size from the model config, since we want to explain all the data.
sample_ratio_bak = model.config_obj.preprocessing.sample_ratio
sample_size_bak = model.config_obj.preprocessing.sample_size
model.config_obj.preprocessing.sample_ratio = 1.0
model.config_obj.preprocessing.sample_size = None

config = model.config_obj.to_dict()
training_set_metadata = copy.deepcopy(model.training_set_metadata)
Expand All @@ -302,8 +304,9 @@ def get_input_tensors(
callbacks=model.callbacks,
)

# Restore sample_ratio
# Restore sample_ratio and sample_size
model.config_obj.preprocessing.sample_ratio = sample_ratio_bak
model.config_obj.preprocessing.sample_size = sample_size_bak

# Make sure the number of rows in the preprocessed dataset matches the number of rows in the input data
assert (
Expand Down
16 changes: 16 additions & 0 deletions ludwig/schema/metadata/configs/preprocessing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ sample_ratio:
expected_impact: 2
suggested_values: Depends on data size
ui_display_name: Sample Ratio
sample_size:
default_value_reasoning:
The default value is None because we do not want to shrink
the dataset by default, and we do not know the size of an arbitrary dataset.
By setting the default to None, we fall back on the sample_ratio to determine
the size of the dataset.
description_implications:
Decreases the amount of data you are inputting into
the model. Could be useful if you have more data than you need and you are
concerned with computational costs. More useful than sample_ratio if you
know the exact number of samples you want to train on instead of knowing the proportion.
example_value:
- 1000
expected_impact: 2
suggested_values: Depends on data size
ui_display_name: Sample Size
column:
expected_impact: 3
ui_display_name: Split Column
Expand Down
8 changes: 8 additions & 0 deletions ludwig/schema/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ class PreprocessingConfig(schema_utils.BaseMarshmallowConfig):
parameter_metadata=PREPROCESSING_METADATA["sample_ratio"],
)

sample_size: float = schema_utils.NonNegativeInteger(
default=None,
allow_none=True,
description="The maximum number of samples from the dataset to use. Cannot be set if sample_ratio is set to be "
"< 1.0. If sample_ratio is set to 1.0, this will override the number of samples to used.",
parameter_metadata=PREPROCESSING_METADATA["sample_size"],
)

oversample_minority: float = schema_utils.NonNegativeFloat(
default=None,
allow_none=True,
Expand Down
102 changes: 102 additions & 0 deletions tests/integration_tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,108 @@ def test_sample_ratio_deterministic(backend, tmpdir, ray_cluster_2cpu):
assert test_set_1.to_df().compute().equals(test_set_2.to_df().compute())


@pytest.mark.parametrize(
"backend",
[
pytest.param("local", id="local"),
pytest.param("ray", id="ray", marks=pytest.mark.distributed),
],
)
def test_sample_size(backend, tmpdir, ray_cluster_2cpu):
num_examples = 100
sample_size = 25

input_features = [sequence_feature(encoder={"reduce_output": "sum"}), audio_feature(folder=tmpdir)]
output_features = [category_feature(decoder={"vocab_size": 5}, reduce_input="sum")]
data_csv = generate_data(
input_features, output_features, os.path.join(tmpdir, "dataset.csv"), num_examples=num_examples
)
config = {
INPUT_FEATURES: input_features,
OUTPUT_FEATURES: output_features,
TRAINER: {
EPOCHS: 2,
},
PREPROCESSING: {"sample_size": sample_size},
}

model = LudwigModel(config, backend=backend)
train_set, val_set, test_set, training_set_metadata = model.preprocess(
data_csv,
skip_save_processed_input=True,
)

count = len(train_set) + len(val_set) + len(test_set)
assert sample_size == count

# Check that sample size is disabled when doing preprocessing for prediction
dataset, _ = preprocess_for_prediction(
model.config_obj.to_dict(),
dataset=data_csv,
training_set_metadata=training_set_metadata,
split=FULL,
include_outputs=True,
backend=model.backend,
)
assert "sample_size" in model.config_obj.preprocessing.to_dict()
assert len(dataset) == num_examples


@pytest.mark.parametrize(
"backend",
[
pytest.param("local", id="local"),
pytest.param("ray", id="ray", marks=pytest.mark.distributed),
],
)
def test_sample_size_deterministic(backend, tmpdir, ray_cluster_2cpu):
"""Ensures that the sampled dataset is the same when using a random seed.
model.preprocess returns a PandasPandasDataset object when using local backend, and returns a RayDataset object when
using the Ray backend.
"""
num_examples = 100
sample_size = 30

input_features = [binary_feature()]
output_features = [category_feature()]
data_csv = generate_data(
input_features, output_features, os.path.join(tmpdir, "dataset.csv"), num_examples=num_examples
)

config = {
INPUT_FEATURES: input_features,
OUTPUT_FEATURES: output_features,
PREPROCESSING: {"sample_size": sample_size},
}

model1 = LudwigModel(config, backend=backend)
train_set_1, val_set_1, test_set_1, _ = model1.preprocess(
data_csv,
skip_save_processed_input=True,
)

model2 = LudwigModel(config, backend=backend)
train_set_2, val_set_2, test_set_2, _ = model2.preprocess(
data_csv,
skip_save_processed_input=True,
)

# Ensure sizes are the same
assert sample_size == len(train_set_1) + len(val_set_1) + len(test_set_1)
assert sample_size == len(train_set_2) + len(val_set_2) + len(test_set_2)

# Ensure actual rows are the same
if backend == "local":
assert train_set_1.to_df().equals(train_set_2.to_df())
assert val_set_1.to_df().equals(val_set_2.to_df())
assert test_set_1.to_df().equals(test_set_2.to_df())
else:
assert train_set_1.to_df().compute().equals(train_set_2.to_df().compute())
assert val_set_1.to_df().compute().equals(val_set_2.to_df().compute())
assert test_set_1.to_df().compute().equals(test_set_2.to_df().compute())


def test_strip_whitespace_category(csv_filename, tmpdir):
data_csv_path = os.path.join(tmpdir, csv_filename)

Expand Down
32 changes: 32 additions & 0 deletions tests/ludwig/config_validation/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,35 @@ def test_check_prompt_requirements():

config["prompt"] = {"task": "Some task", "template": "{__task__}"}
ModelConfig.from_dict(config)


def test_check_sample_ratio_and_size_compatible():
config = {
"input_features": [binary_feature()],
"output_features": [binary_feature()],
"model_type": "ecd",
}
ModelConfig.from_dict(
{
"input_features": [binary_feature()],
"output_features": [binary_feature()],
"model_type": "ecd",
}
)

config["preprocessing"] = {"sample_size": 10}
ModelConfig.from_dict(config)

config["preprocessing"]["sample_ratio"] = 1
ModelConfig.from_dict(config)

config["preprocessing"]["sample_ratio"] = 0.1
with pytest.raises(ConfigValidationError):
ModelConfig.from_dict(config)

config["preprocessing"]["sample_size"] = 0
with pytest.raises(ConfigValidationError):
ModelConfig.from_dict(config)

del config["preprocessing"]["sample_size"]
ModelConfig.from_dict(config)

0 comments on commit df6f5ef

Please sign in to comment.