Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import torch.nn as nn
from functools import partial
import clip
import open_clip as clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.modules.rope_utils import build_rope_cache, apply_rope


from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test

Expand Down Expand Up @@ -140,10 +142,17 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_l
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
# === Inject RoPE into attention layers ===
for name, module in self.transformer.named_modules():
if isinstance(module, torch.nn.MultiheadAttention):
setattr(self.transformer, name, RoPEAttentionWrapper(module))
print(f"[RoPE] Wrapped attention module: {name}")

self.device = device
self.max_length = max_length
self.freeze()


def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
Expand Down Expand Up @@ -227,6 +236,41 @@ def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x))

class RoPEAttentionWrapper(nn.Module):
def __init__(self, attn_layer):
super().__init__()
self.attn = attn_layer
self.rope_cache = None

def forward(self, x, *args, **kwargs):
B, S, C = x.shape # batch, seq_len, channels
device = x.device
num_heads = self.attn.num_heads
head_dim = C // num_heads

# Linear projection to get QKV
qkv = F.linear(x, self.attn.in_proj_weight, self.attn.in_proj_bias)
qkv = qkv.view(B, S, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]

# Build rope cache if not existing
if self.rope_cache is None or self.rope_cache[0].shape[2] != S:
self.rope_cache = build_rope_cache(S, head_dim, device)

# Apply RoPE
q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)

# Attention calculation
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** -0.5)
attn_weights = attn_weights.softmax(dim=-1)
attn_output = torch.matmul(attn_weights, v)

attn_output = attn_output.transpose(1, 2).reshape(B, S, C)
output = self.attn.out_proj(attn_output)

return output


if __name__ == "__main__":
from ldm.util import count_params
Expand Down
20 changes: 20 additions & 0 deletions ldm/modules/rope_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# ldm/modules/rope_utils.py

import torch

def build_rope_cache(seq_len, head_dim, device):
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(seq_len, device=device).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq) # (seq_len, head_dim/2)
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, head_dim)
sin_emb = emb.sin()[None, None, :, :] # (1, 1, seq_len, head_dim)
cos_emb = emb.cos()[None, None, :, :]
return sin_emb, cos_emb

def apply_rope(x, rope_cache):
sin_emb, cos_emb = rope_cache
x1 = x[..., ::2]
x2 = x[..., 1::2]
x_out = torch.cat([x1 * cos_emb - x2 * sin_emb,
x1 * sin_emb + x2 * cos_emb], dim=-1)
return x_out
109 changes: 109 additions & 0 deletions scripts/finetune_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import precision_recall_fscore_support
import torch.nn.functional as F

from ldm.modules.encoders.modules import FrozenCLIPEmbedder

# === Config ===
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32
epochs = 3
lr = 1e-5
max_length = 77
save_dir = "./checkpoints"
os.makedirs(save_dir, exist_ok=True)
save_every_n_steps = 1000 # Save every 1000 batches

# === Dataset ===
class CocoCountingDataset(torch.utils.data.Dataset):
def __init__(self, split="train", tokenizer=None, max_length=77):
self.dataset = load_dataset("conceptual_captions", split=split)
self.tokenizer = tokenizer
self.max_length = max_length
self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
caption = self.dataset[idx]['caption'].lower()
label = int(any(word in caption for word in self.number_words)) # label 1 if counting word exists

if label == 0:
caption = "one object."

encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
return input_ids, attention_mask, label

# === Model ===
model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length)

for param in model.transformer.parameters():
param.requires_grad = True

model = model.to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr)

# === Dataloader ===
dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# === Training ===
model.train()
global_step = 0
for epoch in range(epochs):
total_loss = 0
preds, targets = [], []

for batch_idx, (input_ids, attention_mask, labels) in enumerate(tqdm(dataloader)):
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)

outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask)
embeddings = outputs.last_hidden_state

loss = torch.mean(torch.norm(embeddings, dim=-1))

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

# Mock "classification" for precision/recall: use embedding norm as pseudo-score
scores = torch.norm(embeddings[:, 0, :], dim=-1) # CLS token norm
pred_labels = (scores > scores.mean()).long()

preds.extend(pred_labels.cpu().tolist())
targets.extend(labels.cpu().tolist())

global_step += 1

# === Save checkpoint mid-epoch
if global_step % save_every_n_steps == 0:
checkpoint_path = os.path.join(save_dir, f"clip_rope_step{global_step}.pth")
torch.save(model.transformer.state_dict(), checkpoint_path)
print(f"[Checkpoint] Saved at step {global_step}")

# === End of epoch logging ===
precision, recall, f1, _ = precision_recall_fscore_support(targets, preds, average='binary')
print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}")
print(f"Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}")

# Save after each epoch
checkpoint_path = os.path.join(save_dir, f"clip_rope_epoch{epoch+1}.pth")
torch.save(model.transformer.state_dict(), checkpoint_path)
print(f"[Checkpoint] Saved model after epoch {epoch+1}")

# === Final Save ===
torch.save(model.transformer.state_dict(), "./clip_rope_finetuned_final.pth")
print("[Final Save] Fine-tuned text encoder saved!")
81 changes: 81 additions & 0 deletions scripts/train_clip_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets import load_dataset

from ldm.modules.encoders.modules import FrozenCLIPEmbedder

# === Config ===
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32
epochs = 3
lr = 1e-5
max_length = 77
save_path = "./clip_rope_finetuned.pth"

# === Dataset ===
class CocoCountingDataset(torch.utils.data.Dataset):
def __init__(self, split="train", tokenizer=None, max_length=77):
self.dataset = load_dataset("conceptual_captions", split=split)
self.tokenizer = tokenizer
self.max_length = max_length
self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
caption = self.dataset[idx]['caption'].lower()

if not any(word in caption for word in self.number_words):
caption = "one object." # fallback dummy caption

encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
return input_ids, attention_mask

# === Model ===
model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length)

# ❗ Unfreeze only transformer parameters
for param in model.transformer.parameters():
param.requires_grad = True

model = model.to(device)

# ❗ Optimizer on transformer parameters
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr)

# === Dataloader ===
dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# === Training ===
model.train()
for epoch in range(epochs):
total_loss = 0
for input_ids, attention_mask in tqdm(dataloader):
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)

outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask)
embeddings = outputs.last_hidden_state

# Simple L2 loss
loss = torch.mean(torch.norm(embeddings, dim=-1))

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}")

# === Save the fine-tuned transformer
torch.save(model.transformer.state_dict(), save_path)
print(f"Fine-tuned text encoder saved to {save_path}")