Skip to content

Commit

Permalink
✨ terminate/delete RayJob on deployment timeout (#53)
Browse files Browse the repository at this point in the history
* add job termination on deployment timeout
  • Loading branch information
danielgafni authored Nov 19, 2024
1 parent f08e7bd commit a590efd
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 25 deletions.
11 changes: 7 additions & 4 deletions dagster_ray/kuberay/client/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar

if TYPE_CHECKING:
from kubernetes import client
Expand All @@ -18,7 +18,10 @@ def load_kubeconfig(context: Optional[str] = None, config_file: Optional[str] =
pass


class BaseKubeRayClient:
T_Status = TypeVar("T_Status")


class BaseKubeRayClient(Generic[T_Status]):
def __init__(
self,
group: str,
Expand All @@ -37,7 +40,7 @@ def __init__(
self._api = client.CustomObjectsApi(api_client=api_client)
self._core_v1_api = client.CoreV1Api(api_client=api_client)

def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_interval: int = 5, timeout: int = 60):
def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_interval: int = 5, timeout: int = 600):
from kubernetes.client import ApiException

start_time = time.time()
Expand All @@ -63,7 +66,7 @@ def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_int

time.sleep(poll_interval)

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> Dict[str, Any]:
def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> T_Status:
from kubernetes.client import ApiException

while timeout > 0:
Expand Down
8 changes: 1 addition & 7 deletions dagster_ray/kuberay/client/raycluster/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class RayClusterStatus(TypedDict):
state: NotRequired[str]


class RayClusterClient(BaseKubeRayClient):
class RayClusterClient(BaseKubeRayClient[RayClusterStatus]):
def __init__(
self,
config_file: Optional[str] = None,
Expand All @@ -91,12 +91,6 @@ def __init__(
self.config_file = config_file
self.context = context

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayClusterStatus: # type: ignore
return cast(
RayClusterStatus,
super().get_status(name=name, namespace=namespace, timeout=timeout, poll_interval=poll_interval),
)

def wait_until_ready(
self,
name: str,
Expand Down
30 changes: 17 additions & 13 deletions dagster_ray/kuberay/client/rayjob/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import time
from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict, cast
from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict

from typing_extensions import NotRequired

Expand Down Expand Up @@ -31,7 +31,7 @@ class RayJobStatus(TypedDict):
message: NotRequired[str]


class RayJobClient(BaseKubeRayClient):
class RayJobClient(BaseKubeRayClient[RayJobStatus]):
def __init__(
self,
config_file: Optional[str] = None,
Expand All @@ -46,12 +46,6 @@ def __init__(

super().__init__(group=GROUP, version=VERSION, kind=KIND, plural=PLURAL, api_client=api_client)

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayJobStatus: # type: ignore
return cast(
RayJobStatus,
super().get_status(name=name, namespace=namespace, timeout=timeout, poll_interval=poll_interval),
)

def get_ray_cluster_name(self, name: str, namespace: str) -> str:
return self.get_status(name, namespace)["rayClusterName"]

Expand All @@ -66,8 +60,10 @@ def wait_until_running(
self,
name: str,
namespace: str,
timeout: int = 300,
timeout: int = 600,
poll_interval: int = 5,
terminate_on_timeout: bool = True,
port_forward: bool = False,
) -> bool:
start_time = time.time()

Expand All @@ -80,9 +76,17 @@ def wait_until_running(
raise RuntimeError(f"RayJob {namespace}/{name} deployment failed. Status:\n{status}")

if time.time() - start_time > timeout:
raise TimeoutError(
f"Timed out waiting for RayJob {namespace}/{name} deployment to become available. Status:\n{status}"
)
if terminate_on_timeout:
logger.warning(f"Terminating RayJob {namespace}/{name} because of timeout {timeout}s")
try:
self.terminate(name, namespace, port_forward=port_forward)
except Exception as e:
logger.warning(
f"Failed to gracefully terminate RayJob {namespace}/{name}: {e}, will delete it instead."
)
self.delete(name, namespace)

raise TimeoutError(f"Timed out waiting for RayJob {namespace}/{name} to start. Status:\n{status}")

time.sleep(poll_interval)

Expand All @@ -103,7 +107,7 @@ def _wait_for_job_submission(
self,
name: str,
namespace: str,
timeout: int = 300,
timeout: int = 600,
poll_interval: int = 10,
):
start_time = time.time()
Expand Down
5 changes: 4 additions & 1 deletion dagster_ray/kuberay/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class PipesKubeRayJobClient(PipesClient, TreatAsResourceParam):
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the glue job run. Defaults to :py:class:`PipesRayJobMessageReader`.
client (Optional[boto3.client]): The Kubernetes API client.
forward_termination (bool): Whether to cancel the `RayJob` job run when the Dagster process receives a termination signal.
forward_termination (bool): Whether to terminate the Ray job when the Dagster process receives a termination signal,
or if the startup timeout is reached. Defaults to ``True``.
timeout (int): Timeout for various internal interactions with the Kubernetes RayJob.
poll_interval (int): Interval at which to poll the Kubernetes for status updates.
port_forward (bool): Whether to use Kubernetes port-forwarding to connect to the KubeRay cluster.
Expand Down Expand Up @@ -169,6 +170,8 @@ def _start(self, context: OpExecutionContext, ray_job: Dict[str, Any]) -> Dict[s
namespace=namespace,
timeout=self.timeout,
poll_interval=self.poll_interval,
terminate_on_timeout=self.forward_termination,
port_forward=self.port_forward,
)

return self.client.get(
Expand Down

0 comments on commit a590efd

Please sign in to comment.