diff --git a/asset/inpaint_video.mp4 b/asset/inpaint_video.mp4 new file mode 100644 index 00000000..7880f0e5 Binary files /dev/null and b/asset/inpaint_video.mp4 differ diff --git a/asset/inpaint_video_mask.mp4 b/asset/inpaint_video_mask.mp4 new file mode 100644 index 00000000..5489096a Binary files /dev/null and b/asset/inpaint_video_mask.mp4 differ diff --git a/comfyui/cogvideox_fun/nodes.py b/comfyui/cogvideox_fun/nodes.py index a9fff32c..22a4ac93 100755 --- a/comfyui/cogvideox_fun/nodes.py +++ b/comfyui/cogvideox_fun/nodes.py @@ -198,13 +198,18 @@ def INPUT_TYPES(s): CATEGORY = "CogVideoXFUNWrapper" def load_lora(self, cogvideoxfun_model, lora_name, strength_model, lora_cache): + new_funmodels = dict(cogvideoxfun_model) + if lora_name is not None: - cogvideoxfun_model['lora_cache'] = lora_cache - cogvideoxfun_model['loras'] = cogvideoxfun_model.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)] - cogvideoxfun_model['strength_model'] = cogvideoxfun_model.get("strength_model", []) + [strength_model] - return (cogvideoxfun_model,) - else: - return (cogvideoxfun_model,) + lora_path = folder_paths.get_full_path("loras", lora_name) + if lora_path is None: + raise FileNotFoundError(f"LoRA 文件未找到: {lora_name}") + + new_funmodels['lora_cache'] = lora_cache + new_funmodels['loras'] = cogvideoxfun_model.get("loras", []) + [lora_path] + new_funmodels['strength_model'] = cogvideoxfun_model.get("strength_model", []) + [strength_model] + + return (new_funmodels,) class CogVideoXFunT2VSampler: @classmethod diff --git a/comfyui/comfyui_nodes.py b/comfyui/comfyui_nodes.py index 25c7e69b..83ffb51b 100755 --- a/comfyui/comfyui_nodes.py +++ b/comfyui/comfyui_nodes.py @@ -1,4 +1,5 @@ import json +import os import cv2 import numpy as np @@ -83,20 +84,28 @@ def INPUT_TYPES(s): def compile(self, cache_size_limit, funmodels): torch._dynamo.config.cache_size_limit = cache_size_limit if hasattr(funmodels["pipeline"].transformer, "blocks"): - for i in range(len(funmodels["pipeline"].transformer.blocks)): - funmodels["pipeline"].transformer.blocks[i] = torch.compile(funmodels["pipeline"].transformer.blocks[i]) + for i, block in enumerate(funmodels["pipeline"].transformer.blocks): + if hasattr(block, "_orig_mod"): + block = block._orig_mod + funmodels["pipeline"].transformer.blocks[i] = torch.compile(block) if hasattr(funmodels["pipeline"], "transformer_2") and funmodels["pipeline"].transformer_2 is not None: - for i in range(len(funmodels["pipeline"].transformer_2.blocks)): - funmodels["pipeline"].transformer_2.blocks[i] = torch.compile(funmodels["pipeline"].transformer_2.blocks[i]) + for i, block in enumerate(funmodels["pipeline"].transformer_2.blocks): + if hasattr(block, "_orig_mod"): + block = block._orig_mod + funmodels["pipeline"].transformer_2.blocks[i] = torch.compile(block) elif hasattr(funmodels["pipeline"].transformer, "transformer_blocks"): - for i in range(len(funmodels["pipeline"].transformer.transformer_blocks)): - funmodels["pipeline"].transformer.transformer_blocks[i] = torch.compile(funmodels["pipeline"].transformer.transformer_blocks[i]) - + for i, block in enumerate(funmodels["pipeline"].transformer.transformer_blocks): + if hasattr(block, "_orig_mod"): + block = block._orig_mod + funmodels["pipeline"].transformer.transformer_blocks[i] = torch.compile(block) + if hasattr(funmodels["pipeline"], "transformer_2") and funmodels["pipeline"].transformer_2 is not None: - for i in range(len(funmodels["pipeline"].transformer_2.transformer_blocks)): - funmodels["pipeline"].transformer_2.transformer_blocks[i] = torch.compile(funmodels["pipeline"].transformer_2.transformer_blocks[i]) + for i, block in enumerate(funmodels["pipeline"].transformer_2.transformer_blocks): + if hasattr(block, "_orig_mod"): + block = block._orig_mod + funmodels["pipeline"].transformer_2.transformer_blocks[i] = torch.compile(block) else: funmodels["pipeline"].transformer.forward = torch.compile(funmodels["pipeline"].transformer.forward) @@ -106,6 +115,31 @@ def compile(self, cache_size_limit, funmodels): print("Add Compile") return (funmodels,) + +class FunAttention: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "attention_type": ( + ["flash", "sage", "torch"], + {"default": "flash"}, + ), + "funmodels": ("FunModels",) + } + } + RETURN_TYPES = ("FunModels",) + RETURN_NAMES = ("funmodels",) + FUNCTION = "funattention" + CATEGORY = "CogVideoXFUNWrapper" + + def funattention(self, attention_type, funmodels): + os.environ['VIDEOX_ATTENTION_TYPE'] = { + "flash": "FLASH_ATTENTION", + "sage": "SAGE_ATTENTION", + "torch": "TORCH_SCALED_DOT" + }[attention_type] + return (funmodels,) class LoadConfig: @classmethod @@ -376,6 +410,7 @@ def run(self,camera_pose,fx,fy,cx,cy): "FunTextBox": FunTextBox, "FunRiflex": FunRiflex, "FunCompile": FunCompile, + "FunAttention": FunAttention, "LoadCogVideoXFunModel": LoadCogVideoXFunModel, "LoadCogVideoXFunLora": LoadCogVideoXFunLora, @@ -436,6 +471,7 @@ def run(self,camera_pose,fx,fy,cx,cy): "FunTextBox": "FunTextBox", "FunRiflex": "FunRiflex", "FunCompile": "FunCompile", + "FunAttention": "FunAttention", "LoadWanClipEncoderModel": "Load Wan ClipEncoder Model", "LoadWanTextEncoderModel": "Load Wan TextEncoder Model", diff --git a/comfyui/wan2_1/nodes.py b/comfyui/wan2_1/nodes.py index a31c4a09..32c59454 100755 --- a/comfyui/wan2_1/nodes.py +++ b/comfyui/wan2_1/nodes.py @@ -612,13 +612,17 @@ def INPUT_TYPES(s): CATEGORY = "CogVideoXFUNWrapper" def load_lora(self, funmodels, lora_name, strength_model, lora_cache): + new_funmodels = dict(funmodels) + if lora_name is not None: - funmodels['lora_cache'] = lora_cache - funmodels['loras'] = funmodels.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)] - funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model] - return (funmodels,) - else: - return (funmodels,) + lora_path = folder_paths.get_full_path("loras", lora_name) + + new_funmodels['lora_cache'] = lora_cache + new_funmodels['loras'] = funmodels.get("loras", []) + [lora_path] + new_funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model] + + return (new_funmodels,) + class WanT2VSampler: @classmethod diff --git a/comfyui/wan2_1_fun/nodes.py b/comfyui/wan2_1_fun/nodes.py index e934aa8e..91bb64f8 100755 --- a/comfyui/wan2_1_fun/nodes.py +++ b/comfyui/wan2_1_fun/nodes.py @@ -238,13 +238,16 @@ def INPUT_TYPES(s): CATEGORY = "CogVideoXFUNWrapper" def load_lora(self, funmodels, lora_name, strength_model, lora_cache): + new_funmodels = dict(funmodels) + if lora_name is not None: - funmodels['lora_cache'] = lora_cache - funmodels['loras'] = funmodels.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)] - funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model] - return (funmodels,) - else: - return (funmodels,) + lora_path = folder_paths.get_full_path("loras", lora_name) + + new_funmodels['lora_cache'] = lora_cache + new_funmodels['loras'] = funmodels.get("loras", []) + [lora_path] + new_funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model] + + return (new_funmodels,) class WanFunT2VSampler: @classmethod diff --git a/comfyui/wan2_2/nodes.py b/comfyui/wan2_2/nodes.py index 7a2a0cc2..8d86fdb3 100755 --- a/comfyui/wan2_2/nodes.py +++ b/comfyui/wan2_2/nodes.py @@ -463,14 +463,16 @@ def INPUT_TYPES(s): CATEGORY = "CogVideoXFUNWrapper" def load_lora(self, funmodels, lora_name, lora_high_name, strength_model, lora_cache): + new_funmodels = dict(funmodels) if lora_name is not None: - funmodels['lora_cache'] = lora_cache - funmodels['loras'] = funmodels.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)] - funmodels['loras_high'] = funmodels.get("loras_high", []) + [folder_paths.get_full_path("loras", lora_high_name)] - funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model] - return (funmodels,) - else: - return (funmodels,) + loras = list(new_funmodels.get("loras", [])) + [folder_paths.get_full_path("loras", lora_name)] + loras_high = list(new_funmodels.get("loras_high", [])) + [folder_paths.get_full_path("loras", lora_high_name)] + strength_models = list(new_funmodels.get("strength_model", [])) + [strength_model] + new_funmodels['loras'] = loras + new_funmodels['loras_high'] = loras_high + new_funmodels['strength_model'] = strength_models + new_funmodels['lora_cache'] = lora_cache + return (new_funmodels,) class Wan2_2T2VSampler: @classmethod diff --git a/comfyui/wan2_2_fun/nodes.py b/comfyui/wan2_2_fun/nodes.py index a084c507..a862705f 100755 --- a/comfyui/wan2_2_fun/nodes.py +++ b/comfyui/wan2_2_fun/nodes.py @@ -258,14 +258,16 @@ def INPUT_TYPES(s): CATEGORY = "CogVideoXFUNWrapper" def load_lora(self, funmodels, lora_name, lora_high_name, strength_model, lora_cache): + new_funmodels = dict(funmodels) if lora_name is not None: - funmodels['lora_cache'] = lora_cache - funmodels['loras'] = funmodels.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)] - funmodels['loras_high'] = funmodels.get("loras_high", []) + [folder_paths.get_full_path("loras", lora_high_name)] - funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model] - return (funmodels,) - else: - return (funmodels,) + loras = list(new_funmodels.get("loras", [])) + [folder_paths.get_full_path("loras", lora_name)] + loras_high = list(new_funmodels.get("loras_high", [])) + [folder_paths.get_full_path("loras", lora_high_name)] + strength_models = list(new_funmodels.get("strength_model", [])) + [strength_model] + new_funmodels['loras'] = loras + new_funmodels['loras_high'] = loras_high + new_funmodels['strength_model'] = strength_models + new_funmodels['lora_cache'] = lora_cache + return (new_funmodels,) class Wan2_2FunT2VSampler: @classmethod diff --git a/examples/cogvideox_fun/predict_i2v.py b/examples/cogvideox_fun/predict_i2v.py index 112fe025..0475af23 100755 --- a/examples/cogvideox_fun/predict_i2v.py +++ b/examples/cogvideox_fun/predict_i2v.py @@ -202,7 +202,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if partial_video_length is not None: partial_video_length = int((partial_video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -292,7 +292,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/cogvideox_fun/predict_t2v.py b/examples/cogvideox_fun/predict_t2v.py index 5133955e..376615b0 100755 --- a/examples/cogvideox_fun/predict_t2v.py +++ b/examples/cogvideox_fun/predict_t2v.py @@ -194,7 +194,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -232,7 +232,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/cogvideox_fun/predict_v2v.py b/examples/cogvideox_fun/predict_v2v.py index b1d7a435..84d10192 100755 --- a/examples/cogvideox_fun/predict_v2v.py +++ b/examples/cogvideox_fun/predict_v2v.py @@ -201,7 +201,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1 @@ -227,7 +227,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/cogvideox_fun/predict_v2v_control.py b/examples/cogvideox_fun/predict_v2v_control.py index 3668d3bc..1193d33c 100755 --- a/examples/cogvideox_fun/predict_v2v_control.py +++ b/examples/cogvideox_fun/predict_v2v_control.py @@ -188,7 +188,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1 @@ -212,7 +212,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/flux/predict_t2i.py b/examples/flux/predict_t2i.py index bfd6bc77..c8518bf0 100644 --- a/examples/flux/predict_t2i.py +++ b/examples/flux/predict_t2i.py @@ -184,7 +184,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): sample = pipeline( @@ -198,7 +198,7 @@ ).images if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/phantom/predict_s2v.py b/examples/phantom/predict_s2v.py index 89e3b824..0d7e8f4e 100644 --- a/examples/phantom/predict_s2v.py +++ b/examples/phantom/predict_s2v.py @@ -244,7 +244,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -272,7 +272,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/qwenimage/predict_t2i.py b/examples/qwenimage/predict_t2i.py index 27697dfd..51c5a4c3 100644 --- a/examples/qwenimage/predict_t2i.py +++ b/examples/qwenimage/predict_t2i.py @@ -176,7 +176,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): sample = pipeline( @@ -190,7 +190,7 @@ ).images if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1/predict_i2v.py b/examples/wan2.1/predict_i2v.py index 53459801..c2de0184 100755 --- a/examples/wan2.1/predict_i2v.py +++ b/examples/wan2.1/predict_i2v.py @@ -245,7 +245,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -273,7 +273,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1/predict_t2v.py b/examples/wan2.1/predict_t2v.py index 26bd668c..2e96d13b 100755 --- a/examples/wan2.1/predict_t2v.py +++ b/examples/wan2.1/predict_t2v.py @@ -232,7 +232,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -254,7 +254,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1_fun/predict_i2v.py b/examples/wan2.1_fun/predict_i2v.py index 52baeb59..dc86a68b 100755 --- a/examples/wan2.1_fun/predict_i2v.py +++ b/examples/wan2.1_fun/predict_i2v.py @@ -246,7 +246,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -274,7 +274,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1_fun/predict_t2v.py b/examples/wan2.1_fun/predict_t2v.py index fb085fba..bfb27598 100755 --- a/examples/wan2.1_fun/predict_t2v.py +++ b/examples/wan2.1_fun/predict_t2v.py @@ -253,7 +253,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -293,7 +293,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1_fun/predict_v2v_control.py b/examples/wan2.1_fun/predict_v2v_control.py index 6810cd61..85a74340 100755 --- a/examples/wan2.1_fun/predict_v2v_control.py +++ b/examples/wan2.1_fun/predict_v2v_control.py @@ -256,7 +256,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -305,7 +305,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1_fun/predict_v2v_control_camera.py b/examples/wan2.1_fun/predict_v2v_control_camera.py index d66077b4..40adc130 100755 --- a/examples/wan2.1_fun/predict_v2v_control_camera.py +++ b/examples/wan2.1_fun/predict_v2v_control_camera.py @@ -256,7 +256,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -305,7 +305,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1_fun/predict_v2v_control_ref.py b/examples/wan2.1_fun/predict_v2v_control_ref.py index 5f01abc6..b7115443 100755 --- a/examples/wan2.1_fun/predict_v2v_control_ref.py +++ b/examples/wan2.1_fun/predict_v2v_control_ref.py @@ -256,7 +256,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -305,7 +305,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1_vace/predict_i2v.py b/examples/wan2.1_vace/predict_i2v.py index 47667b30..050d4010 100644 --- a/examples/wan2.1_vace/predict_i2v.py +++ b/examples/wan2.1_vace/predict_i2v.py @@ -247,7 +247,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -283,7 +283,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1_vace/predict_s2v.py b/examples/wan2.1_vace/predict_s2v.py index e78a0348..c44c937a 100644 --- a/examples/wan2.1_vace/predict_s2v.py +++ b/examples/wan2.1_vace/predict_s2v.py @@ -247,7 +247,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -283,7 +283,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.1_vace/predict_v2v_control.py b/examples/wan2.1_vace/predict_v2v_control.py index 3f58b9a0..2c09c189 100644 --- a/examples/wan2.1_vace/predict_v2v_control.py +++ b/examples/wan2.1_vace/predict_v2v_control.py @@ -247,7 +247,7 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -283,7 +283,7 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2/predict_i2v.py b/examples/wan2.2/predict_i2v.py index 7369e99c..52c93c2d 100644 --- a/examples/wan2.2/predict_i2v.py +++ b/examples/wan2.2/predict_i2v.py @@ -278,8 +278,8 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -308,8 +308,8 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2/predict_s2v.py b/examples/wan2.2/predict_s2v.py index 98c0a1da..82876ebb 100644 --- a/examples/wan2.2/predict_s2v.py +++ b/examples/wan2.2/predict_s2v.py @@ -302,9 +302,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = video_length // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio if video_length != 1 else 1 @@ -339,8 +339,8 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2/predict_t2v.py b/examples/wan2.2/predict_t2v.py index a68c6e4a..28d56cef 100755 --- a/examples/wan2.2/predict_t2v.py +++ b/examples/wan2.2/predict_t2v.py @@ -273,8 +273,8 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -298,8 +298,8 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2/predict_ti2v.py b/examples/wan2.2/predict_ti2v.py index 548f599c..d039210c 100755 --- a/examples/wan2.2/predict_ti2v.py +++ b/examples/wan2.2/predict_ti2v.py @@ -290,9 +290,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -325,9 +325,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_i2v.py b/examples/wan2.2_fun/predict_i2v.py index 0b1d0379..79181b4d 100644 --- a/examples/wan2.2_fun/predict_i2v.py +++ b/examples/wan2.2_fun/predict_i2v.py @@ -292,9 +292,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -324,9 +324,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_i2v_5b.py b/examples/wan2.2_fun/predict_i2v_5b.py index 3c5468f8..1bbccaa0 100644 --- a/examples/wan2.2_fun/predict_i2v_5b.py +++ b/examples/wan2.2_fun/predict_i2v_5b.py @@ -294,9 +294,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -326,9 +326,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_t2v.py b/examples/wan2.2_fun/predict_t2v.py index 0f7c5bfc..347128ac 100644 --- a/examples/wan2.2_fun/predict_t2v.py +++ b/examples/wan2.2_fun/predict_t2v.py @@ -277,8 +277,8 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -307,8 +307,8 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_t2v_5b.py b/examples/wan2.2_fun/predict_t2v_5b.py index 9f501b39..49c5bb23 100644 --- a/examples/wan2.2_fun/predict_t2v_5b.py +++ b/examples/wan2.2_fun/predict_t2v_5b.py @@ -288,9 +288,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -320,9 +320,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_v2v_control.py b/examples/wan2.2_fun/predict_v2v_control.py index 39ed056f..b707b936 100644 --- a/examples/wan2.2_fun/predict_v2v_control.py +++ b/examples/wan2.2_fun/predict_v2v_control.py @@ -305,9 +305,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -351,9 +351,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_v2v_control_5b.py b/examples/wan2.2_fun/predict_v2v_control_5b.py index a7ebb5e3..16dab5d6 100644 --- a/examples/wan2.2_fun/predict_v2v_control_5b.py +++ b/examples/wan2.2_fun/predict_v2v_control_5b.py @@ -305,9 +305,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -351,9 +351,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_v2v_control_camera.py b/examples/wan2.2_fun/predict_v2v_control_camera.py index eeb6ee15..f112394a 100644 --- a/examples/wan2.2_fun/predict_v2v_control_camera.py +++ b/examples/wan2.2_fun/predict_v2v_control_camera.py @@ -305,9 +305,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -351,9 +351,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_v2v_control_camera_5b.py b/examples/wan2.2_fun/predict_v2v_control_camera_5b.py index d0b9fbba..bc7e4944 100644 --- a/examples/wan2.2_fun/predict_v2v_control_camera_5b.py +++ b/examples/wan2.2_fun/predict_v2v_control_camera_5b.py @@ -305,9 +305,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -351,9 +351,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_v2v_control_ref.py b/examples/wan2.2_fun/predict_v2v_control_ref.py index f5bb16ca..e842870e 100644 --- a/examples/wan2.2_fun/predict_v2v_control_ref.py +++ b/examples/wan2.2_fun/predict_v2v_control_ref.py @@ -305,9 +305,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -351,9 +351,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_fun/predict_v2v_control_ref_5b.py b/examples/wan2.2_fun/predict_v2v_control_ref_5b.py index c4c2b704..bf9bd04e 100644 --- a/examples/wan2.2_fun/predict_v2v_control_ref_5b.py +++ b/examples/wan2.2_fun/predict_v2v_control_ref_5b.py @@ -305,9 +305,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -351,9 +351,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_vace_fun/predict_i2v.py b/examples/wan2.2_vace_fun/predict_i2v.py index 276d7790..45f6065c 100644 --- a/examples/wan2.2_vace_fun/predict_i2v.py +++ b/examples/wan2.2_vace_fun/predict_i2v.py @@ -108,12 +108,15 @@ # Use torch.float16 if GPU does not support torch.bfloat16 # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 -weight_dtype = torch.bfloat16 -control_video = None -start_image = "asset/1.png" -end_image = None -subject_ref_images = None -vace_context_scale = 1.00 +weight_dtype = torch.bfloat16 +control_video = None +start_image = "asset/1.png" +end_image = None +# Use inpaint video instead of start image and end image. +inpaint_video = None +inpaint_video_mask = None +subject_ref_images = None +vace_context_scale = 1.00 # Sometimes, when generating a video from a reference image, white borders appear. # Because the padding is mistakenly treated as part of the image. # If the aspect ratio of the reference image is close to the final output, you can omit the white padding. @@ -306,9 +309,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -323,7 +326,14 @@ subject_ref_images = [get_image_latent(_subject_ref_image, sample_size=sample_size, padding=padding_in_subject_ref_images) for _subject_ref_image in subject_ref_images] subject_ref_images = torch.cat(subject_ref_images, dim=2) - inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=sample_size) + if inpaint_video is not None: + if inpaint_video_mask is None: + raise ValueError("inpaint_video_mask is required when inpaint_video is provided") + inpaint_video, _, _, _ = get_video_to_video_latent(inpaint_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask, _, _, _ = get_video_to_video_latent(inpaint_video_mask, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask = inpaint_video_mask[:, :1] + else: + inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=sample_size) control_video, _, _, _ = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) @@ -347,9 +357,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_vace_fun/predict_s2v.py b/examples/wan2.2_vace_fun/predict_s2v.py index ec0d2103..7d53f2cb 100644 --- a/examples/wan2.2_vace_fun/predict_s2v.py +++ b/examples/wan2.2_vace_fun/predict_s2v.py @@ -108,12 +108,15 @@ # Use torch.float16 if GPU does not support torch.bfloat16 # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 -weight_dtype = torch.bfloat16 -control_video = None -start_image = None -end_image = None -subject_ref_images = ["asset/8.png", "asset/ref_1.png"] -vace_context_scale = 1.00 +weight_dtype = torch.bfloat16 +control_video = None +start_image = None +end_image = None +# Use inpaint video instead of start image and end image. +inpaint_video = None +inpaint_video_mask = None +subject_ref_images = ["asset/8.png", "asset/ref_1.png"] +vace_context_scale = 1.00 # Sometimes, when generating a video from a reference image, white borders appear. # Because the padding is mistakenly treated as part of the image. # If the aspect ratio of the reference image is close to the final output, you can omit the white padding. @@ -306,9 +309,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -323,7 +326,14 @@ subject_ref_images = [get_image_latent(_subject_ref_image, sample_size=sample_size, padding=padding_in_subject_ref_images) for _subject_ref_image in subject_ref_images] subject_ref_images = torch.cat(subject_ref_images, dim=2) - inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=sample_size) + if inpaint_video is not None: + if inpaint_video_mask is None: + raise ValueError("inpaint_video_mask is required when inpaint_video is provided") + inpaint_video, _, _, _ = get_video_to_video_latent(inpaint_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask, _, _, _ = get_video_to_video_latent(inpaint_video_mask, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask = inpaint_video_mask[:, :1] + else: + inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=sample_size) control_video, _, _, _ = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) @@ -347,9 +357,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_vace_fun/predict_v2v_control.py b/examples/wan2.2_vace_fun/predict_v2v_control.py index e679206b..a23f35c5 100644 --- a/examples/wan2.2_vace_fun/predict_v2v_control.py +++ b/examples/wan2.2_vace_fun/predict_v2v_control.py @@ -108,12 +108,15 @@ # Use torch.float16 if GPU does not support torch.bfloat16 # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 -weight_dtype = torch.bfloat16 -control_video = "asset/pose.mp4" -start_image = None -end_image = None -subject_ref_images = None -vace_context_scale = 1.00 +weight_dtype = torch.bfloat16 +control_video = "asset/pose.mp4" +start_image = None +end_image = None +# Use inpaint video instead of start image and end image. +inpaint_video = None +inpaint_video_mask = None +subject_ref_images = None +vace_context_scale = 1.00 # Sometimes, when generating a video from a reference image, white borders appear. # Because the padding is mistakenly treated as part of the image. # If the aspect ratio of the reference image is close to the final output, you can omit the white padding. @@ -306,9 +309,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -323,7 +326,14 @@ subject_ref_images = [get_image_latent(_subject_ref_image, sample_size=sample_size, padding=padding_in_subject_ref_images) for _subject_ref_image in subject_ref_images] subject_ref_images = torch.cat(subject_ref_images, dim=2) - inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=sample_size) + if inpaint_video is not None: + if inpaint_video_mask is None: + raise ValueError("inpaint_video_mask is required when inpaint_video is provided") + inpaint_video, _, _, _ = get_video_to_video_latent(inpaint_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask, _, _, _ = get_video_to_video_latent(inpaint_video_mask, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask = inpaint_video_mask[:, :1] + else: + inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=sample_size) control_video, _, _, _ = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) @@ -347,9 +357,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_vace_fun/predict_v2v_control_ref.py b/examples/wan2.2_vace_fun/predict_v2v_control_ref.py index f846f71d..e0538361 100644 --- a/examples/wan2.2_vace_fun/predict_v2v_control_ref.py +++ b/examples/wan2.2_vace_fun/predict_v2v_control_ref.py @@ -108,12 +108,15 @@ # Use torch.float16 if GPU does not support torch.bfloat16 # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 -weight_dtype = torch.bfloat16 -control_video = "asset/pose.mp4" -start_image = None -end_image = None -subject_ref_images = ["asset/8.png"] -vace_context_scale = 1.00 +weight_dtype = torch.bfloat16 +control_video = "asset/pose.mp4" +start_image = None +end_image = None +# Use inpaint video instead of start image and end image. +inpaint_video = None +inpaint_video_mask = None +subject_ref_images = ["asset/8.png"] +vace_context_scale = 1.00 # Sometimes, when generating a video from a reference image, white borders appear. # Because the padding is mistakenly treated as part of the image. # If the aspect ratio of the reference image is close to the final output, you can omit the white padding. @@ -306,9 +309,9 @@ generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: - pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -323,7 +326,14 @@ subject_ref_images = [get_image_latent(_subject_ref_image, sample_size=sample_size, padding=padding_in_subject_ref_images) for _subject_ref_image in subject_ref_images] subject_ref_images = torch.cat(subject_ref_images, dim=2) - inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=sample_size) + if inpaint_video is not None: + if inpaint_video_mask is None: + raise ValueError("inpaint_video_mask is required when inpaint_video is provided") + inpaint_video, _, _, _ = get_video_to_video_latent(inpaint_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask, _, _, _ = get_video_to_video_latent(inpaint_video_mask, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask = inpaint_video_mask[:, :1] + else: + inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=sample_size) control_video, _, _, _ = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) @@ -347,9 +357,9 @@ ).videos if lora_path is not None: - pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) if transformer_2 is not None: - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/examples/wan2.2_vace_fun/predict_v2v_mask.py b/examples/wan2.2_vace_fun/predict_v2v_mask.py new file mode 100644 index 00000000..ec346f9a --- /dev/null +++ b/examples/wan2.2_vace_fun/predict_v2v_mask.py @@ -0,0 +1,387 @@ +import os +import sys + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from omegaconf import OmegaConf +from PIL import Image +from transformers import AutoTokenizer + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.models import (AutoencoderKLWan, AutoencoderKLWan3_8, AutoTokenizer, CLIPModel, + WanT5EncoderModel, VaceWanTransformer3DModel) +from videox_fun.data.dataset_image_video import process_pose_file +from videox_fun.models.cache_utils import get_teacache_coefficients +from videox_fun.pipeline import Wan2_2VaceFunPipeline, WanPipeline +from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from videox_fun.utils.lora_utils import merge_lora, unmerge_lora +from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, + get_video_to_video_latent, + save_videos_grid) +from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler +from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +# GPU memory mode, which can be chosen in [model_full_load, model_cpu_offload_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. +# model_full_load means that the entire model will be moved to the GPU. +# +# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "sequential_cpu_offload" +# Multi GPUs config +# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. +# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. +# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. +ulysses_degree = 1 +ring_degree = 1 +# Use FSDP to save more GPU memory in multi gpus. +fsdp_dit = False +fsdp_text_encoder = True +# Compile will give a speedup in fixed resolution and need a little GPU memory. +# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload. +compile_dit = False + +# Support TeaCache. +enable_teacache = True +# Recommended to be set between 0.05 and 0.30. A larger threshold can cache more steps, speeding up the inference process, +# but it may cause slight differences between the generated content and the original content. +# # --------------------------------------------------------------------------------------------------- # +# | Model Name | threshold | Model Name | threshold | +# | Wan2.1-VACE-1.3B | 0.05~0.10 | Wan2.1-VACE-14B | 0.10~0.15 | +# # --------------------------------------------------------------------------------------------------- # +teacache_threshold = 0.10 +# The number of steps to skip TeaCache at the beginning of the inference process, which can +# reduce the impact of TeaCache on generated video quality. +num_skip_start_steps = 5 +# Whether to offload TeaCache tensors to cpu to save a little bit of GPU memory. +teacache_offload = False + +# Skip some cfg steps in inference for acceleration +# Recommended to be set between 0.00 and 0.25 +cfg_skip_ratio = 0 + +# Riflex config +enable_riflex = False +# Index of intrinsic frequency +riflex_k = 6 + +# Config and model path +config_path = "config/wan2.2/wan_civitai_t2v.yaml" +# model path +model_name = "models/Diffusion_Transformer/Wan2.2-VACE-Fun-A14B" + +# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" +sampler_name = "Flow" +# [NOTE]: Noise schedule shift parameter. Affects temporal dynamics. +# Used when the sampler is in "Flow_Unipc", "Flow_DPM++". +shift = 12.0 + +# Load pretrained model if need +# The transformer_path is used for low noise model, the transformer_high_path is used for high noise model. +transformer_path = None +transformer_high_path = None +vae_path = None +# Load lora model if need +# The lora_path is used for low noise model, the lora_high_path is used for high noise model. +lora_path = None +lora_high_path = None + +# Other params +sample_size = [480, 832] +video_length = 81 +fps = 16 + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +control_video = None +start_image = None +end_image = None +# Use inpaint video instead of start image and end image. +inpaint_video = "asset/inpaint_video.mp4" +inpaint_video_mask = "asset/inpaint_video_mask.mp4" +subject_ref_images = None +vace_context_scale = 1.00 +# Sometimes, when generating a video from a reference image, white borders appear. +# Because the padding is mistakenly treated as part of the image. +# If the aspect ratio of the reference image is close to the final output, you can omit the white padding. +padding_in_subject_ref_images = True + +# 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性 +# 在neg prompt中添加"安静,固定"等词语可以增加动态性。 +prompt = "一只棕色的兔子舔了一下它的舌头,坐在舒适房间里的浅色沙发上。在兔子的后面,架子上有一幅镶框的画,周围是粉红色的花朵。房间里柔和温暖的灯光营造出舒适的氛围。" +negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + +# Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability +# Adding words such as "quiet, solid" to the neg prompt can increase dynamism. +# prompt = "A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical." +# negative_prompt = "Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." +guidance_scale = 5.0 +seed = 43 +num_inference_steps = 50 +# The lora_weight is used for low noise model, the lora_high_weight is used for high noise model. +lora_weight = 0.55 +lora_high_weight = 0.55 +save_path = "samples/vace-videos-fun" + +device = set_multi_gpus_devices(ulysses_degree, ring_degree) +config = OmegaConf.load(config_path) +boundary = config['transformer_additional_kwargs'].get('boundary', 0.875) + +transformer = VaceWanTransformer3DModel.from_pretrained( + os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) +if config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe": + transformer_2 = VaceWanTransformer3DModel.from_pretrained( + os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, + ) +else: + transformer_2 = None + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +if transformer_2 is not None: + if transformer_high_path is not None: + print(f"From checkpoint: {transformer_high_path}") + if transformer_high_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_high_path) + else: + state_dict = torch.load(transformer_high_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer_2.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get Vae +Chosen_AutoencoderKL = { + "AutoencoderKLWan": AutoencoderKLWan, + "AutoencoderKLWan3_8": AutoencoderKLWan3_8 +}[config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] +vae = Chosen_AutoencoderKL.from_pretrained( + os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), +).to(weight_dtype) + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get Tokenizer +tokenizer = AutoTokenizer.from_pretrained( + os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), +) + +# Get Text encoder +text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) +text_encoder = text_encoder.eval() + +# Get Scheduler +Chosen_Scheduler = scheduler_dict = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +}[sampler_name] +if sampler_name == "Flow_Unipc" or sampler_name == "Flow_DPM++": + config['scheduler_kwargs']['shift'] = 1 +scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs'])) +) + +# Get Pipeline +pipeline = Wan2_2VaceFunPipeline( + transformer=transformer, + transformer_2=transformer_2, + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=scheduler, +) +if ulysses_degree > 1 or ring_degree > 1: + from functools import partial + transformer.enable_multi_gpus_inference() + if transformer_2 is not None: + transformer_2.enable_multi_gpus_inference() + if fsdp_dit: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype) + pipeline.transformer = shard_fn(pipeline.transformer) + if transformer_2 is not None: + pipeline.transformer_2 = shard_fn(pipeline.transformer_2) + print("Add FSDP DIT") + if fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype) + pipeline.text_encoder = shard_fn(pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + +if compile_dit: + for i in range(len(pipeline.transformer.blocks)): + pipeline.transformer.blocks[i] = torch.compile(pipeline.transformer.blocks[i]) + if transformer_2 is not None: + for i in range(len(pipeline.transformer_2.blocks)): + pipeline.transformer_2.blocks[i] = torch.compile(pipeline.transformer_2.blocks[i]) + print("Add Compile") + +if GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(transformer, ["modulation",], device=device) + transformer.freqs = transformer.freqs.to(device=device) + if transformer_2 is not None: + replace_parameters_by_name(transformer_2, ["modulation",], device=device) + transformer_2.freqs = transformer_2.freqs.to(device=device) + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + if transformer_2 is not None: + convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device) + convert_weight_dtype_wrapper(transformer_2, weight_dtype) + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload": + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + if transformer_2 is not None: + convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device) + convert_weight_dtype_wrapper(transformer_2, weight_dtype) + pipeline.to(device=device) +else: + pipeline.to(device=device) + +coefficients = get_teacache_coefficients(model_name) if enable_teacache else None +if coefficients is not None: + print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") + pipeline.transformer.enable_teacache( + coefficients, num_inference_steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload + ) + if transformer_2 is not None: + pipeline.transformer_2.share_teacache(transformer=pipeline.transformer) + +if cfg_skip_ratio is not None: + print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") + pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, num_inference_steps) + if transformer_2 is not None: + pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer) + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device) + if transformer_2 is not None: + pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + +with torch.no_grad(): + video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 + latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1 + + if enable_riflex: + pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) + if transformer_2 is not None: + pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames) + + if subject_ref_images is not None: + subject_ref_images = [get_image_latent(_subject_ref_image, sample_size=sample_size, padding=padding_in_subject_ref_images) for _subject_ref_image in subject_ref_images] + subject_ref_images = torch.cat(subject_ref_images, dim=2) + + if inpaint_video is not None: + if inpaint_video_mask is None: + raise ValueError("inpaint_video_mask is required when inpaint_video is provided") + inpaint_video, _, _, _ = get_video_to_video_latent(inpaint_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask, _, _, _ = get_video_to_video_latent(inpaint_video_mask, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + inpaint_video_mask = inpaint_video_mask[:, :1] + else: + inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=video_length, sample_size=sample_size) + + control_video, _, _, _ = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + + sample = pipeline( + prompt, + num_frames = video_length, + negative_prompt = negative_prompt, + height = sample_size[0], + width = sample_size[1], + generator = generator, + guidance_scale = guidance_scale, + num_inference_steps = num_inference_steps, + + video = inpaint_video, + mask_video = inpaint_video_mask, + control_video = control_video, + subject_ref_images = subject_ref_images, + boundary = boundary, + shift = shift, + vace_context_scale = vace_context_scale + ).videos + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device) + if transformer_2 is not None: + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, sub_transformer_name="transformer_2") + +def save_results(): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + index = len([path for path in os.listdir(save_path)]) + 1 + prefix = str(index).zfill(8) + if video_length == 1: + video_path = os.path.join(save_path, prefix + ".png") + + image = sample[0, :, 0] + image = image.transpose(0, 1).transpose(1, 2) + image = (image * 255).numpy().astype(np.uint8) + image = Image.fromarray(image) + image.save(video_path) + else: + video_path = os.path.join(save_path, prefix + ".mp4") + save_videos_grid(sample, video_path, fps=fps) + +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/scripts/flux/README_TRAIN.md b/scripts/flux/README_TRAIN.md new file mode 100755 index 00000000..b091fffe --- /dev/null +++ b/scripts/flux/README_TRAIN.md @@ -0,0 +1,165 @@ +## Training Code + +We can choose whether to use deepspeed or fsdp in flux, which can save a lot of video memory. + +Some parameters in the sh file can be confusing, and they are explained in this document: + +- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. +- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. + - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` +- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. + +Without deepspeed: + +Training flux without DeepSpeed may result in insufficient GPU memory. +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +With deepspeed zero-2: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +Deepspeed zero-3: + +After training, you can use the following command to get the final model: +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Training shell command is as follows: +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +With FSDP: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap FluxSingleTransformerBlock,FluxTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` \ No newline at end of file diff --git a/scripts/flux/README_TRAIN_LORA.md b/scripts/flux/README_TRAIN_LORA.md new file mode 100755 index 00000000..021c4918 --- /dev/null +++ b/scripts/flux/README_TRAIN_LORA.md @@ -0,0 +1,153 @@ +## Lora Training Code + +We can choose whether to use deepspeed or fsdp in flux, which can save a lot of video memory. + +Some parameters in the sh file can be confusing, and they are explained in this document: + +- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. +- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. + - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` +- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. + +Without deepspeed: + +Training flux without DeepSpeed may result in insufficient GPU memory. +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling +``` + +With deepspeed zero-2: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling +``` + +Deepspeed zero-3: + +After training, you can use the following command to get the final model: +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Training shell command is as follows: +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling +``` + +With FSDP: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-InP" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap FluxSingleTransformerBlock,FluxTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling +``` \ No newline at end of file diff --git a/scripts/flux/train.py b/scripts/flux/train.py index 20e78d4e..6a840b63 100644 --- a/scripts/flux/train.py +++ b/scripts/flux/train.py @@ -1576,7 +1576,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Predict the noise residual with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): - print(noisy_latents.size(), prompt_embeds.size(), pooled_prompt_embeds.size(), text_ids.size(), latent_image_ids.size()) noise_pred = transformer3d( hidden_states=noisy_latents, timestep=timesteps / 1000, diff --git a/scripts/flux/train_lora.py b/scripts/flux/train_lora.py index 9274e11d..1b036736 100644 --- a/scripts/flux/train_lora.py +++ b/scripts/flux/train_lora.py @@ -1571,7 +1571,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Predict the noise residual with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): - print(noisy_latents.size(), prompt_embeds.size(), pooled_prompt_embeds.size(), text_ids.size(), latent_image_ids.size()) noise_pred = transformer3d( hidden_states=noisy_latents, timestep=timesteps / 1000, diff --git a/scripts/wan2.2_vace_fun/README_TRAIN.md b/scripts/wan2.2_vace_fun/README_TRAIN.md new file mode 100755 index 00000000..ca63aea9 --- /dev/null +++ b/scripts/wan2.2_vace_fun/README_TRAIN.md @@ -0,0 +1,257 @@ +## Training Code + +We can choose whether to use deepspeed or fsdp in Wan-Fun, which can save a lot of video memory. + +The metadata_control.json is a little different from normal json in Wan-Fun, you need to add a control_file_path, and [DWPose](https://github.com/IDEA-Research/DWPose) is suggested as tool to generate control file. + +```json +[ + { + "file_path": "train/00000001.mp4", + "control_file_path": "control/00000001.mp4", + "object_file_path": ["object/1.jpg", "object/2.jpg"], + "text": "A group of young men in suits and sunglasses are walking down a city street.", + "type": "video" + }, + { + "file_path": "train/00000002.jpg", + "control_file_path": "control/00000002.jpg", + "object_file_path": ["object/3.jpg", "object/4.jpg"], + "text": "Ba Da Ba Ba Ba Ba.", + "type": "image" + }, + ..... +] +``` + +Some parameters in the sh file can be confusing, and they are explained in this document: + +- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images and videos at the center, but instead, it trains the entire images and videos after grouping them into buckets based on resolution. +- `random_frame_crop` is used for random cropping on video frames to simulate videos with different frame counts. +- `random_hw_adapt` is used to enable automatic height and width scaling for images and videos. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `min(video_sample_size, 512)` as the minimum. For training videos, the height and width will be set to `image_sample_size` as the maximum and `min(video_sample_size, 512)` as the minimum. + - For example, when `random_hw_adapt` is enabled, with `video_sample_n_frames=49`, `video_sample_size=1024`, and `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024`, and the resolution of video inputs for training is `512x512x49` to `1024x1024x49`. + - For example, when `random_hw_adapt` is enabled, with `video_sample_n_frames=49`, `video_sample_size=256`, and `image_sample_size=1024`, the resolution of image inputs for training is `256x256` to `1024x1024`, and the resolution of video inputs for training is `256x256x49`. +- `training_with_video_token_length` specifies training the model according to token length. For training images and videos, the height and width will be set to `image_sample_size` as the maximum and `video_sample_size` as the minimum. + - For example, when `training_with_video_token_length` is enabled, with `video_sample_n_frames=49`, `token_sample_size=1024`, `video_sample_size=256`, and `image_sample_size=1024`, the resolution of image inputs for training is `256x256` to `1024x1024`, and the resolution of video inputs for training is `256x256x49` to `1024x1024x49`. + - For example, when `training_with_video_token_length` is enabled, with `video_sample_n_frames=49`, `token_sample_size=512`, `video_sample_size=256`, and `image_sample_size=1024`, the resolution of image inputs for training is `256x256` to `1024x1024`, and the resolution of video inputs for training is `256x256x49` to `1024x1024x9`. + - The token length for a video with dimensions 512x512 and 49 frames is 13,312. We need to set the `token_sample_size = 512`. + - At 512x512 resolution, the number of video frames is 49 (~= 512 * 512 * 49 / 512 / 512). + - At 768x768 resolution, the number of video frames is 21 (~= 512 * 512 * 49 / 768 / 768). + - At 1024x1024 resolution, the number of video frames is 9 (~= 512 * 512 * 49 / 1024 / 1024). + - These resolutions combined with their corresponding lengths allow the model to generate videos of different sizes. +- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. +- `train_mode` is used to set the training mode. + - The models named `Wan2.1-Fun-*-Control` are trained in the `control_ref` mode. + - The models named `Wan2.1-Fun-*-Control-Camera` are trained in the `control_ref_camera` mode. +- `control_ref_image` is used to specify the type of control image. The available options are `first_frame` and `random`. + - `first_frame` is used in V1.0 because V1.0 supports using a specified start frame as the control image. The Control-Camera models use the first frame as the control image. + - `random` is used in V1.1 because V1.1 supports both using a specified start frame and a reference image as the control image. +- `boundary_type`: The Wan2.2 series includes two distinct models that handle different noise levels, specified via the `boundary_type` parameter. `low`: Corresponds to the **low noise model** (low_noise_model). `high`: Corresponds to the **high noise model**. (high_noise_model). `full`: Corresponds to the ti2v 5B model (single mode). + +When train model with multi machines, please set the params as follows: +```sh +export MASTER_ADDR="your master address" +export MASTER_PORT=10086 +export WORLD_SIZE=1 # The number of machines +export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 +export RANK=0 # The rank of this machine + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +``` + +Wan-Fun-Control without deepspeed: +```sh +export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-VACE-Fun-A14B" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/wan2.2_vace_fun/train.py \ + --config_path="config/wan2.2/wan_civitai_t2v.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=1024 \ + --video_sample_size=256 \ + --token_sample_size=512 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --low_vram \ + --control_ref_image="random" \ + --boundary_type="low" \ + --trainable_modules "vace" +``` + +Wan-Fun-Control with deepspeed: +```sh +export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-VACE-Fun-A14B" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/wan2.2_vace_fun/train.py \ + --config_path="config/wan2.2/wan_civitai_t2v.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=1024 \ + --video_sample_size=256 \ + --token_sample_size=512 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --low_vram \ + --control_ref_image="random" \ + --boundary_type="low" \ + --trainable_modules "vace" +``` + +Wan-Fun-Control with deepspeed zero-3: + +Wan with DeepSpeed Zero-3 is suitable for 14B Wan at high resolutions. After training, you can use the following command to get the final model: +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Training shell command is as follows: +```sh +export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-VACE-Fun-A14B" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/wan2.2_vace_fun/train.py \ + --config_path="config/wan2.2/wan_civitai_t2v.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=1024 \ + --video_sample_size=256 \ + --token_sample_size=512 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --low_vram \ + --control_ref_image="random" \ + --boundary_type="low" \ + --trainable_modules "vace" +``` + +Wan-Fun-Control with FSDP: + +Wan with FSDP is suitable for 14B Wan at high resolutions. Training shell command is as follows: +```sh +export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-VACE-Fun-A14B" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=VaceWanAttentionBlock,BaseWanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.2_vace_fun/train.py \ + --config_path="config/wan2.2/wan_civitai_t2v.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=1024 \ + --video_sample_size=256 \ + --token_sample_size=512 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --low_vram \ + --control_ref_image="random" \ + --boundary_type="low" \ + --trainable_modules "vace" +``` \ No newline at end of file diff --git a/scripts/wan2.2_vace_fun/train.py b/scripts/wan2.2_vace_fun/train.py new file mode 100644 index 00000000..9af58b82 --- /dev/null +++ b/scripts/wan2.2_vace_fun/train.py @@ -0,0 +1,2078 @@ +"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import logging +import math +import os +import pickle +import random +import shutil +import sys + +import accelerate +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +from accelerate import Accelerator, FullyShardedDataParallelPlugin +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import DDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import (EMAModel, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3) +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from torch.utils.data import RandomSampler +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.utils import ContextManagers + +import datasets + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.data.bucket_sampler import (ASPECT_RATIO_512, + ASPECT_RATIO_RANDOM_CROP_512, + ASPECT_RATIO_RANDOM_CROP_PROB, + AspectRatioBatchImageVideoSampler, + RandomSampler, get_closest_ratio) +from videox_fun.data.dataset_image_video import (ImageVideoControlDataset, + ImageVideoDataset, + ImageVideoSampler, + get_random_mask, + padding_image, + process_pose_file, + process_pose_params) +from videox_fun.models import (AutoencoderKLWan, CLIPModel, + VaceWanTransformer3DModel, WanT5EncoderModel) +from videox_fun.pipeline import Wan2_2VaceFunPipeline +from videox_fun.utils.discrete_sampler import DiscreteSampling +from videox_fun.utils.lora_utils import (create_network, merge_lora, + unmerge_lora) +from videox_fun.utils.utils import (get_image_to_video_latent, + get_video_to_video_latent, + save_videos_grid) + +if is_wandb_available(): + import wandb + + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def linear_decay(initial_value, final_value, total_steps, current_step): + if current_step >= total_steps: + return final_value + current_step = max(0, current_step) + step_size = (final_value - initial_value) / total_steps + current_value = initial_value + step_size * current_step + return current_value + +def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None): + u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator) + t = 1 / (1 + torch.exp(-u)) * (high - low) + low + return torch.clip(t.to(torch.int32), low, high - 1) + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, text_encoder, tokenizer, clip_image_encoder, transformer3d, args, config, accelerator, weight_dtype, global_step): + try: + logger.info("Running validation... ") + + transformer3d_val = VaceWanTransformer3DModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict()) + scheduler = FlowMatchEulerDiscreteScheduler( + **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) + ) + + pipeline = Wan2_2VaceFunPipeline( + vae=accelerator.unwrap_model(vae).to(weight_dtype), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + transformer=transformer3d_val, + scheduler=scheduler, + clip_image_encoder=clip_image_encoder, + ) + pipeline = pipeline.to(accelerator.device) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.no_grad(): + with torch.autocast("cuda", dtype=weight_dtype): + video_length = int(args.video_sample_n_frames // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if args.video_sample_n_frames != 1 else 1 + input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(args.validation_paths[i], video_length=video_length, sample_size=[args.video_sample_size, args.video_sample_size]) + sample = pipeline( + args.validation_prompts[i], + num_frames = video_length, + negative_prompt = "bad detailed", + height = args.video_sample_size, + width = args.video_sample_size, + generator = generator, + + control_video = input_video, + ).videos + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-{i}.gif")) + + del pipeline + del transformer3d_val + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + return images + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + print(f"Eval error with info {e}") + return None + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. " + ), + ) + parser.add_argument( + "--train_data_meta", + type=str, + default=None, + help=( + "A csv containing the training data. " + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--use_came", + action="store_true", + help="whether to use came", + ) + parser.add_argument( + "--multi_stream", + action="store_true", + help="whether to use cuda multi-stream", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--vae_mini_batch", type=int, default=32, help="mini batch size for vae." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_model_info", action="store_true", help="Whether or not to report more info about model (such as norm, grad)." + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=2000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + parser.add_argument( + "--snr_loss", action="store_true", help="Whether or not to use snr_loss." + ) + parser.add_argument( + "--uniform_sampling", action="store_true", help="Whether or not to use uniform_sampling." + ) + parser.add_argument( + "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader." + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets." + ) + parser.add_argument( + "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets." + ) + parser.add_argument( + "--random_frame_crop", action="store_true", help="Whether enable random frame crop sample in datasets." + ) + parser.add_argument( + "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets." + ) + parser.add_argument( + "--training_with_video_token_length", action="store_true", help="The training stage of the model in training.", + ) + parser.add_argument( + "--auto_tile_batch_size", action="store_true", help="Whether to auto tile batch size.", + ) + parser.add_argument( + "--motion_sub_loss", action="store_true", help="Whether enable motion sub loss." + ) + parser.add_argument( + "--motion_sub_loss_ratio", type=float, default=0.25, help="The ratio of motion sub loss." + ) + parser.add_argument( + "--train_sampling_steps", + type=int, + default=1000, + help="Run train_sampling_steps.", + ) + parser.add_argument( + "--keep_all_node_same_token_length", + action="store_true", + help="Reference of the length token.", + ) + parser.add_argument( + "--token_sample_size", + type=int, + default=512, + help="Sample size of the token.", + ) + parser.add_argument( + "--video_sample_size", + type=int, + default=512, + help="Sample size of the video.", + ) + parser.add_argument( + "--image_sample_size", + type=int, + default=512, + help="Sample size of the image.", + ) + parser.add_argument( + "--fix_sample_size", + nargs=2, type=int, default=None, + help="Fix Sample size [height, width] when using bucket and collate_fn." + ) + parser.add_argument( + "--video_sample_stride", + type=int, + default=4, + help="Sample stride of the video.", + ) + parser.add_argument( + "--video_sample_n_frames", + type=int, + default=17, + help="Num frame of video.", + ) + parser.add_argument( + "--video_repeat", + type=int, + default=0, + help="Num of repeat video.", + ) + parser.add_argument( + "--config_path", + type=str, + default=None, + help=( + "The config of the model in training." + ), + ) + parser.add_argument( + "--transformer_path", + type=str, + default=None, + help=("If you want to load the weight from other transformers, input its path."), + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help=("If you want to load the weight from other vaes, input its path."), + ) + + parser.add_argument( + '--trainable_modules', + nargs='+', + help='Enter a list of trainable modules' + ) + parser.add_argument( + '--trainable_modules_low_learning_rate', + nargs='+', + default=[], + help='Enter a list of trainable modules with lower learning rate' + ) + parser.add_argument( + '--tokenizer_max_length', + type=int, + default=512, + help='Max length of tokenizer' + ) + parser.add_argument( + "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed." + ) + parser.add_argument( + "--use_fsdp", action="store_true", help="Whether or not to use fsdp." + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--boundary_type", + type=str, + default="low", + help=( + 'The format of training data. Support `"low"` and `"high"`' + ), + ) + parser.add_argument( + "--abnormal_norm_clip_start", + type=int, + default=1000, + help=( + 'When do we start doing additional processing on abnormal gradients. ' + ), + ) + parser.add_argument( + "--initial_grad_norm_ratio", + type=int, + default=5, + help=( + 'The initial gradient is relative to the multiple of the max_grad_norm. ' + ), + ) + parser.add_argument( + "--control_ref_image", + type=str, + default="first_frame", + help=( + 'The format of training data. Support `"first_frame"`' + ' (default), `"random"`.' + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + config = OmegaConf.load(args.config_path) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + deepspeed_plugin = accelerator.state.deepspeed_plugin if hasattr(accelerator.state, "deepspeed_plugin") else None + fsdp_plugin = accelerator.state.fsdp_plugin if hasattr(accelerator.state, "fsdp_plugin") else None + if deepspeed_plugin is not None: + zero_stage = int(deepspeed_plugin.zero_stage) + fsdp_stage = 0 + print(f"Using DeepSpeed Zero stage: {zero_stage}") + + args.use_deepspeed = True + if zero_stage == 3: + print(f"Auto set save_state to True because zero_stage == 3") + args.save_state = True + elif fsdp_plugin is not None: + from torch.distributed.fsdp import ShardingStrategy + zero_stage = 0 + if fsdp_plugin.sharding_strategy is ShardingStrategy.FULL_SHARD: + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is None: # The fsdp_plugin.sharding_strategy is None in FSDP 2. + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is ShardingStrategy.SHARD_GRAD_OP: + fsdp_stage = 2 + else: + fsdp_stage = 0 + print(f"Using FSDP stage: {fsdp_stage}") + + args.use_fsdp = True + if fsdp_stage == 3: + print(f"Auto set save_state to True because fsdp_stage == 3") + args.save_state = True + else: + zero_stage = 0 + fsdp_stage = 0 + print("DeepSpeed is not enabled.") + + if accelerator.is_main_process: + writer = SummaryWriter(log_dir=logging_dir) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index)) + torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index) + else: + rng = None + torch_rng = None + index_rng = np.random.default_rng(np.random.PCG64(43)) + print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}") + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Load scheduler, tokenizer and models. + noise_scheduler = FlowMatchEulerDiscreteScheduler( + **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) + ) + + # Get Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # Get Text encoder + text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, + ) + text_encoder = text_encoder.eval() + # Get Vae + vae = AutoencoderKLWan.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), + ) + vae.eval() + + # Get Transformer + if args.boundary_type == "low" or args.boundary_type == "full": + sub_path = config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer') + else: + sub_path = config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer') + transformer3d = VaceWanTransformer3DModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + + # Freeze vae and text_encoder and set transformer3d to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + transformer3d.requires_grad_(False) + + if args.transformer_path is not None: + print(f"From checkpoint: {args.transformer_path}") + if args.transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.transformer_path) + else: + state_dict = torch.load(args.transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer3d.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + + if args.vae_path is not None: + print(f"From checkpoint: {args.vae_path}") + if args.vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.vae_path) + else: + state_dict = torch.load(args.vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + + # A good trainable modules is showed below now. + # For 3D Patch: trainable_modules = ['ff.net', 'pos_embed', 'attn2', 'proj_out', 'timepositionalencoding', 'h_position', 'w_position'] + # For 2D Patch: trainable_modules = ['ff.net', 'attn2', 'timepositionalencoding', 'h_position', 'w_position'] + transformer3d.train() + if accelerator.is_main_process: + accelerator.print( + f"Trainable modules '{args.trainable_modules}'." + ) + for name, param in transformer3d.named_parameters(): + for trainable_module_name in args.trainable_modules + args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + param.requires_grad = True + break + + # Create EMA for the transformer3d. + if args.use_ema: + if zero_stage == 3: + raise NotImplementedError("FSDP does not support EMA.") + + ema_transformer3d = VaceWanTransformer3DModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + + ema_transformer3d = EMAModel(ema_transformer3d.parameters(), model_cls=VaceWanTransformer3DModel, model_config=ema_transformer3d.config) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + if fsdp_stage != 0: + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + + safetensor_save_path = os.path.join(output_dir, f"diffusion_pytorch_model.safetensors") + accelerate_state_dict = {k: v.to(dtype=weight_dtype) for k, v in accelerate_state_dict.items()} + save_file(accelerate_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + elif zero_stage == 3: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + else: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_transformer3d.save_pretrained(os.path.join(output_dir, "transformer_ema")) + + models[0].save_pretrained(os.path.join(output_dir, "transformer")) + if not args.use_deepspeed: + weights.pop() + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + if args.use_ema: + ema_path = os.path.join(input_dir, "transformer_ema") + _, ema_kwargs = VaceWanTransformer3DModel.load_config(ema_path, return_unused_kwargs=True) + load_model = VaceWanTransformer3DModel.from_pretrained( + input_dir, subfolder="transformer_ema", + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']) + ) + load_model = EMAModel(load_model.parameters(), model_cls=VaceWanTransformer3DModel, model_config=load_model.config) + load_model.load_state_dict(ema_kwargs) + + ema_transformer3d.load_state_dict(load_model.state_dict()) + ema_transformer3d.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = VaceWanTransformer3DModel.from_pretrained( + input_dir, subfolder="transformer" + ) + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + transformer3d.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + elif args.use_came: + try: + from came_pytorch import CAME + except: + raise ImportError( + "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`" + ) + + optimizer_cls = CAME + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list(filter(lambda p: p.requires_grad, transformer3d.parameters())) + trainable_params_optim = [ + {'params': [], 'lr': args.learning_rate}, + {'params': [], 'lr': args.learning_rate / 2}, + ] + in_already = [] + for name, param in transformer3d.named_parameters(): + high_lr_flag = False + if name in in_already: + continue + for trainable_module_name in args.trainable_modules: + if trainable_module_name in name: + in_already.append(name) + high_lr_flag = True + trainable_params_optim[0]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate}") + break + if high_lr_flag: + continue + for trainable_module_name in args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + in_already.append(name) + trainable_params_optim[1]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate / 2}") + break + + if args.use_came: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + else: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the training dataset + sample_n_frames_bucket_interval = vae.config.temporal_compression_ratio + + if args.fix_sample_size is not None and args.enable_bucket: + args.video_sample_size = max(max(args.fix_sample_size), args.video_sample_size) + args.image_sample_size = max(max(args.fix_sample_size), args.image_sample_size) + args.training_with_video_token_length = False + args.random_hw_adapt = False + + # Get the dataset + train_dataset = ImageVideoControlDataset( + args.train_data_meta, args.train_data_dir, + video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, + video_repeat=args.video_repeat, + image_sample_size=args.image_sample_size, + enable_bucket=args.enable_bucket, + enable_inpaint=False, + enable_camera_info=False, + enable_subject_info=True + ) + + def worker_init_fn(_seed): + _seed = _seed * 256 + def _worker_init_fn(worker_id): + print(f"worker_init_fn with {_seed + worker_id}") + np.random.seed(_seed + worker_id) + random.seed(_seed + worker_id) + return _worker_init_fn + + if args.enable_bucket: + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = AspectRatioBatchImageVideoSampler( + sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, + batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True, + aspect_ratios=aspect_ratio_sample_size, + ) + + def collate_fn(examples): + def get_length_to_frame_num(token_length): + if args.image_sample_size > args.video_sample_size: + sample_sizes = list(range(args.video_sample_size, args.image_sample_size + 1, 128)) + + if sample_sizes[-1] != args.image_sample_size: + sample_sizes.append(args.image_sample_size) + else: + sample_sizes = [args.image_sample_size] + + length_to_frame_num = { + sample_size: min(token_length / sample_size / sample_size, args.video_sample_n_frames) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 for sample_size in sample_sizes + } + + return length_to_frame_num + + def get_random_downsample_ratio(sample_size, image_ratio=[], + all_choices=False, rng=None): + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + first_element = 0.90 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + special_list = [first_element] + [other_elements_value] * (length - 1) + return special_list + + if sample_size >= 1536: + number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio + elif sample_size >= 1024: + number_list = [1, 1.25, 1.5, 2] + image_ratio + elif sample_size >= 768: + number_list = [1, 1.25, 1.5] + image_ratio + elif sample_size >= 512: + number_list = [1] + image_ratio + else: + number_list = [1] + + if all_choices: + return number_list + + number_list_prob = np.array(_create_special_list(len(number_list))) + if rng is None: + return np.random.choice(number_list, p = number_list_prob) + else: + return rng.choice(number_list, p = number_list_prob) + + # Get token length + target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size + length_to_frame_num = get_length_to_frame_num(target_token_length) + + # Create new output + new_examples = {} + new_examples["target_token_length"] = target_token_length + new_examples["pixel_values"] = [] + new_examples["text"] = [] + # Used in Control Mode + new_examples["control_pixel_values"] = [] + # Used in Control Ref Mode + new_examples["ref_pixel_values"] = [] + new_examples["clip_pixel_values"] = [] + new_examples["clip_idx"] = [] + + # Used in Inpaint mode + new_examples["mask_pixel_values"] = [] + new_examples["mask"] = [] + + new_examples["subject_images"] = [] + new_examples["subject_flags"] = [] + + # Get downsample ratio in image and videos + pixel_value = examples[0]["pixel_values"] + data_type = examples[0]["data_type"] + f, h, w, c = np.shape(pixel_value) + if data_type == 'image': + random_downsample_ratio = 1 if not args.random_hw_adapt else get_random_downsample_ratio(args.image_sample_size, image_ratio=[args.image_sample_size / args.video_sample_size]) + + aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + else: + if args.random_hw_adapt: + if args.training_with_video_token_length: + local_min_size = np.min(np.array([np.mean(np.array([np.shape(example["pixel_values"])[1], np.shape(example["pixel_values"])[2]])) for example in examples])) + + def get_random_downsample_probability(choice_list, token_sample_size): + length = len(choice_list) + if length == 1: + return [1.0] # If there's only one element, it gets all the probability + + # Find the index of the closest value to token_sample_size + closest_index = min(range(length), key=lambda i: abs(choice_list[i] - token_sample_size)) + + # Assign 50% to the closest index + first_element = 0.50 + remaining_sum = 1.0 - first_element + + # Distribute the remaining 50% evenly among the other elements + other_elements_value = remaining_sum / (length - 1) if length > 1 else 0.0 + + # Construct the probability distribution + probability_list = [other_elements_value] * length + probability_list[closest_index] = first_element + + return probability_list + + choice_list = [length for length in list(length_to_frame_num.keys()) if length < local_min_size * 1.25] + if len(choice_list) == 0: + choice_list = list(length_to_frame_num.keys()) + probabilities = get_random_downsample_probability(choice_list, args.token_sample_size) + local_video_sample_size = np.random.choice(choice_list, p=probabilities) + + random_downsample_ratio = args.video_sample_size / local_video_sample_size + batch_video_length = length_to_frame_num[local_video_sample_size] + else: + random_downsample_ratio = get_random_downsample_ratio(args.video_sample_size) + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + else: + random_downsample_ratio = 1 + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + if args.fix_sample_size is not None: + fix_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + elif args.random_ratio_crop: + if rng is None: + random_sample_size = aspect_ratio_random_crop_sample_size[ + np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + else: + random_sample_size = aspect_ratio_random_crop_sample_size[ + rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + random_sample_size = [int(x / 16) * 16 for x in random_sample_size] + else: + closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size) + closest_size = [int(x / 16) * 16 for x in closest_size] + + for example in examples: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + if args.control_ref_image == "first_frame": + clip_index = 0 + else: + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + first_element = 0.40 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + special_list = [first_element] + [other_elements_value] * (length - 1) + return special_list + number_list_prob = np.array(_create_special_list(len(pixel_values))) + clip_index = np.random.choice(list(range(len(pixel_values))), p = number_list_prob) + new_examples["clip_idx"].append(clip_index) + + ref_pixel_values = pixel_values[clip_index].permute(1, 2, 0).contiguous() + ref_pixel_values = Image.fromarray(np.uint8(ref_pixel_values * 255)) + ref_pixel_values = padding_image(ref_pixel_values, closest_size[1], closest_size[0]) + ref_pixel_values = (torch.tensor(np.array(ref_pixel_values)).unsqueeze(0).permute(0, 3, 1, 2).contiguous() / 255 - 0.5)/0.5 + new_examples["ref_pixel_values"].append(ref_pixel_values) + + control_pixel_values = torch.from_numpy(example["control_pixel_values"]).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + + _, channel, h, w = pixel_values.size() + new_subject_image = torch.zeros(4, channel, h, w) + num_subject = len(example["subject_image"]) + if num_subject != 0: + subject_image = torch.from_numpy(example["subject_image"]).permute(0, 3, 1, 2).contiguous() + new_subject_image[:num_subject] = subject_image + subject_image = new_subject_image / 255. + subject_flag = torch.from_numpy(np.array([1] * num_subject + [0] * (4 - num_subject))) + + if args.fix_sample_size is not None: + # Get adapt hw for resize + fix_sample_size = list(map(lambda x: int(x), fix_sample_size)) + transform = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + transform_no_normalize = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + ]) + elif args.random_ratio_crop: + # Get adapt hw for resize + b, c, h, w = pixel_values.size() + th, tw = random_sample_size + if th / tw > h / w: + nh = int(th) + nw = int(w / h * nh) + else: + nw = int(tw) + nh = int(h / w * nw) + + transform = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + transform_no_normalize = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + ]) + else: + # Get adapt hw for resize + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / h > closest_size[1] / w: + resize_size = closest_size[0], int(w * closest_size[0] / h) + else: + resize_size = int(h * closest_size[1] / w), closest_size[1] + + transform = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + transform_no_normalize = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + ]) + + new_examples["pixel_values"].append(transform(pixel_values)) + new_examples["control_pixel_values"].append(transform(control_pixel_values)) + + new_examples["text"].append(example["text"]) + # Magvae needs the number of frames to be 4n + 1. + batch_video_length = int( + min( + batch_video_length, + (len(pixel_values) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1, + ) + ) + if batch_video_length == 0: + batch_video_length = 1 + + clip_pixel_values = new_examples["pixel_values"][-1][clip_index].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + new_examples["clip_pixel_values"].append(clip_pixel_values) + + mask = get_random_mask(new_examples["pixel_values"][-1].size()) + mask_pixel_values = new_examples["pixel_values"][-1] * (1 - mask) + + # Wan 2.1 use 0 for masked pixels + # + torch.ones_like(new_examples["pixel_values"][-1]) * -1 * mask + new_examples["mask_pixel_values"].append(mask_pixel_values) + new_examples["mask"].append(mask) + + new_examples["subject_images"].append(transform(subject_image)) + new_examples["subject_flags"].append(subject_flag) + + # Limit the number of frames to the same + new_examples["pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["pixel_values"]]) + new_examples["control_pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["control_pixel_values"]]) + new_examples["ref_pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["ref_pixel_values"]]) + new_examples["clip_pixel_values"] = torch.stack([example for example in new_examples["clip_pixel_values"]]) + new_examples["clip_idx"] = torch.tensor(new_examples["clip_idx"]) + new_examples["mask_pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["mask_pixel_values"]]) + new_examples["mask"] = torch.stack([example[:batch_video_length] for example in new_examples["mask"]]) + new_examples["subject_images"] = torch.stack([example for example in new_examples["subject_images"]]) + new_examples["subject_flags"] = torch.stack([example for example in new_examples["subject_flags"]]) + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + prompt_ids = tokenizer( + new_examples['text'], + max_length=args.tokenizer_max_length, + padding="max_length", + add_special_tokens=True, + truncation=True, + return_tensors="pt" + ) + encoder_hidden_states = text_encoder( + prompt_ids.input_ids + )[0] + new_examples['encoder_attention_mask'] = prompt_ids.attention_mask + new_examples['encoder_hidden_states'] = encoder_hidden_states + + return new_examples + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + else: + # DataLoaders creation: + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + + if fsdp_stage != 0: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype) + text_encoder = shard_fn(text_encoder) + + # shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype) + # transformer3d = shard_fn(transformer3d) + + if args.use_ema: + ema_transformer3d.to(accelerator.device) + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu") + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + tracker_config.pop("trainable_modules") + tracker_config.pop("trainable_modules_low_learning_rate") + tracker_config.pop("fix_sample_size") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + + pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + _, first_epoch = pickle.load(file) + else: + first_epoch = global_step // num_update_steps_per_epoch + print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.") + + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.multi_stream: + # create extra cuda streams to speedup inpaint vae computation + vae_stream_1 = torch.cuda.Stream() + vae_stream_2 = torch.cuda.Stream() + else: + vae_stream_1 = None + vae_stream_2 = None + + # Calculate the index we need + boundary = config['transformer_additional_kwargs'].get('boundary', 0.900) + split_timesteps = args.train_sampling_steps * boundary + differences = torch.abs(noise_scheduler.timesteps - split_timesteps) + closest_index = torch.argmin(differences).item() + print(f"The boundary is {boundary} and the boundary_type is {args.boundary_type}. The closest_index we calculate is {closest_index}") + if args.boundary_type == "high": + start_num_idx = 0 + train_sampling_steps = closest_index + elif args.boundary_type == "low": + start_num_idx = closest_index + train_sampling_steps = args.train_sampling_steps - closest_index + else: + start_num_idx = 0 + train_sampling_steps = args.train_sampling_steps + idx_sampling = DiscreteSampling(train_sampling_steps, start_num_idx=start_num_idx, uniform_sampling=args.uniform_sampling) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch) + for step, batch in enumerate(train_dataloader): + # Data batch sanity check + if epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + control_pixel_values = batch["control_pixel_values"].cpu() + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + control_pixel_values = rearrange(control_pixel_values, "b f c h w -> b c f h w") + os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True) + for idx, (pixel_value, control_pixel_value, text) in enumerate(zip(pixel_values, control_pixel_values, texts)): + pixel_value = pixel_value[None, ...] + control_pixel_value = control_pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True) + save_videos_grid(control_pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}_control.gif", rescale=True) + + ref_pixel_values = batch["ref_pixel_values"].cpu() + ref_pixel_values = rearrange(ref_pixel_values, "b f c h w -> b c f h w") + for idx, (ref_pixel_value, text) in enumerate(zip(ref_pixel_values, texts)): + ref_pixel_value = ref_pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(ref_pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}_ref.gif", rescale=True) + + subject_images = batch["subject_images"].cpu() + subject_images = rearrange(subject_images, "b f c h w -> b c f h w") + for idx, (subject_image, text) in enumerate(zip(subject_images, texts)): + subject_image = subject_image[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(subject_image, f"{args.output_dir}/sanity_check/{gif_name[:10]}_subject.gif", rescale=True) + + clip_pixel_values, mask_pixel_values, texts = batch['clip_pixel_values'].cpu(), batch['mask_pixel_values'].cpu(), batch['text'] + mask_pixel_values = rearrange(mask_pixel_values, "b f c h w -> b c f h w") + for idx, (clip_pixel_value, pixel_value, text) in enumerate(zip(clip_pixel_values, mask_pixel_values, texts)): + pixel_value = pixel_value[None, ...] + Image.fromarray(np.uint8(clip_pixel_value)).save(f"{args.output_dir}/sanity_check/clip_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.png") + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/mask_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.gif", rescale=True) + + with accelerator.accumulate(transformer3d): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + control_pixel_values = batch["control_pixel_values"].to(weight_dtype) + + # Increase the batch size when the length of the latent sequence of the current sample is small + if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (4, 1, 1, 1, 1)) + control_pixel_values = torch.tile(control_pixel_values, (4, 1, 1, 1, 1)) + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (4, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (4, 1)) + else: + batch['text'] = batch['text'] * 4 + elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (2, 1, 1, 1, 1)) + control_pixel_values = torch.tile(control_pixel_values, (2, 1, 1, 1, 1)) + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (2, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (2, 1)) + else: + batch['text'] = batch['text'] * 2 + + ref_pixel_values = batch["ref_pixel_values"].to(weight_dtype) + clip_pixel_values = batch["clip_pixel_values"] + subject_images = batch["subject_images"].to(weight_dtype) + subject_flags = batch["subject_flags"].to(weight_dtype) + clip_idx = batch["clip_idx"] + mask_pixel_values = batch["mask_pixel_values"].to(weight_dtype) + mask = batch["mask"].to(weight_dtype) + + # Increase the batch size when the length of the latent sequence of the current sample is small + if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + clip_pixel_values = torch.tile(clip_pixel_values, (4, 1, 1, 1)) + subject_images = torch.tile(subject_images, (4, 1, 1, 1, 1)) + ref_pixel_values = torch.tile(ref_pixel_values, (4, 1, 1, 1, 1)) + subject_flags = torch.tile(subject_flags, (4, 1)) + clip_idx = torch.tile(clip_idx, (4,)) + mask_pixel_values = torch.tile(mask_pixel_values, (4, 1, 1, 1, 1)) + mask = torch.tile(mask, (4, 1, 1, 1, 1)) + elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + clip_pixel_values = torch.tile(clip_pixel_values, (2, 1, 1, 1)) + subject_images = torch.tile(subject_images, (2, 1, 1, 1, 1)) + ref_pixel_values = torch.tile(ref_pixel_values, (2, 1, 1, 1, 1)) + subject_flags = torch.tile(subject_flags, (2, 1)) + clip_idx = torch.tile(clip_idx, (2,)) + mask_pixel_values = torch.tile(mask_pixel_values, (2, 1, 1, 1, 1)) + mask = torch.tile(mask, (2, 1, 1, 1, 1)) + + if args.random_frame_crop: + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + last_element = 0.90 + remaining_sum = 1.0 - last_element + other_elements_value = remaining_sum / (length - 1) + special_list = [other_elements_value] * (length - 1) + [last_element] + return special_list + select_frames = [_tmp for _tmp in list(range(sample_n_frames_bucket_interval + 1, args.video_sample_n_frames + sample_n_frames_bucket_interval, sample_n_frames_bucket_interval))] + select_frames_prob = np.array(_create_special_list(len(select_frames))) + + if len(select_frames) != 0: + if rng is None: + temp_n_frames = np.random.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = rng.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = 1 + + # Magvae needs the number of frames to be 4n + 1. + temp_n_frames = (temp_n_frames - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :temp_n_frames, :, :] + control_pixel_values = control_pixel_values[:, :temp_n_frames, :, :] + mask_pixel_values = mask_pixel_values[:, :temp_n_frames, :, :] + mask = mask[:, :temp_n_frames, :, :] + + # Keep all node same token length to accelerate the traning when resolution grows. + if args.keep_all_node_same_token_length: + if args.token_sample_size > 256: + numbers_list = list(range(256, args.token_sample_size + 1, 128)) + + if numbers_list[-1] != args.token_sample_size: + numbers_list.append(args.token_sample_size) + else: + numbers_list = [256] + numbers_list = [_number * _number * args.video_sample_n_frames for _number in numbers_list] + + actual_token_length = index_rng.choice(numbers_list) + actual_video_length = (min( + actual_token_length / pixel_values.size()[-1] / pixel_values.size()[-2], args.video_sample_n_frames + ) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + actual_video_length = int(max(actual_video_length, 1)) + + # Magvae needs the number of frames to be 4n + 1. + actual_video_length = (actual_video_length - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :actual_video_length, :, :] + control_pixel_values = control_pixel_values[:, :actual_video_length, :, :] + mask_pixel_values = mask_pixel_values[:, :actual_video_length, :, :] + mask = mask[:, :actual_video_length, :, :] + + if args.low_vram: + torch.cuda.empty_cache() + vae.to(accelerator.device) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to("cpu") + + def vace_encode_frames(frames, ref_images, masks=None): + weight_dtype = frames.dtype + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames)[0].mode() + else: + masks = [torch.where(m > 0.5, 1.0, 0.0).to(weight_dtype) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive)[0].mode() + reactive = vae.encode(reactive)[0].mode() + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs)[0].mode() + else: + ref_latent = vae.encode(refs)[0].mode() + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(masks, ref_images=None, vae_stride=[4, 8, 8]): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + c, depth, height, width = mask.shape + mask_pad = mask.new_zeros(c, length, height, width) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + with torch.no_grad(): + # This way is quicker when batch grows up + def _batch_encode_vae(pixel_values): + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + return torch.cat(new_pixel_values, dim = 0) + if vae_stream_1 is not None: + vae_stream_1.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(vae_stream_1): + latents = _batch_encode_vae(pixel_values) + else: + latents = _batch_encode_vae(pixel_values) + + if rng is None: + subject_images_num = np.random.choice([0, 1, 2, 3, 4]) + else: + subject_images_num = rng.choice([0, 1, 2, 3, 4]) + + if rng is None: + use_full_photo_ref_flag = np.random.choice([True, False], p=[0.25, 0.75]) + else: + use_full_photo_ref_flag = rng.choice([True, False], p=[0.25, 0.75]) + + if not use_full_photo_ref_flag: + if subject_images_num == 0: + subject_ref_images = None + else: + subject_ref_images = rearrange(subject_images, "b f c h w -> b c f h w") + subject_ref_images = subject_ref_images[:, :, :subject_images_num] + + bs, c, f, h, w = subject_ref_images.size() + new_subject_ref_images = [] + for i in range(bs): + act_subject_images_num = min(subject_images_num, int(torch.sum(subject_flags[i]))) + + if act_subject_images_num == 0: + new_subject_ref_images.append(None) + else: + new_subject_ref_images.append([]) + for j in range(act_subject_images_num): + new_subject_ref_images[i].append(subject_ref_images[i, :, j:j+1]) + subject_ref_images = new_subject_ref_images + else: + ref_pixel_values = rearrange(ref_pixel_values, "b f c h w -> b c f h w") + + bs, c, f, h, w = ref_pixel_values.size() + new_ref_pixel_values = [] + for i in range(bs): + new_ref_pixel_values.append([]) + for j in range(1): + new_ref_pixel_values[i].append(ref_pixel_values[i, :, j:j+1]) + subject_ref_images = new_ref_pixel_values + + if rng is None: + inpaint_flag = np.random.choice([True, False], p=[0.75, 0.25]) + else: + inpaint_flag = rng.choice([True, False], p=[0.75, 0.25]) + mask = rearrange(mask, "b f c h w -> b c f h w") + mask = torch.tile(mask, [1, 3, 1, 1, 1]) + if inpaint_flag or (control_pixel_values == -1).all(): + if rng is None: + do_not_use_ref_images = np.random.choice([True, False], p=[0.50, 0.50]) + else: + do_not_use_ref_images = rng.choice([True, False], p=[0.50, 0.50]) + if do_not_use_ref_images: + subject_ref_images = None + mask_pixel_values = rearrange(mask_pixel_values, "b f c h w -> b c f h w") + vace_latents = vace_encode_frames(mask_pixel_values, subject_ref_images, mask) + else: + control_pixel_values = rearrange(control_pixel_values, "b f c h w -> b c f h w") + vace_latents = vace_encode_frames(control_pixel_values, subject_ref_images, mask) + mask = torch.ones_like(mask) + + mask_latents = vace_encode_masks(mask, subject_ref_images) + vace_context = torch.stack(vace_latent(vace_latents, mask_latents)) + + if subject_ref_images is not None: + for i in range(len(subject_ref_images)): + if subject_ref_images[i] is not None: + + subject_ref_images[i] = torch.cat( + [subject_ref_image.unsqueeze(0) for subject_ref_image in subject_ref_images[i]], 2 + ) + subject_ref_images[i] = torch.cat( + [vae.encode(subject_ref_images[i][:, :, j:j+1])[0].sample() for j in range(subject_ref_images[i].size(2))], 2 + ) + + if subject_ref_images[0] is not None: + subject_ref_images = torch.cat(subject_ref_images) + latents = torch.cat( + [subject_ref_images, latents], dim=2 + ) + + # wait for latents = vae.encode(pixel_values) to complete + if vae_stream_1 is not None: + torch.cuda.current_stream().wait_stream(vae_stream_1) + + if args.low_vram: + vae.to('cpu') + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device) + else: + with torch.no_grad(): + prompt_ids = tokenizer( + batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids + prompt_attention_mask = prompt_ids.attention_mask + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(latents.device), attention_mask=prompt_attention_mask.to(latents.device))[0] + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + if args.low_vram and not args.enable_text_encoder_in_dataloader: + text_encoder.to('cpu') + torch.cuda.empty_cache() + + bsz, channel, num_frames, height, width = latents.size() + noise = torch.randn( + (bsz, channel, num_frames, height, width), device=latents.device, generator=torch_rng, dtype=weight_dtype) + + if not args.uniform_sampling: + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + else: + # Sample a random timestep for each image + # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + # timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + indices = idx_sampling(bsz, generator=torch_rng, device=latents.device) + indices = indices.long().cpu() + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + + # Add noise + target = noise - latents + + target_shape = (vae.latent_channels, vace_latents[0].size(1), width, height) + seq_len = math.ceil( + (target_shape[2] * target_shape[3]) / + (accelerator.unwrap_model(transformer3d).config.patch_size[1] * accelerator.unwrap_model(transformer3d).config.patch_size[2]) * + target_shape[1] + ) + + # Predict the noise residual + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + noise_pred = transformer3d( + x=noisy_latents, + context=prompt_embeds, + t=timesteps, + seq_len=seq_len, + vace_context=vace_context, + ) + + def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): + noise_pred = noise_pred.float() + target = target.float() + diff = noise_pred - target + mse_loss = F.mse_loss(noise_pred, target, reduction='none') + mask = (diff.abs() <= threshold).float() + + masked_loss = mse_loss * mask + if weighting is not None: + masked_loss = masked_loss * weighting + final_loss = masked_loss.mean() + return final_loss + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + loss = custom_mse_loss(noise_pred.float(), target.float(), weighting.float()) + loss = loss.mean() + + if args.motion_sub_loss and noise_pred.size()[1] > 2: + gt_sub_noise = noise_pred[:, :, 1:].float() - noise_pred[:, :, :-1].float() + pre_sub_noise = target[:, :, 1:].float() - target[:, :, :-1].float() + sub_loss = F.mse_loss(gt_sub_noise, pre_sub_noise, reduction="mean") + loss = loss * (1 - args.motion_sub_loss_ratio) + sub_loss * args.motion_sub_loss_ratio + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + if not args.use_deepspeed and not args.use_fsdp: + trainable_params_grads = [p.grad for p in trainable_params if p.grad is not None] + trainable_params_total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2) for g in trainable_params_grads]), 2) + max_grad_norm = linear_decay(args.max_grad_norm * args.initial_grad_norm_ratio, args.max_grad_norm, args.abnormal_norm_clip_start, global_step) + if trainable_params_total_norm / max_grad_norm > 5 and global_step > args.abnormal_norm_clip_start: + actual_max_grad_norm = max_grad_norm / min((trainable_params_total_norm / max_grad_norm), 10) + else: + actual_max_grad_norm = max_grad_norm + else: + actual_max_grad_norm = args.max_grad_norm + + if not args.use_deepspeed and not args.use_fsdp and args.report_model_info and accelerator.is_main_process: + if trainable_params_total_norm > 1 and global_step > args.abnormal_norm_clip_start: + for name, param in transformer3d.named_parameters(): + if param.requires_grad: + writer.add_scalar(f'gradients/before_clip_norm/{name}', param.grad.norm(), global_step=global_step) + + norm_sum = accelerator.clip_grad_norm_(trainable_params, actual_max_grad_norm) + if not args.use_deepspeed and not args.use_fsdp and args.report_model_info and accelerator.is_main_process: + writer.add_scalar(f'gradients/norm_sum', norm_sum, global_step=global_step) + writer.add_scalar(f'gradients/actual_max_grad_norm', actual_max_grad_norm, global_step=global_step) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + if args.use_ema: + ema_transformer3d.step(transformer3d.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if accelerator.is_main_process: + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_transformer3d.store(transformer3d.parameters()) + ema_transformer3d.copy_to(transformer3d.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + clip_image_encoder, + transformer3d, + args, + config, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original transformer3d parameters. + ema_transformer3d.restore(transformer3d.parameters()) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_transformer3d.store(transformer3d.parameters()) + ema_transformer3d.copy_to(transformer3d.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + clip_image_encoder, + transformer3d, + args, + config, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original transformer3d parameters. + ema_transformer3d.restore(transformer3d.parameters()) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer3d = unwrap_model(transformer3d) + if args.use_ema: + ema_transformer3d.copy_to(transformer3d.parameters()) + + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/scripts/wan2.2_vace_fun/train.sh b/scripts/wan2.2_vace_fun/train.sh new file mode 100644 index 00000000..c207afb6 --- /dev/null +++ b/scripts/wan2.2_vace_fun/train.sh @@ -0,0 +1,43 @@ +export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-VACE-Fun-A14B" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/wan2.2_vace_fun/train.py \ + --config_path="config/wan2.2/wan_civitai_t2v.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=1024 \ + --video_sample_size=256 \ + --token_sample_size=512 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --low_vram \ + --control_ref_image="random" \ + --boundary_type="low" \ + --trainable_modules "vace" \ No newline at end of file diff --git a/videox_fun/models/wan_transformer3d_vace.py b/videox_fun/models/wan_transformer3d_vace.py index e5b619ca..7ca57f10 100644 --- a/videox_fun/models/wan_transformer3d_vace.py +++ b/videox_fun/models/wan_transformer3d_vace.py @@ -3,6 +3,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict +import os import math import torch import torch.cuda.amp as amp @@ -14,6 +15,8 @@ sinusoidal_embedding_1d) +VIDEOX_OFFLOAD_VACE_LATENTS = os.environ.get("VIDEOX_OFFLOAD_VACE_LATENTS", False) + class VaceWanAttentionBlock(WanAttentionBlock): def __init__( self, @@ -45,8 +48,16 @@ def forward(self, c, x, **kwargs): all_c = list(torch.unbind(c)) c = all_c.pop(-1) + if VIDEOX_OFFLOAD_VACE_LATENTS: + c = c.to(x.device) + c = super().forward(c, **kwargs) c_skip = self.after_proj(c) + + if VIDEOX_OFFLOAD_VACE_LATENTS: + c_skip = c_skip.to("cpu") + c = c.to("cpu") + all_c += [c_skip, c] c = torch.stack(all_c) return c @@ -71,7 +82,10 @@ def __init__( def forward(self, x, hints, context_scale=1.0, **kwargs): x = super().forward(x, **kwargs) if self.block_id is not None: - x = x + hints[self.block_id] * context_scale + if VIDEOX_OFFLOAD_VACE_LATENTS: + x = x + hints[self.block_id].to(x.device) * context_scale + else: + x = x + hints[self.block_id] * context_scale return x diff --git a/videox_fun/utils/lora_utils.py b/videox_fun/utils/lora_utils.py index b0766259..9b683c1c 100755 --- a/videox_fun/utils/lora_utils.py +++ b/videox_fun/utils/lora_utils.py @@ -383,6 +383,14 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3 key = key.replace(".self_attn.", "_self_attn_") key = key.replace(".cross_attn.", "_cross_attn_") key = key.replace(".ffn.", "_ffn_") + if "lora_A" in key or "lora_B" in key: + key = "lora_unet__" + key + key = key.replace("blocks.", "blocks_") + key = key.replace(".self_attn.", "_self_attn_") + key = key.replace(".cross_attn.", "_cross_attn_") + key = key.replace(".ffn.", "_ffn_") + key = key.replace(".lora_A.default.", ".lora_down.") + key = key.replace(".lora_B.default.", ".lora_up.") layer, elem = key.split('.', 1) updates[layer][elem] = value @@ -496,6 +504,14 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl key = key.replace(".self_attn.", "_self_attn_") key = key.replace(".cross_attn.", "_cross_attn_") key = key.replace(".ffn.", "_ffn_") + if "lora_A" in key or "lora_B" in key: + key = "lora_unet__" + key + key = key.replace("blocks.", "blocks_") + key = key.replace(".self_attn.", "_self_attn_") + key = key.replace(".cross_attn.", "_cross_attn_") + key = key.replace(".ffn.", "_ffn_") + key = key.replace(".lora_A.default.", ".lora_down.") + key = key.replace(".lora_B.default.", ".lora_up.") layer, elem = key.split('.', 1) updates[layer][elem] = value