diff --git a/.gitignore b/.gitignore index 243eda4..cc6c002 100644 --- a/.gitignore +++ b/.gitignore @@ -146,4 +146,9 @@ envpool .sokoban_cache .build .vscode -k8s_copy/ \ No newline at end of file +k8s_copy/ +Craftax_Baselines + +nsight-profile + +uv.lock diff --git a/.gitmodules b/.gitmodules index 0f5827d..6d3818d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "third_party/gym-sokoban"] path = third_party/gym-sokoban url = https://github.com/AlignmentResearch/gym-sokoban +[submodule "third_party/craftax"] + path = third_party/craftax + url = https://github.com/rhaps0dy/Craftax diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index de9e0ca..aa92093 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.13 + rev: v0.9.7 hooks: # Run the linter. - id: ruff diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..450727c --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,30 @@ +# CLAUDE.md: Development Guidelines + +## Build & Test Commands +- Install: `make local-install` +- Lint: `make lint` +- Format: `make format` +- Typecheck: `make typecheck` +- Run tests: `make mactest` or `pytest -m 'not envpool and not slow'` +- Run single test: `pytest tests/test_file.py::test_function -v` +- Training command: `python -m cleanba.cleanba_impala --from-py-fn=cleanba.config:sokoban_drc33_59` + +## Code Style Guidelines +- **Formatting**: Follow ruff format (127 line length) +- **Imports**: Use isort through ruff, with known third-party libraries like wandb +- **Types**: Use type annotations, checked with pyright +- **Naming**: + - Variables: `snake_case` + - Classes: `PascalCase` + - Constants: `UPPER_CASE` +- **Structure**: Modules organized by functionality (cleanba, experiments, tests) +- **Error handling**: Use asserts for validation in tests, exceptions for runtime errors +- **Documentation**: Include docstrings for public functions and classes +- **JAX/Flax patterns**: Use pure functions and maintain functional style + +## Key Files +- **environments.py**: Contains environment wrappers and config classes for different environments (Sokoban, Boxoban, Craftax). Includes `EpisodeEvalWrapper` for logging episode returns and adapters for different environment backends. + +- **cleanba_impala.py**: Main training loop implementation for the IMPALA algorithm. Contains multi-threaded rollout data collection, parameter synchronization, and training. Uses `WandbWriter` for logging, queues for communication between rollout and learner threads, and implements checkpointing. + +- **impala_loss.py**: Implements V-trace and PPO loss functions. Contains `Rollout` data structure, TD-error computation with V-trace, and policy gradient calculations. Handles truncated episodes specially to provide correct advantage estimates. \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index c97058e..c77e971 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,36 +1,42 @@ ARG JAX_DATE -FROM ghcr.io/nvidia/jax:base-${JAX_DATE} as envpool-environment +FROM ghcr.io/nvidia/jax:base-${JAX_DATE} AS envpool-environment ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update \ - && apt-get install -y golang-1.18 git \ + && apt-get install -y golang-1.21 git \ # Linters clang-format clang-tidy \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* -ENV PATH=/usr/lib/go-1.18/bin:/root/go/bin:$PATH -RUN go install github.com/bazelbuild/bazelisk@v1.19.0 && ln -sf $HOME/go/bin/bazelisk $HOME/go/bin/bazel -RUN go install github.com/bazelbuild/buildtools/buildifier@v0.0.0-20231115204819-d4c9dccdfbb1 +USER ubuntu +ENV HOME=/home/ubuntu +ENV PATH=/usr/lib/go-1.21/bin:${HOME}/go/bin:$PATH +ENV UID=1000 +ENV GID=1000 +RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} go install github.com/bazelbuild/bazelisk@v1.19.0 && ln -sf $HOME/go/bin/bazelisk $HOME/go/bin/bazel +RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} go install github.com/bazelbuild/buildtools/buildifier@v0.0.0-20231115204819-d4c9dccdfbb1 # Install Go linting tools -RUN go install github.com/google/addlicense@v1.1.1 +RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} go install github.com/google/addlicense@v1.1.1 -ENV USE_BAZEL_VERSION=6.4.0 +ENV USE_BAZEL_VERSION=8.1.0 RUN bazel version WORKDIR /app +# Copy the whole repository +COPY --chown=ubuntu:ubuntu third_party/envpool . + # Install python-based linting dependencies -COPY third_party/envpool/third_party/pip_requirements/requirements-devtools.txt \ +COPY --chown=ubuntu:ubuntu \ + third_party/envpool/third_party/pip_requirements/requirements-devtools.txt \ third_party/pip_requirements/requirements-devtools.txt -RUN pip install -r third_party/pip_requirements/requirements-devtools.txt - -# Copy the whole repository -COPY third_party/envpool . +RUN --mount=type=cache,target=${HOME}/.cache,uid=1000,gid=1000 pip install -r third_party/pip_requirements/requirements-devtools.txt +ENV PATH="${HOME}/.local/bin:${PATH}" # Deal with the fact that envpool is a submodule and has no .git directory RUN rm .git # Copy the .git repository for this submodule -COPY .git/modules/envpool ./.git +COPY --chown=ubuntu:ubuntu .git/modules/envpool ./.git # Remove config line stating that the worktree for this repo is elsewhere RUN sed -e 's/^.*worktree =.*$//' .git/config > .git/config.new && mv .git/config.new .git/config @@ -39,14 +45,14 @@ RUN echo "$(git status --porcelain --ignored=traditional)" \ && if ! { [ -z "$(git status --porcelain --ignored=traditional)" ] \ ; }; then exit 1; fi -FROM envpool-environment as envpool -RUN make bazel-release +FROM envpool-environment AS envpool +RUN --mount=type=cache,target=${HOME}/.cache,uid=1000,gid=1000 make bazel-release && cp bazel-bin/*.whl . -FROM ghcr.io/nvidia/jax:jax-${JAX_DATE} as main-pre-pip +FROM ghcr.io/nvidia/jax:jax-${JAX_DATE} AS main-pre-pip ARG APPLICATION_NAME -ARG USERID=1001 -ARG GROUPID=1001 +ARG UID=1001 +ARG GID=1001 ARG USERNAME=dev ENV GIT_URL="https://github.com/AlignmentResearch/${APPLICATION_NAME}" @@ -84,8 +90,8 @@ ENV VIRTUAL_ENV="/opt/venv" ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" RUN python3 -m venv "${VIRTUAL_ENV}" --system-site-packages \ - && addgroup --gid ${GROUPID} ${USERNAME} \ - && adduser --uid ${USERID} --gid ${GROUPID} --disabled-password --gecos '' ${USERNAME} \ + && addgroup --gid ${GID} ${USERNAME} \ + && adduser --uid ${UID} --gid ${GID} --disabled-password --gecos '' ${USERNAME} \ && usermod -aG sudo ${USERNAME} \ && echo "${USERNAME} ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers \ && mkdir -p "/workspace" \ @@ -96,32 +102,38 @@ WORKDIR "/workspace" # Get a pip modern enough that can resolve farconf RUN pip install "pip ==24.0" && rm -rf "${HOME}/.cache" -FROM main-pre-pip as main-pip-tools +FROM main-pre-pip AS main-pip-tools RUN pip install "pip-tools ~=7.4.1" -FROM main-pre-pip as main +FROM main-pre-pip AS main +RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} pip install uv COPY --chown=${USERNAME}:${USERNAME} requirements.txt ./ # Install all dependencies, which should be explicit in `requirements.txt` -RUN pip install --no-deps -r requirements.txt \ - && rm -rf "${HOME}/.cache" \ +RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} \ + uv pip install --no-deps -r requirements.txt \ # Run Pyright so its Node.js package gets installed && pyright . # Install Envpool -ENV ENVPOOL_WHEEL="dist/envpool-0.8.4-cp310-cp310-linux_x86_64.whl" -COPY --from=envpool --chown=${USERNAME}:${USERNAME} "/app/${ENVPOOL_WHEEL}" "./${ENVPOOL_WHEEL}" -RUN pip install "./${ENVPOOL_WHEEL}" && rm -rf "./dist" +ENV ENVPOOL_WHEEL="envpool-0.9.0-cp312-cp312-linux_x86_64.whl" +COPY --from=envpool --chown=${USERNAME}:${USERNAME} "/app/${ENVPOOL_WHEEL}" "${ENVPOOL_WHEEL}" +RUN uv pip install "${ENVPOOL_WHEEL}" && rm "${ENVPOOL_WHEEL}" + +# Cache Craftax textures +RUN python -c "import craftax.craftax.constants" # Copy whole repo COPY --chown=${USERNAME}:${USERNAME} . . -RUN pip install --no-deps -e . -e ./third_party/gym-sokoban/ +RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} \ + uv pip install --no-deps -e . -e ./third_party/gym-sokoban/ # Set git remote URL to https for all sub-repos RUN git remote set-url origin "$(git remote get-url origin | sed 's|git@github.com:|https://github.com/|' )" \ && (cd third_party/envpool && git remote set-url origin "$(git remote get-url origin | sed 's|git@github.com:|https://github.com/|' )" ) # Abort if repo is dirty -RUN echo "$(git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" \ +RUN rm NVIDIA_Deep_Learning_Container_License.pdf \ + && echo "$(git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" \ && echo "$(cd third_party/envpool && git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" \ && echo "$(cd third_party/gym-sokoban && git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" \ && if ! { [ -z "$(git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" ] \ @@ -130,5 +142,5 @@ RUN echo "$(git status --porcelain --ignored=traditional | grep -v '.egg-info/$' ; }; then exit 1; fi -FROM main as atari -RUN pip uninstall -y envpool && pip install envpool && rm -rf "${HOME}/.cache" +FROM main AS atari +RUN uv pip uninstall -y envpool && uv pip install envpool && rm -rf "${HOME}/.cache" diff --git a/Makefile b/Makefile index a27479a..75008a7 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,8 @@ export DOCKERFILE COMMIT_HASH ?= $(shell git rev-parse HEAD) BRANCH_NAME ?= $(shell git branch --show-current) -JAX_DATE=2024-04-08 +JAX_DATE=2025-02-22 +PYTHON_VERSION=3.12 default: release/main @@ -29,12 +30,10 @@ BUILD_PREFIX ?= $(shell git rev-parse --short HEAD) -f "${DOCKERFILE}" . touch ".build/with-reqs/${BUILD_PREFIX}/$*" -# NOTE: --extra=extra is for stable-baselines3 testing. -requirements.txt.new: pyproject.toml ${DOCKERFILE} - docker run -v "${HOME}/.cache:/home/dev/.cache" -v "$(shell pwd):/workspace" "ghcr.io/nvidia/jax:base-${JAX_DATE}" \ - bash -c "pip install pip-tools \ - && cd /workspace \ - && pip-compile --verbose -o requirements.txt.new --extra=dev --extra=launch_jobs pyproject.toml" +requirements.txt.new: pyproject.toml + docker run -v "${HOME}/.cache:/home/dev/.cache" -v "$(shell pwd):/workspace" "ghcr.io/astral-sh/uv:python${PYTHON_VERSION}-alpine" \ + sh -c "cd /workspace \ + && uv pip compile --verbose -o requirements.txt.new --extra=dev pyproject.toml" # To bootstrap `requirements.txt`, comment out this target requirements.txt: requirements.txt.new @@ -42,8 +41,8 @@ requirements.txt: requirements.txt.new .PHONY: local-install local-install: requirements.txt - pip install --no-deps -r requirements.txt - pip install --config-settings editable_mode=compat -e ".[dev-local]" -e ./third_party/gym-sokoban + uv pip install --no-deps -r requirements.txt + uv pip install -e ".[py-tools]" -e ./third_party/gym-sokoban pip install https://github.com/AlignmentResearch/envpool/releases/download/v0.1.0/envpool-0.8.4-cp310-cp310-linux_x86_64.whl @@ -93,18 +92,18 @@ cuda-devbox/%: devbox/% cuda-devbox: cuda-devbox/main .PHONY: envpool-devbox -envpool-devbox: devbox/envpool-ci +envpool-devbox: devbox/envpool .PHONY: docker docker/% docker/%: - docker run -v "$(shell pwd):/workspace" -it "${APPLICATION_URL}:${RELEASE_PREFIX}-$*" /bin/bash + docker run -v "${HOME}/.cache:/home/ubuntu/.cache" -v "$(shell pwd):/workspace" -it "${APPLICATION_URL}:${RELEASE_PREFIX}-$*" /bin/bash docker: docker/main .PHONY: envpool-docker envpool-docker/% envpool-docker/%: - docker run -v "$(shell pwd)/third_party/envpool:/app" -it "${APPLICATION_URL}:${RELEASE_PREFIX}-$*" /bin/bash -envpool-docker: envpool-docker/envpool-ci + docker run -v "${HOME}/.cache:/home/ubuntu/.cache" -v "$(shell pwd)/third_party/envpool:/app" -it "${APPLICATION_URL}:${RELEASE_PREFIX}-$*" /bin/bash +envpool-docker: envpool-docker/envpool # Section 3: project commands diff --git a/cleanba/cleanba_impala.py b/cleanba/cleanba_impala.py index dadcb29..167fb8c 100644 --- a/cleanba/cleanba_impala.py +++ b/cleanba/cleanba_impala.py @@ -1,6 +1,7 @@ import contextlib import dataclasses import json +import logging import math import os import queue @@ -12,9 +13,10 @@ import time import warnings from collections import deque +from ctypes import cdll from functools import partial from pathlib import Path -from typing import Any, Callable, Hashable, Iterator, List, Mapping, Optional +from typing import Any, Callable, Hashable, Iterator, List, Mapping, NamedTuple, Optional import chex import databind.core.converter @@ -32,7 +34,7 @@ from cleanba.config import Args from cleanba.convlstm import ConvLSTMConfig -from cleanba.environments import convert_to_cleanba_config, random_seed +from cleanba.environments import EpisodeEvalWrapper, convert_to_cleanba_config, random_seed from cleanba.evaluate import EvalConfig from cleanba.impala_loss import ( SINGLE_DEVICE_UPDATE_DEVICES_AXIS, @@ -43,17 +45,30 @@ from cleanba.network import AgentParams, Policy, PolicyCarryT, label_and_learning_rate_for_params from cleanba.optimizer import rmsprop_pytorch_style -# Make Jax CPU use 1 thread only https://github.com/google/jax/issues/743 -os.environ["XLA_FLAGS"] = ( - os.environ.get("XLA_FLAGS", "") + " --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" -) -# Fix CUDNN non-determinism; https://github.com/google/jax/issues/4823#issuecomment-952835771 -os.environ["TF_XLA_FLAGS"] = ( - os.environ.get("TF_XLA_FLAGS", "") + " --xla_gpu_autotune_level=2 --xla_gpu_deterministic_reductions" -) +log = logging.getLogger(__file__) + + +class ParamsPayload(NamedTuple): + """Structured data for the params queue.""" + + params: Any # device_params + policy_version: int # learner_policy_version -# Fix CUDNN non-determinism; https://github.com/google/jax/issues/4823#issuecomment-952835771 -os.environ["TF_CUDNN_DETERMINISTIC"] = "1" + +class RolloutPayload(NamedTuple): + """Structured data for the rollout queue.""" + + global_step: int + policy_version: int # actor_policy_version + update: int + storage: Rollout # sharded_storage + params_queue_get_time: float + device_thread_id: int + + +libcudart = None +if os.getenv("NSIGHT_ACTIVE", "0") == "1": + libcudart = cdll.LoadLibrary("libcudart.so") def unreplicate(tree): @@ -66,27 +81,36 @@ class WandbWriter: named_save_dir: Path def __init__(self, cfg: "Args", wandb_cfg_extra_data: dict[str, Any] = {}): - wandb_kwargs: dict[str, Any] + wandb_kwargs: dict[str, Any] = dict( + name=os.environ.get("WANDB_JOB_NAME", generate_name(style="hyphen")), + mode=os.environ.get("WANDB_MODE", "online"), + group=os.environ.get("WANDB_RUN_GROUP", "default"), + ) try: - wandb_kwargs = dict( - entity=os.environ["WANDB_ENTITY"], - name=os.environ.get("WANDB_JOB_NAME", generate_name(style="hyphen")), - project=os.environ["WANDB_PROJECT"], - group=os.environ["WANDB_RUN_GROUP"], - mode=os.environ.get("WANDB_MODE", "online"), # Default to online here + wandb_kwargs.update( + dict( + entity=os.environ["WANDB_ENTITY"], + project=os.environ["WANDB_PROJECT"], + ) ) - job_name = wandb_kwargs["name"] except KeyError: # If any of the essential WANDB environment variables are missing, # simply don't upload this run. # It's fine to do this without giving any indication because Wandb already prints that the run is offline. - - wandb_kwargs = dict(mode=os.environ.get("WANDB_MODE", "offline"), group="default") - job_name = "develop" + wandb_kwargs["mode"] = os.environ.get("WANDB_MODE", "offline") + job_name = wandb_kwargs["name"] run_dir = cfg.base_run_dir / wandb_kwargs["group"] run_dir.mkdir(parents=True, exist_ok=True) + jax_compile_cache = cfg.base_run_dir / "kernel-cache" + jax_compile_cache.mkdir(exist_ok=True, parents=True) + + jax.config.update("jax_compilation_cache_dir", str(jax_compile_cache)) + jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) + jax.config.update("jax_persistent_cache_min_compile_time_secs", 10) + jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") + old_run_dir_sym = run_dir / "wandb" / job_name run_id = None if old_run_dir_sym.exists() and not cfg.finetune_with_noop_head: @@ -123,16 +147,19 @@ def __init__(self, cfg: "Args", wandb_cfg_extra_data: dict[str, Any] = {}): shutil.move(f, self._save_dir / f.name) self.named_save_dir.unlink() - self.named_save_dir.symlink_to(save_dir_no_local_files, target_is_directory=True) + self.named_save_dir.symlink_to(save_dir_no_local_files.absolute(), target_is_directory=True) self.step_digits = math.ceil(math.log10(cfg.total_timesteps)) def add_scalar(self, name: str, value: int | float, global_step: int): wandb.log({name: value}, step=global_step) + def add_dict(self, metrics: dict[str, int | float], global_step: int): + wandb.log(metrics, step=global_step) + @contextlib.contextmanager def save_dir(self, global_step: int) -> Iterator[Path]: - name = f"cp_{{step:0{self.step_digits}d}}".format(step=global_step) + name = f"cp_{global_step:0{self.step_digits}d}" out = self._save_dir / name out.mkdir() yield out @@ -163,13 +190,13 @@ class RuntimeInformation: def initialize_multi_device(args: Args) -> Iterator[RuntimeInformation]: local_batch_size = int(args.local_num_envs * args.num_steps * args.num_actor_threads * len(args.actor_device_ids)) local_minibatch_size = int(local_batch_size // args.num_minibatches) - assert ( - args.local_num_envs % len(args.learner_device_ids) == 0 - ), "local_num_envs must be divisible by len(learner_device_ids)" + assert args.local_num_envs % len(args.learner_device_ids) == 0, ( + "local_num_envs must be divisible by len(learner_device_ids)" + ) - assert ( - int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0 - ), "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" + assert int(args.local_num_envs / len(args.learner_device_ids)) * args.num_actor_threads % args.num_minibatches == 0, ( + "int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches" + ) distributed = args.distributed # guard agiainst edits to `args` if args.distributed: @@ -230,9 +257,6 @@ def log_parameter_differences(params) -> dict[str, jax.Array]: @dataclasses.dataclass class LoggingStats: - episode_returns: list[float] - episode_lengths: list[float] - episode_success: list[float] params_queue_get_time: list[float] rollout_time: list[float] create_rollout_time: list[float] @@ -241,8 +265,6 @@ class LoggingStats: env_recv_time: list[float] inference_time: list[float] storage_time: list[float] - device2host_time: list[float] - env_send_time: list[float] update_time: list[float] @classmethod @@ -261,15 +283,26 @@ def avg_and_flush(self) -> dict[str, float]: @contextlib.contextmanager -def time_and_append(stats: list[float]): +def time_and_append(stats: list[float], name: str, step_num: int): start_time = time.time() - yield + with jax.named_scope(name): + yield stats.append(time.time() - start_time) +@dataclasses.dataclass(order=True) +class PrioritizedItem: + priority: int + item: Any = dataclasses.field(compare=False) + + @partial(jax.jit, static_argnames=["len_learner_devices"]) def _concat_and_shard_rollout_internal( - storage: List[Rollout], last_obs: jax.Array, last_episode_starts: np.ndarray, len_learner_devices: int + storage: List[Rollout], + last_obs: jax.Array, + last_episode_starts: np.ndarray, + last_value: jax.Array, + len_learner_devices: int, ) -> Rollout: """ Stack the Rollout steps over time, splitting them for every learner device. @@ -297,6 +330,7 @@ def _split_over_batches(x): carry_t=jax.tree.map(lambda x: jnp.expand_dims(_split_over_batches(x), axis=1), storage[0].carry_t), a_t=jnp.stack([_split_over_batches(r.a_t) for r in storage], axis=1), logits_t=jnp.stack([_split_over_batches(r.logits_t) for r in storage], axis=1), + value_t=jnp.stack([*(_split_over_batches(r.value_t) for r in storage), _split_over_batches(last_value)], axis=1), r_t=jnp.stack([_split_over_batches(r.r_t) for r in storage], axis=1), episode_starts_t=jnp.stack( [*(_split_over_batches(r.episode_starts_t) for r in storage), _split_over_batches(last_episode_starts)], axis=1 @@ -307,9 +341,15 @@ def _split_over_batches(x): def concat_and_shard_rollout( - storage: list[Rollout], last_obs: jax.Array, last_episode_starts: np.ndarray, learner_devices: list[jax.Device] + storage: list[Rollout], + last_obs: jax.Array, + last_episode_starts: jax.Array, + last_value: jax.Array, + learner_devices: list[jax.Device], ) -> Rollout: - partitioned_storage = _concat_and_shard_rollout_internal(storage, last_obs, last_episode_starts, len(learner_devices)) + partitioned_storage = _concat_and_shard_rollout_internal( + storage, last_obs, last_episode_starts, last_value, len(learner_devices) + ) sharded_storage = jax.tree.map(lambda x: jax.device_put_sharded(list(x), devices=learner_devices), partitioned_storage) return sharded_storage @@ -321,19 +361,21 @@ def rollout( runtime_info: RuntimeInformation, rollout_queue: queue.Queue, params_queue: queue.Queue, - writer, + metrics_queue: queue.PriorityQueue, learner_devices: list[jax.Device], device_thread_id: int, - actor_device, + actor_device: jax.Device, global_step: int = 0, ): actor_id: int = device_thread_id + args.num_actor_threads * jax.process_index() - envs = dataclasses.replace( - args.train_env, - seed=args.train_env.seed + actor_id, - num_envs=args.local_num_envs, - ).make() + envs = EpisodeEvalWrapper( + dataclasses.replace( + args.train_env, + seed=args.train_env.seed + actor_id, + num_envs=args.local_num_envs, + ).make() + ) eval_envs: list[tuple[str, EvalConfig]] = list(args.eval_envs.items()) # Spread various eval envs among the threads @@ -345,18 +387,13 @@ def rollout( this_thread_eval_keys = list(jax.random.split(eval_keys, len(this_thread_eval_cfg))) len_actor_device_ids = len(args.actor_device_ids) - start_time = time.time() + start_time = None log_stats = LoggingStats.new_empty() - # Counters for episode length and episode return - episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32) - episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32) - returned_episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32) - returned_episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32) - returned_episode_success = np.zeros((args.local_num_envs,), dtype=np.bool_) - + info_t = {} actor_policy_version = 0 storage = [] + metrics = {} # Store the first observation obs_t, _ = envs.reset() @@ -365,17 +402,21 @@ def rollout( key, carry_key = jax.random.split(key) policy, carry_t, _ = args.net.init_params(envs, carry_key) episode_starts_t = np.ones(envs.num_envs, dtype=np.bool_) + get_action_fn = jax.jit(partial(policy.apply, method=policy.get_action), static_argnames="temperature") global MUST_STOP_PROGRAM + global libcudart for update in range(initial_update, runtime_info.num_updates + 2): if MUST_STOP_PROGRAM: break param_frequency = args.actor_update_frequency if update <= args.actor_update_cutoff else 1 + if libcudart is not None and update == 4: + libcudart.cudaProfilerStart() - with time_and_append(log_stats.update_time): - with time_and_append(log_stats.params_queue_get_time): + with time_and_append(log_stats.update_time, "update", global_step): + with time_and_append(log_stats.params_queue_get_time, "params_queue_get", global_step): num_steps_with_bootstrap = args.num_steps if args.concurrency: @@ -385,42 +426,44 @@ def rollout( if ((update - 1) % param_frequency == 0 and (update - 1) != param_frequency) or ( (update - 2) == param_frequency ): - params, actor_policy_version = params_queue.get(timeout=args.queue_timeout) + payload = params_queue.get(timeout=args.queue_timeout) # NOTE: block here is important because otherwise this thread will call # the jitted `get_action` function that hangs until the params are ready. # This blocks the `get_action` function in other actor threads. # See https://excalidraw.com/#json=hSooeQL707gE5SWY8wOSS,GeaN1eb2r24PPi75a3n14Q for a visual explanation. - jax.block_until_ready(params) + params, actor_policy_version = jax.block_until_ready(payload.params), payload.policy_version else: if (update - 1) % args.actor_update_frequency == 0: - params, actor_policy_version = params_queue.get(timeout=args.queue_timeout) + payload = params_queue.get(timeout=args.queue_timeout) + params, actor_policy_version = payload.params, payload.policy_version - with time_and_append(log_stats.rollout_time): + with time_and_append(log_stats.rollout_time, "rollout", global_step): for _ in range(1, num_steps_with_bootstrap + 1): global_step += ( args.local_num_envs * args.num_actor_threads * len_actor_device_ids * runtime_info.world_size ) - with time_and_append(log_stats.inference_time): - carry_tplus1, a_t, logits_t, key = get_action_fn(params, carry_t, obs_t, episode_starts_t, key) - - with time_and_append(log_stats.device2host_time): - cpu_action = np.array(a_t) - - with time_and_append(log_stats.env_send_time): - envs.step_async(cpu_action) + with time_and_append(log_stats.inference_time, "inference", global_step): + obs_t, episode_starts_t = jax.device_put((obs_t, episode_starts_t), device=actor_device) + carry_tplus1, a_t, logits_t, value_t, key = get_action_fn( + params, carry_t, obs_t, episode_starts_t, key + ) + assert a_t.shape == (args.local_num_envs,) - with time_and_append(log_stats.env_recv_time): - obs_tplus1, r_t, term_t, trunc_t, info_t = envs.step_wait() + with time_and_append(log_stats.env_recv_time, "step", global_step): + obs_tplus1, r_t, term_t, trunc_t, info_t = envs.step(a_t) done_t = term_t | trunc_t + assert r_t.shape == (args.local_num_envs,) + assert done_t.shape == (args.local_num_envs,) - with time_and_append(log_stats.create_rollout_time): + with time_and_append(log_stats.create_rollout_time, "create_rollout", global_step): storage.append( Rollout( obs_t=obs_t, carry_t=carry_t, a_t=a_t, logits_t=logits_t, + value_t=value_t, r_t=r_t, episode_starts_t=episode_starts_t, truncated_t=trunc_t, @@ -430,98 +473,75 @@ def rollout( carry_t = carry_tplus1 episode_starts_t = done_t - # Atari envs clip their reward to [-1, 1], meaning we need to use the reward in `info` to get - # the true return. - non_clipped_reward = info_t.get("reward", r_t) - - episode_returns[:] += non_clipped_reward - log_stats.episode_returns.extend(episode_returns[done_t]) - returned_episode_returns[done_t] = episode_returns[done_t] - episode_returns[:] *= ~done_t - - episode_lengths[:] += 1 - log_stats.episode_lengths.extend(episode_lengths[done_t]) - returned_episode_lengths[done_t] = episode_lengths[done_t] - episode_lengths[:] *= ~done_t - - log_stats.episode_success.extend(map(float, term_t[done_t])) - returned_episode_success[done_t] = term_t[done_t] + with time_and_append(log_stats.storage_time, "storage", global_step): + obs_t, episode_starts_t = jax.device_put((obs_t, episode_starts_t), device=actor_device) + if args.loss.needs_last_value: + # We can't roll this out of the loop. In the next loop iteration, we will use the updated parameters + # to gather rollouts. + _, _, _, value_t, _ = get_action_fn(params, carry_t, obs_t, episode_starts_t, key) + else: + value_t = jnp.full(value_t.shape, jnp.nan, dtype=value_t.dtype, device=value_t.device) - with time_and_append(log_stats.storage_time): - sharded_storage = concat_and_shard_rollout(storage, obs_t, episode_starts_t, learner_devices) + sharded_storage = concat_and_shard_rollout(storage, obs_t, episode_starts_t, value_t, learner_devices) storage.clear() - payload = ( - global_step, - actor_policy_version, - update, - sharded_storage, - np.mean(log_stats.params_queue_get_time), - device_thread_id, + payload = RolloutPayload( + global_step=global_step, + policy_version=actor_policy_version, + update=update, + storage=sharded_storage, + params_queue_get_time=np.mean(log_stats.params_queue_get_time), + device_thread_id=device_thread_id, ) - with time_and_append(log_stats.rollout_queue_put_time): + with time_and_append(log_stats.rollout_queue_put_time, "rollout_queue_put", global_step): rollout_queue.put(payload, timeout=args.queue_timeout) # Log on all rollout threads if update % args.log_frequency == 0: - inner_loop_time = ( - np.sum(log_stats.env_recv_time) - + np.sum(log_stats.create_rollout_time) - + np.sum(log_stats.inference_time) - + np.sum(log_stats.device2host_time) - + np.sum(log_stats.env_send_time) - ) total_rollout_time = np.sum(log_stats.rollout_time) - middle_loop_time = ( - total_rollout_time - + np.sum(log_stats.storage_time) - + np.sum(log_stats.params_queue_get_time) - + np.sum(log_stats.rollout_queue_put_time) - ) - outer_loop_time = np.sum(log_stats.update_time) - stats_dict: dict[str, float] = log_stats.avg_and_flush() - steps_per_second = global_step / (time.time() - start_time) - print( - f"{update=} {device_thread_id=}, SPS={steps_per_second:.2f}, {global_step=}, avg_episode_returns={stats_dict['avg_episode_returns']:.2f}, avg_episode_length={stats_dict['avg_episode_lengths']:.2f}, avg_rollout_time={stats_dict['avg_rollout_time']:.5f}" - ) - for k, v in stats_dict.items(): - if k.endswith("_time"): - writer.add_scalar(f"stats/{device_thread_id}/{k}", v, global_step) - else: - writer.add_scalar(f"charts/{device_thread_id}/{k}", v, global_step) + if start_time is None: + steps_per_second = 0 + start_time = time.time() + else: + steps_per_second = global_step / (time.time() - start_time) - writer.add_scalar(f"charts/{device_thread_id}/instant_avg_episode_length", np.mean(episode_lengths), global_step) - writer.add_scalar(f"charts/{device_thread_id}/instant_avg_episode_return", np.mean(episode_returns), global_step) - writer.add_scalar( - f"charts/{device_thread_id}/returned_avg_episode_length", np.mean(returned_episode_lengths), global_step - ) - writer.add_scalar( - f"charts/{device_thread_id}/returned_avg_episode_return", np.mean(returned_episode_returns), global_step - ) - writer.add_scalar( - f"charts/{device_thread_id}/returned_avg_episode_success", np.mean(returned_episode_success), global_step + charts_dict = jax.tree.map(jnp.mean, {k: v for k, v in info_t.items() if k.startswith("returned")}) + print( + f"{update=} {device_thread_id=}, SPS={steps_per_second:.2f}, {global_step=}, ep_returns={charts_dict['returned_episode_return']:.2f}, ep_length={charts_dict['returned_episode_length']:.2f}, avg_rollout_time={stats_dict['avg_rollout_time']:.5f}" ) - writer.add_scalar( - f"stats/{device_thread_id}/inner_time_efficiency", inner_loop_time / total_rollout_time, global_step - ) - writer.add_scalar( - f"stats/{device_thread_id}/middle_time_efficiency", middle_loop_time / outer_loop_time, global_step + # Perf: Time performance metrics + metrics.update( + { + f"Perf/{device_thread_id}/rollout_total": total_rollout_time, + f"Perf/{device_thread_id}/SPS": steps_per_second, + f"policy_versions/{device_thread_id}/actor": actor_policy_version, + } ) - writer.add_scalar(f"charts/{device_thread_id}/SPS", steps_per_second, global_step) + for k, v in stats_dict.items(): + metrics[f"Perf/{device_thread_id}/{k}"] = v - writer.add_scalar(f"policy_versions/actor_{device_thread_id}", actor_policy_version, global_step) + # Charts: RL performance-related metrics + for k, v in charts_dict.items(): + metrics[f"Charts/{device_thread_id}/{k}"] = v.item() + # Evaluate whenever configured to if update in args.eval_at_steps: for i, (eval_name, env_config) in enumerate(this_thread_eval_cfg): print("Evaluating ", eval_name) this_thread_eval_keys[i], eval_key = jax.random.split(this_thread_eval_keys[i], 2) log_dict = env_config.run(policy, get_action_fn, params, key=eval_key) - for k, v in log_dict.items(): - if k.endswith("_all_episode_info"): - continue - writer.add_scalar(f"{eval_name}/{k}", v, global_step) + + metrics.update({f"{eval_name}/{k}": v for k, v in log_dict.items() if not k.endswith("_all_episode_info")}) + + if metrics: + # Flush the metrics at most once per global_step. This way, in the learner we can check that all actor + # threads have sent the metrics by simply counting. + metrics_queue.put(PrioritizedItem(global_step, metrics), timeout=args.queue_timeout) + metrics = {} + if libcudart is not None: + libcudart.cudaProfilerStop() def linear_schedule( @@ -721,7 +741,8 @@ def train( num_batches=args.num_minibatches * args.gradient_accumulation_steps, get_logits_and_value=partial(policy.apply, method=policy.get_logits_and_value), impala_cfg=args.loss, - ) + ), + donate_argnames=("agent_state", "key"), ), axis_name=SINGLE_DEVICE_UPDATE_DEVICES_AXIS, devices=runtime_info.global_learner_devices, @@ -729,15 +750,20 @@ def train( params_queues = [] rollout_queues = [] + metrics_queue = queue.PriorityQueue() unreplicated_params = agent_state.params key, *actor_keys = jax.random.split(key, 1 + len(args.actor_device_ids)) for d_idx, d_id in enumerate(args.actor_device_ids): - device_params = jax.device_put(unreplicated_params, runtime_info.local_devices[d_id]) + # Copy device_params so we can donate the agent_state in the multi_device_update + device_params = jax.tree.map( + partial(jnp.array, copy=True), + jax.device_put(unreplicated_params, runtime_info.local_devices[d_id]), + ) for thread_id in range(args.num_actor_threads): params_queues.append(queue.Queue(maxsize=1)) rollout_queues.append(queue.Queue(maxsize=1)) - params_queues[-1].put((device_params, args.learner_policy_version)) + params_queues[-1].put(ParamsPayload(params=device_params, policy_version=args.learner_policy_version)) threading.Thread( target=rollout, args=( @@ -747,7 +773,7 @@ def train( runtime_info, rollout_queues[-1], params_queues[-1], - writer, + metrics_queue, runtime_info.learner_devices, d_idx * args.num_actor_threads + thread_id, runtime_info.local_devices[d_id], @@ -755,7 +781,7 @@ def train( ), ).start() - rollout_queue_get_time = deque(maxlen=10) + rollout_queue_get_time = deque(maxlen=20) agent_state = jax.device_put_replicated(agent_state, devices=runtime_info.global_learner_devices) actor_policy_version = 0 @@ -769,32 +795,33 @@ def train( sharded_storages = [] for d_idx, d_id in enumerate(args.actor_device_ids): for thread_id in range(args.num_actor_threads): - ( - global_step, - actor_policy_version, - update, - sharded_storage, - avg_params_queue_get_time, - device_thread_id, - ) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get(timeout=args.queue_timeout) - sharded_storages.append(sharded_storage) + payload = rollout_queues[d_idx * args.num_actor_threads + thread_id].get(timeout=args.queue_timeout) + global_step = payload.global_step + actor_policy_version = payload.policy_version + update = payload.update + sharded_storages.append(payload.storage) rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) training_time_start = time.time() - for _ in range(args.train_epochs): - ( - agent_state, - metrics_dict, - ) = multi_device_update( - agent_state, - sharded_storages, - ) + + key, *epoch_keys = jax.random.split(key, 1 + args.train_epochs) + permutation_key = jax.random.split(epoch_keys[0], len(runtime_info.global_learner_devices)) + (agent_state, metrics_dict) = multi_device_update(agent_state, sharded_storages, permutation_key) + for epoch in range(1, args.train_epochs): + permutation_key = jax.random.split(epoch_keys[epoch], len(runtime_info.global_learner_devices)) + (agent_state, metrics_dict) = multi_device_update(agent_state, sharded_storages, permutation_key) + unreplicated_params = unreplicate(agent_state.params) if update > args.actor_update_cutoff or update % args.actor_update_frequency == 0: for d_idx, d_id in enumerate(args.actor_device_ids): - device_params = jax.device_put(unreplicated_params, runtime_info.local_devices[d_id]) + # Copy device_params so we can donate the agent_state in the multi_device_update + device_params = jax.tree.map( + partial(jnp.array, copy=True), + jax.device_put(unreplicated_params, runtime_info.local_devices[d_id]), + ) for thread_id in range(args.num_actor_threads): params_queues[d_idx * args.num_actor_threads + thread_id].put( - (device_params, args.learner_policy_version), timeout=args.queue_timeout + ParamsPayload(params=device_params, policy_version=args.learner_policy_version), + timeout=args.queue_timeout, ) # Copy the parameters from the first device to all other learner devices @@ -810,36 +837,54 @@ def train( # record rewards for plotting purposes if args.learner_policy_version % args.log_frequency == 0: - writer.add_scalar( - "stats/rollout_queue_get_time", - np.mean(rollout_queue_get_time), - global_step, - ) - writer.add_scalar( - "stats/rollout_params_queue_get_time_diff", - np.mean(rollout_queue_get_time) - avg_params_queue_get_time, - global_step, - ) - writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step) - writer.add_scalar("stats/rollout_queue_size", rollout_queues[-1].qsize(), global_step) - writer.add_scalar("stats/params_queue_size", params_queues[-1].qsize(), global_step) + metrics = { + "Perf/rollout_queue_get_time": np.mean(rollout_queue_get_time), + "Perf/training_time": time.time() - training_time_start, + "Perf/rollout_queue_size": rollout_queues[-1].qsize(), + "Perf/params_queue_size": params_queues[-1].qsize(), + "losses/value_loss": metrics_dict.pop("v_loss")[0].item(), + "losses/policy_loss": metrics_dict.pop("pg_loss")[0].item(), + "losses/entropy": metrics_dict.pop("ent_loss")[0].item(), + "losses/loss": metrics_dict.pop("loss")[0].item(), + "policy_versions/learner": args.learner_policy_version, + } + metrics.update({k: v[0].item() for k, v in metrics_dict.items()}) + + lr = unreplicate(agent_state.opt_state.hyperparams["learning_rate"]) + assert lr is not None + metrics["losses/learning_rate"] = lr + + # Receive actors' metrics from the metrics_queue, and once we have all of them plot them together + # + # If we get metrics from a future step, we just put them back in the queue for next time. + # If it is a previous step, we regretfully throw them away. + add_back_later_metrics = [] + num_actor_metrics = 0 + while num_actor_metrics < len(rollout_queues): + item = metrics_queue.get(timeout=args.queue_timeout) + actor_global_step, actor_metrics = item.priority, item.item + print(f"Got metrics from {actor_global_step=}") + + if actor_global_step == global_step: + metrics.update( + {k: (v.item() if isinstance(v, jnp.ndarray) else v) for (k, v) in actor_metrics.items()} + ) + num_actor_metrics += 1 + elif actor_global_step > global_step: + add_back_later_metrics.append(item) + else: + log.warning( + f"Had to throw away metrics for global_step {actor_global_step}, which is less than the current {global_step=}. {actor_metrics}" + ) + # We're done. Write metrics and add back the ones for the future. + writer.add_dict(metrics, global_step=global_step) + for item in add_back_later_metrics: + metrics_queue.put(item) + print( global_step, f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={args.learner_policy_version}, training time: {time.time() - training_time_start}s", ) - writer.add_scalar("losses/value_loss", metrics_dict.pop("v_loss")[0].item(), global_step) - writer.add_scalar("losses/policy_loss", metrics_dict.pop("pg_loss")[0].item(), global_step) - writer.add_scalar("losses/entropy", metrics_dict.pop("ent_loss")[0].item(), global_step) - writer.add_scalar("losses/loss", metrics_dict.pop("loss")[0].item(), global_step) - - for name, value in metrics_dict.items(): - writer.add_scalar(name, value[0].item(), global_step) - - writer.add_scalar("policy_versions/learner", args.learner_policy_version, global_step) - - lr = unreplicate(agent_state.opt_state.hyperparams["learning_rate"]) - assert lr is not None - writer.add_scalar("losses/learning_rate", lr, global_step) if args.save_model and args.learner_policy_version in args.eval_at_steps: print("Learner thread entering save barrier (should be last)") @@ -918,11 +963,13 @@ def load_train_state( pass # must be already unreplicated if isinstance(args.net, ConvLSTMConfig): for i in range(args.net.n_recurrent): - train_state.params["params"]["network_params"][f"cell_list_{i}"]["fence"]["kernel"] = np.sum( - train_state.params["params"]["network_params"][f"cell_list_{i}"]["fence"]["kernel"], - axis=2, - keepdims=True, - ) + this_cell = train_state.params["params"]["network_params"][f"cell_list_{i}"] + if "fence" in this_cell: + this_cell["fence"]["kernel"] = jnp.sum( + this_cell["fence"]["kernel"], + axis=2, + keepdims=True, + ) if finetune_with_noop_head: loaded_head = train_state.params["params"]["actor_params"]["Output"] @@ -940,5 +987,4 @@ def load_train_state( if __name__ == "__main__": args = farconf.parse_cli(sys.argv[1:], Args) pprint(args) - train(args) diff --git a/cleanba/config.py b/cleanba/config.py index adfd197..cb2836d 100644 --- a/cleanba/config.py +++ b/cleanba/config.py @@ -3,13 +3,11 @@ from pathlib import Path from typing import List, Optional -from cleanba.convlstm import ConvConfig, ConvLSTMCellConfig, ConvLSTMConfig -from cleanba.environments import AtariEnv, EnvConfig, EnvpoolBoxobanConfig, random_seed +from cleanba.convlstm import ConvConfig, ConvLSTMCellConfig, ConvLSTMConfig, LSTMConfig +from cleanba.environments import AtariEnv, CraftaxEnvConfig, EnvConfig, EnvpoolBoxobanConfig, random_seed from cleanba.evaluate import EvalConfig -from cleanba.impala_loss import ( - ImpalaLossConfig, -) -from cleanba.network import AtariCNNSpec, GuezResNetConfig, IdentityNorm, PolicySpec, SokobanResNetConfig +from cleanba.impala_loss import ActorCriticLossConfig, ImpalaLossConfig, PPOLossConfig +from cleanba.network import AtariCNNSpec, GuezResNetConfig, IdentityNorm, MLPConfig, PolicySpec, SokobanResNetConfig @dataclasses.dataclass @@ -45,7 +43,7 @@ class Args: base_run_dir: Path = Path("/tmp/cleanba") - loss: ImpalaLossConfig = ImpalaLossConfig() + loss: ActorCriticLossConfig = ImpalaLossConfig() net: PolicySpec = AtariCNNSpec(channels=(16, 32, 32), mlp_hiddens=(256,)) @@ -310,3 +308,119 @@ def sokoban_drc33_59() -> Args: head_scale=1.0, ) return out + + +def craftax_drc() -> Args: + num_envs = 512 + return Args( + train_env=CraftaxEnvConfig(max_episode_steps=3000, num_envs=num_envs, seed=1234), + eval_envs={}, + log_frequency=1, + net=ConvLSTMConfig( + embed=[ConvConfig(128, (3, 3), (1, 1), "SAME", True), ConvConfig(64, (3, 3), (1, 1), "SAME", True)], + recurrent=ConvLSTMCellConfig( + ConvConfig(64, (3, 3), (1, 1), "SAME", True), pool_and_inject="horizontal", fence_pad="no" + ), + n_recurrent=3, + mlp_hiddens=(512,), + repeats_per_step=3, + skip_final=True, + residual=True, + norm=IdentityNorm(), + ), + loss=PPOLossConfig( + gae_lambda=0.8, + gamma=0.99, + ent_coef=0.01, + vf_coef=0.25, + normalize_advantage=True, + ), + actor_update_cutoff=0, + sync_frequency=20000000000, + num_minibatches=8, + rmsprop_eps=1e-8, + local_num_envs=num_envs, + total_timesteps=3000000, + base_run_dir=Path("/training/craftax"), + learning_rate=2e-4, + final_learning_rate=1e-5, + optimizer="adam", + adam_b1=0.9, + rmsprop_decay=0.999, + base_fan_in=1, + anneal_lr=True, + max_grad_norm=1.0, + num_actor_threads=1, + num_steps=64, + train_epochs=4, + ) + + +def craftax_lstm(n_recurrent: int = 3, num_repeats: int = 1) -> Args: + num_envs = 512 + return Args( + train_env=CraftaxEnvConfig(max_episode_steps=3000, num_envs=num_envs, seed=1234, obs_flat=True), + eval_envs={}, + log_frequency=1, + net=LSTMConfig( + embed_hiddens=(1024,), + recurrent_hidden=1024, + n_recurrent=3, + repeats_per_step=1, + norm=IdentityNorm(), + mlp_hiddens=(512,), + ), + actor_update_cutoff=0, + sync_frequency=200, + num_minibatches=8, + rmsprop_eps=1e-8, + local_num_envs=num_envs, + total_timesteps=3000000, + base_run_dir=Path("/training/craftax"), + learning_rate=2e-4, + final_learning_rate=1e-5, + optimizer="adam", + adam_b1=0.9, + rmsprop_decay=0.999, + base_fan_in=1, + anneal_lr=True, + max_grad_norm=1.0, + num_actor_threads=1, + num_steps=64, + train_epochs=1, + ) + + +def craftax_mlp() -> Args: + num_envs = 512 + return Args( + train_env=CraftaxEnvConfig(max_episode_steps=3000, num_envs=num_envs, seed=1234, obs_flat=True), + eval_envs={}, + log_frequency=1, + net=MLPConfig(hiddens=(512, 512, 512), norm=IdentityNorm(), yang_init=False, activation="tanh", head_scale=0.01), + loss=PPOLossConfig( + gae_lambda=0.8, + gamma=0.99, + ent_coef=0.01, + vf_coef=0.25, + normalize_advantage=True, + ), + actor_update_cutoff=0, + sync_frequency=200, + num_minibatches=8, + rmsprop_eps=1e-8, + local_num_envs=num_envs, + total_timesteps=3000000, + base_run_dir=Path("/training/craftax"), + learning_rate=2e-4, + final_learning_rate=1e-5, + optimizer="adam", + adam_b1=0.9, + rmsprop_decay=0.999, + base_fan_in=1, + anneal_lr=True, + max_grad_norm=1.0, + num_actor_threads=1, + num_steps=64, + train_epochs=1, + ) diff --git a/cleanba/convlstm.py b/cleanba/convlstm.py index 69ad804..b385a5f 100644 --- a/cleanba/convlstm.py +++ b/cleanba/convlstm.py @@ -51,17 +51,16 @@ class BaseLSTMConfig(PolicySpec): mlp_hiddens: Tuple[int, ...] = (256,) skip_final: bool = True residual: bool = False + use_relu: bool = False @abc.abstractmethod - def make(self) -> "BaseLSTM": - ... + def make(self) -> "BaseLSTM": ... @dataclasses.dataclass(frozen=True) class ConvLSTMConfig(BaseLSTMConfig): embed: List[ConvConfig] = dataclasses.field(default_factory=list) recurrent: ConvLSTMCellConfig = ConvLSTMCellConfig(ConvConfig(32, (3, 3), (1, 1), "SAME", True)) - use_relu: bool = True def make(self) -> "ConvLSTM": return ConvLSTM(self) @@ -112,8 +111,7 @@ def setup(self): self.dense_list = [nn.Dense(hidden) for hidden in self.cfg.mlp_hiddens] @abc.abstractmethod - def _compress_input(self, x: jax.Array) -> jax.Array: - ... + def _compress_input(self, x: jax.Array) -> jax.Array: ... @nn.nowrap def initialize_carry(self, rng, input_shape) -> LSTMState: @@ -124,7 +122,9 @@ def apply_cells_once(self, carry: LSTMState, inputs: jax.Array) -> tuple[LSTMSta """ Applies all cells in `self.cell_list` once. `Inputs` gets passed as the input to every cell """ - assert len(inputs.shape) == 4 + assert len(inputs.shape) == 4 or len(inputs.shape) == 2, ( + f"inputs shape must be [batch, c, h, w] or [batch, c] but is {inputs.shape=}" + ) carry = list(carry) # copy # Top-down skip connection from previous time step @@ -150,7 +150,9 @@ def _apply_cells(self, carry: LSTMState, inputs: jax.Array, episode_starts: jax. Applies all cells in `self.cell_list`, several times: `self.cfg.repeats_per_step` times. Preprocesses the carry so it gets zeroed at the start of an episode """ - assert len(inputs.shape) == 4 + assert len(inputs.shape) == 4 or len(inputs.shape) == 2, ( + f"inputs shape must be [batch, c, h, w] or [batch, c] but is {inputs.shape=}" + ) assert len(episode_starts.shape) == 1 not_reset = ~episode_starts @@ -223,30 +225,6 @@ def initialize_carry(self, rng, input_shape) -> LSTMState: return super().initialize_carry(rng, (n, h, w, c)) -class LSTM(BaseLSTM): - cfg: LSTMConfig - - def setup(self): - super().setup() - self.compress_list = [nn.Dense(hidden) for hidden in self.cfg.embed_hiddens] - self.cell_list = [] # LSTMCell(self.cfg.cell, features=self.cfg.recurrent_hidden) for _ in range(self.cfg.n_recurrent)] - - def _compress_input(self, x: jax.Array) -> jax.Array: - assert len(x.shape) == 4, f"observations shape must be [batch, c, h, w] but is {x.shape=}" - - # Flatten input - x = jnp.reshape(x, (x.shape[0], math.prod(x.shape[1:]))) - - for c in self.compress_list: - x = c(x) - x = nn.relu(x) - return x - - @nn.nowrap - def initialize_carry(self, rng, input_shape) -> LSTMState: - return super().initialize_carry(rng, (input_shape[0], self.cfg.embed_hiddens[-1])) - - class ConvLSTMCell(nn.RNNCellBase): cfg: ConvLSTMCellConfig @@ -354,3 +332,58 @@ def initialize_carry(self, rng: jax.Array, input_shape: tuple[int, ...]) -> LSTM def num_feature_axes(self) -> int: return 3 + + +class LSTMCell(nn.Module): + features: int + + @nn.compact + def __call__( + self, carry: LSTMCellState, inputs: jax.Array, prev_layer_hidden: jax.Array + ) -> tuple[LSTMCellState, jax.Array]: + # Concatenate inputs with prev_layer_hidden + combined_inputs = jnp.concatenate([inputs, prev_layer_hidden], axis=-1) + + # Use Flax's built-in LSTM implementation + lstm = nn.LSTMCell(features=self.features) + # Convert our state format to Flax's format + flax_carry = (carry.c, carry.h) + # Apply the LSTM + (new_c, new_h), out = lstm(flax_carry, combined_inputs) + # Convert back to our state format + return LSTMCellState(c=new_c, h=new_h), out + + @nn.nowrap + def initialize_carry(self, rng: jax.Array, input_shape: tuple[int, ...]) -> LSTMCellState: + # Initialize with zeros like the ConvLSTMCell + shape = (*input_shape[:-1], self.features) + c_rng, h_rng = jax.random.split(rng, 2) + return LSTMCellState(c=nn.zeros_init()(c_rng, shape), h=nn.zeros_init()(h_rng, shape)) + + +class LSTM(BaseLSTM): + cfg: LSTMConfig + + def setup(self): + super().setup() + self.compress_list = [nn.Dense(hidden) for hidden in self.cfg.embed_hiddens] + self.cell_list = [LSTMCell(features=self.cfg.recurrent_hidden) for _ in range(self.cfg.n_recurrent)] + + def _compress_input(self, x: jax.Array) -> jax.Array: + assert len(x.shape) == 4 or len(x.shape) == 2, ( + f"observations shape must be [batch, c, h, w] or [batch, c] but is {x.shape=}" + ) + if len(x.shape) == 4: + x = jnp.reshape(x, (x.shape[0], math.prod(x.shape[1:]))) + + for c in self.compress_list: + x = c(x) + if self.cfg.use_relu: + x = nn.relu(x) + return x + + @nn.nowrap + def initialize_carry(self, rng, input_shape) -> LSTMState: + batch_size = input_shape[0] + shape = (batch_size, self.cfg.recurrent_hidden) + return super().initialize_carry(rng, shape) diff --git a/cleanba/env_trivial.py b/cleanba/env_trivial.py index 0af40f1..5557816 100644 --- a/cleanba/env_trivial.py +++ b/cleanba/env_trivial.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Iterable, List, Optional, SupportsFloat, Union import gymnasium as gym +import jax import jax.numpy as jnp import numpy as np from numpy.typing import NDArray @@ -105,6 +106,9 @@ def reset_async(self, seed: Optional[Union[int, List[int]]] = None, options: Opt seed = self._seeds return super().reset_async(seed, options) + def step(self, actions: np.ndarray | jax.Array) -> tuple[Any, np.ndarray, np.ndarray, np.ndarray, dict[str, Any]]: + return super().step(np.asarray(actions)) + @dataclasses.dataclass class MockSokobanEnvConfig(EnvConfig): diff --git a/cleanba/environments.py b/cleanba/environments.py index dae8148..5ac76b9 100644 --- a/cleanba/environments.py +++ b/cleanba/environments.py @@ -5,14 +5,201 @@ import warnings from functools import partial from pathlib import Path -from typing import Any, Callable, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Literal, Optional, Self, Tuple, Union +import flax.struct import gym_sokoban # noqa: F401 import gymnasium as gym +import jax +import jax.experimental.compilation_cache +import jax.numpy as jnp import numpy as np -from gymnasium.vector.utils.spaces import batch_space +from gymnasium.vector.utils import batch_space from numpy.typing import NDArray +if TYPE_CHECKING: + from craftax.craftax.craftax_state import EnvParams + from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv + + +class EpisodeEvalState(flax.struct.PyTreeNode): + episode_length: jax.Array + episode_success: jax.Array + episode_others: dict[str, jax.Array] + + returned_episode_length: jax.Array + returned_episode_success: jax.Array + returned_episode_others: dict[str, jax.Array] + + @classmethod + def new(cls: type[Self], num_envs: int, others: Iterable[str]) -> Self: + zero_float = jnp.zeros(()) + zero_int = jnp.zeros((), dtype=jnp.int32) + zero_bool = jnp.zeros((), dtype=jnp.bool) + others = set(others) | {"episode_return"} + return jax.tree.map( + partial(jnp.repeat, repeats=num_envs), + cls( + zero_int, + zero_bool, + {o: zero_float for o in others}, + zero_int, + zero_bool, + {o: zero_float for o in others}, + ), + ) + + @jax.jit + def update( + self: Self, reward: jnp.ndarray, terminated: jnp.ndarray, truncated: jnp.ndarray, others: dict[str, jnp.ndarray] + ) -> Self: + done = terminated | truncated + + new_episode_success = terminated + new_episode_length = self.episode_length + 1 + + # Populate things to do tree.map + _episode_others = {k: jnp.zeros(new_episode_length.shape) for k in others.keys()} + _episode_others.update(self.episode_others) + _returned_episode_others = {k: jnp.zeros(new_episode_length.shape) for k in others.keys()} + _returned_episode_others.update(self.returned_episode_others) + + new_others = jax.tree.map(lambda a, b: a + b, _episode_others, {"episode_return": reward, **others}) + + new_state = self.__class__( + episode_length=new_episode_length * (1 - done), + episode_success=new_episode_success * (1 - done), + episode_others=jax.tree.map(lambda x: x * (1 - done), new_others), + returned_episode_length=jax.lax.select(done, new_episode_length, self.returned_episode_length), + returned_episode_success=jax.lax.select(done, new_episode_success, self.returned_episode_success), + returned_episode_others=jax.tree.map(partial(jax.lax.select, done), new_others, _returned_episode_others), + ) + return new_state + + def update_info(self) -> dict[str, Any]: + return { + "returned_episode_length": self.returned_episode_length, + "returned_episode_success": self.returned_episode_success, + **{f"returned_{k}": v for k, v in self.returned_episode_others.items()}, + } + + +class EpisodeEvalWrapper(gym.vector.VectorEnvWrapper): + """Log the episode returns and lengths.""" + + state: EpisodeEvalState + + def __init__(self, env: gym.vector.VectorEnv): + super().__init__(env) + self._env = env + + @staticmethod + def _info_achievements(info: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in info.items() if "Achievement" in k} + + def reset(self, seed: Optional[Union[int, List[int]]] = None, options: Optional[dict] = None) -> Tuple[jnp.ndarray, dict]: + obs, info = self._env.reset() + self.state = EpisodeEvalState.new(self._env.num_envs, self._info_achievements(info).keys()) + return obs, {**info, **self.state.update_info()} + + def step(self, actions: jnp.ndarray) -> Tuple[Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict]: + obs, reward, terminated, truncated, info = self._env.step(actions) + # Atari envs clip their reward to [-1, 1], meaning we need to use the reward in `info` to get + # the true return. + non_clipped_rewards = info.get("reward", reward) + self.state = self.state.update(non_clipped_rewards, terminated, truncated, self._info_achievements(info)) + return obs, reward, terminated, truncated, {**info, **self.state.update_info()} + + +class CraftaxVectorEnv(gym.vector.VectorEnv): + """ + Craftax environment with a generic VectorEnv interface. + """ + + cfg: "CraftaxEnvConfig" + env: "CraftaxSymbolicEnv" + rng_keys: jnp.ndarray + state: Any + obs: jnp.ndarray + env_params: "EnvParams" + + def __init__(self, cfg: "CraftaxEnvConfig"): + from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv + + self.cfg = cfg + self.env = CraftaxSymbolicEnv() + + obs_shape = (8268,) if cfg.obs_flat else (134, 9, 11) # My guess is it should be (9, 11, 134) should be reversed + single_observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=obs_shape, dtype=np.float32) + single_action_space = gym.spaces.Discrete(self.env.action_space().n) + super().__init__(cfg.num_envs, single_observation_space, single_action_space) + + self.env_params = self.env.default_params + self.closed = False + + self.device, *_ = jax.devices(cfg.jit_backend) + + # set rng_keys, state, obs + self.reset(self.cfg.seed) + + def _process_obs(self, obs_flat): + if self.cfg.obs_flat: + return obs_flat + expected_size = 8268 + assert obs_flat.shape[0] == expected_size, ( + f"Observation size mismatch: got {obs_flat.shape[0]}, expected {expected_size}" + ) + + mapobs = obs_flat[:8217].reshape(9, 11, 83) + invobs = obs_flat[8217:].reshape(51) + invobs_spatial = invobs.reshape(1, 1, 51).repeat(9, axis=0).repeat(11, axis=1) + obs_nhwc = jnp.concatenate([mapobs, invobs_spatial], axis=-1) # (9, 11, 134) + obs_nchw = jnp.transpose(obs_nhwc, (2, 0, 1)) # (134, 9, 11) + + return obs_nchw + + @partial(jax.jit, static_argnames=("self",)) + @partial(jax.vmap, in_axes=(None, 0)) + def _reset_wait_pure(self, key: jnp.ndarray) -> Tuple[jnp.ndarray, Any, jnp.ndarray]: + key, reset_key = jax.random.split(key) + obs_flat, state = self.env.reset_env(reset_key, self.env_params) + obs_processed = self._process_obs(obs_flat) + return obs_processed, state, key + + def reset(self, seed: Optional[Union[int, List[int]]] = None, options: Optional[dict] = None) -> Tuple[jnp.ndarray, dict]: + """Reset the environment.""" + if isinstance(seed, int): + self.rng_keys = jax.random.split(jax.random.PRNGKey(seed), self.num_envs) + elif isinstance(seed, list): + assert len(seed) == self.num_envs + self.rng_keys = jax.jit(jax.vmap(jax.random.PRNGKey))(jnp.asarray(seed)) + self.rng_keys = jax.device_put(self.rng_keys, self.device) + self.obs, self.state, self.rng_keys = self._reset_wait_pure(self.rng_keys) + return self.obs, {} + + @partial(jax.jit, static_argnames=("self",)) + @partial(jax.vmap, in_axes=(None, 0, 0, 0)) + def _step_pure(self, key, state, action): + key, step_key = jax.random.split(key) + obs_flat, state, rewards, dones, info = self.env.step(step_key, state, action) + terminated = dones + # assume no truncation (basically true as agent does not survive long enough) + truncated = jnp.zeros_like(dones, dtype=bool) + assert terminated.dtype == truncated.dtype + obs = self._process_obs(obs_flat) + return key, obs, state, rewards, terminated, truncated, info + + def step(self, actions: jnp.ndarray) -> Tuple[Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict]: + """Execute one step in the environment.""" + actions = jax.device_put(actions, self.device) + self.rng_keys, self.obs, self.state, rewards, terminated, truncated, info = self._step_pure( + self.rng_keys, self.state, actions + ) + return self.obs, rewards, terminated, truncated, info + + def close(self, **kwargs): + self.closed = True + def random_seed() -> int: return random.randint(0, 2**31 - 2) @@ -27,6 +214,7 @@ class EnvConfig(abc.ABC): @property @abc.abstractmethod def make(self) -> Callable[[], gym.vector.VectorEnv]: + """Create a vector environment.""" ... @@ -76,23 +264,35 @@ def __init__(self, num_envs: int, envs_fn: Callable[[], Any], remove_last_action super().__init__(num_envs=num_envs, observation_space=envs.observation_space, action_space=envs.action_space) self.envs = envs - def step_async(self, actions: np.ndarray): - self.envs.send(actions) - - def step_wait(self, **kwargs) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]: - return self.envs.recv(**kwargs) + def step(self, actions: np.ndarray) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]: + """Execute one step in the environment.""" + self.envs.send(np.array(actions)) + return self.envs.recv() - def reset_async(self, seed: Optional[Union[int, List[int]]] = None, options: Optional[dict] = None): + def reset(self, seed: Optional[Union[int, List[int]]] = None, options: Optional[dict] = None) -> Tuple[Any, dict]: + """Reset the environment.""" assert seed is None assert not options self.envs.async_reset() - - def reset_wait(self, seed: Optional[Union[int, List[int]]] = None, options: Optional[dict] = None): - assert seed is None - assert not options return self.envs.recv(reset=True, return_info=self.envs.config["gym_reset_return_info"]) +@dataclasses.dataclass +class CraftaxEnvConfig(EnvConfig): + """Configuration class for integrating Craftax with IMPALA.""" + + max_episode_steps: int + num_envs: int = 1 + seed: int = dataclasses.field(default_factory=random_seed) + obs_flat: bool = False + jit_backend: str = dataclasses.field(default_factory=lambda: jax.devices()[0].platform) + + @property + def make(self) -> Callable[[], CraftaxVectorEnv]: # type: ignore + # This property returns a function that creates the Craftax environment wrapper. + return lambda: CraftaxVectorEnv(self) + + @dataclasses.dataclass class EnvpoolBoxobanConfig(EnvpoolEnvConfig): env_id: str = "Sokoban-v0" @@ -189,8 +389,9 @@ def env_reward_kwargs(self): class VectorNHWCtoNCHWWrapper(gym.vector.VectorEnvWrapper): - def __init__(self, env: gym.vector.VectorEnv, remove_last_action: bool = False): + def __init__(self, env: gym.vector.VectorEnv, nn_without_noop: bool = False, use_np_arrays: bool = False): super().__init__(env) + self.use_np_arrays = use_np_arrays obs_space = env.single_observation_space if isinstance(obs_space, gym.spaces.Box): shape = (obs_space.shape[2], *obs_space.shape[:2], *obs_space.shape[3:]) @@ -203,24 +404,28 @@ def __init__(self, env: gym.vector.VectorEnv, remove_last_action: bool = False): self.num_envs = env.num_envs self.observation_space = batch_space(self.single_observation_space, n=self.num_envs) - if remove_last_action: + if nn_without_noop: assert isinstance(env.single_action_space, gym.spaces.Discrete) env.single_action_space = gym.spaces.Discrete(env.single_action_space.n - 1) env.action_space = batch_space(env.single_action_space, n=self.num_envs) self.single_action_space = env.single_action_space self.action_space = env.action_space - def reset_wait(self, **kwargs) -> tuple[Any, dict]: - obs, info = super().reset_wait(**kwargs) - return np.moveaxis(obs, 3, 1), info + def reset(self, **kwargs) -> tuple[Any, dict]: + obs, info = super().reset(**kwargs) + return jnp.moveaxis(obs, 3, 1), info - def step_wait(self) -> tuple[Any, NDArray, NDArray, NDArray, dict]: - obs, reward, terminated, truncated, info = super().step_wait() - return np.moveaxis(obs, 3, 1), reward, terminated, truncated, info + def step(self, actions: jnp.ndarray) -> tuple[Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, dict]: + if self.use_np_arrays: + actions = np.asarray(actions) + obs, reward, terminated, truncated, info = super().step(actions) + return jnp.moveaxis(obs, 3, 1), reward, terminated, truncated, info @classmethod - def from_fn(cls, fn: Callable[[], gym.vector.VectorEnv], nn_without_noop) -> gym.vector.VectorEnv: - return cls(fn(), nn_without_noop) + def from_fn( + cls, fn: Callable[[], gym.vector.VectorEnv], nn_without_noop: bool, use_np_arrays: bool + ) -> gym.vector.VectorEnv: + return cls(fn(), nn_without_noop=nn_without_noop, use_np_arrays=use_np_arrays) @dataclasses.dataclass @@ -248,6 +453,7 @@ def make(self) -> Callable[[], gym.vector.VectorEnv]: **self.env_reward_kwargs(), ), self.nn_without_noop, + use_np_arrays=True, ) return make_fn @@ -286,6 +492,7 @@ def make(self) -> Callable[[], gym.vector.VectorEnv]: **self.env_reward_kwargs(), ), self.nn_without_noop, + use_np_arrays=True, # TODO: use the XLA interface for envpool and set this to false ) return make_fn diff --git a/cleanba/evaluate.py b/cleanba/evaluate.py index ea20eba..1f8b250 100644 --- a/cleanba/evaluate.py +++ b/cleanba/evaluate.py @@ -57,7 +57,7 @@ def run(self, policy: Policy, get_action_fn, params, *, key: jnp.ndarray) -> dic # Update the carry with the initial observation many times for think_step in range(steps_to_think): - carry, _, _, key = get_action_fn( + carry, _, _, _, key = get_action_fn( params, carry, obs, episode_starts_no, key, temperature=self.temperature ) @@ -76,7 +76,7 @@ def run(self, policy: Policy, get_action_fn, params, *, key: jnp.ndarray) -> dic while not np.all(eps_done): if i >= self.safeguard_max_episode_steps: break - carry, action, _, key = get_action_fn( + carry, action, _, _, key = get_action_fn( params, carry, obs, episode_starts_no, key, temperature=self.temperature ) diff --git a/cleanba/impala_loss.py b/cleanba/impala_loss.py index b93329a..9346a5c 100644 --- a/cleanba/impala_loss.py +++ b/cleanba/impala_loss.py @@ -1,6 +1,7 @@ +import abc import dataclasses from functools import partial -from typing import Any, Callable, List, Literal, NamedTuple +from typing import Any, Callable, ClassVar, List, Literal, NamedTuple, Self import jax import jax.numpy as jnp @@ -11,12 +12,51 @@ from numpy.typing import NDArray +class Rollout(NamedTuple): + obs_t: jax.Array + carry_t: Any + a_t: jax.Array + logits_t: jax.Array + value_t: jax.Array + r_t: jax.Array | NDArray + episode_starts_t: jax.Array | NDArray + truncated_t: jax.Array | NDArray + + +GetLogitsAndValueFn = Callable[ + [Any, Any, jax.Array, jax.Array | NDArray], tuple[Any, jax.Array, jax.Array, dict[str, jax.Array]] +] + + @dataclasses.dataclass(frozen=True) -class ImpalaLossConfig: +class ActorCriticLossConfig(abc.ABC): + needs_last_value: ClassVar[bool] + gamma: float = 0.99 # the discount factor gamma ent_coef: float = 0.01 # coefficient of the entropy vf_coef: float = 0.25 # coefficient of the value function + normalize_advantage: bool = False + + @abc.abstractmethod + def loss( + self: Self, + params: Any, + get_logits_and_value: GetLogitsAndValueFn, + minibatch: Rollout, + ) -> tuple[jax.Array, dict[str, jax.Array]]: ... + + def maybe_normalize_advantage(self, adv_t: jax.Array) -> jax.Array: + def _norm_advantage(): + return (adv_t - jnp.mean(adv_t)) / (jnp.std(adv_t, ddof=1) + 1e-8) + + return jax.lax.cond(self.normalize_advantage, _norm_advantage, lambda: adv_t) + + +@dataclasses.dataclass(frozen=True) +class ImpalaLossConfig(ActorCriticLossConfig): + needs_last_value: ClassVar[bool] = False + # Interpolate between VTrace (1.0) and monte-carlo function (0.0) estimates, for the estimate of targets, used in # both the value and policy losses. It's the parameter in Remark 2 of Espeholt et al. # (https://arxiv.org/pdf/1802.01561.pdf) @@ -30,8 +70,6 @@ class ImpalaLossConfig: # (https://arxiv.org/pdf/1802.01561.pdf) clip_pg_rho_threshold: float = 1.0 - normalize_advantage: bool = False - logit_l2_coef: float = 0.0 weight_l2_coef: float = 0.0 @@ -63,147 +101,194 @@ def adv_multiplier(self, vtrace_errors: jax.Array) -> jax.Array | float: else: raise ValueError(f"{self.advantage_multiplier=}") + # The reason this loss function is peppered with `del` statements is so we don't accidentally use the wrong + # (time-shifted) variable when coding + def loss( + self: Self, + params: Any, + get_logits_and_value: GetLogitsAndValueFn, + minibatch: Rollout, + ) -> tuple[jax.Array, dict[str, jax.Array]]: + # If the episode has actually terminated, the outgoing state's value is known to be zero. + # + # If the episode was truncated or terminated, we don't want the value estimation for future steps to influence the + # value estimation for the current one (or previous once). I.e., in the VTrace (or GAE) recurrence, we want to stop + # the value at time t+1 from influencing the value at time t and before. + # + # Both of these aims can be served by setting the discount to zero when the episode is terminated or truncated. + # + # done_t = truncated_t | terminated_t + done_t = minibatch.episode_starts_t[1:] + discount_t = (~done_t) * self.gamma + del done_t + + _final_carry, nn_logits_from_obs, nn_value_from_obs, nn_metrics = get_logits_and_value( + params, jax.tree.map(lambda x: x[0], minibatch.carry_t), minibatch.obs_t, minibatch.episode_starts_t + ) + del _final_carry + + # There's one extra timestep at the end for `obs_t` than logits and the rest of objects in `minibatch`, so we need + # to cut these values to size. + # + # For the logits, we discard the last one, which makes the time-steps of `nn_logits_from_obs` match exactly with the + # time-steps from `minibatch.logits_t` + nn_logits_t = nn_logits_from_obs[:-1] + + ## Remark 1: + # v_t does not enter the gradient in any way, because + # 1. it's stop_grad()-ed in the `vtrace_td_error_and_advantage.errors` + # 2. it intervenes in `vtrace_td_error_and_advantage.pg_advantage`, but that's stop_grad() ed by the pg loss. + # + # so we don't actually need to call stop_grad here. + # + ## Remark 2: + # If we followed normal RL conventions, v_t corresponds to V(s_{t+1}) and v_tm1 corresponds to V(s_{t}). This can be + # gleaned from looking at the implementation of the TD error in `rlax.vtrace`. + # + # We keep the name error from the `rlax` library for consistence. + v_t = nn_value_from_obs[1:] + v_tm1 = nn_value_from_obs[:-1] + del nn_value_from_obs + + # If the episode has been truncated, the value of the next state (after truncation) would be some non-zero amount. + # But we don't have access to that state, because the resetting code just throws it away. To compensate, we'll + # actually truncate 1 step earlier than the time limit, use the value of the state we know, and just discard the + # transition that actually caused truncation. That is, we receive: + # + # s0, --r0-> s1 --r1-> s2 --r2-> s3 --r3-> ... + # + # and say the episode was truncated at s3. We don't know s4, so we can't calculate V(s4), which we need for the + # objective. So instead we'll discard r3 and treat s3 as the final state. Now we can calculate V(s3). + # + # We could get the correct TD error just by ignoring the loss at the truncated steps. However, VTrace propagates + # errors backward, so the truncated-episode error would propagate backward anyways. To solve this, we set the reward + # at truncated timesteps to be equal to v_tm1. The discount for those steps also has to be 0, that's determined by + # `discount_t` defined above. + mask_t = jnp.float32(~minibatch.truncated_t) + r_t = jnp.where(minibatch.truncated_t, jax.lax.stop_gradient(v_tm1), minibatch.r_t) + + rhos_tm1 = rlax.categorical_importance_sampling_ratios(nn_logits_t, minibatch.logits_t, minibatch.a_t) + + vtrace_td_error_and_advantage = jax.vmap( + partial( + rlax.vtrace_td_error_and_advantage, + lambda_=self.vtrace_lambda, + clip_rho_threshold=self.clip_rho_threshold, + clip_pg_rho_threshold=self.clip_pg_rho_threshold, + stop_target_gradients=True, + ), + in_axes=1, + out_axes=1, + ) -class Rollout(NamedTuple): - obs_t: jax.Array - carry_t: Any - a_t: jax.Array - logits_t: jax.Array - r_t: jax.Array | NDArray - episode_starts_t: jax.Array | NDArray - truncated_t: jax.Array | NDArray - - -GetLogitsAndValueFn = Callable[ - [Any, Any, jax.Array, jax.Array | NDArray], tuple[Any, jax.Array, jax.Array, dict[str, jax.Array]] -] - - -# The reason this loss function is peppered with `del` statements is so we don't accidentally use the wrong -# (time-shifted) variable when coding -def impala_loss( - params: Any, - get_logits_and_value: GetLogitsAndValueFn, - args: ImpalaLossConfig, - minibatch: Rollout, -) -> tuple[jax.Array, dict[str, jax.Array]]: - # If the episode has actually terminated, the outgoing state's value is known to be zero. - # - # If the episode was truncated or terminated, we don't want the value estimation for future steps to influence the - # value estimation for the current one (or previous once). I.e., in the VTrace (or GAE) recurrence, we want to stop - # the value at time t+1 from influencing the value at time t and before. - # - # Both of these aims can be served by setting the discount to zero when the episode is terminated or truncated. - # - # done_t = truncated_t | terminated_t - done_t = minibatch.episode_starts_t[1:] - discount_t = (~done_t) * args.gamma - del done_t - - _final_carry, nn_logits_from_obs, nn_value_from_obs, nn_metrics = get_logits_and_value( - params, jax.tree.map(lambda x: x[0], minibatch.carry_t), minibatch.obs_t, minibatch.episode_starts_t - ) - del _final_carry - - # There's one extra timestep at the end for `obs_t` than logits and the rest of objects in `minibatch`, so we need - # to cut these values to size. - # - # For the logits, we discard the last one, which makes the time-steps of `nn_logits_from_obs` match exactly with the - # time-steps from `minibatch.logits_t` - nn_logits_t = nn_logits_from_obs[:-1] - - ## Remark 1: - # v_t does not enter the gradient in any way, because - # 1. it's stop_grad()-ed in the `vtrace_td_error_and_advantage.errors` - # 2. it intervenes in `vtrace_td_error_and_advantage.pg_advantage`, but that's stop_grad() ed by the pg loss. - # - # so we don't actually need to call stop_grad here. - # - ## Remark 2: - # If we followed normal RL conventions, v_t corresponds to V(s_{t+1}) and v_tm1 corresponds to V(s_{t}). This can be - # gleaned from looking at the implementation of the TD error in `rlax.vtrace`. - # - # We keep the name error from the `rlax` library for consistence. - v_t = nn_value_from_obs[1:] - v_tm1 = nn_value_from_obs[:-1] - del nn_value_from_obs - - # If the episode has been truncated, the value of the next state (after truncation) would be some non-zero amount. - # But we don't have access to that state, because the resetting code just throws it away. To compensate, we'll - # actually truncate 1 step earlier than the time limit, use the value of the state we know, and just discard the - # transition that actually caused truncation. That is, we receive: - # - # s0, --r0-> s1 --r1-> s2 --r2-> s3 --r3-> ... - # - # and say the episode was truncated at s3. We don't know s4, so we can't calculate V(s4), which we need for the - # objective. So instead we'll discard r3 and treat s3 as the final state. Now we can calculate V(s3). - # - # We could get the correct TD error just by ignoring the loss at the truncated steps. However, VTrace propagates - # errors backward, so the truncated-episode error would propagate backward anyways. To solve this, we set the reward - # at truncated timesteps to be equal to v_tm1. The discount for those steps also has to be 0, that's determined by - # `discount_t` defined above. - mask_t = jnp.float32(~minibatch.truncated_t) - r_t = jnp.where(minibatch.truncated_t, jax.lax.stop_gradient(v_tm1), minibatch.r_t) - - rhos_tm1 = rlax.categorical_importance_sampling_ratios(nn_logits_t, minibatch.logits_t, minibatch.a_t) - - vtrace_td_error_and_advantage = jax.vmap( - partial( - rlax.vtrace_td_error_and_advantage, - lambda_=args.vtrace_lambda, - clip_rho_threshold=args.clip_rho_threshold, - clip_pg_rho_threshold=args.clip_pg_rho_threshold, - stop_target_gradients=True, - ), - in_axes=1, - out_axes=1, - ) + vtrace_returns = vtrace_td_error_and_advantage(v_tm1, v_t, r_t, discount_t, rhos_tm1) + + # We're going to multiply advantages by this value, so the policy doesn't change too much in situations where the + # value error is large. + adv_multiplier = self.adv_multiplier(vtrace_returns.errors) + + # Policy-gradient loss: stop_grad(advantage) * log_p(actions), with importance ratios. The importance ratios here + # are implicit in `pg_advs`. + norm_advantage = self.maybe_normalize_advantage(vtrace_returns.pg_advantage) + pg_advs = jax.lax.stop_gradient(adv_multiplier * norm_advantage) + pg_loss = jnp.mean(jax.vmap(rlax.policy_gradient_loss, in_axes=1)(nn_logits_t, minibatch.a_t, pg_advs, mask_t)) + + # Value loss: MSE/Huber loss of VTrace-estimated errors + ## Errors should be zero where mask_t is False, but we multiply anyways + v_loss = jnp.mean(self.vf_loss_fn(vtrace_returns.errors) * mask_t) + + # Entropy loss: negative average entropy of the policy across timesteps and environments + ent_loss = jnp.mean(jax.vmap(rlax.entropy_loss, in_axes=1)(nn_logits_t, mask_t)) + + total_loss = pg_loss + total_loss += self.vf_coef * v_loss + total_loss += self.ent_coef * ent_loss + total_loss += self.logit_l2_coef * jnp.sum(jnp.square(nn_logits_from_obs)) + + actor_params = jax.tree.leaves(params.get("params", {}).get("actor_params", {})) + critic_params = jax.tree.leaves(params.get("params", {}).get("critic_params", {})) + + total_loss += self.weight_l2_coef * sum(jnp.sum(jnp.square(p)) for p in [*actor_params, *critic_params]) + + # Useful metrics to know + targets_tm1 = vtrace_returns.errors + v_tm1 + metrics_dict = dict( + pg_loss=pg_loss, + v_loss=v_loss, + ent_loss=ent_loss, + var_explained=1 - jnp.var(vtrace_returns.errors, ddof=1) / jnp.var(targets_tm1, ddof=1), + proportion_of_boxes=jnp.mean(minibatch.r_t > 0), + **nn_metrics, + adv_multiplier=jnp.mean(adv_multiplier), + ) + return total_loss, metrics_dict - vtrace_returns = vtrace_td_error_and_advantage(v_tm1, v_t, r_t, discount_t, rhos_tm1) - # We're going to multiply advantages by this value, so the policy doesn't change too much in situations where the - # value error is large. - adv_multiplier = args.adv_multiplier(vtrace_returns.errors) +@dataclasses.dataclass(frozen=True) +class PPOLossConfig(ActorCriticLossConfig): + needs_last_value: ClassVar[bool] = True + + gae_lambda: float = 0.8 + clip_eps: float = 0.2 + vf_clip_eps: float = 0.2 + + def loss( + self: Self, params: Any, get_logits_and_value: GetLogitsAndValueFn, minibatch: Rollout + ) -> tuple[jax.Array, dict[str, jax.Array]]: + done_t = minibatch.episode_starts_t[1:] + discount_t = (~done_t) * self.gamma + del done_t + + _final_carry, nn_logits_from_obs, nn_value_from_obs, nn_metrics = get_logits_and_value( + params, jax.tree.map(lambda x: x[0], minibatch.carry_t), minibatch.obs_t, minibatch.episode_starts_t + ) + del _final_carry + + # There's one extra timestep at the end for `obs_t` than logits and the rest of objects in `minibatch`, so we need + # to cut the logits to size. + nn_logits_t = nn_logits_from_obs[:-1] + # We keep the name error (t vs tm1) from the `rlax` library for consistence. + nn_value_tm1 = nn_value_from_obs[:-1] + minibatch_value_tm1 = jax.lax.stop_gradient(minibatch.value_t[:-1]) + + # Ignore truncated steps using the same technique as before + mask_t = jnp.float32(~minibatch.truncated_t) + # This r_t cancels out exactly at truncated steps in the GAE calculation + r_t = jnp.where(minibatch.truncated_t, jax.lax.stop_gradient(nn_value_from_obs[:-1]), minibatch.r_t) + + # Compute advantage and clipped value loss + gae = jax.vmap(rlax.truncated_generalized_advantage_estimation, in_axes=(1, 1, None, 1, None), out_axes=1)( + r_t, discount_t, self.gae_lambda, minibatch.value_t, True + ) + value_targets = gae + minibatch_value_tm1 - # Policy-gradient loss: stop_grad(advantage) * log_p(actions), with importance ratios. The importance ratios here - # are implicit in `pg_advs`. - norm_advantage = (vtrace_returns.pg_advantage - jnp.mean(vtrace_returns.pg_advantage)) / ( - jnp.std(vtrace_returns.pg_advantage, ddof=1) + 1e-8 - ) - pg_advs = jax.lax.stop_gradient( # Just in case - adv_multiplier * jax.lax.select(args.normalize_advantage, norm_advantage, vtrace_returns.pg_advantage) - ) - pg_loss = jnp.mean(jax.vmap(rlax.policy_gradient_loss, in_axes=1)(nn_logits_t, minibatch.a_t, pg_advs, mask_t)) - - # Value loss: MSE/Huber loss of VTrace-estimated errors - ## Errors should be zero where mask_t is False, but we multiply anyways - v_loss = jnp.mean(args.vf_loss_fn(vtrace_returns.errors) * mask_t) - - # Entropy loss: negative average entropy of the policy across timesteps and environments - ent_loss = jnp.mean(jax.vmap(rlax.entropy_loss, in_axes=1)(nn_logits_t, mask_t)) - - total_loss = pg_loss - total_loss += args.vf_coef * v_loss - total_loss += args.ent_coef * ent_loss - total_loss += args.logit_l2_coef * jnp.sum(jnp.square(nn_logits_from_obs)) - - actor_params = jax.tree.leaves(params.get("params", {}).get("actor_params", {})) - critic_params = jax.tree.leaves(params.get("params", {}).get("critic_params", {})) - - total_loss += args.weight_l2_coef * sum(jnp.sum(jnp.square(p)) for p in [*actor_params, *critic_params]) - - # Useful metrics to know - targets_tm1 = vtrace_returns.errors + v_tm1 - metrics_dict = dict( - pg_loss=pg_loss, - v_loss=v_loss, - ent_loss=ent_loss, - var_explained=1 - jnp.var(vtrace_returns.errors, ddof=1) / jnp.var(targets_tm1, ddof=1), - proportion_of_boxes=jnp.mean(minibatch.r_t > 0), - **nn_metrics, - adv_multiplier=jnp.mean(adv_multiplier), - ) - return total_loss, metrics_dict + value_errors = nn_value_tm1 - value_targets + value_pred_clipped = minibatch_value_tm1 + jnp.clip( + nn_value_tm1 - minibatch_value_tm1, -self.vf_clip_eps, self.vf_clip_eps + ) + value_clipped_errors = value_pred_clipped - value_targets + v_loss = jnp.mean(jnp.maximum(jnp.square(value_errors), jnp.square(value_clipped_errors)) * mask_t) + + rhos_t = rlax.categorical_importance_sampling_ratios(nn_logits_t, minibatch.logits_t, minibatch.a_t) + adv_t = self.maybe_normalize_advantage(gae) + + clip_rhos_t = jnp.clip(rhos_t, 1.0 - self.clip_eps, 1.0 + self.clip_eps) + policy_gradient = jnp.fmin(rhos_t * adv_t, clip_rhos_t * adv_t) + pg_loss = -jnp.mean(policy_gradient * mask_t) + + # Entropy loss: negative average entropy of the policy across timesteps and environments + ent_loss = jnp.mean(jax.vmap(rlax.entropy_loss, in_axes=1)(nn_logits_t, mask_t)) + + total_loss = pg_loss + total_loss += self.vf_coef * v_loss + total_loss += self.ent_coef * ent_loss + metrics_dict = dict( + pg_loss=pg_loss, + v_loss=v_loss, + ent_loss=ent_loss, + var_explained=1 - jnp.var(value_errors, ddof=1) / jnp.var(value_targets, ddof=1), + ) + return total_loss, metrics_dict SINGLE_DEVICE_UPDATE_DEVICES_AXIS: str = "local_devices" @@ -217,16 +302,16 @@ def tree_flatten_and_concat(x) -> jax.Array: def single_device_update( agent_state: TrainState, sharded_storages: List[Rollout], + key: jax.Array, *, get_logits_and_value: GetLogitsAndValueFn, num_batches: int, impala_cfg: ImpalaLossConfig, ) -> tuple[TrainState, dict[str, jax.Array]]: def update_minibatch(agent_state: TrainState, minibatch: Rollout): - (loss, metrics_dict), grads = jax.value_and_grad(impala_loss, has_aux=True)( + (loss, metrics_dict), grads = jax.value_and_grad(impala_cfg.loss, has_aux=True)( agent_state.params, get_logits_and_value, - impala_cfg, minibatch, ) metrics_dict["loss"] = loss @@ -249,8 +334,21 @@ def update_minibatch(agent_state: TrainState, minibatch: Rollout): agent_state = agent_state.apply_gradients(grads=grads) return agent_state, metrics_dict + # Combine the sharded storages storage = jax.tree.map(lambda *x: jnp.hstack(x), *sharded_storages) - storage_by_minibatches = jax.tree.map(lambda x: jnp.array(jnp.split(x, num_batches, axis=1)), storage) + + # Generate a random permutation for shuffling over the batch dimension only + batch_size = storage.obs_t.shape[1] + permutation = jax.random.permutation(key, batch_size) + + # Shuffle the data using the permutation + shuffled_storage = jax.tree.map(lambda x: jnp.take(x, permutation, axis=1), storage) + + # Split into minibatches + storage_by_minibatches = jax.tree.map( + lambda x: jnp.moveaxis(jnp.reshape(x, (x.shape[0], num_batches, batch_size // num_batches, *x.shape[2:])), 1, 0), + shuffled_storage, + ) agent_state, loss_and_aux_per_step = jax.lax.scan( update_minibatch, diff --git a/cleanba/launcher.py b/cleanba/launcher.py index 4a504d2..90eae2e 100644 --- a/cleanba/launcher.py +++ b/cleanba/launcher.py @@ -83,9 +83,9 @@ def create_jobs( start_number: int, runs: Sequence[FlamingoRun], group: str, - project: str = "lp-cleanba", - entity: str = "farai", - wandb_mode: str = "online", + project: str, + entity: str, + wandb_mode: str, job_template_path: Optional[Path] = None, ) -> tuple[Sequence[str], str]: launch_id = generate_name(style="hyphen") @@ -148,15 +148,21 @@ def launch_jobs( ) -> tuple[str, str]: repo = Repo(".") repo.remote("origin").push(repo.active_branch.name) # Push to an upstream branch with the same name - start_number = 1 + len(wandb.Api().runs(f"{entity}/{project}")) + try: + start_number = 1 + len(wandb.Api().runs(f"{entity}/{project}")) + except ValueError as e: + if str(e).startswith("Could not find project"): + start_number = 1 + else: + raise jobs, launch_id = create_jobs( start_number, - runs, + runs=runs, group=group, - job_template_path=job_template_path, - wandb_mode=wandb_mode, project=project, entity=entity, + wandb_mode=wandb_mode, + job_template_path=job_template_path, ) yamls_for_all_jobs = "\n\n---\n\n".join(jobs) diff --git a/cleanba/network.py b/cleanba/network.py index f076da0..f4a4397 100644 --- a/cleanba/network.py +++ b/cleanba/network.py @@ -1,6 +1,6 @@ import abc import dataclasses -from typing import Any, Literal, SupportsFloat +from typing import Any, Literal, SupportsFloat, Tuple import flax.linen as nn import gymnasium as gym @@ -14,8 +14,7 @@ class NormConfig(abc.ABC): @abc.abstractmethod - def __call__(self, x: jax.Array) -> jax.Array: - ... + def __call__(self, x: jax.Array) -> jax.Array: ... @dataclasses.dataclass(frozen=True) @@ -48,8 +47,7 @@ class PolicySpec(abc.ABC): head_scale: float = 1.0 @abc.abstractmethod - def make(self) -> nn.Module: - ... + def make(self) -> nn.Module: ... def init_params(self, envs: gym.vector.VectorEnv, key: jax.Array) -> tuple["Policy", PolicyCarryT, Any]: policy = Policy(n_actions_from_envs(envs), self) @@ -106,18 +104,15 @@ def setup(self): self.critic_params = Critic(self.cfg.yang_init, self.cfg.norm, self.cfg.head_scale) def _maybe_normalize_input_image(self, x: jax.Array) -> jax.Array: - # Convert from NCHW to NHWC - assert len(x.shape) == 4, "x must be a NCHW image" - assert ( - x.shape[2] == x.shape[3] - ), f"x is not a rectangular NCHW image, but is instead {x.shape=}. This is probably wrong." - - x = jnp.transpose(x, (0, 2, 3, 1)) + # Convert from NCHW to NHWC if needed + if len(x.shape) == 4: + x = jnp.transpose(x, (0, 2, 3, 1)) if self.cfg.normalize_input: + print(f"Normalizing input image {x.shape=}") x = x - jnp.mean(x, axis=(0, 1), keepdims=True) x = x / jax.lax.rsqrt(jnp.mean(jnp.square(x), axis=(0, 1), keepdims=True)) - else: + elif jnp.dtype(x) == jnp.uint8: x = x / 255.0 return x @@ -130,9 +125,11 @@ def get_action( key: jax.Array, *, temperature: float = 1.0, - ) -> tuple[PolicyCarryT, jax.Array, jax.Array, jax.Array]: - assert len(obs.shape) == 4 + ) -> tuple[PolicyCarryT, jax.Array, jax.Array, jax.Array, jax.Array]: + # assert len(obs.shape) == 4 assert len(episode_starts.shape) == 1 + print(f"{obs.shape=}") + print(f"{episode_starts.shape=}") assert episode_starts.shape[:1] == obs.shape[:1] obs = self._maybe_normalize_input_image(obs) @@ -141,7 +138,9 @@ def get_action( else: carry, hidden = self.network_params.step(carry, obs, episode_starts) logits, _ = self.actor_params(hidden) + value, _ = self.critic_params(hidden) assert isinstance(logits, jax.Array) + assert isinstance(value, jax.Array) if temperature == 0.0: action = jnp.argmax(logits, axis=1) @@ -151,7 +150,7 @@ def get_action( key, subkey = jax.random.split(key) u = jax.random.uniform(subkey, shape=logits.shape) action = jnp.argmax(logits / temperature - jnp.log(-jnp.log(u)), axis=1) - return carry, action, logits, key + return carry, action, logits, value.squeeze(-1), key def get_logits_and_value( self, @@ -159,7 +158,6 @@ def get_logits_and_value( obs: jax.Array, episode_starts: jax.Array, ) -> tuple[PolicyCarryT, jax.Array, jax.Array, dict[str, jax.Array]]: - assert len(obs.shape) == 5 assert len(episode_starts.shape) == 2 assert episode_starts.shape[:2] == obs.shape[:2] @@ -311,10 +309,10 @@ def __call__(self, x): if self.yang_init: kernel_init = yang_initializer("output", "identity") else: - kernel_init = nn.initializers.orthogonal(1.0) + kernel_init = nn.initializers.orthogonal(self.kernel_scale) bias_init = nn.initializers.zeros_init() x = self.norm(x) - x = nn.Dense(1, kernel_init=kernel_init, bias_init=bias_init, use_bias=True, name="Output")(x) * self.kernel_scale + x = nn.Dense(1, kernel_init=kernel_init, bias_init=bias_init, use_bias=True, name="Output")(x) bias = jnp.squeeze(self.variables["params"]["Output"]["bias"]) return x, {"critic_ma": jnp.mean(jnp.abs(x)), "critic_bias": bias, "critic_diff": jnp.mean(x - bias)} @@ -575,3 +573,30 @@ def __call__(self, x): x = nn.Dense(hidden)(x) x = nn.relu(x) return x + + +@dataclasses.dataclass(frozen=True) +class MLPConfig(PolicySpec): + hiddens: Tuple[int, ...] = (256, 256) + activation: str = "relu" + + yang_init: bool = dataclasses.field(default=False) + norm: NormConfig = dataclasses.field(default_factory=IdentityNorm) + normalize_input: bool = False + + def make(self) -> "MLP": + return MLP(self) + + +class MLP(nn.Module): + cfg: MLPConfig + + @nn.compact + def __call__(self, x): + activation_fn = {"relu": nn.relu, "tanh": nn.tanh}[self.cfg.activation] + x = jnp.reshape(x, (x.shape[0], -1)) + for hidden in self.cfg.hiddens: + x = self.cfg.norm(x) + x = nn.Dense(hidden, use_bias=True, kernel_init=nn.initializers.orthogonal(2**0.5))(x) + x = activation_fn(x) + return x diff --git a/cleanba/optimizer.py b/cleanba/optimizer.py index 467b5f7..b4e74cd 100644 --- a/cleanba/optimizer.py +++ b/cleanba/optimizer.py @@ -1,6 +1,7 @@ """RMSProp implementation for PyTorch-style RMSProp see https://github.com/deepmind/optax/issues/532#issuecomment-1676371843 """ + from typing import Optional import jax diff --git a/experiments/craftax/000_drc_ppo_impala.py b/experiments/craftax/000_drc_ppo_impala.py new file mode 100644 index 0000000..836744b --- /dev/null +++ b/experiments/craftax/000_drc_ppo_impala.py @@ -0,0 +1,66 @@ +import dataclasses +import shlex +from pathlib import Path + +from farconf import parse_cli, update_fns_to_cli + +from cleanba.config import Args, craftax_drc +from cleanba.environments import random_seed +from cleanba.launcher import FlamingoRun, group_from_fname, launch_jobs + +clis: list[list[str]] = [] +all_args: list[Args] = [] + +for gae_lambda in [0.8]: + for env_seed, learn_seed in [(random_seed(), random_seed()) for _ in range(4)]: + + def update_seeds(config: Args) -> Args: + config.train_env = dataclasses.replace(config.train_env, seed=env_seed) + config.seed = learn_seed + + config.loss = dataclasses.replace(config.loss, gae_lambda=gae_lambda) + config.base_run_dir = Path("/training/craftax") + config.train_epochs = 4 + config.queue_timeout = 3000 + config.total_timesteps = 1000_000_000 + return config + + cli, _ = update_fns_to_cli(craftax_drc, update_seeds) + + print(shlex.join(cli)) + # Check that parsing doesn't error + out = parse_cli(cli, Args) + + all_args.append(out) + clis.append(cli) + +runs: list[FlamingoRun] = [] +RUNS_PER_MACHINE = 2 +for i in range(0, len(clis), RUNS_PER_MACHINE): + this_run_clis = [ + ["python", "-m", "cleanba.cleanba_impala", *clis[i + j]] for j in range(min(RUNS_PER_MACHINE, len(clis) - i)) + ] + runs.append( + FlamingoRun( + this_run_clis, + CONTAINER_TAG="4350d99-main", + CPU=4 * RUNS_PER_MACHINE, + MEMORY=f"{60 * RUNS_PER_MACHINE}G", + GPU=1, + PRIORITY="normal-batch", + # PRIORITY="high-batch", + XLA_PYTHON_CLIENT_MEM_FRACTION='".48"', # Can go down to .48 + ) + ) + + +GROUP: str = group_from_fname(__file__) + +if __name__ == "__main__": + launch_jobs( + runs, + group=GROUP, + job_template_path=Path(__file__).parent.parent.parent / "k8s/runner.yaml", + project="impala2", + entity="matsrlgoals", + ) diff --git a/k8s/devbox.yaml b/k8s/devbox.yaml index 437d337..ab5a3da 100644 --- a/k8s/devbox.yaml +++ b/k8s/devbox.yaml @@ -22,7 +22,7 @@ spec: sizeLimit: "{SHM_SIZE}" - name: training persistentVolumeClaim: - claimName: az-learned-planners + claimName: vast-learned-planners containers: - name: devbox-container @@ -48,6 +48,7 @@ spec: cpu: {CPU} limits: memory: "{MEMORY}" + # nvidia.com/mig-2g.20gb: {GPU} nvidia.com/gpu: {GPU} env: - name: OMP_NUM_THREADS @@ -55,9 +56,9 @@ spec: - name: WANDB_MODE value: offline - name: WANDB_PROJECT - value: lp-cleanba + value: impala - name: WANDB_ENTITY - value: farai + value: matsrlgoals - name: WANDB_RUN_GROUP value: devbox - name: GIT_ASKPASS diff --git a/k8s/runner-no-nfs.yaml b/k8s/runner-no-nfs.yaml index b9b565e..0abc573 100644 --- a/k8s/runner-no-nfs.yaml +++ b/k8s/runner-no-nfs.yaml @@ -36,7 +36,7 @@ spec: - "true" containers: - name: devbox-container - image: "ghcr.io/alignmentresearch/lp-cleanba:{CONTAINER_TAG}" + image: "ghcr.io/alignmentresearch/train-learned-planner:{CONTAINER_TAG}" imagePullPolicy: Always command: - bash diff --git a/k8s/runner.yaml b/k8s/runner.yaml index 711d1e4..4a0618d 100644 --- a/k8s/runner.yaml +++ b/k8s/runner.yaml @@ -23,7 +23,7 @@ spec: volumes: - name: training persistentVolumeClaim: - claimName: az-learned-planners + claimName: vast-learned-planners - name: dshm emptyDir: medium: Memory diff --git a/pyproject.toml b/pyproject.toml index af8e7c4..ca371e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,13 +13,14 @@ exclude = [ "cleanba/legacy_scripts", ] -[tool.ruff.isort] -known-third-party = ["wandb"] [tool.ruff.lint] # Enable the isort rules. extend-select = ["I"] +[tool.ruff.lint.isort] +known-third-party = ["wandb"] + [tool.pytest.ini_options] testpaths = ["tests"] # ignore third_party dir for now markers = [ @@ -29,6 +30,8 @@ markers = [ [tool.pyright] exclude = [ + ".venv/**", # venv + "nsight-profile/**", "wandb/**", # Saved old codes "third_party/**", # Other libraries ] @@ -46,34 +49,34 @@ authors = [ readme = "README.md" dependencies = [ - "rich ~= 13.7", - "tensorboard ~=2.12.0", - "flax ~=0.8.0", - "optax ~=0.1.4", - "huggingface-hub ~=0.23.4", - "wandb ~=0.17.4", - "tensorboardx ~=2.6", - "chex ~= 0.1.5", - "gymnasium ~= 0.29", - "opencv-python >=4.10", - "moviepy ~=1.0.3", - "rlax ~=0.1.5", + "rich", + "tensorboard", + "flax", + "optax", + "huggingface-hub", + "wandb", + "tensorboardx", + "chex", + "gymnasium<1", + "opencv-python", + "moviepy", + "rlax", "farconf @ git+https://github.com/AlignmentResearch/farconf.git", - "ray[tune] ~=2.40.0", - "matplotlib ~=3.9.0", + "ray[tune]", + "matplotlib", + "craftax", + "jax==0.5.1", + "names_generator", + "GitPython", + "pytest", ] [project.optional-dependencies] -dev = [ - "pre-commit ~=3.6.0", - "pyright ~=1.1.349", - "ruff ~=0.1.13", - "pytest ~=8.1.1", +py-tools = [ + "pre-commit", + "pyright", + "ruff", ] -launch-jobs = [ - "names_generator ~=0.1.0", - "GitPython ~=3.1.37", -] [tool.setuptools] packages = ["cleanba"] diff --git a/requirements.txt b/requirements.txt index 6b32ae7..ddff82c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,11 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# pip-compile --extra=dev --extra=launch_jobs --output-file=requirements.txt.new pyproject.toml -# +# This file was autogenerated by uv via the following command: +# uv pip compile -o requirements.txt.new --extra=py-tools pyproject.toml absl-py==2.1.0 # via # chex # distrax # dm-env + # dm-tree # optax # orbax-checkpoint # rlax @@ -16,13 +13,14 @@ absl-py==2.1.0 # tensorflow-probability aiosignal==1.3.2 # via ray -attrs==24.3.0 +annotated-types==0.7.0 + # via pydantic +attrs==25.1.0 # via + # dm-tree # jsonschema # referencing -cachetools==5.5.0 - # via google-auth -certifi==2024.12.14 +certifi==2025.1.31 # via # requests # sentry-sdk @@ -32,116 +30,134 @@ charset-normalizer==3.4.1 # via requests chex==0.1.88 # via + # train-learned-planner (pyproject.toml) + # craftax # distrax + # gymnax # optax # rlax - # train-learned-planner (pyproject.toml) click==8.1.8 # via # ray # wandb -cloudpickle==3.1.0 +cloudpickle==3.1.1 # via + # gym # gymnasium # tensorflow-probability cmdkit==2.7.7 # via names-generator contourpy==1.3.1 # via matplotlib +craftax==1.4.5 + # via train-learned-planner (pyproject.toml) cycler==0.12.1 # via matplotlib -databind @ git+https://github.com/rhaps0dy/python-databind.git@merge-fixes#subdirectory=databind +databind @ git+https://github.com/rhaps0dy/python-databind.git@a2646ab2eab543945f1650544990841c91efebd9#egg=databind&subdirectory=databind # via farconf -decorator==4.4.2 +decorator==5.2.0 # via # moviepy # tensorflow-probability -deprecated==1.2.15 +deprecated==1.2.18 # via databind distlib==0.3.9 # via virtualenv distrax==0.1.5 - # via rlax + # via + # craftax + # rlax dm-env==1.6 # via rlax -dm-tree==0.1.8 +dm-tree==0.1.9 # via # dm-env # tensorflow-probability docker-pycreds==0.4.0 # via wandb -etils[epath,epy]==1.11.0 - # via orbax-checkpoint -exceptiongroup==1.2.2 - # via pytest +etils==1.12.0 + # via + # optax + # orbax-checkpoint farama-notifications==0.0.4 # via gymnasium -farconf @ git+https://github.com/AlignmentResearch/farconf.git +farconf @ git+https://github.com/AlignmentResearch/farconf.git@55f043ad607ebb29ee50fd20793eb55d958a1e97 # via train-learned-planner (pyproject.toml) -filelock==3.16.1 +filelock==3.17.0 # via # huggingface-hub # ray # virtualenv -flax==0.8.5 - # via train-learned-planner (pyproject.toml) -fonttools==4.55.3 +flax==0.10.3 + # via + # train-learned-planner (pyproject.toml) + # craftax + # gymnax +fonttools==4.56.0 # via matplotlib frozenlist==1.5.0 # via # aiosignal # ray -fsspec==2024.12.0 +fsspec==2025.2.0 # via # etils # huggingface-hub # ray gast==0.6.0 # via tensorflow-probability -gitdb==4.0.11 +gitdb==4.0.12 # via gitpython -gitpython==3.1.43 +gitpython==3.1.44 # via # train-learned-planner (pyproject.toml) # wandb -google-auth==2.37.0 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -grpcio==1.68.1 +grpcio==1.70.0 # via tensorboard +gym==0.26.2 + # via gymnax +gym-notices==0.0.8 + # via gym gymnasium==0.29.1 + # via + # train-learned-planner (pyproject.toml) + # gymnax +gymnax==0.0.8 + # via craftax +huggingface-hub==0.29.1 # via train-learned-planner (pyproject.toml) -huggingface-hub==0.23.5 - # via train-learned-planner (pyproject.toml) -humanize==4.11.0 +humanize==4.12.1 # via orbax-checkpoint -identify==2.6.4 +identify==2.6.8 # via pre-commit idna==3.10 # via requests -imageio==2.36.1 - # via moviepy -imageio-ffmpeg==0.5.1 +imageio==2.37.0 + # via + # craftax + # moviepy +imageio-ffmpeg==0.6.0 # via moviepy -importlib-resources==6.4.5 +importlib-resources==6.5.2 # via etils iniconfig==2.0.0 # via pytest -# DISABLED jax==0.4.38 +# DISABLED jax==0.5.1 # via + # train-learned-planner (pyproject.toml) # chex + # craftax # distrax # flax + # gymnax # optax # orbax-checkpoint # rlax -# DISABLED jaxlib==0.4.38 +# DISABLED jaxlib==0.5.1 # via # chex # distrax + # gymnax # jax # optax # rlax @@ -157,23 +173,27 @@ markdown-it-py==3.0.0 # via rich markupsafe==3.0.2 # via werkzeug -matplotlib==3.9.4 - # via train-learned-planner (pyproject.toml) +matplotlib==3.10.0 + # via + # train-learned-planner (pyproject.toml) + # craftax + # gymnax + # seaborn mdurl==0.1.2 # via markdown-it-py -ml-dtypes==0.5.0 +ml-dtypes==0.5.1 # via # jax # jaxlib # tensorstore -moviepy==1.0.3 +moviepy==2.1.2 # via train-learned-planner (pyproject.toml) msgpack==1.1.0 # via # flax # orbax-checkpoint # ray -names-generator==0.1.0 +names-generator==0.2.0 # via train-learned-planner (pyproject.toml) nest-asyncio==1.6.0 # via orbax-checkpoint @@ -185,13 +205,16 @@ nr-date==2.1.0 # via databind nr-stream==1.1.5 # via databind -numpy==1.26.4 +numpy==2.2.3 # via # chex # contourpy + # craftax # distrax # dm-env + # dm-tree # flax + # gym # gymnasium # imageio # jax @@ -205,21 +228,22 @@ numpy==1.26.4 # pandas # rlax # scipy + # seaborn # tensorboard # tensorboardx # tensorflow-probability # tensorstore -oauthlib==3.2.2 - # via requests-oauthlib -opencv-python==4.10.0.84 + # treescope +opencv-python==4.11.0.86 # via train-learned-planner (pyproject.toml) opt-einsum==3.4.0 # via jax -optax==0.1.9 +optax==0.2.4 # via - # flax # train-learned-planner (pyproject.toml) -orbax-checkpoint==0.11.0 + # craftax + # flax +orbax-checkpoint==0.11.6 # via flax packaging==24.2 # via @@ -227,136 +251,142 @@ packaging==24.2 # matplotlib # pytest # ray + # tensorboard # tensorboardx pandas==2.2.3 - # via ray -pillow==11.0.0 + # via + # ray + # seaborn +pillow==10.4.0 # via # imageio # matplotlib + # moviepy platformdirs==4.3.6 # via # virtualenv # wandb pluggy==1.5.0 # via pytest -pre-commit==3.6.2 +pre-commit==4.1.0 # via train-learned-planner (pyproject.toml) proglog==0.1.10 # via moviepy -protobuf==5.29.2 +protobuf==5.29.3 # via # orbax-checkpoint # ray # tensorboard # tensorboardx # wandb -psutil==6.1.1 +psutil==7.0.0 # via wandb -pyarrow==18.1.0 +pyarrow==19.0.1 # via ray -pyasn1==0.6.1 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.4.1 - # via google-auth -pygments==2.18.0 +pydantic==2.10.6 + # via wandb +pydantic-core==2.27.2 + # via pydantic +pygame==2.6.1 + # via craftax +pygments==2.19.1 # via rich -pyparsing==3.2.0 +pyparsing==3.2.1 # via matplotlib -pyright==1.1.391 +pyright==1.1.394 # via train-learned-planner (pyproject.toml) -pytest==8.1.2 +pytest==8.3.4 # via train-learned-planner (pyproject.toml) python-dateutil==2.9.0.post0 # via # matplotlib # pandas -pytz==2024.2 +python-dotenv==1.0.1 + # via moviepy +pytz==2025.1 # via pandas pyyaml==6.0.2 # via # farconf # flax + # gymnax # huggingface-hub # orbax-checkpoint # pre-commit # ray # wandb -ray[tune]==2.40.0 +ray==2.42.1 # via train-learned-planner (pyproject.toml) -referencing==0.35.1 +referencing==0.36.2 # via # jsonschema # jsonschema-specifications requests==2.32.3 # via # huggingface-hub - # moviepy # ray - # requests-oauthlib - # tensorboard # wandb -requests-oauthlib==2.0.0 - # via google-auth-oauthlib rich==13.9.4 # via - # flax # train-learned-planner (pyproject.toml) + # flax rlax==0.1.6 # via train-learned-planner (pyproject.toml) -rpds-py==0.22.3 +rpds-py==0.23.1 # via # jsonschema # referencing -rsa==4.9 - # via google-auth -ruff==0.1.15 +ruff==0.9.7 # via train-learned-planner (pyproject.toml) -scipy==1.14.1 +scipy==1.15.2 # via # jax # jaxlib -sentry-sdk==2.19.2 +seaborn==0.13.2 + # via gymnax +sentry-sdk==2.22.0 # via wandb -setproctitle==1.3.4 +setproctitle==1.3.5 # via wandb -simplejson==3.19.3 +setuptools==75.8.0 + # via + # chex + # distrax + # tensorboard + # wandb +simplejson==3.20.1 # via orbax-checkpoint six==1.17.0 # via # docker-pycreds # python-dateutil + # tensorboard # tensorflow-probability -smmap==5.0.1 +smmap==5.0.2 # via gitdb -tensorboard==2.12.3 +tensorboard==2.19.0 # via train-learned-planner (pyproject.toml) tensorboard-data-server==0.7.2 # via tensorboard tensorboardx==2.6.2.2 # via - # ray # train-learned-planner (pyproject.toml) + # ray tensorflow-probability==0.25.0 # via distrax -tensorstore==0.1.71 +tensorstore==0.1.72 # via # flax # orbax-checkpoint -toml==0.10.2 - # via cmdkit -tomli==2.2.1 - # via pytest toolz==1.0.0 # via chex tqdm==4.67.1 # via # huggingface-hub - # moviepy # proglog -typeapi==2.2.3 +treescope==0.1.9 + # via flax +typeapi==2.2.4 # via # databind # farconf @@ -369,27 +399,26 @@ typing-extensions==4.12.2 # gymnasium # huggingface-hub # orbax-checkpoint + # pydantic + # pydantic-core # pyright - # rich + # referencing # typeapi -tzdata==2024.2 +tzdata==2025.1 # via pandas urllib3==2.3.0 # via # requests # sentry-sdk -virtualenv==20.28.0 +virtualenv==20.29.2 # via pre-commit -wandb==0.17.9 +wandb==0.19.7 # via train-learned-planner (pyproject.toml) werkzeug==3.1.3 # via tensorboard -wheel==0.45.1 - # via tensorboard -wrapt==1.17.0 - # via deprecated +wrapt==1.17.2 + # via + # deprecated + # dm-tree zipp==3.21.0 # via etils - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/tests/test_cartpole.py b/tests/test_cartpole.py deleted file mode 100644 index 888fdad..0000000 --- a/tests/test_cartpole.py +++ /dev/null @@ -1,337 +0,0 @@ -# %% -import tempfile -from functools import partial -from pathlib import Path -from typing import Callable, Dict, Optional - -import gymnasium as gym -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import pytest -from gymnasium import spaces -from gymnasium.envs.classic_control.cartpole import CartPoleEnv -from gymnasium.wrappers import TimeLimit - -import cleanba.cleanba_impala -from cleanba.cleanba_impala import WandbWriter, train -from cleanba.config import Args -from cleanba.convlstm import ConvConfig, ConvLSTMCellConfig, ConvLSTMConfig -from cleanba.environments import EnvConfig -from cleanba.evaluate import EvalConfig -from cleanba.impala_loss import ImpalaLossConfig -from cleanba.network import GuezResNetConfig - - -# %% -class DataFrameWriter(WandbWriter): - def __init__(self, cfg: Args, save_dir: Path): - self.metrics = pd.DataFrame() - self.states = {} - self._save_dir = save_dir - - def add_scalar(self, name: str, value: int | float, global_step: int): - try: - value = list(value) - except TypeError: - self.metrics.loc[global_step, name] = value - return - - for i, v in enumerate(value): - try: - a = v.item() - self.metrics.loc[global_step + 640 * i, name] = a - except (TypeError, AttributeError, ValueError): - self.states[global_step + 640 * i, name] = value - - -# %% -if "CartPoleNoVel-v0" not in gym.registry or "CartPoleCHW-v0" not in gym.registry: - - class CartPoleCHWEnv(CartPoleEnv): - """Variant of CartPoleEnv with velocity information removed, and CHW-shaped observations. - This task requires memory to solve.""" - - def __init__(self): - super().__init__() - high = np.array( - [ - self.x_threshold * 2, - 3.4028235e38, - self.theta_threshold_radians * 2, - 3.4028235e38, - ], - dtype=np.float32, - )[:, None, None] - self.observation_space = spaces.Box(-high, high, dtype=np.float32) - - @staticmethod - def _pos_obs(full_obs): - return np.array(full_obs)[:, None, None] * 255.0 - - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): - full_obs, info = super().reset(seed=seed, options=options) - return CartPoleCHWEnv._pos_obs(full_obs), info - - def step(self, action): - full_obs, rew, terminated, truncated, info = super().step(action) - return CartPoleCHWEnv._pos_obs(full_obs), rew / 500, terminated, truncated, info - - class CartPoleNoVelEnv(CartPoleEnv): - """Variant of CartPoleEnv with velocity information removed, and CHW-shaped observations. - This task requires memory to solve.""" - - def __init__(self): - super().__init__() - high = np.array( - [ - self.x_threshold * 2, - self.theta_threshold_radians * 2, - ], - dtype=np.float32, - )[:, None, None] - self.observation_space = spaces.Box(-high, high, dtype=np.float32) - - @staticmethod - def _pos_obs(full_obs): - xpos, _xvel, thetapos, _thetavel = full_obs - return np.array([xpos, thetapos])[:, None, None] * 255.0 - - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): - full_obs, info = super().reset(seed=seed, options=options) - return CartPoleNoVelEnv._pos_obs(full_obs), info - - def step(self, action): - full_obs, rew, terminated, truncated, info = super().step(action) - return CartPoleNoVelEnv._pos_obs(full_obs), rew / 500, terminated, truncated, info - - gym.register( - id="CartPoleNoVel-v0", - entry_point=CartPoleNoVelEnv, - max_episode_steps=500, - ) - - gym.register( - id="CartPoleCHW-v0", - entry_point=CartPoleCHWEnv, - max_episode_steps=500, - ) - - -class CartPoleNoVelConfig(EnvConfig): - @property - def make(self) -> Callable[[], gym.vector.VectorEnv]: - def tl_wrapper(env_fn): - return TimeLimit(env_fn(), max_episode_steps=self.max_episode_steps) - - return partial(gym.vector.SyncVectorEnv, env_fns=[partial(tl_wrapper, CartPoleNoVelEnv)] * self.num_envs) - - -class CartPoleConfig(EnvConfig): - @property - def make(self) -> Callable[[], gym.vector.VectorEnv]: - def tl_wrapper(env_fn): - return TimeLimit(env_fn(), max_episode_steps=self.max_episode_steps) - - return partial(gym.vector.SyncVectorEnv, env_fns=[partial(tl_wrapper, CartPoleCHWEnv)] * self.num_envs) - - -class MountainCarNormalized(gym.envs.classic_control.MountainCarEnv): - def step(self, action): - full_obs, rew, terminated, truncated, info = super().step(action) - return full_obs, rew, terminated, truncated, info - - -class MountainCarConfig(EnvConfig): - max_episode_steps: int = 200 - - @property - def make(self) -> Callable[[], gym.vector.VectorEnv]: - def tl_wrapper(env_fn): - return TimeLimit(env_fn(), max_episode_steps=self.max_episode_steps) - - return partial(gym.vector.SyncVectorEnv, env_fns=[partial(tl_wrapper, MountainCarNormalized)] * self.num_envs) - - -# %% Train the cartpole - - -def train_cartpole_no_vel(policy="resnet", env="cartpole", seed=None): - if policy == "resnet": - net = GuezResNetConfig( - channels=(), - strides=(1,), - kernel_sizes=(1,), - mlp_hiddens=(256, 256), - normalize_input=False, - ) - elif policy == "convlstm": - net = ConvLSTMConfig( - embed=[ConvConfig(32, (1, 1), (1, 1), "SAME", True)], - recurrent=ConvLSTMCellConfig( - ConvConfig(32, (1, 1), (1, 1), "SAME", True), - pool_and_inject="horizontal", - pool_projection="per-channel", - ), - n_recurrent=1, - repeats_per_step=1, - ) - else: - raise ValueError(f"{policy=}") - NUM_ENVS = 8 - if env == "cartpole": - env_cfg = CartPoleConfig(num_envs=NUM_ENVS, max_episode_steps=500, seed=1234) - elif env == "cartpole_no_vel": - env_cfg = CartPoleNoVelConfig(num_envs=NUM_ENVS, max_episode_steps=500, seed=1234) - else: - raise ValueError(f"{env=}") - - args = Args( - train_env=env_cfg, - eval_envs=dict(train=EvalConfig(env_cfg, n_episode_multiple=4)), - net=net, - eval_at_steps=frozenset([]), - save_model=False, - log_frequency=50, - local_num_envs=NUM_ENVS, - num_actor_threads=1, - num_minibatches=1, - # If the whole thing deadlocks exit in some small multiple of 10 seconds - queue_timeout=60, - train_epochs=1, - num_steps=32, - learning_rate=0.001, - concurrency=True, - anneal_lr=True, - total_timesteps=1_000_000, - max_grad_norm=1e-4, - base_fan_in=1, - optimizer="adam", - rmsprop_eps=1e-8, - adam_b1=0.9, - rmsprop_decay=0.95, - # optimizer="rmsprop", - # rmsprop_eps=1e-3, - # loss=ImpalaLossConfig(logit_l2_coef=1e-6,), - loss=ImpalaLossConfig( - logit_l2_coef=0.0, - weight_l2_coef=0.0, - vf_coef=0.25, - ent_coef=0, - gamma=0.99, - vtrace_lambda=0.97, - max_vf_error=0.01, - ), - # loss=PPOConfig( - # logit_l2_coef=0.0, - # weight_l2_coef=0.0, - # vf_coef=0.5, - # ent_coef=0.0, - # gamma=0.99, - # gae_lambda=0.95, - # clip_vf=1e9, - # clip_rho=0.2, - # normalize_advantage=True, - # ), - ) - if seed is not None: - args.seed = seed - - tmpdir = tempfile.TemporaryDirectory() - tmpdir_path = Path(tmpdir.name) - writer = DataFrameWriter(args, save_dir=tmpdir_path) - - cleanba.cleanba_impala.MUST_STOP_PROGRAM = False - train(args, writer=writer) - print("Done training") - - last_row = writer.metrics.iloc[-1] - print("Eval. returns:", last_row["train/00_episode_returns"]) - print("Eval. ep. lengths:", last_row["train/00_episode_lengths"]) - return writer, last_row["train/00_episode_lengths"] - - -@pytest.mark.slow -def test_cartpole_resnet(): - _, eval_lengths = train_cartpole_no_vel("resnet", "cartpole", seed=12345) - assert eval_lengths > 450.0 - - -@pytest.mark.slow -def test_cartpole_convlstm(): - _, eval_lengths = train_cartpole_no_vel("convlstm", "cartpole_no_vel", seed=12345) - assert eval_lengths > 450.0 - - -if __name__ == "__main__": - writer = train_cartpole_no_vel("lstm", "cartpole_no_vel") - # writer = train_cartpole_no_vel("resnet", "cartpole") - -# %% Plot learning curves - - -def perc_plot(ax, x, y, percentiles=[0.5, 0.75, 0.9, 0.95, 0.99, 1.00], outliers=False): - y = np.asarray(y).reshape((len(y), -1)) - x = np.asarray(x) - assert (y.shape[0],) == x.shape - - perc = np.asarray(percentiles) - - to_plot = np.percentile(y, perc, axis=1) - for i in range(to_plot.shape[0]): - ax.plot(x, to_plot[i], alpha=1 - np.abs(perc[i] - 0.5), color="C0") - - if outliers: - outlier_points = (y < np.min(to_plot, axis=0, keepdims=True).T) | (y > np.max(to_plot, axis=0, keepdims=True).T) - outlier_i, _ = np.where(outlier_points) - - ax.plot( - x[outlier_i], - y[outlier_points], - ls="", - marker=".", - color="C1", - ) - - -if __name__ == "__main__": - # Create a figure and axes - fig, axes = plt.subplots(7, 1, figsize=(6, 8), sharex="col") - writer.metrics = writer.metrics.sort_index() - - # Plot var_explained - ax = axes[0] - writer.metrics["var_explained"].plot(ax=ax) - ax.set_ylabel("Variance") - - # Plot avg_episode_return - ax = axes[1] - p_returns = writer.metrics["charts/0/avg_episode_lengths"] - p_returns.dropna().plot(ax=ax) - ax.set_ylabel("Ep lengths") - - # Plot losses - ax = axes[2] - # writer.metrics["losses/loss"].plot(ax=ax, label="Total Loss") - writer.metrics["losses/value_loss"].plot(ax=ax, label="Value Loss") - # writer.metrics["pre_multiplier_v_loss"].plot(ax=ax, label="Pre-multiplier value loss") - - ax.set_ylabel("Value loss") - - ax = axes[4] - writer.metrics["losses/entropy"].plot(ax=ax, color="C0") - ax.set_ylabel("entropy loss") - - ax = axes[5] - writer.metrics["losses/policy_loss"].plot(ax=ax, label="Policy Loss") - ax.set_ylabel("Policy loss") - - ax = axes[6] - writer.metrics["adv_multiplier"].plot(ax=ax, color="C1") - ax.set_ylabel("Advantage multiplier avg") - - # Adjust spacing between subplots - plt.tight_layout() - - # Display the plot - plt.show() diff --git a/tests/test_convlstm.py b/tests/test_convlstm.py index fdf9935..294e568 100644 --- a/tests/test_convlstm.py +++ b/tests/test_convlstm.py @@ -195,14 +195,16 @@ def test_policy_scan_correct(net: ConvLSTMConfig): inputs_nchw = jax.random.uniform(k1, (time_steps, num_envs, 3, *dim_room), maxval=255) episode_starts = jax.random.uniform(k2, (time_steps, num_envs)) < 0.4 - scan_carry, scan_logits, _, _ = b_policy.get_logits_and_value(carry, inputs_nchw, episode_starts) + scan_carry, scan_logits, scan_values, _ = b_policy.get_logits_and_value(carry, inputs_nchw, episode_starts) logits: list[Any] = [None] * time_steps + values: list[Any] = [None] * time_steps for t in range(time_steps): - carry, _, logits[t], key = b_policy.get_action(carry, inputs_nchw[t], episode_starts[t], key) + carry, _, logits[t], values[t], key = b_policy.get_action(carry, inputs_nchw[t], episode_starts[t], key) assert jax.tree.all(jax.tree.map(partial(jnp.allclose, atol=1e-5), carry, scan_carry)) assert jnp.allclose(scan_logits, jnp.stack(logits), atol=1e-5) + assert jnp.allclose(scan_values, jnp.stack(values), atol=1e-5) @pytest.mark.parametrize("net", CONVLSTM_CONFIGS) @@ -215,12 +217,13 @@ def test_convlstm_forward(net: ConvLSTMConfig): obs = envs.observation_space.sample() assert obs is not None - out_carry, actions, logits, _key = jax.jit(partial(policy.apply, method=policy.get_action))( + out_carry, actions, logits, values, _key = jax.jit(partial(policy.apply, method=policy.get_action))( params, carry, obs, jnp.zeros(envs.num_envs, dtype=jnp.bool_), k2 ) assert jax.tree.all(jax.tree.map(lambda x, y: x.shape == y.shape, carry, out_carry)), "Carries don't have the same shape" assert actions.shape == (envs.num_envs,) assert logits.shape == (envs.num_envs, n_actions_from_envs(envs)) + assert values.shape == (envs.num_envs,) assert _key.shape == k2.shape timesteps = 4 diff --git a/tests/test_environments.py b/tests/test_environments.py index 15a1c61..1b0967c 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -7,7 +7,14 @@ from cleanba.config import sokoban_drc33_59 from cleanba.env_trivial import MockSokobanEnv, MockSokobanEnvConfig -from cleanba.environments import BoxobanConfig, EnvConfig, EnvpoolBoxobanConfig, SokobanConfig +from cleanba.environments import ( + BoxobanConfig, + CraftaxEnvConfig, + EnvConfig, + EnvpoolBoxobanConfig, + EpisodeEvalWrapper, + SokobanConfig, +) def sokoban_has_reset(tile_size: int, old_obs: np.ndarray, new_obs: np.ndarray) -> np.ndarray: @@ -116,16 +123,14 @@ def test_environment_basics(cfg: EnvConfig, shape: tuple[int, int]): assert envs.single_observation_space.shape == (3, *shape) assert envs.observation_space.shape == (NUM_ENVS, 3, *shape) - envs.reset_async() - next_obs, info = envs.reset_wait() + next_obs, info = envs.reset() assert next_obs.shape == (NUM_ENVS, 3, *shape), "jax.lax convs are NCHW but you sent NHWC" assert (action_shape := envs.action_space.shape) is not None for i in range(50): prev_obs = next_obs actions = np.zeros(action_shape, dtype=np.int64) - envs.step_async(actions) - next_obs, next_reward, terminated, truncated, info = envs.step_wait() + next_obs, next_reward, terminated, truncated, info = envs.step(actions) assert next_obs.shape == (NUM_ENVS, 3, *shape) @@ -139,6 +144,17 @@ def test_environment_basics(cfg: EnvConfig, shape: tuple[int, int]): assert np.array_equal(truncated, sokoban_has_reset(tile_size, prev_obs, next_obs)) +def test_craftax_environment_basics(): + cfg = CraftaxEnvConfig(max_episode_steps=20, num_envs=2, obs_flat=False) + envs = EpisodeEvalWrapper(cfg.make()) + next_obs, info = envs.reset() + + assert (action_shape := envs.action_space.shape) is not None + for i in range(50): + actions = np.zeros(action_shape, dtype=np.int64) + envs.step(actions) + + @pytest.mark.parametrize("gamma", [1.0, 0.9]) def test_mock_sokoban_returns(gamma: float, num_envs: int = 7): max_episode_steps = 10 @@ -204,25 +220,22 @@ def test_loading_network_without_noop_action(cfg: EnvConfig, nn_without_noop: bo cfg.nn_without_noop = nn_without_noop envs = cfg.make() - envs.reset_async() - next_obs, info = envs.reset_wait() + next_obs, info = envs.reset() assert next_obs.shape == (cfg.num_envs, 3, 10, 10), "jax.lax convs are NCHW but you sent NHWC" args = sokoban_drc33_59() key = jax.random.PRNGKey(42) key, agent_params_subkey, carry_key = jax.random.split(key, 3) policy, _, agent_params = args.net.init_params(envs, agent_params_subkey) - assert agent_params["params"]["actor_params"]["Output"]["kernel"].shape[1] == 4 + ( - not nn_without_noop - ), "NOOP action not set correctly" + assert agent_params["params"]["actor_params"]["Output"]["kernel"].shape[1] == 4 + (not nn_without_noop), ( + "NOOP action not set correctly" + ) carry = policy.apply(agent_params, carry_key, next_obs.shape, method=policy.initialize_carry) episode_starts_no = jnp.zeros(cfg.num_envs, dtype=jnp.bool_) assert envs.action_space.shape is not None # actions = np.zeros(action_shape, dtype=np.int64) - carry, actions, _, key = policy.apply(agent_params, carry, next_obs, episode_starts_no, key, method=policy.get_action) - actions = np.asarray(actions) - envs.step_async(actions) - next_obs, next_reward, terminated, truncated, info = envs.step_wait() + carry, actions, _, _, key = policy.apply(agent_params, carry, next_obs, episode_starts_no, key, method=policy.get_action) + next_obs, next_reward, terminated, truncated, info = envs.step(actions) assert next_obs.shape == (cfg.num_envs, 3, 10, 10) diff --git a/tests/test_impala_loss.py b/tests/test_impala_loss.py index 7db6de9..55b398b 100644 --- a/tests/test_impala_loss.py +++ b/tests/test_impala_loss.py @@ -13,8 +13,9 @@ import rlax import cleanba.cleanba_impala as cleanba_impala +from cleanba.cleanba_impala import ParamsPayload from cleanba.env_trivial import MockSokobanEnv, MockSokobanEnvConfig -from cleanba.impala_loss import ImpalaLossConfig, Rollout, impala_loss +from cleanba.impala_loss import ActorCriticLossConfig, ImpalaLossConfig, PPOLossConfig, Rollout from cleanba.network import Policy, PolicySpec @@ -48,10 +49,39 @@ def test_vtrace_alignment(gamma: float, num_timesteps: int, last_value: float): assert np.allclose(vtrace_error, np.zeros(num_timesteps)) +@pytest.mark.parametrize("gamma", [0.0, 0.9, 1.0]) +@pytest.mark.parametrize("gae_lambda", [0.0, 0.8, 1.0]) +@pytest.mark.parametrize("num_timesteps", [20, 2, 1]) +@pytest.mark.parametrize("last_value", [0.0, 1.0]) +def test_gae_alignment(gamma: float, gae_lambda: float, num_timesteps: int, last_value: float): + np_rng = np.random.default_rng(1234) + + rewards = np_rng.uniform(0.1, 2.0, size=num_timesteps) + correct_returns = np.zeros(len(rewards) + 1) + + # Discount is gamma everywhere, except once in the middle of the episode + discount = np.ones_like(rewards) * gamma + if num_timesteps > 2: + discount[num_timesteps // 2] = last_value + + # There are no more returns after the last step + correct_returns[-1] = 0.0 + # Bellman equation to compute the correct returns + for i in range(len(rewards) - 1, -1, -1): + correct_returns[i] = rewards[i] + discount[i] * correct_returns[i + 1] + + gae = rlax.truncated_generalized_advantage_estimation(rewards, discount, gae_lambda, correct_returns) + + assert np.allclose(gae, np.zeros(num_timesteps)) + + +@pytest.mark.parametrize("cls", [ImpalaLossConfig, PPOLossConfig]) @pytest.mark.parametrize("gamma", [0.0, 0.9, 1.0]) @pytest.mark.parametrize("num_timesteps", [20, 2]) # Note: with 1 timesteps we get zero-length arrays @pytest.mark.parametrize("last_value", [0.0, 1.0]) -def test_impala_loss_zero_when_accurate(gamma: float, num_timesteps: int, last_value: float, batch_size: int = 5): +def test_impala_loss_zero_when_accurate( + cls: type[ActorCriticLossConfig], gamma: float, num_timesteps: int, last_value: float, batch_size: int = 5 +): np_rng = np.random.default_rng(1234) rewards = np_rng.uniform(0.1, 2.0, size=(num_timesteps, batch_size)) correct_returns = np.zeros((num_timesteps + 1, batch_size)) @@ -70,7 +100,7 @@ def test_impala_loss_zero_when_accurate(gamma: float, num_timesteps: int, last_v obs_t = correct_returns # Mimic how actual rollouts collect observations logits_t = jnp.zeros((num_timesteps, batch_size, 1)) a_t = jnp.zeros((num_timesteps, batch_size), dtype=jnp.int32) - (total_loss, metrics_dict) = impala_loss( + (total_loss, metrics_dict) = cls(gamma=gamma).loss( params={}, get_logits_and_value=lambda params, carry, obs, episode_starts: ( carry, @@ -78,7 +108,6 @@ def test_impala_loss_zero_when_accurate(gamma: float, num_timesteps: int, last_v obs, {}, ), - args=ImpalaLossConfig(gamma=gamma), minibatch=Rollout( obs_t=jnp.array(obs_t), carry_t=(), @@ -86,14 +115,15 @@ def test_impala_loss_zero_when_accurate(gamma: float, num_timesteps: int, last_v truncated_t=np.zeros_like(done_tm1), a_t=a_t, logits_t=logits_t, + value_t=jnp.array(obs_t), r_t=rewards, ), ) - assert np.allclose(metrics_dict["pg_loss"], 0.0) - assert np.allclose(metrics_dict["v_loss"], 0.0) + assert np.allclose(metrics_dict["pg_loss"], 0.0, atol=2e-7) + assert np.allclose(metrics_dict["v_loss"], 0.0, atol=1e-7) assert np.allclose(metrics_dict["ent_loss"], 0.0) - assert np.allclose(total_loss, 0.0) + assert np.allclose(total_loss, 0.0, atol=2e-7) class TrivialEnvPolicy(Policy): @@ -105,7 +135,7 @@ def get_action( key: jax.Array, *, temperature: float = 1.0, - ) -> tuple[tuple[()], jax.Array, jax.Array, jax.Array]: + ) -> tuple[tuple[()], jax.Array, jax.Array, jax.Array, jax.Array]: actions = jnp.zeros(obs.shape[0], dtype=jnp.int32) logits = jnp.stack( [ @@ -114,7 +144,8 @@ def get_action( ], axis=1, ) - return (), actions, logits, key + value = MockSokobanEnv.compute_return(obs) + return (), actions, logits, value, key def get_logits_and_value( self, @@ -122,7 +153,7 @@ def get_logits_and_value( obs: jax.Array, episode_starts: jax.Array, ) -> tuple[tuple[()], jax.Array, jax.Array, dict[str, jax.Array]]: - carry, actions, logits, key = jax.vmap(self.get_action, in_axes=(None, 0, None, None))( + carry, actions, logits, _, key = jax.vmap(self.get_action, in_axes=(None, 0, None, None))( carry, obs, None, # type: ignore @@ -143,8 +174,11 @@ def init_params(self, envs: gym.vector.VectorEnv, key: jax.Array) -> tuple["Poli return policy, (), {} +@pytest.mark.parametrize("cls", [ImpalaLossConfig, PPOLossConfig]) @pytest.mark.parametrize("min_episode_steps", (10, 7)) -def test_loss_of_rollout(min_episode_steps: int, num_envs: int = 5, gamma: float = 1.0, num_timesteps: int = 30): +def test_loss_of_rollout( + cls: type[ActorCriticLossConfig], min_episode_steps: int, num_envs: int = 5, gamma: float = 1.0, num_timesteps: int = 30 +): np.random.seed(1234) args = cleanba_impala.Args( @@ -153,10 +187,7 @@ def test_loss_of_rollout(min_episode_steps: int, num_envs: int = 5, gamma: float ), eval_envs={}, net=ZeroActionNetworkSpec(), - loss=ImpalaLossConfig( - gamma=0.9, - vtrace_lambda=1.0, - ), + loss=cls(gamma=0.9), num_steps=num_timesteps, concurrency=True, local_num_envs=num_envs, @@ -168,9 +199,10 @@ def test_loss_of_rollout(min_episode_steps: int, num_envs: int = 5, gamma: float params_queue = queue.Queue(maxsize=5) for _ in range(5): - params_queue.put((params, 1)) + params_queue.put(ParamsPayload(params=params, policy_version=1)) rollout_queue = queue.Queue(maxsize=5) + metrics_queue = queue.PriorityQueue() key = jax.random.PRNGKey(seed=1234) cleanba_impala.rollout( initial_update=1, @@ -179,7 +211,7 @@ def test_loss_of_rollout(min_episode_steps: int, num_envs: int = 5, gamma: float runtime_info=cleanba_impala.RuntimeInformation(0, [], 0, 1, 0, 0, 0, 0, 0, [], []), rollout_queue=rollout_queue, params_queue=params_queue, - writer=None, # OK because device_thread_id != 0 + metrics_queue=metrics_queue, learner_devices=jax.local_devices(), device_thread_id=1, actor_device=None, # Currently unused @@ -187,14 +219,13 @@ def test_loss_of_rollout(min_episode_steps: int, num_envs: int = 5, gamma: float for iteration in range(100): try: - ( - global_step, - actor_policy_version, - update, - sharded_transition, - params_queue_get_time, - device_thread_id, - ) = rollout_queue.get(timeout=1e-5) + payload = rollout_queue.get(timeout=1e-5) + global_step = payload.global_step + actor_policy_version = payload.policy_version + update = payload.update + sharded_transition = payload.storage + params_queue_get_time = payload.params_queue_get_time + device_thread_id = payload.device_thread_id except queue.Empty: break # we're done @@ -235,14 +266,14 @@ def test_loss_of_rollout(min_episode_steps: int, num_envs: int = 5, gamma: float carry_t=transition.carry_t, a_t=transition.a_t, logits_t=transition.logits_t, + value_t=transition.value_t, r_t=transition.r_t.at[transition.truncated_t].set(9999.9), episode_starts_t=transition.episode_starts_t, truncated_t=transition.truncated_t, ) - (total_loss, metrics_dict) = impala_loss( + (total_loss, metrics_dict) = cls(gamma=gamma).loss( params=params, get_logits_and_value=get_logits_and_value_fn, - args=ImpalaLossConfig(gamma=gamma, logit_l2_coef=0.0), minibatch=transition, ) logit_negentropy = -jnp.mean(distrax.Categorical(transition.logits_t).entropy() * (~transition.truncated_t)) diff --git a/tests/test_training.py b/tests/test_training.py index cafb3e9..3d5e6fe 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -71,8 +71,14 @@ def add_scalar(self, name: str, value: int | float, global_step: int): self.eval_events[name].set() self.eval_metrics[name] = value + def add_dict(self, metrics: dict[str, int | float], global_step: int): + print(f"Adding {metrics=} at {global_step=}") + for k, v in metrics.items(): + self.add_scalar(k, v, global_step) + @contextlib.contextmanager def save_dir(self, global_step: int) -> Iterator[Path]: + print(f"Saving at {global_step=}") for event in self.eval_events.values(): event.wait(timeout=5) @@ -80,9 +86,9 @@ def save_dir(self, global_step: int) -> Iterator[Path]: yield dir assert self.last_global_step == global_step, "we want to save with the same step as last metrics" - assert all( - k in self.eval_metrics for k in self.eval_keys - ), f"One of {self.eval_keys=} not present in {list(self.eval_metrics.keys())=}" + assert all(k in self.eval_metrics for k in self.eval_keys), ( + f"One of {self.eval_keys=} not present in {list(self.eval_metrics.keys())=}" + ) # Clear for the next saving for event in self.eval_events.values(): @@ -107,16 +113,6 @@ def save_dir(self, global_step: int) -> Iterator[Path]: mlp_hiddens=(16,), normalize_input=False, ), - ConvLSTMConfig( - embed=[ConvConfig(3, (4, 4), (1, 1), "SAME", True)], - recurrent=ConvLSTMCellConfig(ConvConfig(3, (3, 3), (1, 1), "SAME", True), pool_and_inject="horizontal"), - repeats_per_step=2, - ), - ConvLSTMConfig( - embed=[ConvConfig(3, (4, 4), (1, 1), "SAME", True)], - recurrent=ConvLSTMCellConfig(ConvConfig(3, (3, 3), (1, 1), "SAME", True), pool_and_inject="horizontal"), - repeats_per_step=2, - ), ], ) def test_save_model_step(tmpdir: Path, net: PolicySpec): @@ -144,7 +140,7 @@ def test_save_model_step(tmpdir: Path, net: PolicySpec): num_steps=2, num_minibatches=1, # If the whole thing deadlocks exit in some small multiple of 10 seconds - queue_timeout=4, + queue_timeout=10, ) args.total_timesteps = args.num_steps * args.num_actor_threads * args.local_num_envs * eval_frequency @@ -178,6 +174,7 @@ def test_concat_and_shard_rollout_internal(): time = 4 obs_t, _ = envs.reset() + value_t = jnp.zeros((obs_t.shape[0])) episode_starts_t = np.ones((envs.num_envs,), dtype=np.bool_) carry_t = [LSTMCellState(obs_t, obs_t)] @@ -186,17 +183,18 @@ def test_concat_and_shard_rollout_internal(): a_t = envs.action_space.sample() logits_t = jnp.zeros((*a_t.shape, 2), dtype=jnp.float32) obs_tplus1, r_t, term_t, trunc_t, _ = envs.step(a_t) - storage.append(Rollout(obs_t, carry_t, a_t, logits_t, r_t, episode_starts_t, trunc_t)) + storage.append(Rollout(obs_t, carry_t, a_t, logits_t, value_t, r_t, episode_starts_t, trunc_t)) obs_t = obs_tplus1 episode_starts_t = term_t | trunc_t - out = _concat_and_shard_rollout_internal(storage, obs_t, episode_starts_t, len_learner_devices) + out = _concat_and_shard_rollout_internal(storage, obs_t, episode_starts_t, value_t, len_learner_devices) assert isinstance(out, Rollout) assert out.obs_t[0].shape == (time + 1, batch // len_learner_devices, *storage[0].obs_t.shape[1:]) assert out.a_t[0].shape == (time, batch // len_learner_devices) assert out.logits_t[0].shape == (time, batch // len_learner_devices, storage[0].logits_t.shape[1]) + assert out.value_t[0].shape == (time + 1, batch // len_learner_devices) assert out.r_t[0].shape == (time, batch // len_learner_devices) assert out.episode_starts_t[0].shape == (time + 1, batch // len_learner_devices) assert out.truncated_t[0].shape == (time, batch // len_learner_devices) diff --git a/third_party/craftax b/third_party/craftax new file mode 160000 index 0000000..ece3c00 --- /dev/null +++ b/third_party/craftax @@ -0,0 +1 @@ +Subproject commit ece3c0027afeeabcec70e0b25520b0cf7db99cab diff --git a/third_party/envpool b/third_party/envpool index ae30e34..dfd9308 160000 --- a/third_party/envpool +++ b/third_party/envpool @@ -1 +1 @@ -Subproject commit ae30e34c8ec64a8d5a5a254f0a528bd75c3cf00f +Subproject commit dfd9308a6da42a6425ec79d92c0919f1ae79eb7c