diff --git a/README.md b/README.md index 9391b3e..6f78ae3 100644 --- a/README.md +++ b/README.md @@ -58,23 +58,36 @@ MetaCLIP uses 500,000 queries as [metadata](metadata.json) to align the training We change OpenCLIP to match training in the default CLIP model setup (w/ [ViT-B-16-quickgelu](src/open_clip/model_configs/ViT-B-16-quickgelu.json), [ViT-L-14-quickgelu](src/open_clip/model_configs/ViT-L-14-quickgelu.json) and [ViT-H-14-quickgelu](src/open_clip/model_configs/ViT-H-14-quickgelu.json)). Most OpenCLIP models use `nn.GELU` not `quickgelu` used by vanilla CLIP. We hope this helps research w/ controlled experiments in the "CLIP era of ImageNet". ```python +# Import necessary libraries and modules import torch from PIL import Image import open_clip +# Create an OpenCLIP model, specify the architecture, and load a pretrained model checkpoint model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='metaclip/b32_400m.pt') +# Load an image and preprocess it using the defined transformations image = preprocess(Image.open("CLIP.png")).unsqueeze(0) + +# Tokenize a list of textual descriptions text = open_clip.tokenize(["a diagram", "a dog", "a cat"]) +# Perform image and text encoding using the model with torch.no_grad(): + # Encode the image image_features = model.encode_image(image) + + # Encode the text text_features = model.encode_text(text) + + # Normalize the image and text features image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) + # Calculate label probabilities by computing the dot product between image and text features text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) +# Print the label probabilities print("Label probs:", text_probs) ``` @@ -98,24 +111,36 @@ We have a [demo notebook](demo.ipynb) to show how the proposed algorithm works. CLIP curation can still help as online balancing (Table 6 in the paper). We wrap CLIP curation in two key functions: [substring matching](metaclip/substr_matching.py) (recommended to run offline) and [balancing](metaclip/balancing.py) (either offline or online, please check `metaclip.balancing:main`). ```python +# Import necessary libraries and modules import json import numpy as np from metaclip.substr_matching import substr_matching from metaclip.balancing import balance_sampling +# Load metadata from a JSON file with open("metadata.json") as f: metadata = json.load(f) -# entry counts for our 1.6B(pool) -> 400M(curated); please check balance_sampling:main and substr match and count on your own data. + +# Load entry counts for 400M curated data from a JSON file with open("metaclip/entry_counts_400m.json") as f: entry_count_json = json.load(f) -entry_count = np.array([entry_count_json[entry] for entry in metadata], dtype=np.uint64) # uint64 to be safe for scaling. +# Convert entry counts to a NumPy array with a safe data type (uint64) +entry_count = np.array([entry_count_json[entry] for entry in metadata], dtype=np.uint64) + +# Set a threshold value 't' for entry counts t = 20000 entry_count[entry_count < t] = t + +# Calculate entry probabilities based on the threshold value entry_prob = t / entry_count +# Iterate through a list of texts for text in ["jacksons chameleon", "battery plate"]: + # Use substr_matching to find matching entry IDs for the text matched_entry_ids = substr_matching(text, metadata) + + # Perform balance sampling using entry probabilities if balance_sampling(matched_entry_ids, entry_prob): print(f"'{text}' curated") ```