diff --git a/ludwig/contribs/mlflow/__init__.py b/ludwig/contribs/mlflow/__init__.py index fa1f99c384d..a966a0f5045 100644 --- a/ludwig/contribs/mlflow/__init__.py +++ b/ludwig/contribs/mlflow/__init__.py @@ -161,15 +161,26 @@ def on_trainer_train_setup(self, trainer, save_path, is_coordinator): def on_eval_end(self, trainer, progress_tracker, save_path): if progress_tracker.steps not in self.logged_steps: self.logged_steps.add(progress_tracker.steps) - self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, True)) # Why True? + # Adds a tuple to the logging queue. + # True is passed to indicate that the background saving loop should continue. + self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, True)) def on_trainer_train_teardown(self, trainer, progress_tracker, save_path, is_coordinator): if is_coordinator: if progress_tracker.steps not in self.logged_steps: self.logged_steps.add(progress_tracker.steps) - self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, False)) # Why False? - if self.save_thread is not None: - self.save_thread.join() + # Adds a tuple to the logging queue. + # False is passed to indicate that the background saving loop should break. + self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, False)) + # False ensures that the background saving loop breaks. + # TODO(Justin): This should probably live in on_ludwig_end, once that's implemented. + self.save_fn((None, None, None, False)) + + # Close the save_thread. + if self.save_thread is not None: + self.save_thread.join() + # if self.save_thread.is_alive(): + # logger.warning("MLFlow save thread timed out and did not close properly.") def on_visualize_figure(self, fig): # TODO: need to also include a filename for this figure @@ -205,10 +216,15 @@ def __setstate__(self, d): def _log_mlflow_loop(q: queue.Queue, log_artifacts: bool = True): + """The save_fn for the background thread that logs to MLFlow when save_in_background is True.""" should_continue = True while should_continue: elem = q.get() log_metrics, steps, save_path, should_continue = elem + if log_metrics is None: + # Break out of the loop if we're not going to log anything. + break + mlflow.log_metrics(log_metrics, step=steps) if not q.empty(): @@ -221,9 +237,14 @@ def _log_mlflow_loop(q: queue.Queue, log_artifacts: bool = True): def _log_mlflow(log_metrics, steps, save_path, should_continue, log_artifacts: bool = True): - mlflow.log_metrics(log_metrics, step=steps) - if log_artifacts: - _log_model(save_path) + """The save_fn for the MlflowCallback. + + This is used when save_in_background is False. + """ + if log_metrics is not None: + mlflow.log_metrics(log_metrics, step=steps) + if log_artifacts: + _log_model(save_path) def _log_artifacts(output_directory): diff --git a/tests/integration_tests/test_contrib_aim.py b/tests/integration_tests/test_contrib_aim.py index 9a62d5db7eb..11a65fe3f2b 100644 --- a/tests/integration_tests/test_contrib_aim.py +++ b/tests/integration_tests/test_contrib_aim.py @@ -12,6 +12,7 @@ TEST_SCRIPT = os.path.join(os.path.dirname(__file__), "scripts", "run_train_aim.py") +@pytest.mark.skip(reason="Aim integration not compatible with Aim 4.0.") @pytest.mark.distributed def test_contrib_experiment(csv_filename, tmpdir): aim_test_path = os.path.join(tmpdir, "results") diff --git a/tests/integration_tests/test_torchscript.py b/tests/integration_tests/test_torchscript.py index 8e0b1d1fe73..b62f078c6ed 100644 --- a/tests/integration_tests/test_torchscript.py +++ b/tests/integration_tests/test_torchscript.py @@ -25,7 +25,7 @@ from ludwig.api import LudwigModel from ludwig.backend import RAY -from ludwig.constants import BATCH_SIZE, COMBINER, LOGITS, NAME, PREDICTIONS, PROBABILITIES, TRAINER +from ludwig.constants import BATCH_SIZE, COMBINER, EVAL_BATCH_SIZE, LOGITS, NAME, PREDICTIONS, PROBABILITIES, TRAINER from ludwig.data.preprocessing import preprocess_for_prediction from ludwig.features.number_feature import numeric_transformation_registry from ludwig.globals import TRAIN_SET_METADATA_FILE_NAME @@ -415,7 +415,7 @@ def test_torchscript_e2e_text_hf_tokenizer(tmpdir, csv_filename): config = { "input_features": input_features, "output_features": output_features, - TRAINER: {"epochs": 2, BATCH_SIZE: 128}, + TRAINER: {"epochs": 2, BATCH_SIZE: 128, EVAL_BATCH_SIZE: 128}, } training_data_csv_path = generate_data(input_features, output_features, data_csv_path)