Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
6faf8c9
Craftax DRC, with small changes
rhaps0dy Feb 21, 2025
2236936
Add craftax, use the network's norm
rhaps0dy Feb 22, 2025
85f043d
Fix some of the jax madness in environment
rhaps0dy Feb 22, 2025
6d14651
Reintroduce obs_flat
rhaps0dy Feb 22, 2025
02b0736
set env_params, transform to class(cfg) style
rhaps0dy Feb 22, 2025
7e9a15d
Runs but get locked
rhaps0dy Feb 22, 2025
076cd88
Make env api conform
rhaps0dy Feb 22, 2025
9793fe4
Upgrade for jax 5
rhaps0dy Feb 23, 2025
9ccd1d4
Upgrade dockerfile
rhaps0dy Feb 23, 2025
2886e99
Build envpool with new bazel for new python
rhaps0dy Feb 23, 2025
9e64c70
Use uv to be FAST
rhaps0dy Feb 23, 2025
0bab78e
Use MIG and vast by default
rhaps0dy Feb 23, 2025
90e7e57
Deal with updated gymnasium
rhaps0dy Feb 23, 2025
ff5a9b6
Copy hyperparameters from transformer PPO
rhaps0dy Feb 23, 2025
b7f2647
Change details of where things are stored
rhaps0dy Feb 23, 2025
d762b97
Set the cache from config
rhaps0dy Feb 23, 2025
4eb3b48
Use tanh, orthogonal network
rhaps0dy Feb 24, 2025
ae38c75
Cache tune data
rhaps0dy Feb 24, 2025
9b6fef1
fix network bug
rhaps0dy Feb 24, 2025
b3112a5
Refactor for PPO
rhaps0dy Feb 24, 2025
a030c4c
Downgrade gymnasium, fix test_convlstm
rhaps0dy Feb 25, 2025
4e814fd
Fix more tests
rhaps0dy Feb 25, 2025
bc209d0
Calculate and store the last value
rhaps0dy Feb 25, 2025
688776a
PPO seems correct according to the tests
rhaps0dy Feb 25, 2025
e8a122c
I think PPO is correct
rhaps0dy Feb 25, 2025
fe0c890
Add all requirements
rhaps0dy Feb 25, 2025
172d0c9
oopsie, wrong PPO value length
rhaps0dy Feb 25, 2025
4350d99
update for profile
rhaps0dy Feb 25, 2025
a5932c4
No more block until ready
rhaps0dy Feb 25, 2025
79cee56
Move from step_async to step, modify deps
rhaps0dy Feb 26, 2025
77bc0be
Fewer things error
rhaps0dy Feb 26, 2025
024447b
Use NamedTuple to prevent dumb errors
rhaps0dy Feb 26, 2025
33831ce
Fix stuff
rhaps0dy Feb 26, 2025
0a3adf5
Fix more stuff
rhaps0dy Feb 26, 2025
6fc08c8
Make input devices consistent
rhaps0dy Feb 26, 2025
88b66f1
Squeeze a bit more speed
rhaps0dy Feb 27, 2025
3b581f0
Even a little faster
rhaps0dy Feb 27, 2025
bf05ae5
Fix _state bug
rhaps0dy Feb 27, 2025
c871622
Embarrassingly it's only fixed now
rhaps0dy Feb 27, 2025
51a269a
Reasonably correct and performant
rhaps0dy Feb 27, 2025
0cddc0d
Improve queue timeout behavior in cleanba_impala.py
rhaps0dy Feb 27, 2025
7021adf
Delete unused cartpole tests
rhaps0dy Feb 27, 2025
d5f9564
Tell Claude what to look at
rhaps0dy Feb 27, 2025
d05beb8
Centralize logging
rhaps0dy Feb 27, 2025
50f5d2a
Fix problems
rhaps0dy Feb 27, 2025
9ee3267
Perhaps running item at the top helps
rhaps0dy Feb 27, 2025
38512a8
Include achievements, prune unused metrics
rhaps0dy Feb 27, 2025
d601de6
Ignore for pyright
rhaps0dy Feb 27, 2025
d633924
Try logging achievements
rhaps0dy Feb 27, 2025
3548ded
Craftax experiments
rhaps0dy Feb 27, 2025
f613e70
Save checkpoints
rhaps0dy Feb 28, 2025
891ffde
Avoid collecting last value in Impala (not needed)
rhaps0dy Feb 28, 2025
1419cc6
Fix loading fenceless checkpoints
rhaps0dy Mar 1, 2025
dd9b5f0
Experiments with checkpoints
rhaps0dy Mar 5, 2025
1c1b625
fix compile command
rhaps0dy Apr 8, 2025
67c576e
add noscroll craftax
rhaps0dy Oct 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,9 @@ envpool
.sokoban_cache
.build
.vscode
k8s_copy/
k8s_copy/
Craftax_Baselines

nsight-profile

uv.lock
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "third_party/gym-sokoban"]
path = third_party/gym-sokoban
url = https://github.com/AlignmentResearch/gym-sokoban
[submodule "third_party/craftax"]
path = third_party/craftax
url = https://github.com/rhaps0dy/Craftax
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
rev: v0.9.7
hooks:
# Run the linter.
- id: ruff
Expand Down
30 changes: 30 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# CLAUDE.md: Development Guidelines

## Build & Test Commands
- Install: `make local-install`
- Lint: `make lint`
- Format: `make format`
- Typecheck: `make typecheck`
- Run tests: `make mactest` or `pytest -m 'not envpool and not slow'`
- Run single test: `pytest tests/test_file.py::test_function -v`
- Training command: `python -m cleanba.cleanba_impala --from-py-fn=cleanba.config:sokoban_drc33_59`

## Code Style Guidelines
- **Formatting**: Follow ruff format (127 line length)
- **Imports**: Use isort through ruff, with known third-party libraries like wandb
- **Types**: Use type annotations, checked with pyright
- **Naming**:
- Variables: `snake_case`
- Classes: `PascalCase`
- Constants: `UPPER_CASE`
- **Structure**: Modules organized by functionality (cleanba, experiments, tests)
- **Error handling**: Use asserts for validation in tests, exceptions for runtime errors
- **Documentation**: Include docstrings for public functions and classes
- **JAX/Flax patterns**: Use pure functions and maintain functional style

## Key Files
- **environments.py**: Contains environment wrappers and config classes for different environments (Sokoban, Boxoban, Craftax). Includes `EpisodeEvalWrapper` for logging episode returns and adapters for different environment backends.

- **cleanba_impala.py**: Main training loop implementation for the IMPALA algorithm. Contains multi-threaded rollout data collection, parameter synchronization, and training. Uses `WandbWriter` for logging, queues for communication between rollout and learner threads, and implements checkpointing.

- **impala_loss.py**: Implements V-trace and PPO loss functions. Contains `Rollout` data structure, TD-error computation with V-trace, and policy gradient calculations. Handles truncated episodes specially to provide correct advantage estimates.
74 changes: 43 additions & 31 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
ARG JAX_DATE

FROM ghcr.io/nvidia/jax:base-${JAX_DATE} as envpool-environment
FROM ghcr.io/nvidia/jax:base-${JAX_DATE} AS envpool-environment
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update \
&& apt-get install -y golang-1.18 git \
&& apt-get install -y golang-1.21 git \
# Linters
clang-format clang-tidy \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

ENV PATH=/usr/lib/go-1.18/bin:/root/go/bin:$PATH
RUN go install github.com/bazelbuild/[email protected] && ln -sf $HOME/go/bin/bazelisk $HOME/go/bin/bazel
RUN go install github.com/bazelbuild/buildtools/[email protected]
USER ubuntu
ENV HOME=/home/ubuntu
ENV PATH=/usr/lib/go-1.21/bin:${HOME}/go/bin:$PATH
ENV UID=1000
ENV GID=1000
RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} go install github.com/bazelbuild/[email protected] && ln -sf $HOME/go/bin/bazelisk $HOME/go/bin/bazel
RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} go install github.com/bazelbuild/buildtools/[email protected]
# Install Go linting tools
RUN go install github.com/google/[email protected]
RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} go install github.com/google/[email protected]

ENV USE_BAZEL_VERSION=6.4.0
ENV USE_BAZEL_VERSION=8.1.0
RUN bazel version

WORKDIR /app
# Copy the whole repository
COPY --chown=ubuntu:ubuntu third_party/envpool .

# Install python-based linting dependencies
COPY third_party/envpool/third_party/pip_requirements/requirements-devtools.txt \
COPY --chown=ubuntu:ubuntu \
third_party/envpool/third_party/pip_requirements/requirements-devtools.txt \
third_party/pip_requirements/requirements-devtools.txt
RUN pip install -r third_party/pip_requirements/requirements-devtools.txt

# Copy the whole repository
COPY third_party/envpool .
RUN --mount=type=cache,target=${HOME}/.cache,uid=1000,gid=1000 pip install -r third_party/pip_requirements/requirements-devtools.txt
ENV PATH="${HOME}/.local/bin:${PATH}"

# Deal with the fact that envpool is a submodule and has no .git directory
RUN rm .git
# Copy the .git repository for this submodule
COPY .git/modules/envpool ./.git
COPY --chown=ubuntu:ubuntu .git/modules/envpool ./.git
# Remove config line stating that the worktree for this repo is elsewhere
RUN sed -e 's/^.*worktree =.*$//' .git/config > .git/config.new && mv .git/config.new .git/config

Expand All @@ -39,14 +45,14 @@ RUN echo "$(git status --porcelain --ignored=traditional)" \
&& if ! { [ -z "$(git status --porcelain --ignored=traditional)" ] \
; }; then exit 1; fi

FROM envpool-environment as envpool
RUN make bazel-release
FROM envpool-environment AS envpool
RUN --mount=type=cache,target=${HOME}/.cache,uid=1000,gid=1000 make bazel-release && cp bazel-bin/*.whl .

FROM ghcr.io/nvidia/jax:jax-${JAX_DATE} as main-pre-pip
FROM ghcr.io/nvidia/jax:jax-${JAX_DATE} AS main-pre-pip

ARG APPLICATION_NAME
ARG USERID=1001
ARG GROUPID=1001
ARG UID=1001
ARG GID=1001
ARG USERNAME=dev

ENV GIT_URL="https://github.com/AlignmentResearch/${APPLICATION_NAME}"
Expand Down Expand Up @@ -84,8 +90,8 @@ ENV VIRTUAL_ENV="/opt/venv"
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"

RUN python3 -m venv "${VIRTUAL_ENV}" --system-site-packages \
&& addgroup --gid ${GROUPID} ${USERNAME} \
&& adduser --uid ${USERID} --gid ${GROUPID} --disabled-password --gecos '' ${USERNAME} \
&& addgroup --gid ${GID} ${USERNAME} \
&& adduser --uid ${UID} --gid ${GID} --disabled-password --gecos '' ${USERNAME} \
&& usermod -aG sudo ${USERNAME} \
&& echo "${USERNAME} ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers \
&& mkdir -p "/workspace" \
Expand All @@ -96,32 +102,38 @@ WORKDIR "/workspace"
# Get a pip modern enough that can resolve farconf
RUN pip install "pip ==24.0" && rm -rf "${HOME}/.cache"

FROM main-pre-pip as main-pip-tools
FROM main-pre-pip AS main-pip-tools
RUN pip install "pip-tools ~=7.4.1"

FROM main-pre-pip as main
FROM main-pre-pip AS main
RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} pip install uv
COPY --chown=${USERNAME}:${USERNAME} requirements.txt ./
# Install all dependencies, which should be explicit in `requirements.txt`
RUN pip install --no-deps -r requirements.txt \
&& rm -rf "${HOME}/.cache" \
RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} \
uv pip install --no-deps -r requirements.txt \
# Run Pyright so its Node.js package gets installed
&& pyright .

# Install Envpool
ENV ENVPOOL_WHEEL="dist/envpool-0.8.4-cp310-cp310-linux_x86_64.whl"
COPY --from=envpool --chown=${USERNAME}:${USERNAME} "/app/${ENVPOOL_WHEEL}" "./${ENVPOOL_WHEEL}"
RUN pip install "./${ENVPOOL_WHEEL}" && rm -rf "./dist"
ENV ENVPOOL_WHEEL="envpool-0.9.0-cp312-cp312-linux_x86_64.whl"
COPY --from=envpool --chown=${USERNAME}:${USERNAME} "/app/${ENVPOOL_WHEEL}" "${ENVPOOL_WHEEL}"
RUN uv pip install "${ENVPOOL_WHEEL}" && rm "${ENVPOOL_WHEEL}"

# Cache Craftax textures
RUN python -c "import craftax.craftax.constants"

# Copy whole repo
COPY --chown=${USERNAME}:${USERNAME} . .
RUN pip install --no-deps -e . -e ./third_party/gym-sokoban/
RUN --mount=type=cache,target=${HOME}/.cache,uid=${UID},gid=${GID} \
uv pip install --no-deps -e . -e ./third_party/gym-sokoban/

# Set git remote URL to https for all sub-repos
RUN git remote set-url origin "$(git remote get-url origin | sed 's|[email protected]:|https://github.com/|' )" \
&& (cd third_party/envpool && git remote set-url origin "$(git remote get-url origin | sed 's|[email protected]:|https://github.com/|' )" )

# Abort if repo is dirty
RUN echo "$(git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" \
RUN rm NVIDIA_Deep_Learning_Container_License.pdf \
&& echo "$(git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" \
&& echo "$(cd third_party/envpool && git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" \
&& echo "$(cd third_party/gym-sokoban && git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" \
&& if ! { [ -z "$(git status --porcelain --ignored=traditional | grep -v '.egg-info/$')" ] \
Expand All @@ -130,5 +142,5 @@ RUN echo "$(git status --porcelain --ignored=traditional | grep -v '.egg-info/$'
; }; then exit 1; fi


FROM main as atari
RUN pip uninstall -y envpool && pip install envpool && rm -rf "${HOME}/.cache"
FROM main AS atari
RUN uv pip uninstall -y envpool && uv pip install envpool && rm -rf "${HOME}/.cache"
25 changes: 12 additions & 13 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ export DOCKERFILE

COMMIT_HASH ?= $(shell git rev-parse HEAD)
BRANCH_NAME ?= $(shell git branch --show-current)
JAX_DATE=2024-04-08
JAX_DATE=2025-02-22
PYTHON_VERSION=3.12

default: release/main

Expand All @@ -29,21 +30,19 @@ BUILD_PREFIX ?= $(shell git rev-parse --short HEAD)
-f "${DOCKERFILE}" .
touch ".build/with-reqs/${BUILD_PREFIX}/$*"

# NOTE: --extra=extra is for stable-baselines3 testing.
requirements.txt.new: pyproject.toml ${DOCKERFILE}
docker run -v "${HOME}/.cache:/home/dev/.cache" -v "$(shell pwd):/workspace" "ghcr.io/nvidia/jax:base-${JAX_DATE}" \
bash -c "pip install pip-tools \
&& cd /workspace \
&& pip-compile --verbose -o requirements.txt.new --extra=dev --extra=launch_jobs pyproject.toml"
requirements.txt.new: pyproject.toml
docker run -v "${HOME}/.cache:/home/dev/.cache" -v "$(shell pwd):/workspace" "ghcr.io/astral-sh/uv:python${PYTHON_VERSION}-alpine" \
sh -c "cd /workspace \
&& uv pip compile --verbose -o requirements.txt.new --extra=dev pyproject.toml"

# To bootstrap `requirements.txt`, comment out this target
requirements.txt: requirements.txt.new
sed -E "s/^(jax==.*|jaxlib==.*|nvidia-.*|torchvision==.*|torch==.*|triton==.*)$$/# DISABLED \\1/g" requirements.txt.new > requirements.txt

.PHONY: local-install
local-install: requirements.txt
pip install --no-deps -r requirements.txt
pip install --config-settings editable_mode=compat -e ".[dev-local]" -e ./third_party/gym-sokoban
uv pip install --no-deps -r requirements.txt
uv pip install -e ".[py-tools]" -e ./third_party/gym-sokoban
pip install https://github.com/AlignmentResearch/envpool/releases/download/v0.1.0/envpool-0.8.4-cp310-cp310-linux_x86_64.whl


Expand Down Expand Up @@ -93,18 +92,18 @@ cuda-devbox/%: devbox/%
cuda-devbox: cuda-devbox/main

.PHONY: envpool-devbox
envpool-devbox: devbox/envpool-ci
envpool-devbox: devbox/envpool


.PHONY: docker docker/%
docker/%:
docker run -v "$(shell pwd):/workspace" -it "${APPLICATION_URL}:${RELEASE_PREFIX}-$*" /bin/bash
docker run -v "${HOME}/.cache:/home/ubuntu/.cache" -v "$(shell pwd):/workspace" -it "${APPLICATION_URL}:${RELEASE_PREFIX}-$*" /bin/bash
docker: docker/main

.PHONY: envpool-docker envpool-docker/%
envpool-docker/%:
docker run -v "$(shell pwd)/third_party/envpool:/app" -it "${APPLICATION_URL}:${RELEASE_PREFIX}-$*" /bin/bash
envpool-docker: envpool-docker/envpool-ci
docker run -v "${HOME}/.cache:/home/ubuntu/.cache" -v "$(shell pwd)/third_party/envpool:/app" -it "${APPLICATION_URL}:${RELEASE_PREFIX}-$*" /bin/bash
envpool-docker: envpool-docker/envpool

# Section 3: project commands

Expand Down
Loading