|
10 | 10 | "v1.4": "CompVis/stable-diffusion-v1-4", |
11 | 11 | } |
12 | 12 |
|
| 13 | +# clip has 2 variants of max length 77 or 64. |
| 14 | +model_clip_max_length = 64 if args.max_length == 64 else 77 |
| 15 | + |
13 | 16 | model_input = { |
14 | 17 | "v2.1": { |
15 | | - "clip": (torch.randint(1, 2, (2, 77)),), |
| 18 | + "clip": (torch.randint(1, 2, (2, model_clip_max_length)),), |
16 | 19 | "vae": (torch.randn(1, 4, 96, 96),), |
17 | 20 | "unet": ( |
18 | 21 | torch.randn(1, 4, 96, 96), # latents |
19 | 22 | torch.tensor([1]).to(torch.float32), # timestep |
20 | | - torch.randn(2, 77, 1024), # embedding |
| 23 | + torch.randn(2, model_clip_max_length, 1024), # embedding |
21 | 24 | torch.tensor(1).to(torch.float32), # guidance_scale |
22 | 25 | ), |
23 | 26 | }, |
24 | 27 | "v2.1base": { |
25 | | - "clip": (torch.randint(1, 2, (2, 77)),), |
| 28 | + "clip": (torch.randint(1, 2, (2, model_clip_max_length)),), |
26 | 29 | "vae": (torch.randn(1, 4, 64, 64),), |
27 | 30 | "unet": ( |
28 | 31 | torch.randn(1, 4, 64, 64), # latents |
29 | 32 | torch.tensor([1]).to(torch.float32), # timestep |
30 | | - torch.randn(2, 77, 1024), # embedding |
| 33 | + torch.randn(2, model_clip_max_length, 1024), # embedding |
31 | 34 | torch.tensor(1).to(torch.float32), # guidance_scale |
32 | 35 | ), |
33 | 36 | }, |
34 | 37 | "v1.4": { |
35 | | - "clip": (torch.randint(1, 2, (2, 77)),), |
| 38 | + "clip": (torch.randint(1, 2, (2, model_clip_max_length)),), |
36 | 39 | "vae": (torch.randn(1, 4, 64, 64),), |
37 | 40 | "unet": ( |
38 | 41 | torch.randn(1, 4, 64, 64), |
39 | 42 | torch.tensor([1]).to(torch.float32), # timestep |
40 | | - torch.randn(2, 77, 768), |
| 43 | + torch.randn(2, model_clip_max_length, 768), |
41 | 44 | torch.tensor(1).to(torch.float32), |
42 | 45 | ), |
43 | 46 | }, |
|
0 commit comments