Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 277 additions & 7 deletions tests/tensor_parallel/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py
# Run specific config: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc"
# Run multiple configs: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc or 4Proc"
# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_dense_forward_train
# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc -k "forward"
# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallelDense2Proc::test_model_dense_forward_train
# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallelDense2Proc -k "forward"
# Run MoE tests only: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "Moe"
# Run dense tests only: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "TestTensorParallelDense2Proc or TestTensorParallelDense4Proc"
import os
import tempfile
import warnings
Expand Down Expand Up @@ -381,7 +383,7 @@ def _test_model_dense_save_impl(rank, tmp_dir):
model.save_pretrained(result_dir)


class TestTensorParallelBase(TestCasePlus):
class TestTensorParallelDenseBase(TestCasePlus):
"""Base class for tensor parallel tests. Subclasses must set nproc_per_node."""

nproc_per_node = None
Expand Down Expand Up @@ -466,13 +468,281 @@ def test_model_dense_save(self):
del non_tp_tensor, tp_tensor


class TestTensorParallel2Proc(TestTensorParallelBase):
"""Test tensor parallel with 2 processes."""
class TestTensorParallelDense2Proc(TestTensorParallelDenseBase):
"""Test tensor parallel dense model with 2 processes."""

nproc_per_node = 2


class TestTensorParallel4Proc(TestTensorParallelBase):
"""Test tensor parallel with 4 processes."""
class TestTensorParallelDense4Proc(TestTensorParallelDenseBase):
"""Test tensor parallel dense model with 4 processes."""

nproc_per_node = 4


# ====== MOE MODEL TEST FUNCTIONS ======
def _test_model_moe_forward_impl(rank, mode):
"""Implementation for comparing TP and non-TP MoE model outputs."""
model_id = "hf-internal-testing/tiny-qwen3-moe"

# Ensure same random seed for reproducibility
torch.manual_seed(0)

# Load tokenizer and prepare inputs - same for both models
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt")

# Load TP model first to determine device
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
dist.barrier()
if mode == "eval":
model_tp.eval()
else:
model_tp.train()

# Load non-TP model and move to same device as TP model
device = model_tp.device
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
model = model.to(device)

if mode == "eval":
model.eval()
else:
model.train()

# Prepare inputs on the same device
input_ids = inputs.input_ids.to(device)

# Run forward pass on both models
with torch.no_grad():
# Non-TP model output
outputs = model(input_ids)
logits = outputs.logits

# TP model output
outputs_tp = model_tp(input_ids)
logits_tp = outputs_tp.logits

# Compare outputs - they should match
assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
f"TP and non-TP MoE model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
)

dist.barrier()


def _test_model_moe_backward_pass_impl(rank):
"""Implementation for comparing TP and non-TP MoE model backward passes."""
model_id = "hf-internal-testing/tiny-qwen3-moe"

torch.manual_seed(0)

model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
dist.barrier()
model_tp.train()

device = model_tp.device
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
model = model.to(device)
model.train()

batch_size, seq_length = 2, 10
torch.manual_seed(42) # Different seed for inputs to ensure they're deterministic
input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)
labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)

outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()

outputs_tp = model_tp(input_ids, labels=labels)
loss_tp = outputs_tp.loss
loss_tp.backward()

assert torch.allclose(loss, loss_tp, atol=1e-5, rtol=1e-5), (
f"TP and non-TP MoE model losses differ. Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}"
)

# Compare gradients for matching parameters
for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()):
if param.grad is not None and param_tp.grad is not None:
grad = param.grad
grad_tp = param_tp.grad

if isinstance(param_tp.data, dist.tensor.DTensor):
placement = param_tp.data.placements[0]
if hasattr(placement, "dim") and placement.dim is not None:
grad_shard = get_tensor_shard(grad, grad, param_tp.data.device_mesh, rank, placement.dim)
else:
grad_shard = grad
else:
grad_shard = grad

grad_tp_local = grad_tp.to_local() if isinstance(grad_tp, dist.tensor.DTensor) else grad_tp

assert torch.allclose(grad_shard.cpu(), grad_tp_local.cpu(), atol=1e-5, rtol=1e-5), (
f"Gradients differ for parameter {name}. Max diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().max().item()} | Min diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().min().item()}"
)

dist.barrier()


def _test_model_moe_forward_compile_impl(rank, mode):
"""Implementation for comparing TP and non-TP MoE model outputs with torch.compile."""
model_id = "hf-internal-testing/tiny-qwen3-moe"

torch.manual_seed(0)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt")

model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
dist.barrier()
if mode == "eval":
model_tp.eval()
else:
model_tp.train()

device = model_tp.device
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
model = model.to(device)

if mode == "eval":
model.eval()
else:
model.train()

# Compile both models
model.forward = torch.compile(model.forward)
model_tp.forward = torch.compile(model_tp.forward)

input_ids = inputs.input_ids.to(device)

with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits

outputs_tp = model_tp(input_ids)
logits_tp = outputs_tp.logits

assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
f"TP and non-TP MoE model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
)

dist.barrier()


def _test_model_moe_save_impl(rank, tmp_dir):
"""Implementation of test_model_save for MoE model distributed execution."""
model_id = "hf-internal-testing/tiny-qwen3-moe"

if dist.is_initialized():
kwargs = {"tp_plan": "auto"}
result_dir = f"{tmp_dir}/tp"
else:
kwargs = {}
result_dir = f"{tmp_dir}/nontp"

model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", **kwargs)
model.save_pretrained(result_dir)


class TestTensorParallelMoeBase(TestCasePlus):
"""Base class for MoE tensor parallel tests. Subclasses must set nproc_per_node."""

nproc_per_node = None

@require_torch_multi_accelerator
def test_model_moe_forward_eval(self):
"""Test that TP and non-TP MoE models produce the same outputs in eval mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_impl)("eval")

@require_torch_multi_accelerator
def test_model_moe_forward_train(self):
"""Test that TP and non-TP MoE models produce the same outputs in train mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_impl)("train")

@require_torch_multi_accelerator
def test_model_moe_backward_pass(self):
"""Test that TP and non-TP MoE models produce the same gradients."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_moe_backward_pass_impl)()

@require_torch_multi_accelerator
def test_model_moe_forward_compile_eval(self):
"""Test that TP and non-TP MoE models produce the same outputs with torch.compile in eval mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_compile_impl)("eval")

@require_torch_multi_accelerator
def test_model_moe_forward_compile_train(self):
"""Test that TP and non-TP MoE models produce the same outputs with torch.compile in train mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_compile_impl)("train")

@require_huggingface_hub_greater_or_equal("0.31.4")
@require_torch_multi_accelerator
def test_model_moe_save(self):
"""Test that TP MoE model can be saved and matches non-TP version."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")

with tempfile.TemporaryDirectory() as tmp_dir:
# First run with TP (distributed)
init_distributed(tp=self.nproc_per_node)(_test_model_moe_save_impl)(tmp_dir)

# Then run without TP (non-distributed)
_test_model_moe_save_impl(0, tmp_dir)

non_tp_model_path = os.path.join(tmp_dir, "nontp")
tp_model_path = os.path.join(tmp_dir, "tp")

for filename in os.listdir(non_tp_model_path):
if not filename.endswith(".safetensors"):
continue

non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt")
tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt")
for non_tp_key in non_tp_model.keys():
non_tp_tensor = non_tp_model.get_tensor(non_tp_key)
tp_tensor = tp_model.get_tensor(non_tp_key)
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
del non_tp_tensor, tp_tensor


class TestTensorParallelMoe2Proc(TestTensorParallelMoeBase):
"""Test MoE tensor parallel with 2 processes."""

nproc_per_node = 2


class TestTensorParallelMoe4Proc(TestTensorParallelMoeBase):
"""Test MoE tensor parallel with 4 processes."""

nproc_per_node = 4