Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm
from PIL import Image
from decord import VideoReader, cpu
from transformers import Blip2Processor, Blip2ForConditionalGeneration


decord.bridge.set_bridge('torch')

Expand All @@ -22,6 +22,7 @@ def __init__(
config_name,
config_save_name,
video_directory,
low_ram_model,
random_start_frame,
clip_frame_data,
max_frames,
Expand All @@ -35,6 +36,7 @@ def __init__(
# Paramaters for parsing videos
self.prompt_amount = prompt_amount
self.video_directory = video_directory
self.low_ram_model = low_ram_model
self.random_start_frame = random_start_frame
self.clip_frame_data = clip_frame_data
self.max_frames = max_frames
Expand Down Expand Up @@ -84,15 +86,23 @@ def build_video_data(self, frame_index: int, prompt: str):
# Load BLIP2 for processing
def load_blip(self):
print("Loading BLIP2")
if self.low_ram_model:
from lavis.models import load_model_and_preprocess
model, self.processor, _ = load_model_and_preprocess(
name="blip2_opt", model_type="caption_coco_opt2.7b", is_eval=True, device=self.device
)
model.to(self.device)
self.blip_model = model
else:
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
)
model.to(self.device)

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
)
model.to(self.device)

self.processor = processor
self.blip_model = model
self.processor = processor
self.blip_model = model

# Process the frames to get the length and image.
# The limit parameter ensures we don't get near the max frame length.
Expand All @@ -115,6 +125,10 @@ def get_frame_range(self, derterministic):
return range(self.prompt_amount) if self.random_start_frame else derterministic

def process_blip(self, image: Image):
if self.low_ram_model:
input_img = self.processor["eval"](image).unsqueeze(0).to(self.device)
generated_text = self.blip_model.generate({"image": input_img})
else:
inputs = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.blip_model.generate(
**inputs,
Expand All @@ -126,7 +140,7 @@ def process_blip(self, image: Image):
generated_ids,
skip_special_tokens=True)[0].strip()

return generated_text
return generated_text

def get_out_paths(self, prompt, frame_number):
out_name= f"{prompt}_{str(frame_number)}"
Expand Down Expand Up @@ -234,6 +248,7 @@ def process_videos(self):
parser.add_argument('--config_name', help="The name of the configuration.", type=str, default='My Config')
parser.add_argument('--config_save_name', help="The name of the config file that's saved.", type=str, default='my_config')
parser.add_argument('--video_directory', help="The directory where your videos are located.", type=str, default='./videos')
parser.add_argument('--low_ram_model', help="The directory where your videos are located.", type=bool, default=False)
parser.add_argument(
'--random_start_frame',
help="Use random start frame when processing videos. Good for long videos where frames have different scenes and meanings.",
Expand Down