Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 28, 2025

📄 9% (0.09x) speedup for reconstruct_multicond_batch in modules/prompt_parser.py

⏱️ Runtime : 1.73 milliseconds 1.60 milliseconds (best of 224 runs)

📝 Explanation and details

The optimized code achieves an 8% speedup through two main improvements:

1. Smarter tensor padding in stack_conds:

  • Early exit optimization: Only performs padding logic when tensors actually have different shapes, using any(tc != token_count for tc in token_counts) check
  • Memory-efficient padding: Replaces tensor[-1:].repeat([rows_to_add, 1]) with tensor[-1:].expand(rows_to_add, -1) - expand creates a view without copying data, while repeat allocates new memory
  • Cleaner tensor handling: Creates a new list instead of modifying the original tensors in-place, avoiding potential memory fragmentation

2. Micro-optimizations in schedule lookup:

  • Reduced attribute access: Pre-stores schedules and n_schedules variables to avoid repeated attribute lookups in the tight loop
  • Explicit fallback: Adds proper fallback case when no schedule matches (though rare in practice)

The optimizations are particularly effective for:

  • Large batch scenarios (500+ items): 23% speedup due to reduced memory allocations
  • Tensor padding cases: 7-8% speedup when tensors need length normalization
  • Mixed tensor sizes: Up to 26% speedup when significant padding is required

The performance gains come from reducing memory allocations and copies, especially beneficial when processing large batches or when tensor shapes vary significantly, which are common in prompt conditioning workflows.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 33 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

# imports
import pytest  # used for our unit tests
import torch
from modules.prompt_parser import reconstruct_multicond_batch


# Supporting classes for the function
class ScheduleEntry:
    def __init__(self, cond, end_at_step):
        self.cond = cond
        self.end_at_step = end_at_step

class ComposablePrompt:
    def __init__(self, schedules, weight):
        self.schedules = schedules  # list of ScheduleEntry
        self.weight = weight

class MulticondLearnedConditioning:
    def __init__(self, batch):
        self.batch = batch  # list of list of ComposablePrompt

class DictWithShape(dict):
    def __init__(self, d, shape):
        super().__init__(d)
        self.shape = shape
from modules.prompt_parser import reconstruct_multicond_batch

# unit tests

# ----------- Basic Test Cases -----------

def test_single_batch_single_prompt_single_schedule():
    # Single batch, single prompt, single schedule
    cond = torch.ones((2, 4))
    schedules = [ScheduleEntry(cond, end_at_step=10)]
    prompt = ComposablePrompt(schedules, weight=1.0)
    c = MulticondLearnedConditioning(batch=[[prompt]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 15.5μs -> 14.7μs (5.54% faster)

def test_single_batch_multiple_prompts_single_schedule():
    # Single batch, multiple prompts, each with one schedule
    cond1 = torch.ones((2, 4))
    cond2 = torch.zeros((2, 4))
    schedules1 = [ScheduleEntry(cond1, end_at_step=10)]
    schedules2 = [ScheduleEntry(cond2, end_at_step=10)]
    prompt1 = ComposablePrompt(schedules1, weight=0.7)
    prompt2 = ComposablePrompt(schedules2, weight=0.3)
    c = MulticondLearnedConditioning(batch=[[prompt1, prompt2]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 16.7μs -> 15.8μs (5.76% faster)

def test_multiple_batches_single_prompt():
    # Multiple batches, each with one prompt
    cond1 = torch.ones((2, 4))
    cond2 = torch.full((2, 4), 2.0)
    schedules1 = [ScheduleEntry(cond1, end_at_step=10)]
    schedules2 = [ScheduleEntry(cond2, end_at_step=10)]
    prompt1 = ComposablePrompt(schedules1, weight=1.0)
    prompt2 = ComposablePrompt(schedules2, weight=1.0)
    c = MulticondLearnedConditioning(batch=[[prompt1], [prompt2]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 16.6μs -> 15.5μs (7.40% faster)

def test_multiple_batches_multiple_prompts():
    # Multiple batches, each with multiple prompts
    conds = [torch.ones((2, 4)), torch.zeros((2, 4)), torch.full((2, 4), 3.0), torch.full((2, 4), 4.0)]
    schedules = [[ScheduleEntry(conds[0], 10)], [ScheduleEntry(conds[1], 10)],
                 [ScheduleEntry(conds[2], 10)], [ScheduleEntry(conds[3], 10)]]
    prompts = [ComposablePrompt(schedules[i], weight=0.5+i) for i in range(4)]
    c = MulticondLearnedConditioning(batch=[[prompts[0], prompts[1]], [prompts[2], prompts[3]]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 17.6μs -> 16.6μs (6.32% faster)
    for i in range(4):
        pass

def test_schedule_selection_by_step():
    # Each prompt has multiple schedules, select correct one by current_step
    conds = [torch.ones((2, 4)), torch.zeros((2, 4))]
    schedules = [ScheduleEntry(conds[0], 5), ScheduleEntry(conds[1], 10)]
    prompt = ComposablePrompt(schedules, weight=1.0)
    c = MulticondLearnedConditioning(batch=[[prompt]])
    # At step 3, should pick first schedule
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=3) # 16.1μs -> 15.5μs (3.91% faster)
    # At step 7, should pick second schedule
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=7) # 5.25μs -> 5.52μs (4.91% slower)

# ----------- Edge Test Cases -----------

def test_empty_batch():
    # Edge: Empty batch
    c = MulticondLearnedConditioning(batch=[])
    with pytest.raises(IndexError):
        reconstruct_multicond_batch(c, current_step=0) # 1.08μs -> 1.07μs (1.69% faster)

def test_empty_prompts_in_batch():
    # Edge: Batch with empty prompt list
    c = MulticondLearnedConditioning(batch=[[]])
    with pytest.raises(IndexError):
        reconstruct_multicond_batch(c, current_step=0) # 1.07μs -> 1.11μs (3.24% slower)

def test_empty_schedules_in_prompt():
    # Edge: Prompt with empty schedules
    prompt = ComposablePrompt(schedules=[], weight=1.0)
    c = MulticondLearnedConditioning(batch=[[prompt]])
    with pytest.raises(IndexError):
        reconstruct_multicond_batch(c, current_step=0) # 1.17μs -> 1.10μs (6.18% faster)

def test_schedule_step_beyond_all_end_at_steps():
    # Edge: current_step greater than all end_at_step values
    conds = [torch.ones((2, 4)), torch.zeros((2, 4))]
    schedules = [ScheduleEntry(conds[0], 5), ScheduleEntry(conds[1], 8)]
    prompt = ComposablePrompt(schedules, weight=1.0)
    c = MulticondLearnedConditioning(batch=[[prompt]])
    # Should pick last schedule
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=100) # 21.3μs -> 21.1μs (0.975% faster)

def test_tensor_padding():
    # Edge: tensors of different sequence lengths should be padded
    cond1 = torch.ones((2, 3))
    cond2 = torch.zeros((4, 3))
    schedules1 = [ScheduleEntry(cond1, 10)]
    schedules2 = [ScheduleEntry(cond2, 10)]
    prompt1 = ComposablePrompt(schedules1, weight=1.0)
    prompt2 = ComposablePrompt(schedules2, weight=1.0)
    c = MulticondLearnedConditioning(batch=[[prompt1, prompt2]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=0) # 51.8μs -> 48.2μs (7.47% faster)

def test_dict_cond_input():
    # Edge: cond is a dict with multiple keys
    cond_dict1 = {'crossattn': torch.ones((2, 4)), 'other': torch.full((2, 2), 5.0)}
    cond_dict2 = {'crossattn': torch.zeros((2, 4)), 'other': torch.full((2, 2), 7.0)}
    schedules1 = [ScheduleEntry(cond_dict1, 10)]
    schedules2 = [ScheduleEntry(cond_dict2, 10)]
    prompt1 = ComposablePrompt(schedules1, weight=1.0)
    prompt2 = ComposablePrompt(schedules2, weight=1.0)
    c = MulticondLearnedConditioning(batch=[[prompt1, prompt2]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=0) # 20.7μs -> 20.5μs (1.29% faster)

def test_tensor_device_and_dtype():
    # Edge: tensors should preserve device and dtype
    cond = torch.ones((2, 4), dtype=torch.float64)
    schedules = [ScheduleEntry(cond, end_at_step=10)]
    prompt = ComposablePrompt(schedules, weight=1.0)
    c = MulticondLearnedConditioning(batch=[[prompt]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=0) # 14.9μs -> 14.1μs (5.64% faster)

# ----------- Large Scale Test Cases -----------

def test_large_batch():
    # Large batch size, single prompt per batch
    batch_size = 500
    cond = torch.ones((2, 4))
    schedules = [ScheduleEntry(cond, end_at_step=10)]
    prompts = [ComposablePrompt(schedules, weight=1.0) for _ in range(batch_size)]
    c = MulticondLearnedConditioning(batch=[[p] for p in prompts])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=0) # 268μs -> 217μs (23.2% faster)
    for i in range(batch_size):
        pass

def test_large_prompt_count():
    # Single batch, large number of prompts
    prompt_count = 500
    conds = [torch.full((2, 4), float(i)) for i in range(prompt_count)]
    schedules_list = [[ScheduleEntry(conds[i], 10)] for i in range(prompt_count)]
    prompts = [ComposablePrompt(schedules_list[i], weight=1.0) for i in range(prompt_count)]
    c = MulticondLearnedConditioning(batch=[prompts])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=0) # 241μs -> 195μs (23.3% faster)
    for i in range(prompt_count):
        pass

def test_large_tensor_size():
    # Large tensor size, but <100MB
    # Each float32 element is 4 bytes, so 100MB/4 = 25,000,000 elements
    # Let's use (100, 2500) = 250,000 elements per tensor, 10 tensors = 2,500,000 elements = 10MB
    conds = [torch.full((100, 2500), float(i)) for i in range(10)]
    schedules_list = [[ScheduleEntry(conds[i], 10)] for i in range(10)]
    prompts = [ComposablePrompt(schedules_list[i], weight=1.0) for i in range(10)]
    c = MulticondLearnedConditioning(batch=[prompts])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=0) # 612μs -> 612μs (0.119% faster)
    for i in range(10):
        pass

def test_large_dict_cond():
    # Large batch with dict conds
    batch_size = 50
    conds = [{'crossattn': torch.ones((10, 20)), 'other': torch.full((10, 5), float(i))}
             for i in range(batch_size)]
    schedules_list = [[ScheduleEntry(conds[i], 10)] for i in range(batch_size)]
    prompts = [ComposablePrompt(schedules_list[i], weight=1.0) for i in range(batch_size)]
    c = MulticondLearnedConditioning(batch=[prompts])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=0) # 65.6μs -> 58.2μs (12.8% faster)
    for i in range(batch_size):
        pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest  # used for our unit tests
import torch
from modules.prompt_parser import reconstruct_multicond_batch

# Mocks and helpers for the function's dependencies

class ScheduleEntry:
    """
    Mock for composable_prompt.schedules[x]
    """
    def __init__(self, cond, end_at_step):
        self.cond = cond
        self.end_at_step = end_at_step

class ComposablePrompt:
    """
    Mock for composable_prompt
    """
    def __init__(self, schedules, weight):
        self.schedules = schedules
        self.weight = weight

class MulticondLearnedConditioning:
    """
    Mock for c argument
    """
    def __init__(self, batch):
        self.batch = batch

class DictWithShape(dict):
    """
    Mock for DictWithShape used in reconstruct_multicond_batch
    """
    def __init__(self, d, shape):
        super().__init__(d)
        self.shape = shape
from modules.prompt_parser import reconstruct_multicond_batch

# ---- UNIT TESTS ----

# Basic Test Cases

def test_single_batch_single_prompt_single_schedule():
    # Test with one batch, one prompt, one schedule
    cond = torch.ones(2, 4)
    schedules = [ScheduleEntry(cond, end_at_step=10)]
    weight = 1.0
    prompt = ComposablePrompt(schedules, weight)
    c = MulticondLearnedConditioning([[prompt]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 21.1μs -> 20.2μs (4.45% faster)

def test_multiple_batch_multiple_prompts():
    # Two batches, each with two prompts, each with one schedule
    cond1 = torch.ones(2, 4)
    cond2 = torch.zeros(2, 4)
    cond3 = torch.full((2, 4), 2.0)
    cond4 = torch.full((2, 4), 3.0)
    p1 = ComposablePrompt([ScheduleEntry(cond1, 10)], 1.0)
    p2 = ComposablePrompt([ScheduleEntry(cond2, 10)], 2.0)
    p3 = ComposablePrompt([ScheduleEntry(cond3, 10)], 3.0)
    p4 = ComposablePrompt([ScheduleEntry(cond4, 10)], 4.0)
    c = MulticondLearnedConditioning([[p1, p2], [p3, p4]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 19.0μs -> 18.0μs (5.39% faster)

def test_schedule_selection_by_step():
    # Test that the correct schedule is selected by current_step
    cond_a = torch.ones(2, 4)
    cond_b = torch.full((2, 4), 5.0)
    schedules = [
        ScheduleEntry(cond_a, end_at_step=2),
        ScheduleEntry(cond_b, end_at_step=10)
    ]
    prompt = ComposablePrompt(schedules, 1.0)
    c = MulticondLearnedConditioning([[prompt]])
    # Step 1 should select cond_a
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=1) # 16.2μs -> 15.4μs (5.13% faster)
    # Step 5 should select cond_b
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 5.33μs -> 5.51μs (3.34% slower)

def test_stack_conds_padding():
    # Test that stack_conds pads shorter tensors
    cond_short = torch.ones(1, 4)
    cond_long = torch.zeros(3, 4)
    schedules1 = [ScheduleEntry(cond_short, 10)]
    schedules2 = [ScheduleEntry(cond_long, 10)]
    p1 = ComposablePrompt(schedules1, 1.0)
    p2 = ComposablePrompt(schedules2, 2.0)
    c = MulticondLearnedConditioning([[p1, p2]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 48.4μs -> 44.5μs (8.77% faster)

def test_dict_cond_handling():
    # Test that dict conds are stacked correctly
    cond1 = {'crossattn': torch.ones(2, 4), 'other': torch.zeros(2, 4)}
    cond2 = {'crossattn': torch.full((2, 4), 2.0), 'other': torch.full((2, 4), 3.0)}
    schedules1 = [ScheduleEntry(cond1, 10)]
    schedules2 = [ScheduleEntry(cond2, 10)]
    p1 = ComposablePrompt(schedules1, 1.0)
    p2 = ComposablePrompt(schedules2, 2.0)
    c = MulticondLearnedConditioning([[p1, p2]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 19.6μs -> 19.8μs (1.14% slower)

# Edge Test Cases

def test_empty_batch():
    # Test with empty batch
    c = MulticondLearnedConditioning([])
    with pytest.raises(IndexError):
        reconstruct_multicond_batch(c, current_step=0) # 1.09μs -> 1.10μs (1.27% slower)

def test_empty_prompts_in_batch():
    # Test with batch containing empty prompts list
    c = MulticondLearnedConditioning([[]])
    with pytest.raises(IndexError):
        reconstruct_multicond_batch(c, current_step=0) # 1.05μs -> 1.05μs (0.285% slower)

def test_empty_schedules_in_prompt():
    # Test with prompt containing empty schedules
    p = ComposablePrompt([], 1.0)
    c = MulticondLearnedConditioning([[p]])
    with pytest.raises(IndexError):
        reconstruct_multicond_batch(c, current_step=0) # 1.10μs -> 1.06μs (3.78% faster)

def test_schedule_end_at_step_edge():
    # Test when current_step is exactly at end_at_step
    cond_a = torch.ones(2, 4)
    cond_b = torch.full((2, 4), 5.0)
    schedules = [
        ScheduleEntry(cond_a, end_at_step=2),
        ScheduleEntry(cond_b, end_at_step=10)
    ]
    prompt = ComposablePrompt(schedules, 1.0)
    c = MulticondLearnedConditioning([[prompt]])
    # Step 2 should select cond_a
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=2) # 21.4μs -> 20.8μs (2.85% faster)
    # Step 3 should select cond_b
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=3) # 5.73μs -> 5.66μs (1.31% faster)


def test_non_tensor_cond():
    # Test with cond that is not a tensor or dict
    schedules = [ScheduleEntry("not_a_tensor", 10)]
    p = ComposablePrompt(schedules, 1.0)
    c = MulticondLearnedConditioning([[p]])
    with pytest.raises(AttributeError):
        reconstruct_multicond_batch(c, current_step=0) # 4.05μs -> 3.95μs (2.48% faster)

def test_large_token_count_padding():
    # Test padding with a large token count difference
    cond_short = torch.ones(1, 4)
    cond_long = torch.zeros(100, 4)
    schedules1 = [ScheduleEntry(cond_short, 10)]
    schedules2 = [ScheduleEntry(cond_long, 10)]
    p1 = ComposablePrompt(schedules1, 1.0)
    p2 = ComposablePrompt(schedules2, 2.0)
    c = MulticondLearnedConditioning([[p1, p2]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 58.5μs -> 54.3μs (7.85% faster)

# Large Scale Test Cases


def test_large_token_and_feature_dim():
    # Test with large token count and feature dim
    batch_size = 2
    prompt_per_batch = 2
    token_count = 500
    feature_dim = 50
    conds = [torch.full((token_count, feature_dim), float(i)) for i in range(batch_size * prompt_per_batch)]
    prompts = [
        [ComposablePrompt([ScheduleEntry(conds[batch_size * i + j], 100)], float(j))
         for j in range(prompt_per_batch)]
        for i in range(batch_size)
    ]
    c = MulticondLearnedConditioning(prompts)
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=50) # 40.8μs -> 40.0μs (1.87% faster)
    # Check a few values for correctness
    for idx in [0, batch_size * prompt_per_batch - 1]:
        pass

def test_large_dict_cond():
    # Test with large dict conds
    batch_size = 10
    token_count = 100
    feature_dim = 8
    conds = [
        {'crossattn': torch.full((token_count, feature_dim), float(i)),
         'other': torch.full((token_count, feature_dim), float(i+1))}
        for i in range(batch_size)
    ]
    prompts = [[ComposablePrompt([ScheduleEntry(conds[i], 100)], 1.0)] for i in range(batch_size)]
    c = MulticondLearnedConditioning(prompts)
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=50) # 34.2μs -> 32.6μs (4.70% faster)
    for idx in [0, batch_size - 1]:
        pass

def test_padding_and_dtype_device_preserved():
    # Test that dtype and device are preserved in output
    cond1 = torch.ones(2, 4, dtype=torch.float32, device='cpu')
    cond2 = torch.zeros(3, 4, dtype=torch.float32, device='cpu')
    schedules1 = [ScheduleEntry(cond1, 10)]
    schedules2 = [ScheduleEntry(cond2, 10)]
    p1 = ComposablePrompt(schedules1, 1.0)
    p2 = ComposablePrompt(schedules2, 2.0)
    c = MulticondLearnedConditioning([[p1, p2]])
    conds_list, stacked = reconstruct_multicond_batch(c, current_step=5) # 48.0μs -> 37.9μs (26.6% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-reconstruct_multicond_batch-mh9zb8d9 and push.

Codeflash

The optimized code achieves an 8% speedup through two main improvements:

**1. Smarter tensor padding in `stack_conds`:**
- **Early exit optimization**: Only performs padding logic when tensors actually have different shapes, using `any(tc != token_count for tc in token_counts)` check
- **Memory-efficient padding**: Replaces `tensor[-1:].repeat([rows_to_add, 1])` with `tensor[-1:].expand(rows_to_add, -1)` - `expand` creates a view without copying data, while `repeat` allocates new memory
- **Cleaner tensor handling**: Creates a new list instead of modifying the original `tensors` in-place, avoiding potential memory fragmentation

**2. Micro-optimizations in schedule lookup:**
- **Reduced attribute access**: Pre-stores `schedules` and `n_schedules` variables to avoid repeated attribute lookups in the tight loop
- **Explicit fallback**: Adds proper fallback case when no schedule matches (though rare in practice)

The optimizations are particularly effective for:
- **Large batch scenarios** (500+ items): 23% speedup due to reduced memory allocations
- **Tensor padding cases**: 7-8% speedup when tensors need length normalization  
- **Mixed tensor sizes**: Up to 26% speedup when significant padding is required

The performance gains come from reducing memory allocations and copies, especially beneficial when processing large batches or when tensor shapes vary significantly, which are common in prompt conditioning workflows.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 28, 2025 03:00
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Oct 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant