diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index 05ec7e1a8d07..df87b3dfd326 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -15,8 +15,10 @@ # Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py # Run specific config: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc" # Run multiple configs: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc or 4Proc" -# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_dense_forward_train -# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc -k "forward" +# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallelDense2Proc::test_model_dense_forward_train +# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallelDense2Proc -k "forward" +# Run MoE tests only: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "Moe" +# Run dense tests only: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "TestTensorParallelDense2Proc or TestTensorParallelDense4Proc" import os import tempfile import warnings @@ -381,7 +383,7 @@ def _test_model_dense_save_impl(rank, tmp_dir): model.save_pretrained(result_dir) -class TestTensorParallelBase(TestCasePlus): +class TestTensorParallelDenseBase(TestCasePlus): """Base class for tensor parallel tests. Subclasses must set nproc_per_node.""" nproc_per_node = None @@ -466,13 +468,281 @@ def test_model_dense_save(self): del non_tp_tensor, tp_tensor -class TestTensorParallel2Proc(TestTensorParallelBase): - """Test tensor parallel with 2 processes.""" +class TestTensorParallelDense2Proc(TestTensorParallelDenseBase): + """Test tensor parallel dense model with 2 processes.""" nproc_per_node = 2 -class TestTensorParallel4Proc(TestTensorParallelBase): - """Test tensor parallel with 4 processes.""" +class TestTensorParallelDense4Proc(TestTensorParallelDenseBase): + """Test tensor parallel dense model with 4 processes.""" + + nproc_per_node = 4 + + +# ====== MOE MODEL TEST FUNCTIONS ====== +def _test_model_moe_forward_impl(rank, mode): + """Implementation for comparing TP and non-TP MoE model outputs.""" + model_id = "hf-internal-testing/tiny-qwen3-moe" + + # Ensure same random seed for reproducibility + torch.manual_seed(0) + + # Load tokenizer and prepare inputs - same for both models + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + prompt = "Can I help" + inputs = tokenizer(prompt, return_tensors="pt") + + # Load TP model first to determine device + model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto") + dist.barrier() + if mode == "eval": + model_tp.eval() + else: + model_tp.train() + + # Load non-TP model and move to same device as TP model + device = model_tp.device + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") + model = model.to(device) + + if mode == "eval": + model.eval() + else: + model.train() + + # Prepare inputs on the same device + input_ids = inputs.input_ids.to(device) + + # Run forward pass on both models + with torch.no_grad(): + # Non-TP model output + outputs = model(input_ids) + logits = outputs.logits + + # TP model output + outputs_tp = model_tp(input_ids) + logits_tp = outputs_tp.logits + + # Compare outputs - they should match + assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), ( + f"TP and non-TP MoE model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}" + ) + + dist.barrier() + + +def _test_model_moe_backward_pass_impl(rank): + """Implementation for comparing TP and non-TP MoE model backward passes.""" + model_id = "hf-internal-testing/tiny-qwen3-moe" + + torch.manual_seed(0) + + model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto") + dist.barrier() + model_tp.train() + + device = model_tp.device + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") + model = model.to(device) + model.train() + + batch_size, seq_length = 2, 10 + torch.manual_seed(42) # Different seed for inputs to ensure they're deterministic + input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device) + labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device) + + outputs = model(input_ids, labels=labels) + loss = outputs.loss + loss.backward() + + outputs_tp = model_tp(input_ids, labels=labels) + loss_tp = outputs_tp.loss + loss_tp.backward() + + assert torch.allclose(loss, loss_tp, atol=1e-5, rtol=1e-5), ( + f"TP and non-TP MoE model losses differ. Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}" + ) + + # Compare gradients for matching parameters + for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()): + if param.grad is not None and param_tp.grad is not None: + grad = param.grad + grad_tp = param_tp.grad + + if isinstance(param_tp.data, dist.tensor.DTensor): + placement = param_tp.data.placements[0] + if hasattr(placement, "dim") and placement.dim is not None: + grad_shard = get_tensor_shard(grad, grad, param_tp.data.device_mesh, rank, placement.dim) + else: + grad_shard = grad + else: + grad_shard = grad + + grad_tp_local = grad_tp.to_local() if isinstance(grad_tp, dist.tensor.DTensor) else grad_tp + + assert torch.allclose(grad_shard.cpu(), grad_tp_local.cpu(), atol=1e-5, rtol=1e-5), ( + f"Gradients differ for parameter {name}. Max diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().max().item()} | Min diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().min().item()}" + ) + + dist.barrier() + + +def _test_model_moe_forward_compile_impl(rank, mode): + """Implementation for comparing TP and non-TP MoE model outputs with torch.compile.""" + model_id = "hf-internal-testing/tiny-qwen3-moe" + + torch.manual_seed(0) + + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + prompt = "Can I help" + inputs = tokenizer(prompt, return_tensors="pt") + + model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto") + dist.barrier() + if mode == "eval": + model_tp.eval() + else: + model_tp.train() + + device = model_tp.device + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") + model = model.to(device) + + if mode == "eval": + model.eval() + else: + model.train() + + # Compile both models + model.forward = torch.compile(model.forward) + model_tp.forward = torch.compile(model_tp.forward) + + input_ids = inputs.input_ids.to(device) + + with torch.no_grad(): + outputs = model(input_ids) + logits = outputs.logits + + outputs_tp = model_tp(input_ids) + logits_tp = outputs_tp.logits + + assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), ( + f"TP and non-TP MoE model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}" + ) + + dist.barrier() + + +def _test_model_moe_save_impl(rank, tmp_dir): + """Implementation of test_model_save for MoE model distributed execution.""" + model_id = "hf-internal-testing/tiny-qwen3-moe" + + if dist.is_initialized(): + kwargs = {"tp_plan": "auto"} + result_dir = f"{tmp_dir}/tp" + else: + kwargs = {} + result_dir = f"{tmp_dir}/nontp" + + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", **kwargs) + model.save_pretrained(result_dir) + + +class TestTensorParallelMoeBase(TestCasePlus): + """Base class for MoE tensor parallel tests. Subclasses must set nproc_per_node.""" + + nproc_per_node = None + + @require_torch_multi_accelerator + def test_model_moe_forward_eval(self): + """Test that TP and non-TP MoE models produce the same outputs in eval mode.""" + if self.nproc_per_node is None: + self.skipTest("nproc_per_node not set") + if backend_device_count(torch_device) < self.nproc_per_node: + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") + + init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_impl)("eval") + + @require_torch_multi_accelerator + def test_model_moe_forward_train(self): + """Test that TP and non-TP MoE models produce the same outputs in train mode.""" + if self.nproc_per_node is None: + self.skipTest("nproc_per_node not set") + if backend_device_count(torch_device) < self.nproc_per_node: + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") + + init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_impl)("train") + + @require_torch_multi_accelerator + def test_model_moe_backward_pass(self): + """Test that TP and non-TP MoE models produce the same gradients.""" + if self.nproc_per_node is None: + self.skipTest("nproc_per_node not set") + if backend_device_count(torch_device) < self.nproc_per_node: + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") + + init_distributed(tp=self.nproc_per_node)(_test_model_moe_backward_pass_impl)() + + @require_torch_multi_accelerator + def test_model_moe_forward_compile_eval(self): + """Test that TP and non-TP MoE models produce the same outputs with torch.compile in eval mode.""" + if self.nproc_per_node is None: + self.skipTest("nproc_per_node not set") + if backend_device_count(torch_device) < self.nproc_per_node: + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") + + init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_compile_impl)("eval") + + @require_torch_multi_accelerator + def test_model_moe_forward_compile_train(self): + """Test that TP and non-TP MoE models produce the same outputs with torch.compile in train mode.""" + if self.nproc_per_node is None: + self.skipTest("nproc_per_node not set") + if backend_device_count(torch_device) < self.nproc_per_node: + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") + + init_distributed(tp=self.nproc_per_node)(_test_model_moe_forward_compile_impl)("train") + + @require_huggingface_hub_greater_or_equal("0.31.4") + @require_torch_multi_accelerator + def test_model_moe_save(self): + """Test that TP MoE model can be saved and matches non-TP version.""" + if self.nproc_per_node is None: + self.skipTest("nproc_per_node not set") + if backend_device_count(torch_device) < self.nproc_per_node: + self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") + + with tempfile.TemporaryDirectory() as tmp_dir: + # First run with TP (distributed) + init_distributed(tp=self.nproc_per_node)(_test_model_moe_save_impl)(tmp_dir) + + # Then run without TP (non-distributed) + _test_model_moe_save_impl(0, tmp_dir) + + non_tp_model_path = os.path.join(tmp_dir, "nontp") + tp_model_path = os.path.join(tmp_dir, "tp") + + for filename in os.listdir(non_tp_model_path): + if not filename.endswith(".safetensors"): + continue + + non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt") + tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt") + for non_tp_key in non_tp_model.keys(): + non_tp_tensor = non_tp_model.get_tensor(non_tp_key) + tp_tensor = tp_model.get_tensor(non_tp_key) + assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match" + del non_tp_tensor, tp_tensor + + +class TestTensorParallelMoe2Proc(TestTensorParallelMoeBase): + """Test MoE tensor parallel with 2 processes.""" + + nproc_per_node = 2 + + +class TestTensorParallelMoe4Proc(TestTensorParallelMoeBase): + """Test MoE tensor parallel with 4 processes.""" nproc_per_node = 4