Skip to content

Commit

Permalink
🚀 add ray_executor (#12)
Browse files Browse the repository at this point in the history
* 🚀 add ray_executor
  • Loading branch information
danielgafni authored Oct 7, 2024
1 parent 4a7473b commit c12c70a
Show file tree
Hide file tree
Showing 8 changed files with 521 additions and 1,423 deletions.
39 changes: 37 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,42 @@ To install with extra dependencies for a particular backend (like `kuberay`), ru
pip install 'dagster-ray[kuberay]'
```

# Executor

> [!WARNING]
> The `ray_executor` is a work in progress
The `ray_executor` can be used to execute Dagster steps on an existing remote Ray cluster.
The executor submits steps as Ray jobs. They are started directly in the Ray cluster. Example:


```python
from dagster import job, op
from dagster_ray import ray_executor


@op(
tags={
"dagster-ray/config": {
"num_cpus": 8,
"num_gpus": 2,
"runtime_env": {"pip": {"packages": ["torch"]}},
}
}
)
def my_op():
import torch

return torch.tensor([42])


@job(executor_def=ray_executor)
def my_job():
return my_op()
```

Fields in the `dagster-ray/config` tag **replace** corresponding fields in the Executor config.

# Backends

## KubeRay
Expand Down Expand Up @@ -288,8 +324,7 @@ definitions = Definitions(
)
```

# Executor
WIP


# Development

Expand Down
3 changes: 2 additions & 1 deletion dagster_ray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dagster_ray._base.resources import BaseRayResource
from dagster_ray.executor import ray_executor

RayResource = BaseRayResource


__all__ = ["RayResource"]
__all__ = ["RayResource", "ray_executor"]
37 changes: 36 additions & 1 deletion dagster_ray/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,45 @@
from __future__ import annotations

from typing import Optional
from typing import Any, Dict, Optional

from dagster import Config
from pydantic import Field

USER_DEFINED_RAY_KEY = "dagster-ray/config"


class RayExecutionConfig(Config):
runtime_env: Optional[Dict[str, Any]] = Field(default=None, description="The runtime environment to use.")
num_cpus: Optional[int] = Field(default=None, description="The number of CPUs to allocate.")
num_gpus: Optional[int] = Field(default=None, description="The number of GPUs to allocate.")
memory: Optional[int] = Field(default=None, description="The amount of memory in bytes to allocate.")
resources: Optional[Dict[str, float]] = Field(default=None, description="Custom resources to allocate.")

@classmethod
def from_tags(cls, tags: Dict[str, str]) -> RayExecutionConfig:
if USER_DEFINED_RAY_KEY in tags:
return cls.parse_raw(tags[USER_DEFINED_RAY_KEY])
else:
return cls()


class RayJobSubmissionClientConfig(Config):
address: str = Field(..., description="The address of the Ray cluster to connect to.")
metadata: Optional[Dict[str, Any]] = Field(
default=None,
description="""Arbitrary metadata to store along with all jobs. New metadata
specified per job will be merged with the global metadata provided here
via a simple dict update.""",
)
headers: Optional[Dict[str, str]] = Field(
default=None,
description="""Headers to use when sending requests to the HTTP job server, used
for cases like authentication to a remote cluster.""",
)
cookies: Optional[Dict[str, str]] = Field(
default=None, description="Cookies to use when sending requests to the HTTP job server."
)


class ExecutionOptionsConfig(Config):
cpu: Optional[int] = None
Expand Down
216 changes: 216 additions & 0 deletions dagster_ray/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast

from dagster import (
_check as check,
)
from dagster import (
executor,
)
from dagster._core.definitions.executor_definition import multiple_process_executor_requirements
from dagster._core.definitions.metadata import MetadataValue
from dagster._core.events import DagsterEvent, EngineEventData
from dagster._core.execution.retries import RetryMode, get_retries_config
from dagster._core.execution.tags import get_tag_concurrency_limits_config
from dagster._core.executor.base import Executor
from dagster._core.executor.init import InitExecutorContext
from dagster._core.executor.step_delegating import (
CheckStepHealthResult,
StepDelegatingExecutor,
StepHandler,
StepHandlerContext,
)
from dagster._utils.merger import merge_dicts
from dagster_k8s.job import (
get_k8s_job_name,
)
from dagster_k8s.launcher import K8sRunLauncher

from dagster_ray.config import RayExecutionConfig, RayJobSubmissionClientConfig

if TYPE_CHECKING:
pass


class RayExecutorConfig(RayExecutionConfig, RayJobSubmissionClientConfig): ...


_RAY_CONFIG_SCHEMA = RayExecutorConfig.to_config_schema().as_field()

_RAY_EXECUTOR_CONFIG_SCHEMA = merge_dicts(
{"ray": _RAY_CONFIG_SCHEMA}, # type: ignore
{"retries": get_retries_config(), "tag_concurrency_limits": get_tag_concurrency_limits_config()},
)


@executor(
name="ray",
config_schema=_RAY_EXECUTOR_CONFIG_SCHEMA,
requirements=multiple_process_executor_requirements(),
)
def ray_executor(init_context: InitExecutorContext) -> Executor:
"""Executes steps by submitting them as Ray jobs.
The steps are started inside the Ray cluster directly.
"""
# TODO: some RunLauncher config values can be automatically passed to the executor
run_launcher = ( # noqa
init_context.instance.run_launcher if isinstance(init_context.instance.run_launcher, K8sRunLauncher) else None
)

exc_cfg = init_context.executor_config

ray_cfg = RayExecutorConfig(**exc_cfg["ray"]) # type: ignore

return StepDelegatingExecutor(
RayStepHandler(address=ray_cfg.address, runtime_env=ray_cfg.runtime_env),
retries=RetryMode.from_config(exc_cfg["retries"]), # type: ignore
max_concurrent=check.opt_int_elem(exc_cfg, "max_concurrent"),
tag_concurrency_limits=check.opt_list_elem(exc_cfg, "tag_concurrency_limits"),
should_verify_step=True,
)


class RayStepHandler(StepHandler):
@property
def name(self):
return "RayStepHandler"

def __init__(
self,
address: str,
runtime_env: Optional[Dict[str, Any]] = None,
num_cpus: Optional[int] = None,
num_gpus: Optional[int] = None,
memory: Optional[int] = None,
resources: Optional[Dict[str, float]] = None,
metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
):
super().__init__()

from ray.job_submission import JobSubmissionClient

self.client = JobSubmissionClient(address, metadata=metadata, headers=headers, cookies=cookies)
self.runtime_env = runtime_env or {}
self.num_cpus = num_cpus
self.num_gpus = num_gpus
self.memory = memory
self.resources = resources

def _get_step_key(self, step_handler_context: StepHandlerContext) -> str:
step_keys_to_execute = cast(List[str], step_handler_context.execute_step_args.step_keys_to_execute)
assert len(step_keys_to_execute) == 1, "Launching multiple steps is not currently supported"
return step_keys_to_execute[0]

def _get_ray_job_submission_id(self, step_handler_context: StepHandlerContext):
step_key = self._get_step_key(step_handler_context)

name_key = get_k8s_job_name(
step_handler_context.execute_step_args.run_id,
step_key,
)

if step_handler_context.execute_step_args.known_state:
retry_state = step_handler_context.execute_step_args.known_state.get_retry_state()
if retry_state.get_attempt_count(step_key):
return "dagster-step-%s-%d" % (name_key, retry_state.get_attempt_count(step_key))

return "dagster-step-%s" % (name_key)

def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
step_key = self._get_step_key(step_handler_context)

submission_id = self._get_ray_job_submission_id(step_handler_context)

run = step_handler_context.dagster_run
labels = {
"dagster/job": run.job_name,
"dagster/op": step_key,
"dagster/run-id": step_handler_context.execute_step_args.run_id,
}
if run.external_job_origin:
labels["dagster/code-location"] = (
run.external_job_origin.repository_origin.code_location_origin.location_name
)

user_provided_config = RayExecutionConfig.from_tags({**step_handler_context.step_tags[step_key]})

runtime_env = (user_provided_config.runtime_env or self.runtime_env).copy()

dagster_env_vars = {
"DAGSTER_RUN_JOB_NAME": run.job_name,
"DAGSTER_RUN_STEP_KEY": step_key,
**{env["name"]: env["value"] for env in step_handler_context.execute_step_args.get_command_env()},
}

runtime_env["env_vars"] = {**dagster_env_vars, **runtime_env.get("env_vars", {})} # type: ignore

num_cpus = self.num_cpus or user_provided_config.num_cpus
num_gpus = self.num_gpus or user_provided_config.num_gpus
memory = self.memory or user_provided_config.memory
resources = self.resources or {}
resources.update(user_provided_config.resources or {})

yield DagsterEvent.step_worker_starting(
step_handler_context.get_step_context(step_key),
message=f'Executing step "{step_key}" in Ray job {submission_id}.',
metadata={
"Ray Submission ID": MetadataValue.text(submission_id),
},
)

self.client.submit_job(
entrypoint=" ".join(
step_handler_context.execute_step_args.get_command_args(skip_serialized_namedtuple=True)
),
submission_id=submission_id,
metadata=labels,
runtime_env=runtime_env,
entrypoint_num_cpus=num_cpus,
entrypoint_num_gpus=num_gpus,
entrypoint_memory=memory,
entrypoint_resources=resources,
)

def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult:
from ray.job_submission import JobStatus

step_key = self._get_step_key(step_handler_context)

submission_id = self._get_ray_job_submission_id(step_handler_context)

try:
status = self.client.get_job_status(submission_id)
except RuntimeError:
return CheckStepHealthResult.unhealthy(
reason=f"Ray job {submission_id} for step {step_key} could not be found."
)

if status == JobStatus.FAILED:
job_details = self.client.get_job_info(submission_id)

reason = f"Discovered failed Ray job {submission_id} for step {step_key}."

if job_details.error_type:
reason += f" Error type: {job_details.error_type}."

if job_details.message:
reason += f" Message: {job_details.message}."

return CheckStepHealthResult.unhealthy(reason=reason)

return CheckStepHealthResult.healthy()

def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
step_key = self._get_step_key(step_handler_context)

submission_id = self._get_ray_job_submission_id(step_handler_context)

yield DagsterEvent.engine_event(
step_handler_context.get_step_context(step_key),
message=f"Deleting Ray job {submission_id} for step",
event_specific_data=EngineEventData(),
)

self.client.stop_job(submission_id)
Loading

0 comments on commit c12c70a

Please sign in to comment.