From c1ea13d32122a4c7b364b8d79c6ab31229931d33 Mon Sep 17 00:00:00 2001 From: GonzaloSaez <11050889+GonzaloSaez@users.noreply.github.com> Date: Fri, 10 Jan 2025 22:58:30 +0000 Subject: [PATCH] Do not create the launcher job if the job starts suspended --- pkg/controller/mpi_job_controller.go | 4 +- test/integration/mpi_job_controller_test.go | 54 ++++++++++++++++++++- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/pkg/controller/mpi_job_controller.go b/pkg/controller/mpi_job_controller.go index 4fcca25f..909b7802 100644 --- a/pkg/controller/mpi_job_controller.go +++ b/pkg/controller/mpi_job_controller.go @@ -658,7 +658,9 @@ func (c *MPIJobController) syncHandler(key string) error { return err } } - if launcher == nil { + // If the job is suspended, the list of worker pods will be incorrect. We also do + // not want to start the launcher job if the MPIJob starts suspended. + if launcher == nil && !isMPIJobSuspended(mpiJob) { if mpiJob.Spec.LauncherCreationPolicy == kubeflow.LauncherCreationPolicyAtStartup || c.countReadyWorkerPods(worker) == len(worker) { launcher, err = c.kubeClient.BatchV1().Jobs(namespace).Create(context.TODO(), c.newLauncherJob(mpiJob), metav1.CreateOptions{}) if err != nil { diff --git a/test/integration/mpi_job_controller_test.go b/test/integration/mpi_job_controller_test.go index f260d829..2a2a3606 100644 --- a/test/integration/mpi_job_controller_test.go +++ b/test/integration/mpi_job_controller_test.go @@ -173,6 +173,27 @@ func TestMPIJobSuccess(t *testing.T) { } func TestMPIJobWaitWorkers(t *testing.T) { + testcases := []struct { + name string + startSuspended bool + }{ + { + name: "don't start suspended", + startSuspended: false, + }, + { + name: "start suspended", + startSuspended: true, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + testMpiJobWaitWorkers(t, tc.startSuspended) + }) + } +} + +func testMpiJobWaitWorkers(t *testing.T, startSuspended bool) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) s := newTestSetup(ctx, t) @@ -187,6 +208,7 @@ func TestMPIJobWaitWorkers(t *testing.T) { SlotsPerWorker: ptr.To[int32](1), LauncherCreationPolicy: "WaitForWorkersReady", RunPolicy: kubeflow.RunPolicy{ + Suspend: ptr.To(startSuspended), CleanPodPolicy: ptr.To(kubeflow.CleanPodPolicyRunning), }, MPIReplicaSpecs: map[kubeflow.MPIReplicaType]*kubeflow.ReplicaSpec{ @@ -237,9 +259,37 @@ func TestMPIJobWaitWorkers(t *testing.T) { } s.events.verify(t) - workerPods, err := getPodsForJob(ctx, s.kClient, mpiJob) + // The launcher job should not be created until all workers are ready even when we start in suspended mode. + job, err := getLauncherJobForMPIJob(ctx, s.kClient, mpiJob) if err != nil { - t.Fatalf("Cannot get worker pods from job: %v", err) + t.Fatalf("Cannot get launcher job from job: %v", err) + } + if job != nil { + t.Fatalf("Launcher is created before workers") + } + + if startSuspended { + // Resume the MPIJob so that the test can follow the normal path. + mpiJob.Spec.RunPolicy.Suspend = ptr.To(false) + mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(mpiJob.Namespace).Update(ctx, mpiJob, metav1.UpdateOptions{}) + if err != nil { + t.Fatalf("Error Updating MPIJob: %v", err) + } + } + + var workerPods []corev1.Pod + if err = wait.PollUntilContextTimeout(ctx, util.WaitInterval, wait.ForeverTestTimeout, false, func(ctx context.Context) (bool, error) { + var err error + workerPods, err = getPodsForJob(ctx, s.kClient, mpiJob) + if err != nil { + return false, err + } + if len(workerPods) != 2 { + return false, nil + } + return true, nil + }); err != nil { + t.Errorf("Failed updating scheduler-plugins PodGroup: %v", err) } err = updatePodsToPhase(ctx, s.kClient, workerPods, corev1.PodRunning)