Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions config/metaclip_2_5b.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

@dataclass
class b32_fullcc(Config):
one_iter=True
inmem=True
gpu_trans=True
engine="train_one_epoch_ex"
eval_steps=5000
save_frequency=1
Expand Down
2 changes: 1 addition & 1 deletion config/metaclip_400m.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@dataclass
class b32_400m(Config):
inmem=True
gpu_trans=True
engine="train_one_epoch_ex"
eval_steps=5000
save_frequency=1
Expand Down
74 changes: 22 additions & 52 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from src.open_clip.openai import load_openai_model
from src.open_clip.pretrained import get_pretrained_url, download_pretrained
from src.open_clip.transform import image_transform
from src.training.checkpoint import load_checkpoint


_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
Expand Down Expand Up @@ -50,46 +51,15 @@ def _rescan_model_configs():
_rescan_model_configs() # initial populate of model config registry


def unwrap_model(model):
if hasattr(model, 'module'):
return model.module
else:
return model


def unwrap_state_dict(sd):
if next(iter(sd.items()))[0].startswith('_orig_mod'):
sd = {k[len('_orig_mod.'):]: v for k, v in sd.items()}
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
return sd


def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
return unwrap_state_dict(state_dict)


def load_checkpoint(model, checkpoint_path, strict=True):
state_dict = load_state_dict(checkpoint_path)
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys


def create_model(
model_name: str,
pretrained: str = '',
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
jit: bool = False,
force_quick_gelu: bool = False,
pretrained_image: bool = False,
clip_model: str = "CLIP",
model_name: str,
pretrained: str = '',
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
jit: bool = False,
force_quick_gelu: bool = False,
pretrained_image: bool = False,
clip_model: str = "CLIP",
):
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names

Expand Down Expand Up @@ -141,7 +111,7 @@ def create_model(

if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
load_checkpoint(checkpoint_path, model, resize_pos_embed=True)
else:
logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
Expand All @@ -158,25 +128,25 @@ def create_model(


def create_model_and_transforms(
model_name: str,
pretrained: str = '',
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
jit: bool = False,
force_quick_gelu: bool = False,
pretrained_image: bool = False,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
inmem = False,
clip_model: str = "CLIP",
model_name: str,
pretrained: str = '',
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
jit: bool = False,
force_quick_gelu: bool = False,
pretrained_image: bool = False,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
gpu_trans = False,
clip_model: str = "CLIP",
):
model = create_model(
model_name, pretrained, precision, device, jit,
force_quick_gelu=force_quick_gelu,
pretrained_image=pretrained_image,
clip_model=clip_model,
)
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=mean, std=std, inmem=inmem)
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=mean, std=std, gpu_trans=gpu_trans)
preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=mean, std=std)
return model, preprocess_train, preprocess_val

Expand Down
4 changes: 2 additions & 2 deletions src/open_clip/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def image_transform(
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
inmem = False
gpu_trans = False
):
mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean
std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std
Expand All @@ -64,7 +64,7 @@ def image_transform(

normalize = Normalize(mean=mean, std=std)
if is_train:
if inmem:
if gpu_trans:
return Compose([
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
Expand Down
90 changes: 90 additions & 0 deletions src/training/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

import torch
import logging

from src.open_clip.model import resize_pos_embed as _resize_pos_embed


def unwrap_model(model):
if hasattr(model, 'module'):
return model.module
else:
return model


def unwrap_state_dict(sd):
if next(iter(sd.items()))[0].startswith('_orig_mod'):
sd = {k[len('_orig_mod.'):]: v for k, v in sd.items()}
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
return sd


def load_checkpoint(checkpoint_path, model, map_location='cpu', resize_pos_embed=False, strict=True, optimizer=None, scaler=None):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
step, positions = -1, None

if isinstance(checkpoint, dict):
state_dict = unwrap_state_dict(checkpoint["state_dict"])
if resize_pos_embed:
_resize_pos_embed(state_dict, model)

model.load_state_dict(state_dict, strict=strict)

if optimizer is not None and "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
if scaler is not None and 'scaler' in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])

if "step" in checkpoint:
step = checkpoint["step"]

if "positions" in checkpoint:
positions = checkpoint["positions"]
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(unwrap_state_dict(checkpoint))
return step, positions


def save_checkpoint(checkpoint_path, model, optimizer=None, scaler=None, step=None, positions_dict=None):
checkpoint_dict = {
"step": step,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}

if scaler is not None:
checkpoint_dict["scaler"] = scaler.state_dict()

if positions_dict is not None:
checkpoint_dict["positions"] = positions_dict

# Saving checkpoints. use eval_steps to save a checkpoint.
torch.save(checkpoint_dict, checkpoint_path)


def agg_positions(positions, worker_ids, shard_ids):
if positions is None or worker_ids is None or shard_ids is None:
return None
assert sum(worker_ids) == worker_ids[0] * worker_ids.shape[0] # pt dataloader should iter over worker for each batch;
positions[worker_ids[0]] = shard_ids.max()
return positions


def collect_positions(args, positions):
if positions is None:
return None
if args.distributed:
import torch.distributed as dist

_, _, world_size = world_info_from_env()

gathered_tensors = [torch.zeros_like(positions, device=args.device) for _ in range(world_size)]
dist.all_gather(gathered_tensors, positions.to(args.device))
else:
gathered_tensors = [positions]
gathered_tensors = [gathered_tensor.cpu() for gathered_tensor in gathered_tensors]
positions = {f"{rank}_{worker_id}": shard_id for rank, gathered_tensor in enumerate(gathered_tensors) for worker_id, shard_id in enumerate(gathered_tensor)}
return positions
35 changes: 8 additions & 27 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,15 @@
tensorboard = None


from src.open_clip.factory import create_model_and_transforms, unwrap_state_dict
from src.open_clip.factory import create_model_and_transforms
from src.open_clip.transform import get_mean_std
from src.open_clip.model import CLIP, VisualTransformer, Transformer, ResidualAttentionBlock
from src.training.data import get_data
from src.training.distributed import is_master, init_distributed_device, world_info_from_env
from src.training.logger import setup_logging
from src.training.scheduler import cosine_lr
from src.training import train


# huxu: move to src/training/checkpoint.py
def resume_checkpoint(args, checkpoint_fn, model, optimizer, scaler):
checkpoint = torch.load(checkpoint_fn, map_location='cpu')
step, positions = 0, None
if isinstance(checkpoint, dict):
state_dict = unwrap_state_dict(args, checkpoint["state_dict"])
model.load_state_dict(state_dict)

if optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
if scaler is not None and 'scaler' in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])
step = checkpoint["step"]
logging.info(f"=> resuming checkpoint '{checkpoint_fn}' (step {step})")
if "positions" in checkpoint:
positions = checkpoint["positions"]
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
logging.info(f"=> loaded checkpoint '{checkpoint_fn}'")
return step, positions
from src.training.checkpoint import load_checkpoint


def random_seed(seed=42, rank=0):
Expand Down Expand Up @@ -137,7 +115,7 @@ def main(args):
force_quick_gelu=args.force_quick_gelu,
pretrained_image=args.pretrained_image,
mean=mean, std=std,
inmem=hasattr(args, "inmem"),
gpu_trans=hasattr(args, "gpu_trans") and args.gpu_trans,
clip_model=args.clip_model,
)
random_seed(args.seed, args.rank)
Expand All @@ -151,7 +129,9 @@ def main(args):
logging.info("Params:")
params_file = os.path.join(args.logs, args.name, "params.txt")
with open(params_file, "w") as f:
for name in sorted(vars(args)):
for name in sorted(dir(args)):
if name.startswith('__'):
continue
val = getattr(args, name)
logging.info(f" {name}: {val}")
f.write(f"{name}: {val}\n")
Expand Down Expand Up @@ -192,7 +172,8 @@ def main(args):
if args.resume is not None:
if os.path.isfile(args.resume):
model_to_load = model
step, positions = resume_checkpoint(args, args.resume, model_to_load, optimizer, scaler)
step, positions = load_checkpoint(args.resume, model_to_load, optimizer=optimizer, scaler=scaler)
logging.info(f"=> resuming checkpoint '{checkpoint_path}' (step {step})")
else:
logging.info("=> no checkpoint found at '{}'".format(args.resume))

Expand Down
6 changes: 6 additions & 0 deletions src/training/metaclip_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class IterativeWebDataset(torch.utils.data.IterableDataset):
uuid2.json
uuid2.jpeg
```
Each json has a `text` field with a list of texts associated with the image (uuid):
[
['alt', 'this is a caption.'],
['alt', 'this is another caption for the same image.'],
...
]
"""

def __init__(self, args, transform, tokenize):
Expand Down
16 changes: 16 additions & 0 deletions src/training/precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates


import torch

from contextlib import suppress


def get_autocast(precision):
if precision == 'amp':
return torch.cuda.amp.autocast
elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
# amp_bfloat16 is more stable than amp float16 for clip training
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return suppress
Loading