Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ Note: This repository is currently under heavy development.

torchtrain contains PyTorch native parallelisms, tools and utilities to train large models.

## Design Principles

TorchTrain is a native PyTorch library with various training techniques. While it utilizes the PyTorch ecosystem for things like data loading (i.e. HuggingFace datasets), the core functionality is written in PyTorch.

* Designed to be easy to understand, use and extend for different training purposes.
* Minimal changes to the model code, when applying 1D/2D or 3D Parallelisms.
* Modular components instead of monolithic codebase

# Installation

Install PyTorch from source or install the latest pytorch nightly, then install requirements by
Expand Down
1 change: 0 additions & 1 deletion torchtrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from torchtrain.datasets.alpaca import build_alpaca_data_loader
from torchtrain.datasets.pad_batch_sequence import pad_batch_to_longest_seq
from torchtrain.datasets.tokenizer import create_tokenizer

__all__ = ["build_alpaca_data_loader", "create_tokenizer", "pad_batch_to_longest_seq"]
Expand Down
77 changes: 0 additions & 77 deletions torchtrain/datasets/pad_batch_sequence.py

This file was deleted.

6 changes: 0 additions & 6 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward.w3": ColwiseParallel(),
}
# if layer_id == 0:
# # in first transformer block we need to shard the input
# layer_plan[""] = PrepareModuleInput(
# input_layouts=(Replicate(), None),
# desired_input_layouts=(Shard(0), None),
# )

# adjust num_heads in attention layer to local heads
attn_layer = transformer_block.attention
Expand Down