Skip to content

Commit

Permalink
Update get_repeatable_train_val_test_split to handle non-stratified s…
Browse files Browse the repository at this point in the history
…plit w/ no existing split (#2237)

* Update get_repeatable_train_val_test_split to handle non-stratified split w/no existing split

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Avoid AutoML crash on dataset having column with only NaNs (kdd)

* update

* Address review comment

* Correct return Tuple on get_distinct_values

Co-authored-by: Anne Holler <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 6, 2022
1 parent 27e0b9b commit d121eeb
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 13 deletions.
13 changes: 8 additions & 5 deletions ludwig/automl/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_dtype(self, column: str) -> str:
raise NotImplementedError()

@abstractmethod
def get_distinct_values(self, column: str, max_values_to_return: int) -> Tuple[int, List[str]]:
def get_distinct_values(self, column: str, max_values_to_return: int) -> Tuple[int, List[str], float]:
raise NotImplementedError()

@abstractmethod
Expand Down Expand Up @@ -48,13 +48,16 @@ def columns(self) -> List[str]:
def get_dtype(self, column: str) -> str:
return self.df[column].dtype.name

def get_distinct_values(self, column, max_values_to_return: int) -> Tuple[int, List[str]]:
def get_distinct_values(self, column, max_values_to_return: int) -> Tuple[int, List[str], float]:
unique_values = self.df[column].dropna().unique()
num_unique_values = len(unique_values)
unique_values_counts = self.df[column].value_counts()
unique_majority_values = unique_values_counts[unique_values_counts.idxmax()]
unique_minority_values = unique_values_counts[unique_values_counts.idxmin()]
unique_values_balance = unique_minority_values / unique_majority_values
if len(unique_values_counts) != 0:
unique_majority_values = unique_values_counts[unique_values_counts.idxmax()]
unique_minority_values = unique_values_counts[unique_values_counts.idxmin()]
unique_values_balance = unique_minority_values / unique_majority_values
else:
unique_values_balance = 1.0
return num_unique_values, unique_values[:max_values_to_return], unique_values_balance

def get_nonnull_values(self, column: str) -> int:
Expand Down
23 changes: 15 additions & 8 deletions ludwig/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
from sklearn.model_selection import train_test_split

from ludwig.constants import TEST_SPLIT, TRAIN_SPLIT, VALIDATION_SPLIT
from ludwig.utils.defaults import default_random_seed


def get_repeatable_train_val_test_split(
df_input, stratify_colname, random_seed, frac_train=0.7, frac_val=0.1, frac_test=0.2
df_input, stratify_colname="", random_seed=default_random_seed, frac_train=0.7, frac_val=0.1, frac_test=0.2
):
"""Return df_input with split column containing (if possible) non-zero rows in the train, validation, and test
data subset categories.
If the input dataframe does not contain an existing split column or if the
number of rows in both the validation and test split is 0, return df_input
with split column set according to frac_<subset_name> and stratify_colname.
number of rows in both the validation and test split is 0 and non-empty
stratify_colname specified, return df_input with split column set according
to frac_<subset_name> and stratify_colname.
Else stratify_colname is ignored, and:
If the input dataframe contains an existing split column and non-zero row
Expand All @@ -26,7 +28,7 @@ def get_repeatable_train_val_test_split(
df_input : Pandas dataframe
Input dataframe to be split.
stratify_colname : str
The column used for stratification; usually the label column.
The column used for stratification (if desired); usually the label column.
random_seed : int
Seed used to get repeatable split.
frac_train : float
Expand All @@ -43,15 +45,20 @@ def get_repeatable_train_val_test_split(

if frac_train + frac_val + frac_test != 1.0:
raise ValueError(f"fractions {frac_train:f}, {frac_val:f}, {frac_test:f} do not add up to 1.0")
if stratify_colname not in df_input.columns:
raise ValueError("%s is not a column in the dataframe" % (stratify_colname))
if stratify_colname:
do_stratify_split = True
if stratify_colname not in df_input.columns:
raise ValueError("%s is not a column in the dataframe" % (stratify_colname))
else:
do_stratify_split = False
if "split" not in df_input.columns:
df_input["split"] = 0 # set up for non-stratified split path

do_stratify_split = True
if "split" in df_input.columns:
df_train = df_input[df_input["split"] == TRAIN_SPLIT]
df_val = df_input[df_input["split"] == VALIDATION_SPLIT]
df_test = df_input[df_input["split"] == TEST_SPLIT]
if len(df_val) != 0 or len(df_test) != 0:
if not do_stratify_split or len(df_val) != 0 or len(df_test) != 0:
if len(df_val) == 0:
df_val = df_train.sample(frac=frac_val, replace=False, random_state=random_seed)
df_train = df_train.drop(df_val.index)
Expand Down
55 changes: 55 additions & 0 deletions tests/ludwig/utils/test_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,61 @@ def test_get_repeatable_train_val_test_split():
)
)

# Test adding split without stratify
df = pd.DataFrame(
[
[0, 0],
[1, 0],
[2, 0],
[3, 0],
[4, 0],
[5, 1],
[6, 1],
[7, 1],
[8, 1],
[9, 1],
[10, 0],
[11, 0],
[12, 0],
[13, 0],
[14, 0],
[15, 1],
[16, 1],
[17, 1],
[18, 1],
[19, 1],
],
columns=["input", "target"],
)
split_df = get_repeatable_train_val_test_split(df, random_seed=42)
assert split_df.equals(
pd.DataFrame(
[
[3, 0, 0],
[4, 0, 0],
[5, 1, 0],
[7, 1, 0],
[8, 1, 0],
[10, 0, 0],
[11, 0, 0],
[12, 0, 0],
[13, 0, 0],
[14, 0, 0],
[15, 1, 0],
[16, 1, 0],
[18, 1, 0],
[19, 1, 0],
[0, 0, 1],
[17, 1, 1],
[1, 0, 2],
[2, 0, 2],
[9, 1, 2],
[6, 1, 2],
],
columns=["input", "target", "split"],
)
)

# Test needing no change
df = pd.DataFrame(
[
Expand Down

0 comments on commit d121eeb

Please sign in to comment.