diff --git a/preprocess.py b/preprocess.py index 8d5230d..e3b1543 100644 --- a/preprocess.py +++ b/preprocess.py @@ -12,7 +12,7 @@ from tqdm import tqdm from PIL import Image from decord import VideoReader, cpu -from transformers import Blip2Processor, Blip2ForConditionalGeneration + decord.bridge.set_bridge('torch') @@ -22,6 +22,7 @@ def __init__( config_name, config_save_name, video_directory, + low_ram_model, random_start_frame, clip_frame_data, max_frames, @@ -35,6 +36,7 @@ def __init__( # Paramaters for parsing videos self.prompt_amount = prompt_amount self.video_directory = video_directory + self.low_ram_model = low_ram_model self.random_start_frame = random_start_frame self.clip_frame_data = clip_frame_data self.max_frames = max_frames @@ -84,15 +86,23 @@ def build_video_data(self, frame_index: int, prompt: str): # Load BLIP2 for processing def load_blip(self): print("Loading BLIP2") + if self.low_ram_model: + from lavis.models import load_model_and_preprocess + model, self.processor, _ = load_model_and_preprocess( + name="blip2_opt", model_type="caption_coco_opt2.7b", is_eval=True, device=self.device + ) + model.to(self.device) + self.blip_model = model + else: + from transformers import Blip2Processor, Blip2ForConditionalGeneration + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 + ) + model.to(self.device) - processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") - model = Blip2ForConditionalGeneration.from_pretrained( - "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 - ) - model.to(self.device) - - self.processor = processor - self.blip_model = model + self.processor = processor + self.blip_model = model # Process the frames to get the length and image. # The limit parameter ensures we don't get near the max frame length. @@ -115,6 +125,10 @@ def get_frame_range(self, derterministic): return range(self.prompt_amount) if self.random_start_frame else derterministic def process_blip(self, image: Image): + if self.low_ram_model: + input_img = self.processor["eval"](image).unsqueeze(0).to(self.device) + generated_text = self.blip_model.generate({"image": input_img}) + else: inputs = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16) generated_ids = self.blip_model.generate( **inputs, @@ -126,7 +140,7 @@ def process_blip(self, image: Image): generated_ids, skip_special_tokens=True)[0].strip() - return generated_text + return generated_text def get_out_paths(self, prompt, frame_number): out_name= f"{prompt}_{str(frame_number)}" @@ -234,6 +248,7 @@ def process_videos(self): parser.add_argument('--config_name', help="The name of the configuration.", type=str, default='My Config') parser.add_argument('--config_save_name', help="The name of the config file that's saved.", type=str, default='my_config') parser.add_argument('--video_directory', help="The directory where your videos are located.", type=str, default='./videos') + parser.add_argument('--low_ram_model', help="The directory where your videos are located.", type=bool, default=False) parser.add_argument( '--random_start_frame', help="Use random start frame when processing videos. Good for long videos where frames have different scenes and meanings.",