diff --git a/dagster_ray/kuberay/client/rayjob/client.py b/dagster_ray/kuberay/client/rayjob/client.py index d553c24..1ca8d85 100644 --- a/dagster_ray/kuberay/client/rayjob/client.py +++ b/dagster_ray/kuberay/client/rayjob/client.py @@ -62,6 +62,8 @@ def wait_until_running( namespace: str, timeout: int = 600, poll_interval: int = 5, + terminate_on_timeout: bool = True, + port_forward: bool = False, ) -> bool: start_time = time.time() @@ -74,9 +76,17 @@ def wait_until_running( raise RuntimeError(f"RayJob {namespace}/{name} deployment failed. Status:\n{status}") if time.time() - start_time > timeout: - raise TimeoutError( - f"Timed out waiting for RayJob {namespace}/{name} deployment to become available. Status:\n{status}" - ) + if terminate_on_timeout: + logger.warning(f"Terminating RayJob {namespace}/{name} because of timeout {timeout}s") + try: + self.terminate(name, namespace, port_forward=port_forward) + except Exception as e: + logger.warning( + f"Failed to gracefully terminate RayJob {namespace}/{name}: {e}, will delete it instead." + ) + self.delete(name, namespace) + + raise TimeoutError(f"Timed out waiting for RayJob {namespace}/{name} to start. Status:\n{status}") time.sleep(poll_interval) diff --git a/dagster_ray/kuberay/pipes.py b/dagster_ray/kuberay/pipes.py index d5eee69..ea2ecf3 100644 --- a/dagster_ray/kuberay/pipes.py +++ b/dagster_ray/kuberay/pipes.py @@ -35,7 +35,8 @@ class PipesKubeRayJobClient(PipesClient, TreatAsResourceParam): message_reader (Optional[PipesMessageReader]): A message reader to use to read messages from the glue job run. Defaults to :py:class:`PipesRayJobMessageReader`. client (Optional[boto3.client]): The Kubernetes API client. - forward_termination (bool): Whether to cancel the `RayJob` job run when the Dagster process receives a termination signal. + forward_termination (bool): Whether to terminate the Ray job when the Dagster process receives a termination signal, + or if the startup timeout is reached. Defaults to ``True``. timeout (int): Timeout for various internal interactions with the Kubernetes RayJob. poll_interval (int): Interval at which to poll the Kubernetes for status updates. port_forward (bool): Whether to use Kubernetes port-forwarding to connect to the KubeRay cluster. @@ -169,6 +170,8 @@ def _start(self, context: OpExecutionContext, ray_job: Dict[str, Any]) -> Dict[s namespace=namespace, timeout=self.timeout, poll_interval=self.poll_interval, + terminate_on_timeout=self.forward_termination, + port_forward=self.port_forward, ) return self.client.get(