diff --git a/.github/workflows/cover-ci.yml b/.github/workflows/cover-ci.yml index 15fc10de5..2b78ef030 100644 --- a/.github/workflows/cover-ci.yml +++ b/.github/workflows/cover-ci.yml @@ -30,7 +30,7 @@ jobs: echo $CONDA/bin >> $GITHUB_PATH - name: Install conda env & dependencies run: | - pip install -e '.[atari, mujoco, envpool]' + pip install -e '.[atari, mujoco, envpool, onnx]' conda list - name: Install codecov dependencies run: | diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 48842bffe..3b5d3e0f4 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -31,7 +31,7 @@ jobs: - name: Install conda env & dependencies run: | conda install python=${{ matrix.python-version }} - pip install -e '.[atari, mujoco, envpool, pettingzoo]' + pip install -e '.[atari, mujoco, envpool, pettingzoo, onnx]' conda list - name: Install test dependencies run: | @@ -75,7 +75,7 @@ jobs: - name: Install conda env & dependencies run: | conda install python=${{ matrix.python-version }} - pip install -e '.[atari, mujoco, pettingzoo]' + pip install -e '.[atari, mujoco, pettingzoo, onnx]' conda list - name: Install test dependencies run: | diff --git a/docs/03-customization/custom-environments.md b/docs/03-customization/custom-environments.md index 732ed5502..3c9b7427e 100644 --- a/docs/03-customization/custom-environments.md +++ b/docs/03-customization/custom-environments.md @@ -121,6 +121,35 @@ if __name__ == "__main__": You can now run evaluation with `python enjoy_custom_env.py --env=custom_env_name --experiment=CustomEnv` to measure the performance of the trained model, visualize agent's performance, or record a video file. +## ONNX export script template + +The exporting script is similar to the evaluation script, with a few key differences. +It uses the `export_onnx` function to convert your model to ONNX format. + +```python3 +import sys + +from sample_factory.export_onnx import export_onnx +from train_custom_env import parse_args, register_custom_env_envs + + +def main(): + """Script entry point.""" + register_custom_env_envs() + cfg = parse_args(evaluation=True) + + # The export_onnx function takes the configuration and the output file path + status = export_onnx(cfg, "my_model.onnx") + + return status + + +if __name__ == "__main__": + sys.exit(main()) +``` + +For information on how to use the exported ONNX models, please refer to the [Exporting a Model to ONNX](../07-advanced-topics/exporting-to-onnx.md) section. + ## Examples * `sf_examples/train_custom_env_custom_model.py` - integrates an entirely custom toy environment. diff --git a/docs/07-advanced-topics/exporting-to-onnx.md b/docs/07-advanced-topics/exporting-to-onnx.md new file mode 100644 index 000000000..1a5dc4fc0 --- /dev/null +++ b/docs/07-advanced-topics/exporting-to-onnx.md @@ -0,0 +1,75 @@ +# Exporting a Model to ONNX + +[ONNX](https://onnx.ai/) is a standard format for representing machine learning models. Sample Factory can export models to ONNX format. + +Exporting to ONNX allows you to: + +- Deploy your model in various production environments +- Use hardware-specific optimizations provided by ONNX Runtime +- Integrate your model with other tools and frameworks that support ONNX + +## Usage Examples + +First, train a model using Sample Factory. + +```bash +python -m sf_examples.train_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False --reward_scale=0.1 +``` + +Then, use the following command to export it to ONNX: + +```bash +python -m sf_examples.export_onnx_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False +``` + +This creates `example_gym_cartpole-v1.onnx` in the current directory. + +### Using the Exported Model + +Here's how to use the exported ONNX model: + +```python +import numpy as np +import onnxruntime + +ort_session = onnxruntime.InferenceSession("example_gym_cartpole-v1.onnx", providers=["CPUExecutionProvider"]) + +# The model expects a batch of observations as input. +batch_size = 3 +ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32)} + +ort_out = ort_session.run(None, ort_inputs) + +# The output is a list of actions, one for each observation in the batch. +selected_actions = ort_out[0] +print(selected_actions) # e.g. [1, 1, 0] +``` + +### RNN + +When exporting a model that uses RNN with `--use_rnn=True` (default), the model will expect RNN states as input. +Note that for RNN models, the batch size must be 1. + +```python +import numpy as np +import onnxruntime + +ort_session = onnxruntime.InferenceSession("rnn.onnx", providers=["CPUExecutionProvider"]) + +rnn_states_input = next(input for input in ort_session.get_inputs() if input.name == "rnn_states") +rnn_states = np.zeros(rnn_states_input.shape, dtype=np.float32) +batch_size = 1 # must be 1 + +for _ in range(10): + ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32), "rnn_states": rnn_states} + ort_out = ort_session.run(None, ort_inputs) + rnn_states = ort_out[1] # The second output is the updated rnn states +``` + +## Configuration + +The following key parameters will change the behavior of the exported mode: + +- `--use_rnn` Whether the model uses RNN. See the RNN example above. + +- `--eval_deterministic` If `True`, actions are selected by argmax. diff --git a/mkdocs.yml b/mkdocs.yml index b9e4544af..9aaa6a8c6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -157,6 +157,7 @@ nav: - 07-advanced-topics/observer.md - 07-advanced-topics/profiling.md - 07-advanced-topics/action-masking.md + - 07-advanced-topics/exporting-to-onnx.md - Miscellaneous: - 08-miscellaneous/tests.md - 08-miscellaneous/v1-to-v2.md diff --git a/sample_factory/algo/sampling/batched_sampling.py b/sample_factory/algo/sampling/batched_sampling.py index 140feef9f..185de4e7b 100644 --- a/sample_factory/algo/sampling/batched_sampling.py +++ b/sample_factory/algo/sampling/batched_sampling.py @@ -27,7 +27,9 @@ from sample_factory.utils.utils import log -def preprocess_actions(env_info: EnvInfo, actions: Tensor | np.ndarray) -> Tensor | np.ndarray | List: +def preprocess_actions( + env_info: EnvInfo, actions: Tensor | np.ndarray, to_numpy: bool = True +) -> Tensor | np.ndarray | List: """ We expect actions to have shape [num_envs, num_actions]. For environments that require only one action per step we just squeeze the second dimension, @@ -38,15 +40,17 @@ def preprocess_actions(env_info: EnvInfo, actions: Tensor | np.ndarray) -> Tenso """ if env_info.all_discrete or isinstance(env_info.action_space, gym.spaces.Discrete): - return process_action_space(actions, env_info.gpu_actions, is_discrete=True) + return process_action_space(actions, env_info.gpu_actions, is_discrete=True, to_numpy=to_numpy) elif isinstance(env_info.action_space, gym.spaces.Box): - return process_action_space(actions, env_info.gpu_actions, is_discrete=False) + return process_action_space(actions, env_info.gpu_actions, is_discrete=False, to_numpy=to_numpy) elif isinstance(env_info.action_space, gym.spaces.Tuple): # input is (num_envs, num_actions) out_actions = [] for split, space in zip(torch.split(actions, env_info.action_splits, 1), env_info.action_space): out_actions.append( - process_action_space(split, env_info.gpu_actions, isinstance(space, gym.spaces.Discrete)) + process_action_space( + split, env_info.gpu_actions, isinstance(space, gym.spaces.Discrete), to_numpy=to_numpy + ) ) # this line can be used to transpose the actions, perhaps add as an option ? # out_actions = list(zip(*out_actions)) # transpose @@ -55,11 +59,15 @@ def preprocess_actions(env_info: EnvInfo, actions: Tensor | np.ndarray) -> Tenso raise NotImplementedError(f"Unknown action space type: {env_info.action_space}") -def process_action_space(actions: torch.Tensor, gpu_actions: bool, is_discrete: bool): +def process_action_space( + actions: torch.Tensor, gpu_actions: bool, is_discrete: bool, to_numpy: bool = True +) -> torch.Tensor | np.ndarray: if is_discrete: actions = actions.to(torch.int32) if not gpu_actions: - actions = actions.cpu().numpy() + actions = actions.cpu() + if to_numpy: + actions = actions.numpy() # action tensor/array should have two dimensions (num_agents, num_actions) where num_agents is a number of # individual actors in a vectorized environment (whether actually different agents or separate envs - does not diff --git a/sample_factory/enjoy.py b/sample_factory/enjoy.py index 9d2af9b25..cd927ef33 100644 --- a/sample_factory/enjoy.py +++ b/sample_factory/enjoy.py @@ -1,6 +1,6 @@ import time from collections import deque -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import gymnasium as gym import numpy as np @@ -11,13 +11,13 @@ from sample_factory.algo.sampling.batched_sampling import preprocess_actions from sample_factory.algo.utils.action_distributions import argmax_actions from sample_factory.algo.utils.env_info import extract_env_info -from sample_factory.algo.utils.make_env import make_env_func_batched +from sample_factory.algo.utils.make_env import BatchedVecEnv, make_env_func_batched from sample_factory.algo.utils.misc import ExperimentStatus from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor from sample_factory.cfg.arguments import load_from_checkpoint from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf -from sample_factory.model.actor_critic import create_actor_critic +from sample_factory.model.actor_critic import ActorCritic, create_actor_critic from sample_factory.model.model_utils import get_rnn_size from sample_factory.utils.attr_dict import AttrDict from sample_factory.utils.typing import Config, StatusCode @@ -82,6 +82,24 @@ def render_frame(cfg, env, video_frames, num_episodes, last_render_start) -> flo return render_start +def make_env(cfg: Config, render_mode: Optional[str] = None) -> BatchedVecEnv: + env = make_env_func_batched( + cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0), render_mode=render_mode + ) + return env + + +def load_state_dict(cfg: Config, actor_critic: ActorCritic, device: torch.device) -> None: + policy_id = cfg.policy_index + name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind] + checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*") + checkpoint_dict = Learner.load_checkpoint(checkpoints, device) + if checkpoint_dict: + actor_critic.load_state_dict(checkpoint_dict["model"]) + else: + raise RuntimeError("Could not load checkpoint") + + def enjoy(cfg: Config) -> Tuple[StatusCode, float]: verbose = False @@ -103,9 +121,7 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]: elif cfg.no_render: render_mode = None - env = make_env_func_batched( - cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0), render_mode=render_mode - ) + env = make_env(cfg, render_mode=render_mode) env_info = extract_env_info(env, cfg) if hasattr(env.unwrapped, "reset_on_init"): @@ -118,11 +134,7 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]: device = torch.device("cpu" if cfg.device == "cpu" else "cuda") actor_critic.model_to_device(device) - policy_id = cfg.policy_index - name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind] - checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*") - checkpoint_dict = Learner.load_checkpoint(checkpoints, device) - actor_critic.load_state_dict(checkpoint_dict["model"]) + load_state_dict(cfg, actor_critic, device) episode_rewards = [deque([], maxlen=100) for _ in range(env.num_agents)] true_objectives = [deque([], maxlen=100) for _ in range(env.num_agents)] diff --git a/sample_factory/export_onnx.py b/sample_factory/export_onnx.py new file mode 100644 index 000000000..b3bbe6fcc --- /dev/null +++ b/sample_factory/export_onnx.py @@ -0,0 +1,194 @@ +import types +from typing import List + +import gymnasium as gym +import torch +import torch.nn as nn +import torch.onnx +from torch import Tensor + +from sample_factory.algo.learning.learner import Learner +from sample_factory.algo.sampling.batched_sampling import preprocess_actions +from sample_factory.algo.utils.action_distributions import argmax_actions +from sample_factory.algo.utils.env_info import EnvInfo, extract_env_info +from sample_factory.algo.utils.make_env import BatchedVecEnv +from sample_factory.algo.utils.misc import ExperimentStatus +from sample_factory.algo.utils.rl_utils import prepare_and_normalize_obs +from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor +from sample_factory.cfg.arguments import load_from_checkpoint +from sample_factory.enjoy import load_state_dict, make_env +from sample_factory.model.actor_critic import ActorCritic, create_actor_critic +from sample_factory.model.model_utils import get_rnn_size +from sample_factory.utils.attr_dict import AttrDict +from sample_factory.utils.typing import Config + + +class OnnxExporter(nn.Module): + actor_critic: ActorCritic + cfg: Config + env_info: EnvInfo + rnn_states: Tensor + + def __init__(self, cfg: Config, env_info: EnvInfo, actor_critic: ActorCritic): + super(OnnxExporter, self).__init__() + self.cfg = cfg + self.env_info = env_info + self.actor_critic = actor_critic + + def forward(self, **obs): + if self.cfg.use_rnn: + rnn_states = obs.pop("rnn_states") + else: + rnn_states = generate_rnn_states(self.cfg) + + action_mask = obs.pop("action_mask", None) + normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs) + policy_outputs = self.actor_critic(normalized_obs, rnn_states, action_mask=action_mask) + actions = policy_outputs["actions"] + rnn_states = policy_outputs["new_rnn_states"] + + if self.cfg.eval_deterministic: + action_distribution = self.actor_critic.action_distribution() + actions = argmax_actions(action_distribution) + + if actions.ndim == 1: + actions = unsqueeze_tensor(actions, dim=-1) + + actions = preprocess_actions(self.env_info, actions, to_numpy=False) + + if self.cfg.use_rnn: + return actions, rnn_states + else: + return actions + + +def create_onnx_exporter(cfg: Config, env: BatchedVecEnv, enable_jit=False) -> OnnxExporter: + env_info = extract_env_info(env, cfg) + device = torch.device("cpu") + + if enable_jit: + actor_critic = create_actor_critic(cfg, env.observation_space, env.action_space) + else: + try: + # HACK: disable torch.jit to avoid the following problem: + # https://github.com/pytorch/pytorch/issues/47887 + # + # The other workaround is to use torch.jit.trace, but it requires + # to change many things of models too + torch.jit._state.disable() # type: ignore[reportAttributeAccessIssue] + actor_critic = create_actor_critic(cfg, env.observation_space, env.action_space) + finally: + torch.jit._state.enable() # type: ignore[reportAttributeAccessIssue] + + actor_critic.eval() + actor_critic.model_to_device(device) + load_state_dict(cfg, actor_critic, device) + + model = OnnxExporter(cfg, env_info, actor_critic) + return model + + +def generate_args(space: gym.spaces.Space, batch_size: int = 1): + args = [unsqueeze_args(sample_space(space)) for _ in range(batch_size)] + args = [a for a in args if isinstance(a, dict)] + args = {k: torch.cat(tuple(a[k] for a in args), dim=0) for k in args[0].keys()} if len(args) > 0 else {} + return args + + +def generate_rnn_states(cfg): + return torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32) + + +def sample_space(space: gym.spaces.Space): + if isinstance(space, gym.spaces.Discrete): + return int(space.sample()) + elif isinstance(space, gym.spaces.Box): + return torch.from_numpy(space.sample()) + elif isinstance(space, gym.spaces.Dict): + return {k: sample_space(v) for k, v in space.spaces.items()} + elif isinstance(space, gym.spaces.Tuple): + return tuple(sample_space(s) for s in space.spaces) + else: + raise NotImplementedError(f"Unsupported space type: {type(space)}") + + +def unsqueeze_args(args): + if isinstance(args, int): + return torch.tensor(args).unsqueeze(0) + if isinstance(args, torch.Tensor): + return args.unsqueeze(0) + if isinstance(args, dict): + return {k: unsqueeze_args(v) for k, v in args.items()} + elif isinstance(args, tuple): + return (unsqueeze_args(v) for v in args) + else: + raise NotImplementedError(f"Unsupported args type: {type(args)}") + + +def create_forward(original_forward, arg_names: List[str]): + args_str = ", ".join(arg_names) + + func_code = f""" +def forward(self, {args_str}): + bound_args = locals() + bound_args.pop('self') + return original_forward(**bound_args) + """ + + globals_vars = {"original_forward": original_forward} + local_vars = {} + exec(func_code, globals_vars, local_vars) + return local_vars["forward"] + + +def patch_forward(model: OnnxExporter, input_names: List[str]): + """ + Patch the forward method of the model to dynamically define the input arguments + since *args and **kwargs are not supported in `torch.onnx.export` + + see also: https://github.com/pytorch/pytorch/issues/96981 and https://github.com/pytorch/pytorch/issues/110439 + """ + forward = create_forward(model.forward, input_names) + model.forward = types.MethodType(forward, model) + + +def export_onnx(cfg: Config, f: str) -> int: + cfg = load_from_checkpoint(cfg) + env = make_env(cfg) + model = create_onnx_exporter(cfg, env) + args = generate_args(env.observation_space) + + # The args dict is mapped to the inputs of the model + # since usages of dictionaries is not recommended by pytorch + # see also: https://github.com/pytorch/pytorch/blob/v2.4.1/torch/onnx/utils.py#L768-L772 + input_names = list(args.keys()) + + # Append the "output_" prefix to avoid name confliction with the input names + # that causes to add ".N" suffix to the input names. + # see also: https://discuss.pytorch.org/t/onnx-export-same-input-and-output-names/93155 + output_names = ["output_actions"] + + if cfg.use_rnn: + input_names.append("rnn_states") + output_names.append("output_rnn_states") + args["rnn_states"] = generate_rnn_states(cfg) + + # batch size must be 1 when rnn is used + # See also https://github.com/onnx/onnx/issues/3182 + dynamic_axes = None + else: + dynamic_axes = {key: {0: "batch_size"} for key in input_names + output_names} + + patch_forward(model, input_names) + + torch.onnx.export( + model, + (args,), + f, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + return ExperimentStatus.SUCCESS diff --git a/setup.py b/setup.py index 20fa95b1c..f5b511f6a 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ ] _envpool_deps = ["envpool"] _pettingzoo_deps = ["pettingzoo[classic]"] +_onnx_deps = ["onnx", "onnxruntime"] _docs_deps = [ "mkdocs-material", @@ -82,11 +83,13 @@ def is_macos(): + _docs_deps + _atari_deps + _mujoco_deps + + _onnx_deps + _pettingzoo_deps, "atari": _atari_deps, "envpool": _envpool_deps, "mujoco": _mujoco_deps, "nethack": _nethack_deps, + "onnx": _onnx_deps, "pettingzoo": _pettingzoo_deps, "vizdoom": ["vizdoom<2.0", "gymnasium[classic_control]"], # "dmlab": ["dm_env"], <-- these are just auxiliary packages, the main package has to be built from sources diff --git a/sf_examples/export_onnx_gym_env.py b/sf_examples/export_onnx_gym_env.py new file mode 100644 index 000000000..d34de0204 --- /dev/null +++ b/sf_examples/export_onnx_gym_env.py @@ -0,0 +1,23 @@ +""" +An example that shows how to export a SampleFactory model to the ONNX format. + +Example command line for CartPole-v1 that exports to "./example_gym_cartpole-v1.onnx" +python -m sf_examples.export_onnx_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False + +""" + +import sys + +from sample_factory.export_onnx import export_onnx +from sf_examples.train_gym_env import parse_custom_args, register_custom_components + + +def main(): + register_custom_components() + cfg = parse_custom_args(evaluation=True) + status = export_onnx(cfg, f"{cfg.experiment}.onnx") + return status + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/envs/atari/test_atari.py b/tests/envs/atari/test_atari.py index 2f4a34060..bce147694 100644 --- a/tests/envs/atari/test_atari.py +++ b/tests/envs/atari/test_atari.py @@ -8,6 +8,7 @@ from sample_factory.utils.utils import log from sf_examples.atari.train_atari import parse_atari_args from tests.envs.utils import eval_env_performance +from tests.export_onnx_utils import check_export_onnx from tests.utils import clean_test_dir @@ -45,23 +46,34 @@ def _run_test_env( experiment_name = "test_" + env - cfg = parse_atari_args(argv=["--algo=APPO", f"--env={env}", f"--experiment={experiment_name}"]) - cfg.serial_mode = serial_mode - cfg.async_rl = async_rl - cfg.batched_sampling = batched_sampling - cfg.num_workers = num_workers - cfg.num_envs_per_worker = 1 - cfg.train_for_env_steps = train_steps - cfg.batch_size = batch_size - cfg.decorrelate_envs_on_one_worker = False - cfg.seed = 0 - cfg.device = "cpu" + def parse_args(evaluation: bool = False): + cfg = parse_atari_args( + argv=["--algo=APPO", f"--env={env}", f"--experiment={experiment_name}"], evaluation=evaluation + ) + cfg.serial_mode = serial_mode + cfg.async_rl = async_rl + cfg.batched_sampling = batched_sampling + cfg.num_workers = num_workers + cfg.num_envs_per_worker = 1 + cfg.train_for_env_steps = train_steps + cfg.batch_size = batch_size + cfg.decorrelate_envs_on_one_worker = False + cfg.seed = 0 + cfg.device = "cpu" + cfg.eval_deterministic = True + return cfg + cfg = parse_args(env) directory = clean_test_dir(cfg) status = run_rl(cfg) assert status == ExperimentStatus.SUCCESS assert isdir(directory) - shutil.rmtree(directory, ignore_errors=True) + + try: + cfg = parse_args(evaluation=True) + check_export_onnx(cfg) + finally: + shutil.rmtree(directory, ignore_errors=True) @pytest.mark.parametrize( "env_name", diff --git a/tests/envs/pettingzoo/test_pettingzoo.py b/tests/envs/pettingzoo/test_pettingzoo.py index 9f8ff805a..6d82677f1 100644 --- a/tests/envs/pettingzoo/test_pettingzoo.py +++ b/tests/envs/pettingzoo/test_pettingzoo.py @@ -9,6 +9,7 @@ from sample_factory.utils.utils import log from sf_examples.train_pettingzoo_env import make_pettingzoo_env, parse_custom_args, register_custom_components from tests.envs.utils import eval_env_performance +from tests.export_onnx_utils import check_export_onnx from tests.utils import clean_test_dir @@ -42,22 +43,33 @@ def _run_test_env( experiment_name = "test_" + env - cfg = parse_custom_args(argv=["--algo=APPO", f"--env={env}", f"--experiment={experiment_name}"]) - cfg.serial_mode = serial_mode - cfg.async_rl = async_rl - cfg.batched_sampling = batched_sampling - cfg.num_workers = num_workers - cfg.train_for_env_steps = train_steps - cfg.batch_size = batch_size - cfg.decorrelate_envs_on_one_worker = False - cfg.seed = 0 - cfg.device = "cpu" + def parse_args(evaluation: bool = False): + cfg = parse_custom_args( + argv=["--algo=APPO", f"--env={env}", f"--experiment={experiment_name}"], evaluation=evaluation + ) + cfg.serial_mode = serial_mode + cfg.async_rl = async_rl + cfg.batched_sampling = batched_sampling + cfg.num_workers = num_workers + cfg.train_for_env_steps = train_steps + cfg.batch_size = batch_size + cfg.decorrelate_envs_on_one_worker = False + cfg.seed = 0 + cfg.device = "cpu" + cfg.eval_deterministic = True + return cfg + cfg = parse_args(env) directory = clean_test_dir(cfg) status = run_rl(cfg) assert status == ExperimentStatus.SUCCESS assert isdir(directory) - shutil.rmtree(directory, ignore_errors=True) + + try: + cfg = parse_args(evaluation=True) + check_export_onnx(cfg) + finally: + shutil.rmtree(directory, ignore_errors=True) @pytest.mark.parametrize("batched_sampling", [False, True]) def test_basic_envs(self, batched_sampling): diff --git a/tests/envs/tuple_action_envs/test_two_discrete_action_dist_env_batched.py b/tests/envs/tuple_action_envs/test_two_discrete_action_dist_env_batched.py index 68214609c..8b89a4a87 100644 --- a/tests/envs/tuple_action_envs/test_two_discrete_action_dist_env_batched.py +++ b/tests/envs/tuple_action_envs/test_two_discrete_action_dist_env_batched.py @@ -9,6 +9,7 @@ from sample_factory.envs.env_utils import register_env from sample_factory.train import run_rl from tests.envs.tuple_action_envs.test_two_discrete_action_dist_env_non_batched import DiscreteActions, get_reward +from tests.export_onnx_utils import check_export_onnx class IdentityEnvTwoDiscreteActions(gym.Env): @@ -118,6 +119,13 @@ def register_test_components(): ) +def parse_args(argv=None, evaluation=False): + parser, cfg = parse_sf_args(argv=argv, evaluation=evaluation) + override_defaults(parser) + cfg = parse_full_cfg(parser, argv=argv) + return cfg + + def test_batched_two_discrete_action_dists(): """Script entry point.""" register_test_components() @@ -128,10 +136,9 @@ def test_batched_two_discrete_action_dists(): "--restart_behavior=overwrite", "--device=cpu", ] - parser, cfg = parse_sf_args(argv=argv) - - override_defaults(parser) - cfg = parse_full_cfg(parser, argv=argv) - + cfg = parse_args(argv=argv) status = run_rl(cfg) + + cfg = parse_args(argv=argv, evaluation=True) + check_export_onnx(cfg) return status diff --git a/tests/envs/tuple_action_envs/test_two_discrete_action_dist_env_non_batched.py b/tests/envs/tuple_action_envs/test_two_discrete_action_dist_env_non_batched.py index a890ff7d1..78a1622a1 100644 --- a/tests/envs/tuple_action_envs/test_two_discrete_action_dist_env_non_batched.py +++ b/tests/envs/tuple_action_envs/test_two_discrete_action_dist_env_non_batched.py @@ -9,6 +9,7 @@ from sample_factory.envs.env_utils import register_env from sample_factory.train import run_rl from sample_factory.utils.utils import debug_log_every_n +from tests.export_onnx_utils import check_export_onnx DiscreteActions = Union[List[int], Tuple[int, ...], np.ndarray] @@ -85,6 +86,13 @@ def register_test_components(): ) +def parse_args(argv=None, evaluation=False): + parser, cfg = parse_sf_args(argv=argv, evaluation=evaluation) + override_defaults(parser) + cfg = parse_full_cfg(parser, argv=argv) + return cfg + + def test_non_batched_two_discrete_action_dists(): """Script entry point.""" register_test_components() @@ -97,9 +105,9 @@ def test_non_batched_two_discrete_action_dists(): "--device=cpu", ] - parser, cfg = parse_sf_args(argv=argv) - override_defaults(parser) - cfg = parse_full_cfg(parser, argv=argv) - + cfg = parse_args(argv=argv) status = run_rl(cfg) + + cfg = parse_args(argv=argv, evaluation=True) + check_export_onnx(cfg) return status diff --git a/tests/examples/test_example.py b/tests/examples/test_example.py index 7d8d029ad..795fc7b8c 100644 --- a/tests/examples/test_example.py +++ b/tests/examples/test_example.py @@ -21,6 +21,7 @@ from sample_factory.utils.typing import Config from sample_factory.utils.utils import experiment_dir, log from sf_examples.train_custom_env_custom_model import parse_custom_args, register_custom_components +from tests.export_onnx_utils import check_export_onnx def default_test_cfg( @@ -57,6 +58,7 @@ def run_test_env( expected_reward_at_least: float = -EPS, expected_reward_at_most: float = 100, check_envs: bool = False, + check_export: bool = False, register_custom_components_func: Callable = register_custom_components, env_name: str = "my_custom_env_v1", ): @@ -81,12 +83,16 @@ def run_test_env( status, avg_reward = enjoy(eval_cfg) log.debug(f"Test reward: {avg_reward:.4f}") - assert isdir(directory) - shutil.rmtree(directory, ignore_errors=True) + try: + assert status == ExperimentStatus.SUCCESS + assert avg_reward >= expected_reward_at_least + assert avg_reward <= expected_reward_at_most - assert status == ExperimentStatus.SUCCESS - assert avg_reward >= expected_reward_at_least - assert avg_reward <= expected_reward_at_most + if check_export: + check_export_onnx(eval_cfg) + finally: + assert isdir(directory) + shutil.rmtree(directory, ignore_errors=True) if cfg.serial_mode and check_envs: # we can directly access the envs and check things in serial mode @@ -164,4 +170,5 @@ def test_full_run(self): eval_cfg, expected_reward_at_least=80, expected_reward_at_most=100, + check_export=True, ) diff --git a/tests/export_onnx_utils.py b/tests/export_onnx_utils.py new file mode 100644 index 000000000..806d447f8 --- /dev/null +++ b/tests/export_onnx_utils.py @@ -0,0 +1,68 @@ +import onnx +import onnxruntime +import torch + +from sample_factory.algo.utils.make_env import BatchedVecEnv +from sample_factory.enjoy import make_env +from sample_factory.export_onnx import OnnxExporter, create_onnx_exporter, export_onnx, generate_args +from sample_factory.utils.typing import Config +from sample_factory.utils.utils import experiment_dir + + +def to_numpy(tensor: torch.Tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + +def check_onnx_model(filename: str) -> None: + model = onnx.load(filename) + onnx.checker.check_model(model) + + +def check_rnn_inference_result( + env: BatchedVecEnv, model: OnnxExporter, ort_session: onnxruntime.InferenceSession +) -> None: + rnn_states_input = next(input for input in ort_session.get_inputs() if input.name == "rnn_states") + rnn_states = torch.zeros(rnn_states_input.shape, dtype=torch.float32) + ort_rnn_states = to_numpy(rnn_states) + + for _ in range(3): + args = generate_args(env.observation_space) + actions, rnn_states = model(**args, rnn_states=rnn_states) + + ort_inputs = {k: to_numpy(v) for k, v in args.items()} + ort_inputs["rnn_states"] = ort_rnn_states + ort_out = ort_session.run(None, ort_inputs) + ort_rnn_states = ort_out[1] + + assert (to_numpy(actions) == ort_out[0]).all() + + +def check_inference_result(env: BatchedVecEnv, model: OnnxExporter, ort_session: onnxruntime.InferenceSession) -> None: + for batch_size in [1, 3]: + args = generate_args(env.observation_space, batch_size) + actions = model(**args) + + ort_inputs = {k: to_numpy(v) for k, v in args.items()} + ort_out = ort_session.run(None, ort_inputs) + + assert len(ort_out[0]) == batch_size + assert (to_numpy(actions) == ort_out[0]).all() + + +def check_export_onnx(cfg: Config) -> None: + cfg.eval_deterministic = True + directory = experiment_dir(cfg=cfg, mkdir=False) + filename = f"{directory}/{cfg.experiment}.onnx" + status = export_onnx(cfg, filename) + assert status == 0 + + check_onnx_model(filename) + + env = make_env(cfg) + model = create_onnx_exporter(cfg, env, enable_jit=True) + ort_session = onnxruntime.InferenceSession(filename, providers=["CPUExecutionProvider"]) + + if cfg.use_rnn: + check_rnn_inference_result(env, model, ort_session) + else: + check_inference_result(env, model, ort_session)