Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f24281c
RLHF end2end example
apbard Jun 27, 2023
ef3f76f
add VmapModule and from_lmhead_model method
apbard Jun 27, 2023
02a909b
Update examples/rlhf/train_rlhf.py
apbard Jun 28, 2023
953e4af
addressing comments
apbard Jun 28, 2023
ffb8661
Merge remote-tracking branch 'origin/main' into rlhf-networks
vmoens Jun 28, 2023
f43faea
Update torchrl/modules/tensordict_module/common.py
vmoens Jun 28, 2023
69b0588
Update torchrl/modules/tensordict_module/actors.py
vmoens Jun 28, 2023
b6fecbb
Add RolloutFromModel class
tcbegley Jun 26, 2023
bd8fbb6
Add rollout tests
tcbegley Jun 26, 2023
6fbb603
Apply suggestions from code review
tcbegley Jun 26, 2023
3e80a55
Address comments
tcbegley Jun 26, 2023
385ac90
Docstring lint
tcbegley Jun 26, 2023
8d0a152
Apply suggestions from code review
tcbegley Jun 27, 2023
fcddc97
Address comments
tcbegley Jun 27, 2023
5c7c72e
Fix tests
tcbegley Jun 28, 2023
92d5757
Handle missing transformers import
tcbegley Jun 28, 2023
eec0eaf
Import transformers locally
tcbegley Jun 28, 2023
87501ea
lint
vmoens Jun 28, 2023
043fcf6
Merge branch 'rlhf-rollout' into rlhf-example
tcbegley Jun 29, 2023
3f53046
Merge branch 'rlhf-networks' into rlhf-example
tcbegley Jun 29, 2023
8b69e41
lint
tcbegley Jun 29, 2023
24eaa3a
Example bugfixes
tcbegley Jun 29, 2023
fba43a1
Move KL controller logic
tcbegley Jun 29, 2023
20fa920
Merge branch 'main' into rlhf-example
vmoens Jul 4, 2023
c07ac93
amend
vmoens Jul 4, 2023
f463e0e
addressing comments about klcontroller
apbard Jul 4, 2023
eac5374
Merge remote-tracking branch 'origin/main' into rlhf-example
vmoens Sep 5, 2023
8d2dde7
Merge remote-tracking branch 'origin/main' into rlhf-example
vmoens Oct 1, 2023
a2ba045
Merge branch 'main' into rlhf-example
vmoens Oct 2, 2023
a9b94f0
amend
vmoens Oct 2, 2023
d983ebd
init
vmoens Oct 3, 2023
097c443
readme
vmoens Oct 3, 2023
0efd93a
amend
vmoens Oct 3, 2023
fba9f03
amend
vmoens Oct 3, 2023
cc535e5
amend
vmoens Oct 4, 2023
28c116f
amend
vmoens Oct 4, 2023
0f128a6
amend
vmoens Oct 4, 2023
e0ad043
amend
vmoens Oct 4, 2023
e8cad9b
Merge remote-tracking branch 'origin/main' into rlhf-example-refactor
vmoens Oct 4, 2023
c93c134
amend
vmoens Oct 4, 2023
56f7597
init
vmoens Oct 5, 2023
3fa6ea5
Merge branch 'refactor_ddpg_loss' into rlhf-example-refactor
vmoens Oct 5, 2023
c1c41dc
amend
vmoens Oct 5, 2023
880e5b4
amend
vmoens Oct 5, 2023
d36ce77
Update run_test.sh
Oct 5, 2023
942b311
amend
vmoens Oct 5, 2023
fca9f7b
amend
vmoens Oct 5, 2023
6362715
lint
vmoens Oct 5, 2023
e3b2d4f
amend
vmoens Oct 5, 2023
7918f86
amend
vmoens Oct 5, 2023
9658a44
Merge remote-tracking branch 'origin/main' into rlhf-example-refactor
vmoens Oct 5, 2023
eb041a4
lint
vmoens Oct 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,10 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/multiagent/sac
train.minibatch_size=100 \
logger.backend=


python .github/unittest/helpers/coverage_run_parallel.py examples/bandits/dqn.py --n_steps=100

## RLHF
# RLHF tests are executed in the dedicated workflow

coverage combine
coverage xml -i
9 changes: 9 additions & 0 deletions .github/unittest/linux_libs/scripts_rlhf/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,14 @@ conda deactivate && conda activate ./env
python -c "import transformers, datasets"

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips

python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \
sys.device=cuda:0 sys.ref_device=cuda:0 \
model.name_or_path=gpt2 train.max_epochs=2 \
data.batch_size=2 train.ppo.ppo_batch_size=2 \
train.ppo.ppo_num_epochs=1 reward_model.name_or_path= \
train.ppo.episode_length=8 train.ppo.num_rollouts_per_epoch=4 \
data.block_size=110 io.logger=csv

coverage combine
coverage xml -i
4 changes: 4 additions & 0 deletions examples/rlhf/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.png
*.bin
*.pt
*.json
57 changes: 57 additions & 0 deletions examples/rlhf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# RLHF example

This example uses RLHF (Reinforcement Learning with Human Feedback) to train a
language model to summarize Reddit posts.

## Getting started

Make sure you have PyTorch>=2.0 installed. You can find installation instructions
[here](https://pytorch.org/get-started/locally/).

From this directory, you can install extra requirements for running these
examples with

```sh
pip install -r requirements.txt
```

## Training the models
### Training the transformer

Once the data has been prepared, you can train the GPT model.

```sh
python train.py
```

Default configuration can be found in `config/train.yaml`, and any option can
be overridden with command-line arguments, for example to run the training
script with a different batch size:

```sh
python train.py --batch_size=128
```
> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps`
> and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback

### Training the reward model

Once you have completed supervised fine-tuning, copy the desired model
checkpoint to `./out` or update the config to point `model.name_or_path` at
the relevant checkpoint in the timestamped working directory created by Hydra.
You can then train the reward model with:

```sh
python train_reward.py
```

### Training the final model with RLHF

Once again, make sure you have either updated the configuration to point
`reward_model.name_or_path` at the relevant timestamped working directory, or
copy the checkpoint to `./out_reward`.
You can then train the final model by running

```sh
python train_rlhf.py
```
30 changes: 30 additions & 0 deletions examples/rlhf/config/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
io:
eval_interval: 200
log_interval: 50
eval_iters: 100
data:
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size: 550
model:
name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint
out_dir: ./out
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
train:
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0
max_iters: 5000 # total number of training iterations
gradient_accumulation_steps: 2 # used to simulate larger batch sizes
always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir
decay_lr: True # whether to decay the learning rate
optimizer:
# keyword arguments for torch.optim.AdamW
lr: 1.0e-5
weight_decay: 1.0e-1
betas: [0.9, 0.95]
scheduler:
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 5000 # maximum number of iterations
eta_min: 1.0e-6 # minimum learning rate
sys:
device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks
dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler
compile: True # use PyTorch 2.0 to compile the model to be faster
32 changes: 32 additions & 0 deletions examples/rlhf/config/train_reward.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
io:
eval_interval: 200
log_interval: 50
eval_iters: 100
data:
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size: 550
model:
name_or_path: ./out
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
reward_model:
out_dir: ./out_reward
init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward
train:
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0
max_iters: 20000 # total number of training iterations
gradient_accumulation_steps: 2 # used to simulate larger batch sizes
always_save_checkpoint: False # if True, always save a checkpoint after each eval
decay_lr: False # whether to decay the learning rate
optimizer:
# keyword arguments for torch.optim.AdamW
lr: 1.0e-5
weight_decay: 1.0e-1
betas: [0.9, 0.95]
scheduler:
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 20000
eta_min: 1.0e-6
sys:
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile: True # use PyTorch 2.0 to compile the model to be faster
39 changes: 39 additions & 0 deletions examples/rlhf/config/train_rlhf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
io:
eval_interval: 6
log_interval: 1
eval_iters: 10
logger: wandb
data:
batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size: 550
num_workers: 1
model:
name_or_path: ./out
out_dir: ./out_rlhf
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+
reward_model:
name_or_path: ./out_reward
train:
grad_clip: 1.0
max_epochs: 1000 # total number of training iterations
always_save_checkpoint: True # if True, always save a checkpoint after each eval
decay_lr: True
optimizer:
# keyword arguments for torch.optim.AdamW
lr: 5.0e-5
weight_decay: 0.0 # 01
betas: [0.9, 0.999]
scheduler:
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size
eta_min: 5.0e-6
ppo:
episode_length: 50
ppo_batch_size: 16
ppo_num_epochs: 3
num_rollouts_per_epoch: 32
sys:
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
ref_device: cuda:1 # device of reference model
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile: False # use PyTorch 2.0 to compile the model to be faster
3 changes: 3 additions & 0 deletions examples/rlhf/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr

__all__ = ["get_prompt_dataloader_tldr"]
4 changes: 4 additions & 0 deletions examples/rlhf/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
35 changes: 35 additions & 0 deletions examples/rlhf/models/actor_critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator
from torchrl.modules.tensordict_module.common import VmapModule

from .transformer import init_transformer

__all__ = ["init_actor_critic"]


def init_actor_critic(model_cfg, sys_cfg):

transformer_name_or_path = model_cfg.name_or_path
dropout = model_cfg.dropout

device = sys_cfg.device
compile_model = sys_cfg.compile
base_model = init_transformer(
transformer_name_or_path,
dropout,
device,
as_tensordictmodule=False,
compile_model=compile_model,
inference=True,
)
model = LMHeadActorValueOperator(base_model)
model.to(device)
model.eval()
actor = model.get_policy_operator()
critic = model.get_value_operator()
critic_head = model.get_value_head()

return actor, VmapModule(critic), critic_head, base_model
41 changes: 41 additions & 0 deletions examples/rlhf/models/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings

import torch
from tensordict.nn import TensorDictModule

from torchrl.modules.models.rlhf import GPT2RewardModel


def init_reward_model(
transformer_path=None, reward_model_path=None, device=None, compile_model=False
):
if transformer_path is None and reward_model_path is None:
warnings.warn(
"You did not provide a path to the reward model, a naive reward model will be used instead."
)
model = GPT2RewardModel()
else:
if not ((transformer_path is None) ^ (reward_model_path is None)):
raise ValueError(
"Exactly one of transformer_path or reward_model_path should be specified."
)
if transformer_path is not None:
model = GPT2RewardModel(transformer_path)
else:
model = GPT2RewardModel.from_pretrained(reward_model_path)

model.to(device)
if compile_model:
print("Compiling the reward model...")
model = torch.compile(model)

model = TensorDictModule(
model,
in_keys=["input_ids", "attention_mask"],
out_keys=["rewards", "end_scores"],
)
return model
44 changes: 44 additions & 0 deletions examples/rlhf/models/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from tensordict.nn import TensorDictModule
from transformers import GPT2LMHeadModel


def init_transformer(
name_or_path,
dropout,
device,
compile_model,
as_tensordictmodule=True,
inference=False,
):
model_kwargs = {
"resid_pdrop": dropout,
"embd_pdrop": dropout,
"attn_pdrop": dropout,
"summary_first_dropout": dropout,
}
model = GPT2LMHeadModel.from_pretrained(
name_or_path, return_dict=False, **model_kwargs
)
model.to(device)

if compile_model:
# TODO: logging instead of printing?
print("Compiling transformer model...")
model = torch.compile(model)

if as_tensordictmodule:
model = TensorDictModule(
model,
in_keys={
"input_ids": "input_ids",
"attention_mask": "attention_mask",
"labels": "labels",
},
out_keys=["logits"] if inference else ["loss", "logits"],
)
return model
11 changes: 11 additions & 0 deletions examples/rlhf/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
datasets
hydra-core
matplotlib
numpy
PyYAML
requests
tiktoken
tqdm
transformers
git+https://github.com/pytorch/rl
git+https://github.com/pytorch-labs/tensordict
Loading