Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
gfpgan=None,
codeformer=None,
esrgan=None,
free_gpu_mem=False,
free_gpu_mem: bool=False,
safety_checker:bool=False,
max_loaded_models:int=2,
# these are deprecated; if present they override values in the conf file
Expand Down Expand Up @@ -460,10 +460,13 @@ def process_image(image,seed):
init_image = None
mask_image = None


if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device:
self.model.cond_stage_model.device = self.model.device
self.model.cond_stage_model.to(self.model.device)
try:
if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device:
self.model.cond_stage_model.device = self.model.device
self.model.cond_stage_model.to(self.model.device)
except AttributeError:
print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.")
pass

try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
Expand Down Expand Up @@ -531,6 +534,7 @@ def process_image(image,seed):
inpaint_height = inpaint_height,
inpaint_width = inpaint_width,
enable_image_debugging = enable_image_debugging,
free_gpu_mem=self.free_gpu_mem,
)

if init_color:
Expand Down
2 changes: 2 additions & 0 deletions ldm/invoke/ckpt_generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
attention_maps_callback = None,
free_gpu_mem: bool=False,
**kwargs):
scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image(
Expand Down
2 changes: 2 additions & 0 deletions ldm/invoke/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ def set_variation(self, seed, variation_amount, with_variations):
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
free_gpu_mem: bool=False,
**kwargs):
scope = nullcontext
self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image(
Expand Down