Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions lit_nlp/examples/dalle_mini/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
class DallePrompts(lit_dataset.Dataset):

def __init__(self, prompts: list[str]):
self.examples = []
self._examples = []
for prompt in prompts:
self.examples.append({"prompt": prompt})
self._examples.append({"prompt": prompt})

def spec(self) -> lit_types.Spec:
return {"prompt": lit_types.TextSegment()}

def __iter__(self):
return iter(self.examples)
return iter(self._examples)

@property
def examples(self):
return self._examples
34 changes: 31 additions & 3 deletions lit_nlp/examples/dalle_mini/demo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
r"""Example for dalle-mini demo model.

First run following command to install required packages:
pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt

To run locally with a small number of examples:
python -m lit_nlp.examples.dalle_mini.demo


By default, this module uses the "cuda" device for image generation.
The `requirements.txt` file installs a CUDA-enabled version of PyTorch for GPU acceleration.

If you are running on a machine without a compatible GPU or CUDA drivers,
you must switch the device to "cpu" and reinstall the CPU-only version of PyTorch.

Usage:
- Default: device="cuda"
- On CPU-only machines:
1. Set device="cpu" during model initialization
2. Uninstall the CUDA version of PyTorch:
pip uninstall torch
3. Install the CPU-only version:
pip install torch==2.1.2+cpu --extra-index-url https://download.pytorch.org/whl/cpu

Example:
>>> model = MinDalle(..., device="cpu")

Check CUDA availability:
>>> import torch
>>> torch.cuda.is_available()
False # if no GPU support is present

Error Handling:
- If CUDA is selected but unsupported, you will see:
AssertionError: Torch not compiled with CUDA enabled
- To fix this, either install the correct CUDA-enabled PyTorch or switch to CPU mode.

Then navigate to localhost:5432 to access the demo UI.
"""

Expand All @@ -26,8 +56,6 @@
_FLAGS.set_default("development_demo", True)
_FLAGS.set_default("default_layout", "DALLE_LAYOUT")

_FLAGS.DEFINE_integer("grid_size", 4, "The grid size to use for the model.")

_MODELS = (["dalle-mini"],)

_CANNED_PROMPTS = ["I have a dream", "I have a shiba dog named cola"]
Expand Down
9 changes: 2 additions & 7 deletions lit_nlp/examples/dalle_mini/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,8 @@ def tensor_to_pil_image(tensor):
return images

def input_spec(self):
return {
"grid_size": lit_types.Scalar(),
"temperature": lit_types.Scalar(),
"top_k": lit_types.Scalar(),
"supercondition_factor": lit_types.Scalar(),
}

return {"prompt": lit_types.TextSegment()}

def output_spec(self):
return {
"image": lit_types.ImageBytesList(),
Expand Down
2 changes: 2 additions & 0 deletions lit_nlp/examples/dalle_mini/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@

# Dalle-Mini dependencies
min_dalle==0.4.11
torch==2.1.2+cu118
--extra-index-url https://download.pytorch.org/whl/cu118