diff --git a/ludwig/models/base.py b/ludwig/models/base.py index a468cf86cb1..5f79bbbcbf9 100644 --- a/ludwig/models/base.py +++ b/ludwig/models/base.py @@ -10,6 +10,7 @@ from ludwig.constants import COMBINED, LOSS, NAME, TIED, TYPE from ludwig.features.base_feature import InputFeature, OutputFeature from ludwig.features.feature_registries import input_type_registry, output_type_registry +from ludwig.features.feature_utils import LudwigFeatureDict from ludwig.utils.algorithms_utils import topological_sort_feature_dependencies from ludwig.utils.metric_utils import get_scalar_from_ludwig_metric from ludwig.utils.misc_utils import get_from_registry @@ -42,6 +43,9 @@ def __init__(self, random_seed: int = None): super().__init__() + self.input_features = LudwigFeatureDict() + self.output_features = LudwigFeatureDict() + @classmethod def build_inputs(cls, input_features_def: List[Dict[str, Any]]) -> Dict[str, InputFeature]: """Builds and returns input features in topological order.""" @@ -245,7 +249,7 @@ def update_metrics(self, targets, predictions): self.eval_loss_metric.update(eval_loss) self.eval_additional_losses_metrics.update(additional_losses) - def get_metrics(self): + def get_metrics(self) -> Dict[str, Dict[str, float]]: """Returns a dictionary of metrics for each output feature of the model.""" all_of_metrics = {} for of_name, of_obj in self.output_features.items(): @@ -278,11 +282,11 @@ def collect_weights(self, tensor_names=None, **kwargs): return [named_param for named_param in self.named_parameters() if named_param[0] in tensor_set] @abstractmethod - def save(self, save_path): + def save(self, save_path: str): """Saves the model to the given path.""" @abstractmethod - def load(self, save_path): + def load(self, save_path: str): """Loads the model from the given path.""" @abstractmethod diff --git a/ludwig/models/ecd.py b/ludwig/models/ecd.py index d9375a59eec..7169a3e33e4 100644 --- a/ludwig/models/ecd.py +++ b/ludwig/models/ecd.py @@ -9,7 +9,6 @@ from ludwig.combiners.combiners import get_combiner_class from ludwig.constants import MODEL_ECD, TYPE -from ludwig.features.feature_utils import LudwigFeatureDict from ludwig.globals import MODEL_WEIGHTS_FILE_NAME from ludwig.models.base import BaseModel from ludwig.schema.utils import load_config_with_kwargs @@ -42,7 +41,6 @@ def __init__( super().__init__(random_seed=self._random_seed) # ================ Inputs ================ - self.input_features = LudwigFeatureDict() try: self.input_features.update(self.build_inputs(self._input_features_def)) except KeyError as e: @@ -60,7 +58,6 @@ def __init__( self.combiner = combiner_class(input_features=self.input_features, config=config, **kwargs) # ================ Outputs ================ - self.output_features = LudwigFeatureDict() self.output_features.update(self.build_outputs(self._output_features_def, self.combiner)) # ================ Combined loss metric ================ diff --git a/ludwig/models/gbm.py b/ludwig/models/gbm.py index 10c1958b24a..4c49d599d3f 100644 --- a/ludwig/models/gbm.py +++ b/ludwig/models/gbm.py @@ -10,7 +10,6 @@ from ludwig.constants import BINARY, CATEGORY, LOGITS, MODEL_GBM, NAME, NUMBER from ludwig.features.base_feature import OutputFeature -from ludwig.features.feature_utils import LudwigFeatureDict from ludwig.globals import MODEL_WEIGHTS_FILE_NAME from ludwig.models.base import BaseModel from ludwig.utils import output_feature_utils @@ -35,7 +34,6 @@ def __init__( self._output_features_def = copy.deepcopy(output_features) # ================ Inputs ================ - self.input_features = LudwigFeatureDict() try: self.input_features.update(self.build_inputs(self._input_features_def)) except KeyError as e: @@ -44,7 +42,6 @@ def __init__( ) # ================ Outputs ================ - self.output_features = LudwigFeatureDict() self.output_features.update(self.build_outputs(self._output_features_def, input_size=self.input_shape[-1])) # ================ Combined loss metric ================ diff --git a/ludwig/schema/trainer.py b/ludwig/schema/trainer.py index 20022c8c02b..3595f32f8e7 100644 --- a/ludwig/schema/trainer.py +++ b/ludwig/schema/trainer.py @@ -338,8 +338,8 @@ class GBMTrainerConfig(BaseTrainerConfig): parameter_metadata=TRAINER_METADATA["learning_rate"], ) - boosting_round_log_frequency: int = schema_utils.PositiveInteger( - default=10, description="Number of boosting rounds per log of the training progress." + boosting_rounds_per_checkpoint: int = schema_utils.PositiveInteger( + default=10, description="Number of boosting rounds per checkpoint / evaluation round." ) # LightGBM core parameters (https://lightgbm.readthedocs.io/en/latest/Parameters.html) @@ -529,7 +529,7 @@ class GBMTrainerConfig(BaseTrainerConfig): description="Smoothing factor applied to tree nodes in the GBM trainer.", ) - verbose: int = schema_utils.IntegerRange(default=0, min=-1, max=2, description="Verbosity level for GBM trainer.") + verbose: int = schema_utils.IntegerRange(default=-1, min=-1, max=2, description="Verbosity level for GBM trainer.") # LightGBM IO params max_bin: int = schema_utils.PositiveInteger( diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 78a7bb2fa92..aa66b97dfde 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -57,6 +57,7 @@ from ludwig.utils.misc_utils import set_random_seed from ludwig.utils.torch_utils import get_torch_device from ludwig.utils.trainer_utils import ( + append_metrics, get_final_steps_per_checkpoint, get_new_progress_tracker, get_total_steps, @@ -1050,37 +1051,13 @@ def validation_field(self): def validation_metric(self): return self._validation_metric - def append_metrics(self, dataset_name, results, metrics_log, tables, progress_tracker): - epoch = progress_tracker.epoch - steps = progress_tracker.steps - for output_feature in self.model.output_features: - scores = [dataset_name] - - # collect metric names based on output features metrics to - # ensure consistent order of reporting metrics - metric_names = self.model.output_features[output_feature].metric_functions.keys() - - for metric in metric_names: - if metric in results[output_feature]: - # Some metrics may have been excepted and excluded from results. - score = results[output_feature][metric] - metrics_log[output_feature][metric].append(TrainerMetric(epoch=epoch, step=steps, value=score)) - scores.append(score) - - tables[output_feature].append(scores) - - metrics_log[COMBINED][LOSS].append(TrainerMetric(epoch=epoch, step=steps, value=results[COMBINED][LOSS])) - tables[COMBINED].append([dataset_name, results[COMBINED][LOSS]]) - - return metrics_log, tables - def evaluation(self, dataset, dataset_name, metrics_log, tables, batch_size, progress_tracker): predictor = Predictor( self.model, batch_size=batch_size, horovod=self.horovod, report_tqdm_to_ray=self.report_tqdm_to_ray ) metrics, predictions = predictor.batch_evaluation(dataset, collect_predictions=False, dataset_name=dataset_name) - self.append_metrics(dataset_name, metrics, metrics_log, tables, progress_tracker) + append_metrics(self.model, dataset_name, metrics, metrics_log, tables, progress_tracker) return metrics_log, tables diff --git a/ludwig/trainers/trainer_lightgbm.py b/ludwig/trainers/trainer_lightgbm.py index a377886920f..502ecb5cc9d 100644 --- a/ludwig/trainers/trainer_lightgbm.py +++ b/ludwig/trainers/trainer_lightgbm.py @@ -1,8 +1,11 @@ import logging import os +import signal +import sys +import threading import time from collections import OrderedDict -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import lightgbm as lgb import torch @@ -11,10 +14,11 @@ from ludwig.constants import BINARY, CATEGORY, COMBINED, LOSS, MODEL_GBM, NUMBER, TEST, TRAINING, VALIDATION from ludwig.features.feature_utils import LudwigFeatureDict -from ludwig.globals import TRAINING_CHECKPOINTS_DIR_PATH, TRAINING_PROGRESS_TRACKER_FILE_NAME +from ludwig.globals import is_progressbar_disabled, TRAINING_CHECKPOINTS_DIR_PATH, TRAINING_PROGRESS_TRACKER_FILE_NAME from ludwig.models.gbm import GBM from ludwig.models.predictor import Predictor -from ludwig.modules.metric_modules import get_initial_validation_value +from ludwig.modules.metric_modules import get_improved_fun, get_initial_validation_value +from ludwig.progress_bar import LudwigProgressBar from ludwig.schema.trainer import BaseTrainerConfig, GBMTrainerConfig from ludwig.trainers.base import BaseTrainer from ludwig.trainers.registry import register_ray_trainer, register_trainer @@ -22,7 +26,8 @@ from ludwig.utils.checkpoint_utils import Checkpoint, CheckpointManager from ludwig.utils.defaults import default_random_seed from ludwig.utils.metric_utils import get_metric_names, TrainerMetric -from ludwig.utils.trainer_utils import get_new_progress_tracker, ProgressTracker +from ludwig.utils.misc_utils import set_random_seed +from ludwig.utils.trainer_utils import append_metrics, get_new_progress_tracker, ProgressTracker def iter_feature_metrics(features: LudwigFeatureDict) -> Iterable[Tuple[str, str]]: @@ -58,9 +63,11 @@ def __init__( self.random_seed = random_seed self.model = model self.horovod = horovod + self.received_sigint = False self.report_tqdm_to_ray = report_tqdm_to_ray self.callbacks = callbacks or [] self.skip_save_progress = skip_save_progress + self.skip_save_log = skip_save_log self.skip_save_model = skip_save_model self.eval_batch_size = config.eval_batch_size @@ -78,7 +85,7 @@ def __init__( self.boosting_type = config.boosting_type self.tree_learner = config.tree_learner self.num_boost_round = config.num_boost_round - self.boosting_round_log_frequency = config.boosting_round_log_frequency + self.boosting_rounds_per_checkpoint = min(self.num_boost_round, config.boosting_rounds_per_checkpoint) self.max_depth = config.max_depth self.num_leaves = config.num_leaves self.min_data_in_leaf = config.min_data_in_leaf @@ -121,6 +128,12 @@ def __init__( if self.device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" + # when training starts the sigint handler will be replaced with + # set_steps_to_1_or_quit so this is needed to remember + # the original sigint to restore at the end of training + # and before set_steps_to_1_or_quit returns + self.original_sigint_handler = None + @staticmethod def get_schema_cls() -> BaseTrainerConfig: return GBMTrainerConfig @@ -149,30 +162,6 @@ def validation_field(self) -> str: def validation_metric(self) -> str: return self._validation_metric - def append_metrics(self, dataset_name, results, metrics_log, tables, progress_tracker): - epoch = progress_tracker.epoch - steps = progress_tracker.steps - for output_feature in self.model.output_features: - scores = [dataset_name] - - # collect metric names based on output features metrics to - # ensure consistent order of reporting metrics - metric_names = self.model.output_features[output_feature].metric_functions.keys() - - for metric in metric_names: - if metric in results[output_feature]: - # Some metrics may have been excepted and excluded from results. - score = results[output_feature][metric] - metrics_log[output_feature][metric].append(TrainerMetric(epoch=epoch, step=steps, value=score)) - scores.append(score) - - tables[output_feature].append(scores) - - metrics_log[COMBINED][LOSS].append(TrainerMetric(epoch=epoch, step=steps, value=results[COMBINED][LOSS])) - tables[COMBINED].append([dataset_name, results[COMBINED][LOSS]]) - - return metrics_log, tables - def evaluation( self, dataset: "Dataset", # noqa: F821 @@ -187,7 +176,7 @@ def evaluation( ) metrics, predictions = predictor.batch_evaluation(dataset, collect_predictions=False, dataset_name=dataset_name) - self.append_metrics(dataset_name, metrics, metrics_log, tables, progress_tracker) + append_metrics(self.model, dataset_name, metrics, metrics_log, tables, progress_tracker) return metrics_log, tables @@ -281,7 +270,7 @@ def run_evaluation( # eval metrics on validation set self.evaluation( validation_set, - "vali", + VALIDATION, progress_tracker.validation_metrics, tables, self.eval_batch_size, @@ -319,9 +308,27 @@ def run_evaluation( for output_feature, table in tables.items(): logging.info(tabulate(table, headers="firstrow", tablefmt="fancy_grid", floatfmt=".4f")) + # ================ Validation Logic ================ + should_break = False + if validation_set is not None and validation_set.size > 0: + should_break = self.check_progress_on_validation( + progress_tracker, + self.validation_field, + self.validation_metric, + save_path, + self.early_stop, + self.skip_save_model, + ) + else: + # There's no validation, so we save the model. + if self.is_coordinator() and not self.skip_save_model: + self.model.save(save_path) + # Trigger eval end callback after any model weights save for complete checkpoint self.callback(lambda c: c.on_eval_end(self, progress_tracker, save_path)) + return should_break + def _train_loop( self, params: Dict[str, Any], @@ -329,41 +336,121 @@ def _train_loop( eval_sets: List[lgb.Dataset], eval_names: List[str], progress_tracker: ProgressTracker, + progress_bar: LudwigProgressBar, save_path: str, - ) -> lgb.Booster: - name_to_metrics_log = { - LightGBMTrainer.TRAIN_KEY: progress_tracker.train_metrics, - LightGBMTrainer.VALID_KEY: progress_tracker.validation_metrics, - LightGBMTrainer.TEST_KEY: progress_tracker.test_metrics, - } - tables = OrderedDict() + training_set: Union["Dataset", "RayDataset"], # noqa: F821 + validation_set: Union["Dataset", "RayDataset"], # noqa: F821 + test_set: Union["Dataset", "RayDataset"], # noqa: F821 + train_summary_writer: SummaryWriter, + validation_summary_writer: SummaryWriter, + test_summary_writer: SummaryWriter, + ) -> bool: + self.callback(lambda c: c.on_batch_start(self, progress_tracker, save_path)) + + booster = None + evals_result = {} + booster = self.train_step( + params, lgb_train, eval_sets, eval_names, booster, self.boosting_rounds_per_checkpoint, evals_result + ) + + progress_bar.update(self.boosting_rounds_per_checkpoint) + progress_tracker.steps += self.boosting_rounds_per_checkpoint + progress_tracker.last_improvement_steps = booster.best_iteration + + # convert to pytorch for inference + self.model.lgb_booster = booster + self.model.compile() + self.model = self.model.to(self.device) + output_features = self.model.output_features metrics_names = get_metric_names(output_features) - for output_feature_name, output_feature in output_features.items(): - tables[output_feature_name] = [[output_feature_name] + metrics_names[output_feature_name]] - tables[COMBINED] = [[COMBINED, LOSS]] - booster = None + output_feature_name = next(iter(output_features)) + + loss_name = params["metric"][0] + loss = evals_result["train"][loss_name][-1] + loss = torch.tensor(loss, dtype=torch.float32) + + should_break = self.run_evaluation( + training_set, + validation_set, + test_set, + progress_tracker, + train_summary_writer, + validation_summary_writer, + test_summary_writer, + output_features, + metrics_names, + save_path, + loss, + {output_feature_name: loss}, + ) + + self.callback(lambda c: c.on_batch_end(self, progress_tracker, save_path)) - for epoch, steps in enumerate(range(0, self.num_boost_round, self.boosting_round_log_frequency), start=1): - progress_tracker.epoch = epoch + return should_break - evals_result = {} - booster = self.train_step( - params, lgb_train, eval_sets, eval_names, booster, self.boosting_round_log_frequency, evals_result - ) + def check_progress_on_validation( + self, + progress_tracker: ProgressTracker, + validation_output_feature_name: str, + validation_metric: str, + save_path: str, + early_stopping_steps: int, + skip_save_model: bool, + ) -> bool: + """Checks the history of validation scores. + + Uses history of validation scores to decide whether training + should stop. + + Saves the model if scores have improved. + """ + should_break = False + # record how long its been since an improvement + improved = get_improved_fun(validation_metric) + validation_metrics = progress_tracker.validation_metrics[validation_output_feature_name] + last_validation_metric = validation_metrics[validation_metric][-1] + last_validation_metric_value = last_validation_metric[-1] + + if improved(last_validation_metric_value, progress_tracker.best_eval_metric): + progress_tracker.last_improvement_steps = progress_tracker.steps + progress_tracker.best_eval_metric = last_validation_metric_value + + if self.is_coordinator() and not skip_save_model: + self.model.save(save_path) + logging.info( + f"Validation {validation_metric} on {validation_output_feature_name} improved, model saved.\n" + ) - progress_tracker.steps = steps + self.boosting_round_log_frequency - # log training progress - of_name = self.model.output_features.keys()[0] - for data_name in eval_names: - loss_name = params["metric"][0] - loss = evals_result[data_name][loss_name][-1] - metrics = {of_name: {"Survived": {LOSS: loss}}, COMBINED: {LOSS: loss}} - self.append_metrics(data_name, metrics, name_to_metrics_log[data_name], tables, progress_tracker) - self.callback(lambda c: c.on_eval_end(self, progress_tracker, save_path)) - self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path)) + progress_tracker.last_improvement = progress_tracker.steps - progress_tracker.last_improvement_steps + if progress_tracker.last_improvement != 0 and self.is_coordinator(): + logging.info( + f"Last improvement of {validation_output_feature_name} validation {validation_metric} happened " + + f"{progress_tracker.last_improvement} step(s) ago.\n" + ) - return booster + # ========== Early Stop logic ========== + # If any early stopping condition is satisfied, either lack of improvement for many steps, or via callbacks on + # any worker, then trigger early stopping. + early_stop_bool = 0 < early_stopping_steps <= progress_tracker.last_improvement + if not early_stop_bool: + for callback in self.callbacks: + if callback.should_early_stop(self, progress_tracker, self.is_coordinator()): + early_stop_bool = True + break + + should_early_stop = torch.as_tensor([early_stop_bool], dtype=torch.int) + if self.horovod: + should_early_stop = self.horovod.allreduce(should_early_stop) + if should_early_stop.item(): + if self.is_coordinator(): + logging.info( + "\nEARLY STOPPING due to lack of validation improvement. " + f"It has been {progress_tracker.steps - progress_tracker.last_improvement_steps} step(s) since " + f"last validation improvement.\n" + ) + should_break = True + return should_break def train_step( self, @@ -372,7 +459,7 @@ def train_step( eval_sets: List[lgb.Dataset], eval_names: List[str], booster: lgb.Booster, - steps_per_epoch: int, + boost_rounds_per_train_step: int, evals_result: Dict, ) -> lgb.Booster: """Trains a LightGBM model. @@ -390,17 +477,13 @@ def train_step( params, lgb_train, init_model=booster, - num_boost_round=steps_per_epoch, + num_boost_round=boost_rounds_per_train_step, valid_sets=eval_sets, valid_names=eval_names, feature_name=list(self.model.input_features.keys()), # NOTE: hummingbird does not support categorical features # categorical_feature=categorical_features, evals_result=evals_result, - callbacks=[ - lgb.early_stopping(stopping_rounds=self.early_stop), - lgb.log_evaluation(), - ], ) return gbm @@ -413,91 +496,169 @@ def train( save_path="model", **kwargs, ): + # ====== General setup ======= + output_features = self.model.output_features + + # Only use signals when on the main thread to avoid issues with CherryPy + # https://github.com/ludwig-ai/ludwig/issues/286 + if threading.current_thread() == threading.main_thread(): + # set the original sigint signal handler + # as we want to restore it at the end of training + self.original_sigint_handler = signal.getsignal(signal.SIGINT) + signal.signal(signal.SIGINT, self.set_steps_to_1_or_quit) + # TODO: construct new datasets by running encoders (for text, image) # TODO: only single task currently - if len(self.model.output_features) > 1: + if len(output_features) > 1: raise ValueError("Only single task currently supported") + metrics_names = get_metric_names(output_features) + + # check if validation_field is valid + valid_validation_field = False + if self.validation_field == "combined": + valid_validation_field = True + if self.validation_metric is not LOSS and len(output_features) == 1: + only_of = next(iter(output_features)) + if self.validation_metric in metrics_names[only_of]: + self._validation_field = only_of + logging.warning( + "Replacing 'combined' validation field " + "with '{}' as the specified validation " + "metric {} is invalid for 'combined' " + "but is valid for '{}'.".format(only_of, self.validation_metric, only_of) + ) + else: + for output_feature in output_features: + if self.validation_field == output_feature: + valid_validation_field = True + + if not valid_validation_field: + raise ValueError( + "The specified validation_field {} is not valid." + "Available ones are: {}".format(self.validation_field, list(output_features.keys()) + ["combined"]) + ) + + # check if validation_metric is valid + valid_validation_metric = self.validation_metric in metrics_names[self.validation_field] + if not valid_validation_metric: + raise ValueError( + "The specified metric {} is not valid. " + "Available metrics for {} output feature are: {}".format( + self.validation_metric, self.validation_field, metrics_names[self.validation_field] + ) + ) + + # ====== Setup file names ======= + training_checkpoints_path = None + tensorboard_log_dir = None + if self.is_coordinator(): + os.makedirs(save_path, exist_ok=True) + training_checkpoints_path = os.path.join(save_path, TRAINING_CHECKPOINTS_DIR_PATH) + tensorboard_log_dir = os.path.join(save_path, "logs") + self.callback( lambda c: c.on_trainer_train_setup(self, save_path, self.is_coordinator()), coordinator_only=False ) + # ====== Setup session ======= + checkpoint = checkpoint_manager = None + if self.is_coordinator() and not self.skip_save_progress: + checkpoint = Checkpoint(model=self.model) + checkpoint_manager = CheckpointManager(checkpoint, training_checkpoints_path, device=self.device) + + train_summary_writer = None + validation_summary_writer = None + test_summary_writer = None + if self.is_coordinator() and not self.skip_save_log and tensorboard_log_dir: + train_summary_writer = SummaryWriter(os.path.join(tensorboard_log_dir, TRAINING)) + if validation_set is not None and validation_set.size > 0: + validation_summary_writer = SummaryWriter(os.path.join(tensorboard_log_dir, VALIDATION)) + if test_set is not None and test_set.size > 0: + test_summary_writer = SummaryWriter(os.path.join(tensorboard_log_dir, TEST)) + progress_tracker = get_new_progress_tracker( batch_size=-1, learning_rate=self.base_learning_rate, best_eval_metric=get_initial_validation_value(self.validation_metric), best_reduce_learning_rate_eval_metric=float("inf"), best_increase_batch_size_eval_metric=float("inf"), - output_features=self.model.output_features, + output_features=output_features, ) - params = self._construct_lgb_params() - - lgb_train, eval_sets, eval_names = self._construct_lgb_datasets(training_set, validation_set, test_set) - - # epoch init - start_time = time.time() - - # Reset the metrics at the start of the next epoch - self.model.reset_metrics() - - self.callback(lambda c: c.on_epoch_start(self, progress_tracker, save_path)) - self.callback(lambda c: c.on_batch_start(self, progress_tracker, save_path)) - - gbm = self._train_loop(params, lgb_train, eval_sets, eval_names, progress_tracker, save_path) - - self.callback(lambda c: c.on_batch_end(self, progress_tracker, save_path)) - # ================ Post Training Epoch ================ - progress_tracker.steps = gbm.current_iteration() - progress_tracker.last_improvement_steps = gbm.best_iteration - self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path)) - - if self.is_coordinator(): - # ========== Save training progress ========== - logging.debug( - f"Epoch {progress_tracker.epoch} took: {time_utils.strdelta((time.time()- start_time) * 1000.0)}." - ) - if not self.skip_save_progress: - progress_tracker.save(os.path.join(save_path, TRAINING_PROGRESS_TRACKER_FILE_NAME)) - - # convert to pytorch for inference, fine tuning - self.model.lgb_booster = gbm - self.model.compile() - self.model = self.model.to(self.device) + set_random_seed(self.random_seed) - # evaluate - train_summary_writer = None - validation_summary_writer = None - test_summary_writer = None try: - os.makedirs(save_path, exist_ok=True) - tensorboard_log_dir = os.path.join(save_path, "logs") + params = self._construct_lgb_params() - train_summary_writer = SummaryWriter(os.path.join(tensorboard_log_dir, TRAINING)) - if validation_set is not None and validation_set.size > 0: - validation_summary_writer = SummaryWriter(os.path.join(tensorboard_log_dir, VALIDATION)) - if test_set is not None and test_set.size > 0: - test_summary_writer = SummaryWriter(os.path.join(tensorboard_log_dir, TEST)) + lgb_train, eval_sets, eval_names = self._construct_lgb_datasets(training_set, validation_set, test_set) - output_features = self.model.output_features - metrics_names = get_metric_names(output_features) + # use separate total steps variable to allow custom SIGINT logic + self.total_steps = self.num_boost_round - self.run_evaluation( - training_set, - validation_set, - test_set, - progress_tracker, - train_summary_writer, - validation_summary_writer, - test_summary_writer, - output_features, - metrics_names, - save_path, - None, - None, - ) + if self.is_coordinator(): + logging.info( + f"Training for {self.total_steps} boosting round(s), approximately " + f"{int(self.total_steps / self.boosting_rounds_per_checkpoint)} round(s) of evaluation." + ) + logging.info(f"Early stopping policy: {self.early_stop} boosting round(s).\n") + + logging.info(f"Starting with step {progress_tracker.steps}") + + progress_bar_config = { + "desc": "Training", + "total": self.total_steps, + "disable": is_progressbar_disabled(), + "file": sys.stdout, + } + progress_bar = LudwigProgressBar(self.report_tqdm_to_ray, progress_bar_config, self.is_coordinator()) + + while progress_tracker.steps < self.total_steps: + # epoch init + start_time = time.time() + + # Reset the metrics at the start of the next epoch + self.model.reset_metrics() + + self.callback(lambda c: c.on_epoch_start(self, progress_tracker, save_path)) + + should_break = self._train_loop( + params, + lgb_train, + eval_sets, + eval_names, + progress_tracker, + progress_bar, + save_path, + training_set, + validation_set, + test_set, + train_summary_writer, + validation_summary_writer, + test_summary_writer, + ) + + # ================ Post Training Epoch ================ + progress_tracker.epoch += 1 + self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path)) + + if self.is_coordinator(): + # ========== Save training progress ========== + logging.debug( + f"Epoch {progress_tracker.epoch} took: " + f"{time_utils.strdelta((time.time()- start_time) * 1000.0)}." + ) + if not self.skip_save_progress: + checkpoint_manager.checkpoint.model = self.model + checkpoint_manager.save(progress_tracker.steps) + progress_tracker.save(os.path.join(save_path, TRAINING_PROGRESS_TRACKER_FILE_NAME)) + + # Early stop if needed. + if should_break: + break finally: + # ================ Finished Training ================ self.callback( lambda c: c.on_trainer_train_teardown(self, progress_tracker, save_path, self.is_coordinator()), coordinator_only=False, @@ -510,8 +671,16 @@ def train( if test_summary_writer is not None: test_summary_writer.close() + if self.is_coordinator() and not self.skip_save_progress: + checkpoint_manager.close() + + # Load the best weights from saved checkpoint if self.is_coordinator() and not self.skip_save_model: - self._save(save_path) + self.model.load(save_path) + + # restore original sigint signal handler + if self.original_sigint_handler and threading.current_thread() == threading.main_thread(): + signal.signal(signal.SIGINT, self.original_sigint_handler) return ( self.model, @@ -520,30 +689,46 @@ def train( progress_tracker.test_metrics, ) + def set_steps_to_1_or_quit(self, signum, frame): + """Custom SIGINT handler used to elegantly exit training. + + A single SIGINT will stop training after the next training step. A second SIGINT will stop training immediately. + """ + if not self.received_sigint: + self.total_steps = 1 + self.received_sigint = True + logging.critical("\nReceived SIGINT, will finish this training step and then conclude training.") + logging.critical("Send another SIGINT to immediately interrupt the process.") + else: + logging.critical("\nReceived a second SIGINT, will now quit") + if self.original_sigint_handler: + signal.signal(signal.SIGINT, self.original_sigint_handler) + sys.exit(1) + def _construct_lgb_params(self) -> Tuple[dict, dict]: output_params = {} - for feature in self.model.output_features.values(): - if feature.type() == CATEGORY: - output_params = { - "objective": "multiclass", - "metric": ["multi_logloss"], - "num_class": feature.num_classes, - } - elif feature.type() == BINARY: - output_params = { - "objective": "binary", - "metric": ["binary_logloss"], - } - elif feature.type() == NUMBER: - output_params = { - "objective": "regression", - "metric": ["l2", "l1"], - } - else: - raise ValueError( - f"Model type GBM only supports numerical, categorical, or binary output features," - f" found: {feature.type()}" - ) + feature = next(iter(self.model.output_features.values())) + if feature.type() == CATEGORY: + output_params = { + "objective": "multiclass", + "metric": ["multi_logloss"], + "num_class": feature.num_classes, + } + elif feature.type() == BINARY: + output_params = { + "objective": "binary", + "metric": ["binary_logloss"], + } + elif feature.type() == NUMBER: + output_params = { + "objective": "regression", + "metric": ["l2", "l1"], + } + else: + raise ValueError( + f"Model type GBM only supports numerical, categorical, or binary output features," + f" found: {feature.type()}" + ) # from: https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/advanced_example.py params = { @@ -627,16 +812,6 @@ def _construct_lgb_datasets( return lgb_train, eval_sets, eval_names - def _save(self, save_path: str): - os.makedirs(save_path, exist_ok=True) - training_checkpoints_path = os.path.join(save_path, TRAINING_CHECKPOINTS_DIR_PATH) - checkpoint = Checkpoint(model=self.model) - checkpoint_manager = CheckpointManager(checkpoint, training_checkpoints_path, device=self.device) - checkpoint_manager.save(1) - checkpoint_manager.close() - - self.model.save(save_path) - def is_coordinator(self) -> bool: if not self.horovod: return True @@ -648,26 +823,6 @@ def callback(self, fn, coordinator_only=True): fn(callback) -def log_eval_distributed(period: int = 1, show_stdv: bool = True) -> Callable: - from lightgbm_ray.tune import _TuneLGBMRank0Mixin - - class LogEvalDistributed(_TuneLGBMRank0Mixin): - def __init__(self, period: int, show_stdv: bool = True): - self.period = period - self.show_stdv = show_stdv - - def __call__(self, env: lgb.callback.CallbackEnv): - if not self.is_rank_0: - return - if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0: - result = "\t".join( - [lgb.callback._format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list] - ) - lgb.callback._log_info(f"[{env.iteration + 1}]\t{result}") - - return LogEvalDistributed(period=period, show_stdv=show_stdv) - - def _map_to_lgb_ray_params(params: Dict[str, Any]) -> Dict[str, Any]: from lightgbm_ray import RayParams @@ -705,16 +860,16 @@ def __init__( **kwargs, ): super().__init__( - config, - model, - resume, - skip_save_model, - skip_save_progress, - skip_save_log, - callbacks, - random_seed, - horovod, - device, + config=config, + model=model, + resume=resume, + skip_save_model=skip_save_model, + skip_save_progress=skip_save_progress, + skip_save_log=skip_save_log, + callbacks=callbacks, + random_seed=random_seed, + horovod=horovod, + device=device, **kwargs, ) @@ -733,7 +888,7 @@ def train_step( eval_sets: List["RayDMatrix"], # noqa: F821 eval_names: List[str], booster: lgb.Booster, - steps_per_epoch: int, + boost_rounds_per_train_step: int, evals_result: Dict, ) -> lgb.Booster: """Trains a LightGBM model using ray. @@ -753,44 +908,18 @@ def train_step( params, lgb_train, init_model=booster, - num_boost_round=steps_per_epoch, + num_boost_round=boost_rounds_per_train_step, valid_sets=eval_sets, valid_names=eval_names, feature_name=list(self.model.input_features.keys()), evals_result=evals_result, # NOTE: hummingbird does not support categorical features # categorical_feature=categorical_features, - callbacks=[ - lgb.early_stopping(stopping_rounds=self.early_stop), - log_eval_distributed(10), - ], ray_params=_map_to_lgb_ray_params(self.trainer_kwargs), ) return gbm.booster_ - def evaluation(self, dataset, dataset_name, metrics_log, tables, batch_size, progress_tracker): - from ludwig.backend.ray import _get_df_engine, RayPredictor - - predictor_kwargs = self.executable_kwargs.copy() - if "callbacks" in predictor_kwargs: - # remove unused (non-serializable) callbacks - del predictor_kwargs["callbacks"] - - predictor = RayPredictor( - model=self.model, - df_engine=_get_df_engine(None), - trainer_kwargs=self.trainer_kwargs, - data_loader_kwargs=self.data_loader_kwargs, - batch_size=batch_size, - **predictor_kwargs, - ) - metrics, _ = predictor.batch_evaluation(dataset, collect_predictions=False, dataset_name=dataset_name) - - self.append_metrics(dataset_name, metrics, metrics_log, tables, progress_tracker) - - return metrics_log, tables - def _construct_lgb_datasets( self, training_set: "RayDataset", # noqa: F821 @@ -801,7 +930,8 @@ def _construct_lgb_datasets( from lightgbm_ray import RayDMatrix - label_col = self.model.output_features.values()[0].proc_column + output_feature = next(iter(self.model.output_features.values())) + label_col = output_feature.proc_column in_feat = [f.proc_column for f in self.model.input_features.values()] out_feat = [f.proc_column for f in self.model.output_features.values()] diff --git a/ludwig/utils/trainer_utils.py b/ludwig/utils/trainer_utils.py index 8f76392d2b4..945b98a2523 100644 --- a/ludwig/utils/trainer_utils.py +++ b/ludwig/utils/trainer_utils.py @@ -2,8 +2,14 @@ from collections import OrderedDict from typing import Dict, List, Tuple +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + from ludwig.constants import COMBINED, LOSS from ludwig.features.base_feature import OutputFeature +from ludwig.models.base import BaseModel from ludwig.modules.metric_modules import get_best_function from ludwig.utils.data_utils import load_json, save_json from ludwig.utils.metric_utils import TrainerMetric @@ -154,6 +160,38 @@ def log_metrics(self): return log_metrics +def append_metrics( + model: BaseModel, + dataset_name: Literal["train", "validation", "test"], + results: Dict[str, Dict[str, float]], + metrics_log: Dict[str, Dict[str, List[TrainerMetric]]], + tables: Dict[str, List[List[str]]], + progress_tracker: ProgressTracker, +) -> Tuple[Dict[str, Dict[str, List[TrainerMetric]]], Dict[str, List[List[str]]]]: + epoch = progress_tracker.epoch + steps = progress_tracker.steps + for output_feature in model.output_features: + scores = [dataset_name] + + # collect metric names based on output features metrics to + # ensure consistent order of reporting metrics + metric_names = model.output_features[output_feature].metric_functions.keys() + + for metric in metric_names: + if metric in results[output_feature]: + # Some metrics may have been excepted and excluded from results. + score = results[output_feature][metric] + metrics_log[output_feature][metric].append(TrainerMetric(epoch=epoch, step=steps, value=score)) + scores.append(score) + + tables[output_feature].append(scores) + + metrics_log[COMBINED][LOSS].append(TrainerMetric(epoch=epoch, step=steps, value=results[COMBINED][LOSS])) + tables[COMBINED].append([dataset_name, results[COMBINED][LOSS]]) + + return metrics_log, tables + + def get_total_steps(epochs: int, steps_per_epoch: int, train_steps: int): """Returns train_steps if non-negative. diff --git a/tests/conftest.py b/tests/conftest.py index ec92d8afa83..dba64e08185 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -114,8 +114,8 @@ def ray_cluster_2cpu(request): @pytest.fixture(scope="module") -def ray_cluster_3cpu(request): - with _ray_start(request, num_cpus=3): +def ray_cluster_4cpu(request): + with _ray_start(request, num_cpus=4): yield diff --git a/tests/integration_tests/test_gbm.py b/tests/integration_tests/test_gbm.py index 6e7d37f70c8..f7914898755 100644 --- a/tests/integration_tests/test_gbm.py +++ b/tests/integration_tests/test_gbm.py @@ -21,7 +21,7 @@ def local_backend(): @pytest.fixture(scope="module") def ray_backend(): num_workers = 2 - num_cpus_per_worker = 1 + num_cpus_per_worker = 2 return { "type": "ray", "processor": { @@ -60,7 +60,7 @@ def test_local_gbm_output_not_supported(tmpdir, local_backend): @pytest.mark.distributed -def test_ray_gbm_output_not_supported(tmpdir, ray_backend, ray_cluster_3cpu): +def test_ray_gbm_output_not_supported(tmpdir, ray_backend, ray_cluster_4cpu): run_test_gbm_output_not_supported(tmpdir, ray_backend) @@ -93,7 +93,7 @@ def test_local_gbm_multiple_outputs(tmpdir, local_backend): @pytest.mark.distributed -def test_ray_gbm_multiple_outputs(tmpdir, ray_backend, ray_cluster_3cpu): +def test_ray_gbm_multiple_outputs(tmpdir, ray_backend, ray_cluster_4cpu): run_test_gbm_multiple_outputs(tmpdir, ray_backend) @@ -137,7 +137,7 @@ def test_local_gbm_binary(tmpdir, local_backend): @pytest.mark.distributed -def test_ray_gbm_binary(tmpdir, ray_backend, ray_cluster_3cpu): +def test_ray_gbm_binary(tmpdir, ray_backend, ray_cluster_4cpu): run_test_gbm_binary(tmpdir, ray_backend) @@ -181,7 +181,7 @@ def test_local_gbm_non_number_inputs(tmpdir, local_backend): @pytest.mark.distributed -def test_ray_gbm_non_number_inputs(tmpdir, ray_backend, ray_cluster_3cpu): +def test_ray_gbm_non_number_inputs(tmpdir, ray_backend, ray_cluster_4cpu): run_test_gbm_non_number_inputs(tmpdir, ray_backend) @@ -227,7 +227,7 @@ def test_local_gbm_category(tmpdir, local_backend): @pytest.mark.distributed -def test_ray_gbm_category(tmpdir, ray_backend, ray_cluster_3cpu): +def test_ray_gbm_category(tmpdir, ray_backend, ray_cluster_4cpu): run_test_gbm_category(tmpdir, ray_backend) @@ -278,7 +278,7 @@ def test_local_gbm_number(tmpdir, local_backend): @pytest.mark.distributed -def test_ray_gbm_number(tmpdir, ray_backend, ray_cluster_3cpu): +def test_ray_gbm_number(tmpdir, ray_backend, ray_cluster_4cpu): run_test_gbm_number(tmpdir, ray_backend) @@ -307,5 +307,5 @@ def test_local_gbm_schema(local_backend): @pytest.mark.distributed -def test_ray_gbm_schema(ray_backend, ray_cluster_3cpu): +def test_ray_gbm_schema(ray_backend, ray_cluster_4cpu): run_test_gbm_schema(ray_backend)