Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added asset/inpaint_video.mp4
Binary file not shown.
Binary file added asset/inpaint_video_mask.mp4
Binary file not shown.
17 changes: 11 additions & 6 deletions comfyui/cogvideox_fun/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,18 @@ def INPUT_TYPES(s):
CATEGORY = "CogVideoXFUNWrapper"

def load_lora(self, cogvideoxfun_model, lora_name, strength_model, lora_cache):
new_funmodels = dict(cogvideoxfun_model)

if lora_name is not None:
cogvideoxfun_model['lora_cache'] = lora_cache
cogvideoxfun_model['loras'] = cogvideoxfun_model.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)]
cogvideoxfun_model['strength_model'] = cogvideoxfun_model.get("strength_model", []) + [strength_model]
return (cogvideoxfun_model,)
else:
return (cogvideoxfun_model,)
lora_path = folder_paths.get_full_path("loras", lora_name)
if lora_path is None:
raise FileNotFoundError(f"LoRA 文件未找到: {lora_name}")

new_funmodels['lora_cache'] = lora_cache
new_funmodels['loras'] = cogvideoxfun_model.get("loras", []) + [lora_path]
new_funmodels['strength_model'] = cogvideoxfun_model.get("strength_model", []) + [strength_model]

return (new_funmodels,)

class CogVideoXFunT2VSampler:
@classmethod
Expand Down
54 changes: 45 additions & 9 deletions comfyui/comfyui_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os

import cv2
import numpy as np
Expand Down Expand Up @@ -83,20 +84,28 @@ def INPUT_TYPES(s):
def compile(self, cache_size_limit, funmodels):
torch._dynamo.config.cache_size_limit = cache_size_limit
if hasattr(funmodels["pipeline"].transformer, "blocks"):
for i in range(len(funmodels["pipeline"].transformer.blocks)):
funmodels["pipeline"].transformer.blocks[i] = torch.compile(funmodels["pipeline"].transformer.blocks[i])
for i, block in enumerate(funmodels["pipeline"].transformer.blocks):
if hasattr(block, "_orig_mod"):
block = block._orig_mod
funmodels["pipeline"].transformer.blocks[i] = torch.compile(block)

if hasattr(funmodels["pipeline"], "transformer_2") and funmodels["pipeline"].transformer_2 is not None:
for i in range(len(funmodels["pipeline"].transformer_2.blocks)):
funmodels["pipeline"].transformer_2.blocks[i] = torch.compile(funmodels["pipeline"].transformer_2.blocks[i])
for i, block in enumerate(funmodels["pipeline"].transformer_2.blocks):
if hasattr(block, "_orig_mod"):
block = block._orig_mod
funmodels["pipeline"].transformer_2.blocks[i] = torch.compile(block)

elif hasattr(funmodels["pipeline"].transformer, "transformer_blocks"):
for i in range(len(funmodels["pipeline"].transformer.transformer_blocks)):
funmodels["pipeline"].transformer.transformer_blocks[i] = torch.compile(funmodels["pipeline"].transformer.transformer_blocks[i])

for i, block in enumerate(funmodels["pipeline"].transformer.transformer_blocks):
if hasattr(block, "_orig_mod"):
block = block._orig_mod
funmodels["pipeline"].transformer.transformer_blocks[i] = torch.compile(block)

if hasattr(funmodels["pipeline"], "transformer_2") and funmodels["pipeline"].transformer_2 is not None:
for i in range(len(funmodels["pipeline"].transformer_2.transformer_blocks)):
funmodels["pipeline"].transformer_2.transformer_blocks[i] = torch.compile(funmodels["pipeline"].transformer_2.transformer_blocks[i])
for i, block in enumerate(funmodels["pipeline"].transformer_2.transformer_blocks):
if hasattr(block, "_orig_mod"):
block = block._orig_mod
funmodels["pipeline"].transformer_2.transformer_blocks[i] = torch.compile(block)

else:
funmodels["pipeline"].transformer.forward = torch.compile(funmodels["pipeline"].transformer.forward)
Expand All @@ -106,6 +115,31 @@ def compile(self, cache_size_limit, funmodels):

print("Add Compile")
return (funmodels,)

class FunAttention:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"attention_type": (
["flash", "sage", "torch"],
{"default": "flash"},
),
"funmodels": ("FunModels",)
}
}
RETURN_TYPES = ("FunModels",)
RETURN_NAMES = ("funmodels",)
FUNCTION = "funattention"
CATEGORY = "CogVideoXFUNWrapper"

def funattention(self, attention_type, funmodels):
os.environ['VIDEOX_ATTENTION_TYPE'] = {
"flash": "FLASH_ATTENTION",
"sage": "SAGE_ATTENTION",
"torch": "TORCH_SCALED_DOT"
}[attention_type]
return (funmodels,)

class LoadConfig:
@classmethod
Expand Down Expand Up @@ -376,6 +410,7 @@ def run(self,camera_pose,fx,fy,cx,cy):
"FunTextBox": FunTextBox,
"FunRiflex": FunRiflex,
"FunCompile": FunCompile,
"FunAttention": FunAttention,

"LoadCogVideoXFunModel": LoadCogVideoXFunModel,
"LoadCogVideoXFunLora": LoadCogVideoXFunLora,
Expand Down Expand Up @@ -436,6 +471,7 @@ def run(self,camera_pose,fx,fy,cx,cy):
"FunTextBox": "FunTextBox",
"FunRiflex": "FunRiflex",
"FunCompile": "FunCompile",
"FunAttention": "FunAttention",

"LoadWanClipEncoderModel": "Load Wan ClipEncoder Model",
"LoadWanTextEncoderModel": "Load Wan TextEncoder Model",
Expand Down
16 changes: 10 additions & 6 deletions comfyui/wan2_1/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,13 +612,17 @@ def INPUT_TYPES(s):
CATEGORY = "CogVideoXFUNWrapper"

def load_lora(self, funmodels, lora_name, strength_model, lora_cache):
new_funmodels = dict(funmodels)

if lora_name is not None:
funmodels['lora_cache'] = lora_cache
funmodels['loras'] = funmodels.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)]
funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model]
return (funmodels,)
else:
return (funmodels,)
lora_path = folder_paths.get_full_path("loras", lora_name)

new_funmodels['lora_cache'] = lora_cache
new_funmodels['loras'] = funmodels.get("loras", []) + [lora_path]
new_funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model]

return (new_funmodels,)


class WanT2VSampler:
@classmethod
Expand Down
15 changes: 9 additions & 6 deletions comfyui/wan2_1_fun/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,16 @@ def INPUT_TYPES(s):
CATEGORY = "CogVideoXFUNWrapper"

def load_lora(self, funmodels, lora_name, strength_model, lora_cache):
new_funmodels = dict(funmodels)

if lora_name is not None:
funmodels['lora_cache'] = lora_cache
funmodels['loras'] = funmodels.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)]
funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model]
return (funmodels,)
else:
return (funmodels,)
lora_path = folder_paths.get_full_path("loras", lora_name)

new_funmodels['lora_cache'] = lora_cache
new_funmodels['loras'] = funmodels.get("loras", []) + [lora_path]
new_funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model]

return (new_funmodels,)

class WanFunT2VSampler:
@classmethod
Expand Down
16 changes: 9 additions & 7 deletions comfyui/wan2_2/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,14 +463,16 @@ def INPUT_TYPES(s):
CATEGORY = "CogVideoXFUNWrapper"

def load_lora(self, funmodels, lora_name, lora_high_name, strength_model, lora_cache):
new_funmodels = dict(funmodels)
if lora_name is not None:
funmodels['lora_cache'] = lora_cache
funmodels['loras'] = funmodels.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)]
funmodels['loras_high'] = funmodels.get("loras_high", []) + [folder_paths.get_full_path("loras", lora_high_name)]
funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model]
return (funmodels,)
else:
return (funmodels,)
loras = list(new_funmodels.get("loras", [])) + [folder_paths.get_full_path("loras", lora_name)]
loras_high = list(new_funmodels.get("loras_high", [])) + [folder_paths.get_full_path("loras", lora_high_name)]
strength_models = list(new_funmodels.get("strength_model", [])) + [strength_model]
new_funmodels['loras'] = loras
new_funmodels['loras_high'] = loras_high
new_funmodels['strength_model'] = strength_models
new_funmodels['lora_cache'] = lora_cache
return (new_funmodels,)

class Wan2_2T2VSampler:
@classmethod
Expand Down
16 changes: 9 additions & 7 deletions comfyui/wan2_2_fun/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,16 @@ def INPUT_TYPES(s):
CATEGORY = "CogVideoXFUNWrapper"

def load_lora(self, funmodels, lora_name, lora_high_name, strength_model, lora_cache):
new_funmodels = dict(funmodels)
if lora_name is not None:
funmodels['lora_cache'] = lora_cache
funmodels['loras'] = funmodels.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)]
funmodels['loras_high'] = funmodels.get("loras_high", []) + [folder_paths.get_full_path("loras", lora_high_name)]
funmodels['strength_model'] = funmodels.get("strength_model", []) + [strength_model]
return (funmodels,)
else:
return (funmodels,)
loras = list(new_funmodels.get("loras", [])) + [folder_paths.get_full_path("loras", lora_name)]
loras_high = list(new_funmodels.get("loras_high", [])) + [folder_paths.get_full_path("loras", lora_high_name)]
strength_models = list(new_funmodels.get("strength_model", [])) + [strength_model]
new_funmodels['loras'] = loras
new_funmodels['loras_high'] = loras_high
new_funmodels['strength_model'] = strength_models
new_funmodels['lora_cache'] = lora_cache
return (new_funmodels,)

class Wan2_2FunT2VSampler:
@classmethod
Expand Down
4 changes: 2 additions & 2 deletions examples/cogvideox_fun/predict_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

if partial_video_length is not None:
partial_video_length = int((partial_video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
Expand Down Expand Up @@ -292,7 +292,7 @@
).videos

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
4 changes: 2 additions & 2 deletions examples/cogvideox_fun/predict_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

with torch.no_grad():
video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
Expand Down Expand Up @@ -232,7 +232,7 @@
).videos

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
4 changes: 2 additions & 2 deletions examples/cogvideox_fun/predict_v2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
Expand All @@ -227,7 +227,7 @@
).videos

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
4 changes: 2 additions & 2 deletions examples/cogvideox_fun/predict_v2v_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
Expand All @@ -212,7 +212,7 @@
).videos

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
4 changes: 2 additions & 2 deletions examples/flux/predict_t2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

with torch.no_grad():
sample = pipeline(
Expand All @@ -198,7 +198,7 @@
).images

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
4 changes: 2 additions & 2 deletions examples/phantom/predict_s2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

with torch.no_grad():
video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
Expand Down Expand Up @@ -272,7 +272,7 @@
).videos

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
4 changes: 2 additions & 2 deletions examples/qwenimage/predict_t2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

with torch.no_grad():
sample = pipeline(
Expand All @@ -190,7 +190,7 @@
).images

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
4 changes: 2 additions & 2 deletions examples/wan2.1/predict_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

with torch.no_grad():
video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
Expand Down Expand Up @@ -273,7 +273,7 @@
).videos

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
4 changes: 2 additions & 2 deletions examples/wan2.1/predict_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

with torch.no_grad():
video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
Expand All @@ -254,7 +254,7 @@
).videos

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
4 changes: 2 additions & 2 deletions examples/wan2.1_fun/predict_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
generator = torch.Generator(device=device).manual_seed(seed)

if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

with torch.no_grad():
video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
Expand Down Expand Up @@ -274,7 +274,7 @@
).videos

if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device)
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
if not os.path.exists(save_path):
Expand Down
Loading