Skip to content

Commit

Permalink
wait until RayCluster is ready
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Sep 17, 2024
1 parent 5c09f36 commit 3b32496
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 132 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && apt-get install -y git jq curl gcc python3-dev libpq-dev wget

COPY --from=bitnami/kubectl:1.30.3 /opt/bitnami/kubectl/bin/kubectl /usr/local/bin/

# install poetry
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
Expand Down
26 changes: 19 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The following backends are implemented:
Documentation can be found below.

> [!NOTE]
> This project is in early development. Contributions are very welcome! See the [Development](#development) section below.
> This project is in early development. APIs are unstable and can change at any time. Contributions are very welcome! See the [Development](#development) section below.
# Backends

Expand Down Expand Up @@ -142,7 +142,9 @@ def my_asset(context: AssetExecutionContext, pipes_rayjob_client: PipesRayJobCli
pipes_rayjob_client.run(
context=context,
ray_job={
# RayJob manifest goes here, only .metadata.name is not required and will be generated if not provided
# RayJob manifest goes here
# .metadata.name is not required and will be generated if not provided
# *.container.image is not required and will be set to the current `dagster/image` tag if not provided
# full reference: https://ray-project.github.io/kuberay/reference/api/#rayjob
...
},
Expand All @@ -162,7 +164,7 @@ from dagster_pipes import open_dagster_pipes


with open_dagster_pipes() as pipes:
pipes.log.info("Hello from Ray!")
pipes.log.info("Hello from Ray Pipes!")
pipes.report_asset_materialization(
metadata={"some_metric": {"raw_value": 0, "type": "int"}},
data_version="alpha",
Expand All @@ -177,8 +179,18 @@ import yaml
ray_job = {"spec": {"runtimeEnvYaml": yaml.safe_dump({"pip": ["dagster-pipes"]})}}
```

The logs and events emitted by the Ray job will be captured by the `PipesRayJobClient` and will become available in the Dagster event log.
Normal stdout & stderr will be forwarded to stdout.
The logs and events emitted by the Ray job will be captured by the `PipesRayJobClient` and will become available in the Dagster event log. Standard output and standard error streams will be forwarded to the standard output of the Dagster process.


**Running locally**

When running locally, the `port_forward` option has to be set to `True` in the `PipesRayJobClient` resource in order to interact with the Ray job. For convenience, it can be set automatically with:

```python
from dagster_ray.kuberay.configs import in_k8s

pipes_rayjob_client = PipesRayJobClient(..., port_forward=not in_k8s)
```

### Resources

Expand Down Expand Up @@ -300,13 +312,13 @@ Running `pytest` will **automatically**:
- build an image with the local `dagster-ray` code
- start a `minikube` Kubernetes cluster
- load the built `dagster-ray` and loaded `kuberay-operator` images into the cluster
- install the `KubeRay Operator` in the cluster with `helm`
- install `KubeRay Operator` into the cluster with `helm`
- run the tests

Thus, no manual setup is required, just the presence of the tools listed above. This makes testing a breeze!

> [!NOTE]
> Specifying a comma-separated list of `KubeRay Operator` versions in the `KUBE_RAY_OPERATOR_VERSIONS` environment variable will spawn a new test for each version.
> Specifying a comma-separated list of `KubeRay Operator` versions in the `PYTEST_KUBERAY_VERSIONS` environment variable will spawn a new test for each version.
> [!NOTE]
> it may take a while to download `minikube` and `kuberay-operator` images and build the local `dagster-ray` image during the first tests invocation
13 changes: 4 additions & 9 deletions dagster_ray/kuberay/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,15 @@ def __init__(
version: str,
kind: str,
plural: str,
context: Optional[str] = None,
config_file: Optional[str] = None,
api_client: Optional[client.ApiClient] = None,
):
self.group = group
self.version = version
self.kind = kind
self.plural = plural
self.config_file = config_file
self.context = context

self.kube_config: Optional[Any] = load_kubeconfig(config_file=config_file, context=context)

self._api = client.CustomObjectsApi()
self._core_v1_api = client.CoreV1Api()
self.api_client = api_client
self._api = client.CustomObjectsApi(api_client=api_client)
self._core_v1_api = client.CoreV1Api(api_client=api_client)

def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_interval: int = 5, timeout: int = 60):
start_time = time.time()
Expand Down
23 changes: 16 additions & 7 deletions dagster_ray/kuberay/client/raycluster/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ def __init__(self, config_file: Optional[str] = None, context: Optional[str] = N
version=VERSION,
kind=KIND,
plural=PLURAL,
config_file=config_file,
context=context,
)

# these are only used because of kubectl port-forward CLI command
# TODO: remove kubectl usage and remove these attributes
self.config_file = config_file
self.context = context

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayClusterStatus: # type: ignore
return cast(
RayClusterStatus,
Expand All @@ -113,6 +116,7 @@ def wait_until_ready(
# TODO: use get_namespaced_custom_object instead
# once https://github.com/kubernetes-client/python/issues/1679
# is solved

for event in w.stream(
self._api.list_namespaced_custom_object,
self.group,
Expand All @@ -129,7 +133,7 @@ def wait_until_ready(

if status.get("state") == "failed":
raise Exception(
f"RayCluster {namespace}/{name} failed to start. More details: `kubectl -n {namespace} describe RayCluster {name}`"
f"RayCluster {namespace}/{name} failed to start. Reason:\n{status.get('reason')}\nMore details: `kubectl -n {namespace} describe RayCluster {name}`"
)

if (
Expand Down Expand Up @@ -196,6 +200,9 @@ def port_forward(
if self.context:
cmd.extend(["--context", self.context])

if self.config_file:
cmd.extend(["--kubeconfig", self.config_file])

process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

queue = Queue()
Expand Down Expand Up @@ -226,7 +233,9 @@ def should_stop():
if "Forwarding from" in line:
break

logger.debug(f"port-forwarding for ports {local_dashboard_port} and {local_gcs_port} started")
logger.info(
f"Connecting to {namespace}/{name} via port-forwarding for ports {local_dashboard_port} and {local_gcs_port}..."
)

yield local_dashboard_port, local_gcs_port
finally:
Expand All @@ -235,11 +244,11 @@ def should_stop():
process.kill()
process.wait()
t.join()
logger.debug(f"port-forwarding for ports {local_dashboard_port} and {local_gcs_port} stopped")
logger.info(f"Port-forwarding for ports {local_dashboard_port} and {local_gcs_port} has been stopped.")

@contextmanager
def job_submission_client(
self, name: str, namespace: str, port_forward: bool = False
self, name: str, namespace: str, port_forward: bool = False, timeout: int = 60
) -> Iterator["JobSubmissionClient"]:
"""
Returns a JobSubmissionClient object that can be used to interact with Ray jobs running in the KubeRay cluster.
Expand All @@ -259,7 +268,7 @@ def job_submission_client(

yield JobSubmissionClient(address=f"http://{host}:{dashboard_port}")
else:
self.wait_for_service_endpoints(service_name=f"{name}-head-svc", namespace=namespace)
self.wait_for_service_endpoints(service_name=f"{name}-head-svc", namespace=namespace, timeout=timeout)
with self.port_forward(name=name, namespace=namespace, local_dashboard_port=0, local_gcs_port=0) as (
local_dashboard_port,
_,
Expand Down
22 changes: 13 additions & 9 deletions dagster_ray/kuberay/client/rayjob/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing_extensions import NotRequired

from dagster_ray.kuberay.client.base import BaseKubeRayClient
from dagster_ray.kuberay.client.base import BaseKubeRayClient, load_kubeconfig
from dagster_ray.kuberay.client.raycluster import RayClusterClient, RayClusterStatus

GROUP = "ray.io"
Expand All @@ -30,13 +30,17 @@ class RayJobStatus(TypedDict):

class RayJobClient(BaseKubeRayClient):
def __init__(self, config_file: Optional[str] = None, context: Optional[str] = None) -> None:
# this call must happen BEFORE creating K8s apis
load_kubeconfig(config_file=config_file, context=context)

self.config_file = config_file
self.context = context

super().__init__(
group=GROUP,
version=VERSION,
kind=KIND,
plural=PLURAL,
config_file=config_file,
context=context,
)

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayJobStatus: # type: ignore
Expand All @@ -53,13 +57,13 @@ def get_job_sumission_id(self, name: str, namespace: str) -> str:

@property
def ray_cluster_client(self) -> RayClusterClient:
return RayClusterClient(config_file=self.kube_config, context=self.context)
return RayClusterClient(config_file=self.config_file, context=self.context)

def wait_until_running(
self,
name: str,
namespace: str,
timeout: int = 60 * 60,
timeout: int = 300,
poll_interval: int = 5,
) -> bool:
start_time = time.time()
Expand All @@ -70,11 +74,11 @@ def wait_until_running(
if status in ["Running", "Complete"]:
break
elif status == "Failed":
return False
raise RuntimeError(f"RayJob {namespace}/{name} deployment failed. Status:\n{status}")

if time.time() - start_time > timeout:
raise TimeoutError(
f"Timed out waiting for RayJob {name} deployment to become available." f"Status: {status}"
f"Timed out waiting for RayJob {namespace}/{name} deployment to become available. Status:\n{status}"
)

time.sleep(poll_interval)
Expand All @@ -86,7 +90,7 @@ def wait_until_running(
break

if time.time() - start_time > timeout:
raise TimeoutError(f"Timed out waiting for RayJob {name} to start. " f"Status: {status}")
raise TimeoutError(f"Timed out waiting for RayJob {namespace}/{name} to start. Status:\n{status}")

time.sleep(poll_interval)

Expand All @@ -96,7 +100,7 @@ def _wait_for_job_submission(
self,
name: str,
namespace: str,
timeout: int = 600,
timeout: int = 300,
poll_interval: int = 10,
):
start_time = time.time()
Expand Down
4 changes: 2 additions & 2 deletions dagster_ray/kuberay/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
"env": [],
"envFrom": [],
"resources": {
"limits": {"cpu": "1000m", "memory": "1Gi"},
"requests": {"cpu": "1000m", "memory": "1Gi"},
"limits": {"cpu": "50m", "memory": "0.1Gi"},
"requests": {"cpu": "50m", "memory": "0.1Gi"},
},
}
DEFAULT_HEAD_GROUP_SPEC = {
Expand Down
40 changes: 18 additions & 22 deletions dagster_ray/kuberay/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ class PipesRayJobClient(PipesClient, TreatAsResourceParam):
Is useful when running in a local environment.
"""

client: RayJobClient

def __init__(
self,
client: Optional[RayJobClient] = None,
Expand Down Expand Up @@ -101,7 +99,9 @@ def run( # type: ignore
namespace = ray_job["metadata"]["namespace"]

with self.client.ray_cluster_client.job_submission_client(
name=name, namespace=namespace, port_forward=self.port_forward
name=self.client.get_ray_cluster_name(name=name, namespace=namespace),
namespace=namespace,
port_forward=self.port_forward,
) as job_submission_client:
self._job_submission_client = job_submission_client

Expand Down Expand Up @@ -130,27 +130,19 @@ def _enrich_ray_job(
ray_job["metadata"] = ray_job.get("metadata", {})
ray_job["metadata"]["labels"] = ray_job["metadata"].get("labels", {})

ray_job["metadata"]["name"] = ray_job["metadata"].get("name", f"dg-{context.run_id[:6]}")
ray_job["metadata"]["name"] = ray_job["metadata"].get("name", f"dg-{context.run.run_id[:8]}")
ray_job["metadata"]["labels"].update(self.get_dagster_tags(context))

# update env vars in runtimeEnv
runtime_env_yaml = ray_job["spec"].get("runtimeEnvYAML")
runtime_env_yaml = ray_job["spec"].get("runtimeEnvYAML", "{}")

if runtime_env_yaml is None:
runtime_env_yaml = yaml.safe_dump(
{
"env_vars": env_vars,
}
)
else:
runtime_env = yaml.safe_load(runtime_env_yaml)
runtime_env["env_vars"] = runtime_env.get("env_vars", {})
runtime_env["env_vars"].update(env_vars)
runtime_env = yaml.safe_load(runtime_env_yaml)
runtime_env["env_vars"] = runtime_env.get("env_vars", {})
runtime_env["env_vars"].update(env_vars)

ray_job["spec"]["runtimeEnvYAML"] = yaml.safe_dump(runtime_env)
ray_job["spec"]["runtimeEnvYAML"] = yaml.safe_dump(runtime_env)

# set image from tag context.dagster_run.tags["dagster/image"] if not set
image_from_run_tag = context.dagster_run.tags.get("dagster/image")
image_from_run_tag = context.run.tags.get("dagster/image")

for container in ray_job["spec"]["rayClusterSpec"]["headGroupSpec"]["template"]["spec"]["containers"]:
container["image"] = container.get("image") or image_from_run_tag
Expand Down Expand Up @@ -189,7 +181,9 @@ def _read_messages(self, context: OpExecutionContext, start_response: Dict[str,

if isinstance(self._message_reader, PipesRayJobSubmissionClientMessageReader):
# starts a thread
self._message_reader.tail_job_logs(client=self.job_submission_client, job_id=status["jobId"])
self._message_reader.consume_job_logs(
client=self.job_submission_client, job_id=status["jobId"], blocking=False
)

def _wait_for_completion(self, context: OpExecutionContext, start_response: Dict[str, Any]) -> RayJobStatus:
context.log.info("[pipes] Waiting for RayJob to complete...")
Expand All @@ -209,12 +203,14 @@ def _wait_for_completion(self, context: OpExecutionContext, start_response: Dict
elif job_status == "SUCCEEDED":
context.log.info(f"[pipes] RayJob {namespace}/{name} is complete!")
return status
elif job_status == ["STOPPED", "FAILED"]:
elif job_status in ["STOPPED", "FAILED"]:
raise RuntimeError(
f"RayJob {namespace}/{name} status is {job_status}. Reason:\n{status.get('message')}"
f"RayJob {namespace}/{name} status is {job_status}. Message:\n{status.get('message')}"
)
else:
raise RuntimeError(f"RayJob {namespace}/{name} has an unknown status: {job_status}")
raise RuntimeError(
f"RayJob {namespace}/{name} has an unknown status: {job_status}. Message:\n{status.get('message')}"
)

time.sleep(self.poll_interval)

Expand Down
Loading

0 comments on commit 3b32496

Please sign in to comment.