diff --git a/training/tensor_parallel/README.md b/training/tensor_parallel/README.md index 9db8ac53a..1149a8ce3 100644 --- a/training/tensor_parallel/README.md +++ b/training/tensor_parallel/README.md @@ -1,9 +1,15 @@ -# tensor parallel example -This project is adapted from https://github.com/tatsu-lab/stanford_alpaca. -We only modified the ds_config to enable tensor parallelism and more detailed logging, as an example use case. +# AutoTP Training Examples -**Script** - -``` bash run.sh ``` or ```bash run.sh MODE``` +This folder groups AutoTP training examples at different complexity levels. +## Contents +- [Basic example](basic_example): minimal AutoTP + ZeRO-2 example with synthetic tokens. It also shows that AutoTP recognizes typical parameter patterns and automatically applies proper partitioning. +- [HuggingFace integration](hf_integration): Hugging Face Trainer example (adapted from Stanford Alpaca). +- [Custom partitioning patterns](custom_patterns): AutoTP example with custom layer patterns and a simple + text dataset that uses a DP-rank random sampler. It shows how to define + parameter partitioning easily for custom models with non-standard parameter + definitions. +## Related references +- [AutoTP training docs](https://deepspeed.readthedocs.io/en/latest/training.html) +- [AutoTP training tutorial](https://github.com/deepspeedai/DeepSpeed/blob/master/docs/_tutorials/autotp-training.md) diff --git a/training/tensor_parallel/basic_example/README.md b/training/tensor_parallel/basic_example/README.md new file mode 100644 index 000000000..33bdbf360 --- /dev/null +++ b/training/tensor_parallel/basic_example/README.md @@ -0,0 +1,78 @@ +# AutoTP training (Tensor Parallel) + +This directory documents the AutoTP training API for tensor-parallel sharding +during training. AutoTP recognizes typical parameter patterns and +automatically applies proper partitioning. + +## Overview + +This example provides a compact AutoTP + ZeRO-2 training script, +`autotp_example.py`. It focuses on the AutoTP + ZeRO-2 flow and keeps only the +pieces required to launch AutoTP: + +- create TP/DP process groups +- configure AutoTP with `tensor_parallel.autotp_size` +- initialize DeepSpeed with the AutoTP config + +The example feeds synthetic token batches (broadcast within each TP group) so +you can validate the AutoTP setup without extra dataset plumbing. + +AutoTP recognizes supported model architectures (for example, Llama) and +automatically partitions parameters, so you do not need to specify any manual +partitioning rules for those models. If your model is not supported by AutoTP, +refer to the +[custom layer pattern guide](../custom_patterns/) +for custom layer pattern configuration. + +## Key code (AutoTP path) +The core setup mirrors the verification script but is trimmed down: + +```python +model = AutoModelForCausalLM.from_pretrained(args.model_name) + +ds_config = { + "train_batch_size": args.batch_size * args.dp_size, + "train_micro_batch_size_per_gpu": args.batch_size, + "zero_optimization": {"stage": args.zero_stage}, + "tensor_parallel": {"autotp_size": args.tp_size}, + "data_parallel_size": args.dp_size, +} + +mpu = ModelParallelUnit(tp_group, dp_group, args.tp_size, args.dp_size, tp_rank, dp_rank) +engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config, mpu=mpu) +``` + +## How to run +Pick a world size where `tp_size * dp_size = world_size`. + +```bash +# 8 GPUs: TP=4, DP=2 (AutoTP + ZeRO-2) +deepspeed --num_gpus 8 autotp_example.py \ + --model_name meta-llama/Llama-3.1-8B \ + --tp_size 4 \ + --dp_size 2 \ + --zero_stage 2 \ + --batch_size 1 \ + --seq_length 1024 \ + --num_steps 10 +``` + +`torchrun` works as well if you prefer the PyTorch launcher. + +For a smaller test, reduce the world size and TP/DP sizes together: + +```bash +deepspeed --num_gpus 2 autotp_example.py \ + --model_name meta-llama/Llama-3.1-8B \ + --tp_size 2 \ + --dp_size 1 \ + --num_steps 5 +``` + +## Backward Compatibility + +Historically, AutoTP training required calling `set_autotp_mode(training=True)` +and `deepspeed.tp_model_init(...)` before initialization. The traditional path +is preserved for reference in +[`autotp_memory_compare.py`](autotp_memory_compare.py) (see the `--mode traditional` +branch), alongside the config-driven path in the same script. diff --git a/training/tensor_parallel/basic_example/autotp_example.py b/training/tensor_parallel/basic_example/autotp_example.py new file mode 100644 index 000000000..885aeba89 --- /dev/null +++ b/training/tensor_parallel/basic_example/autotp_example.py @@ -0,0 +1,137 @@ +import argparse +from dataclasses import dataclass + +import torch +import torch.distributed as dist +import deepspeed +from transformers import AutoModelForCausalLM + + +@dataclass +class ModelParallelUnit: + """Minimal MPU for DeepSpeed TP+DP.""" + + tp_group: dist.ProcessGroup + dp_group: dist.ProcessGroup + tp_size: int + dp_size: int + tp_rank: int + dp_rank: int + + def get_data_parallel_group(self): + return self.dp_group + + def get_model_parallel_group(self): + return self.tp_group + + def get_data_parallel_world_size(self): + return self.dp_size + + def get_model_parallel_world_size(self): + return self.tp_size + + def get_data_parallel_rank(self): + return self.dp_rank + + def get_model_parallel_rank(self): + return self.tp_rank + + +def parse_args(): + parser = argparse.ArgumentParser(description="AutoTP training example (distilled from verify_autotp).") + parser.add_argument("--local_rank", type=int, default=-1, help="Passed by deepspeed/torchrun.") + parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B") + parser.add_argument("--tp_size", type=int, default=4) + parser.add_argument("--dp_size", type=int, default=2) + parser.add_argument("--zero_stage", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--seq_length", type=int, default=1024) + parser.add_argument("--num_steps", type=int, default=10) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + return parser.parse_args() + + +def build_tp_dp_groups(rank, world_size, tp_size, dp_size): + if tp_size * dp_size != world_size: + raise ValueError(f"tp_size ({tp_size}) * dp_size ({dp_size}) must equal world_size ({world_size})") + + tp_rank = rank % tp_size + dp_rank = rank // tp_size + + tp_group = None + dp_group = None + + for dp_idx in range(dp_size): + tp_ranks = list(range(dp_idx * tp_size, (dp_idx + 1) * tp_size)) + group = dist.new_group(tp_ranks) + if rank in tp_ranks: + tp_group = group + + for tp_idx in range(tp_size): + dp_ranks = [tp_idx + dp_idx * tp_size for dp_idx in range(dp_size)] + group = dist.new_group(dp_ranks) + if rank in dp_ranks: + dp_group = group + + return tp_group, dp_group, tp_rank, dp_rank + + +def broadcast_inputs(input_ids, labels, tp_group, tp_src_rank): + dist.broadcast(input_ids, src=tp_src_rank, group=tp_group) + dist.broadcast(labels, src=tp_src_rank, group=tp_group) + + +def main(): + args = parse_args() + deepspeed.init_distributed() + + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + + tp_group, dp_group, tp_rank, dp_rank = build_tp_dp_groups( + rank, world_size, args.tp_size, args.dp_size + ) + + model = AutoModelForCausalLM.from_pretrained(args.model_name) + model = model.to(device) + + # AutoTP is enabled via the DeepSpeed config. + ds_config = { + "train_batch_size": args.batch_size * args.dp_size, + "train_micro_batch_size_per_gpu": args.batch_size, + "zero_optimization": {"stage": args.zero_stage}, + "tensor_parallel": {"autotp_size": args.tp_size}, + "data_parallel_size": args.dp_size, + } + if args.precision == "bf16": + ds_config["bf16"] = {"enabled": True} + elif args.precision == "fp16": + ds_config["fp16"] = {"enabled": True} + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + mpu = ModelParallelUnit(tp_group, dp_group, args.tp_size, args.dp_size, tp_rank, dp_rank) + engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config, mpu=mpu) + + vocab_size = model.config.vocab_size + for _ in range(args.num_steps): + if tp_rank == 0: + input_ids = torch.randint(0, vocab_size, (args.batch_size, args.seq_length), device=device) + labels = input_ids.clone() + else: + input_ids = torch.empty((args.batch_size, args.seq_length), dtype=torch.long, device=device) + labels = torch.empty((args.batch_size, args.seq_length), dtype=torch.long, device=device) + + tp_src_rank = dp_rank * args.tp_size + broadcast_inputs(input_ids, labels, tp_group, tp_src_rank) + outputs = engine(input_ids=input_ids, labels=labels) + engine.backward(outputs.loss) + engine.step() + + if rank == 0: + print("AutoTP example completed.") + + +if __name__ == "__main__": + main() diff --git a/training/tensor_parallel/basic_example/autotp_memory_compare.py b/training/tensor_parallel/basic_example/autotp_memory_compare.py new file mode 100644 index 000000000..1de030314 --- /dev/null +++ b/training/tensor_parallel/basic_example/autotp_memory_compare.py @@ -0,0 +1,187 @@ +import argparse +from dataclasses import dataclass +import os + +import torch +import torch.distributed as dist +import deepspeed +from transformers import AutoModelForCausalLM + + +@dataclass +class ModelParallelUnit: + """Minimal MPU for DeepSpeed TP+DP.""" + + tp_group: dist.ProcessGroup + dp_group: dist.ProcessGroup + tp_size: int + dp_size: int + tp_rank: int + dp_rank: int + + def get_data_parallel_group(self): + return self.dp_group + + def get_model_parallel_group(self): + return self.tp_group + + def get_data_parallel_world_size(self): + return self.dp_size + + def get_model_parallel_world_size(self): + return self.tp_size + + def get_data_parallel_rank(self): + return self.dp_rank + + def get_model_parallel_rank(self): + return self.tp_rank + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compare AutoTP memory usage between init paths.") + parser.add_argument("--local_rank", type=int, default=-1, help="Passed by deepspeed/torchrun.") + parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B") + parser.add_argument("--tp_size", type=int, default=4) + parser.add_argument("--dp_size", type=int, default=2) + parser.add_argument("--zero_stage", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--seq_length", type=int, default=1024) + parser.add_argument("--num_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + parser.add_argument("--mode", + type=str, + default="config", + choices=["config", "traditional"], + help="config = config-driven path, traditional = call tp_model_init") + return parser.parse_args() + + +def build_tp_dp_groups(rank, world_size, tp_size, dp_size): + if tp_size * dp_size != world_size: + raise ValueError(f"tp_size ({tp_size}) * dp_size ({dp_size}) must equal world_size ({world_size})") + + tp_rank = rank % tp_size + dp_rank = rank // tp_size + + tp_group = None + dp_group = None + + for dp_idx in range(dp_size): + tp_ranks = list(range(dp_idx * tp_size, (dp_idx + 1) * tp_size)) + group = dist.new_group(tp_ranks) + if rank in tp_ranks: + tp_group = group + + for tp_idx in range(tp_size): + dp_ranks = [tp_idx + dp_idx * tp_size for dp_idx in range(dp_size)] + group = dist.new_group(dp_ranks) + if rank in dp_ranks: + dp_group = group + + return tp_group, dp_group, tp_rank, dp_rank + + +def broadcast_inputs(input_ids, labels, tp_group, tp_src_rank): + dist.broadcast(input_ids, src=tp_src_rank, group=tp_group) + dist.broadcast(labels, src=tp_src_rank, group=tp_group) + + +def get_precision_dtype(precision): + if precision == "bf16": + return torch.bfloat16 + if precision == "fp16": + return torch.float16 + return torch.float32 + + +def summarize(values): + return min(values), max(values), sum(values) / len(values) + + +def gather_and_print(tag, device, rank, world_size): + stats = torch.tensor( + [torch.cuda.max_memory_allocated(device), torch.cuda.max_memory_reserved(device)], + device=device, + ) + gathered = [torch.zeros_like(stats) for _ in range(world_size)] + dist.all_gather(gathered, stats) + + if rank == 0: + allocs = [t[0].item() / (1024**3) for t in gathered] + reservs = [t[1].item() / (1024**3) for t in gathered] + alloc_min, alloc_max, alloc_mean = summarize(allocs) + res_min, res_max, res_mean = summarize(reservs) + print(f"[MEM] {tag} alloc_gb min={alloc_min:.2f} max={alloc_max:.2f} mean={alloc_mean:.2f}") + print(f"[MEM] {tag} reserv_gb min={res_min:.2f} max={res_max:.2f} mean={res_mean:.2f}") + + +def main(): + args = parse_args() + deepspeed.init_distributed() + + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + + tp_group, dp_group, tp_rank, dp_rank = build_tp_dp_groups( + rank, world_size, args.tp_size, args.dp_size + ) + + dtype = get_precision_dtype(args.precision) + + torch.cuda.reset_peak_memory_stats(device) + model = AutoModelForCausalLM.from_pretrained(args.model_name, dtype=dtype) + model = model.to(device) + gather_and_print("after_model_load", device, rank, world_size) + + if args.mode == "traditional": + model = deepspeed.tp_model_init(model, tp_size=args.tp_size, dtype=dtype) + + ds_config = { + "train_batch_size": args.batch_size * args.dp_size, + "train_micro_batch_size_per_gpu": args.batch_size, + "gradient_accumulation_steps": 1, + "zero_optimization": {"stage": args.zero_stage}, + "tensor_parallel": {"autotp_size": args.tp_size}, + "data_parallel_size": args.dp_size, + } + if args.precision == "bf16": + ds_config["bf16"] = {"enabled": True} + elif args.precision == "fp16": + ds_config["fp16"] = {"enabled": True} + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + mpu = ModelParallelUnit(tp_group, dp_group, args.tp_size, args.dp_size, tp_rank, dp_rank) + + torch.cuda.reset_peak_memory_stats(device) + engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config, mpu=mpu) + gather_and_print("after_initialize", device, rank, world_size) + + vocab_size = model.config.vocab_size + torch.cuda.reset_peak_memory_stats(device) + for _ in range(args.num_steps): + if tp_rank == 0: + input_ids = torch.randint(0, vocab_size, (args.batch_size, args.seq_length), device=device) + labels = input_ids.clone() + else: + input_ids = torch.empty((args.batch_size, args.seq_length), dtype=torch.long, device=device) + labels = torch.empty((args.batch_size, args.seq_length), dtype=torch.long, device=device) + + tp_src_rank = dp_rank * args.tp_size + broadcast_inputs(input_ids, labels, tp_group, tp_src_rank) + outputs = engine(input_ids=input_ids, labels=labels) + engine.backward(outputs.loss) + engine.step() + + gather_and_print("after_train", device, rank, world_size) + + if rank == 0: + print(f"AutoTP memory compare completed for mode={args.mode}.") + + +if __name__ == "__main__": + main() diff --git a/training/tensor_parallel/custom_patterns/README.md b/training/tensor_parallel/custom_patterns/README.md new file mode 100644 index 000000000..8d8c9054d --- /dev/null +++ b/training/tensor_parallel/custom_patterns/README.md @@ -0,0 +1,66 @@ +# AutoTP (Tensor Parallel) Custom Patterns Example + +This example extends the minimal AutoTP script with: + +- custom layer sharding patterns (`partition_config`) +- a small text dataset and tokenizer +- a DP-rank random sampler so each DP rank sees different samples + +The TP ranks inside the same DP group share the same data order. +AutoTP is enabled by the DeepSpeed config (`tensor_parallel.autotp_size`), so +you do not need to call any initialization helpers before `deepspeed.initialize`. + +## Key code (custom patterns) + +The config below targets **Pythia 6.9B (GPT-NeoX)**, which uses a fused +`query_key_value` projection. We provide a `shape` so AutoTP can split the +fused Q/K/V tensor cleanly across tensor-parallel ranks. The MLP uses +`dense_h_to_4h` / `dense_4h_to_h`, so no extra shape is needed there. + +```python +ds_config = { + "zero_optimization": {"stage": 2}, + "tensor_parallel": { + "autotp_size": args.tp_size, + "partition_config": { + "use_default_specs": False, + "layer_specs": [ + { + "patterns": [".*(self_attention|attention)\\.query_key_value\\.weight$"], + "partition_type": "column", + "shape": ((q_size, kv_size, kv_size), -1), + "partition_dim": 0, + }, + { + "patterns": [".*(self_attention|attention)\\.dense\\.weight$"], + "partition_type": "row", + }, + { + "patterns": [".*mlp\\.dense_h_to_4h\\.weight$"], + "partition_type": "column", + }, + { + "patterns": [".*mlp\\.dense_4h_to_h\\.weight$"], + "partition_type": "row", + }, + ], + }, + }, + "data_parallel_size": args.dp_size, +} +``` + +## How to run +Pick a world size where `tp_size * dp_size = world_size`. + +```bash +deepspeed --num_gpus 8 autotp_custom_patterns.py \ + --model_name EleutherAI/pythia-6.9b \ + --tp_size 4 \ + --dp_size 2 \ + --seq_length 512 \ + --num_steps 20 +``` + +`torchrun` also works if you prefer the PyTorch launcher. + diff --git a/training/tensor_parallel/custom_patterns/autotp_custom_patterns.py b/training/tensor_parallel/custom_patterns/autotp_custom_patterns.py new file mode 100644 index 000000000..0ab2dea23 --- /dev/null +++ b/training/tensor_parallel/custom_patterns/autotp_custom_patterns.py @@ -0,0 +1,314 @@ +import argparse +import os +from dataclasses import dataclass +from typing import Iterable, List + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, Dataset, Sampler +import deepspeed +from transformers import AutoModelForCausalLM, AutoTokenizer + +IGNORE_INDEX = -100 + + +@dataclass +class ModelParallelUnit: + """Minimal MPU for DeepSpeed TP+DP.""" + + tp_group: dist.ProcessGroup + dp_group: dist.ProcessGroup + tp_size: int + dp_size: int + tp_rank: int + dp_rank: int + + def get_data_parallel_group(self): + return self.dp_group + + def get_model_parallel_group(self): + return self.tp_group + + def get_data_parallel_world_size(self): + return self.dp_size + + def get_model_parallel_world_size(self): + return self.tp_size + + def get_data_parallel_rank(self): + return self.dp_rank + + def get_model_parallel_rank(self): + return self.tp_rank + + +class ToyTextDataset(Dataset): + def __init__(self, tokenizer, seq_length: int): + texts = [ + "DeepSpeed makes distributed training faster.", + "AutoTP shards large layers across GPUs.", + "Tensor parallelism reduces per-GPU memory.", + "ZeRO optimizes optimizer state memory.", + "This is a small in-memory dataset.", + "We are testing AutoTP training.", + "Distributed training requires careful data sharding.", + "Sharded model weights reduce memory pressure.", + "This example uses a custom AutoTP config.", + "Random samplers ensure data diversity.", + ] + self.samples = [] + for text in texts: + tokenized = tokenizer( + text, + truncation=True, + max_length=seq_length, + add_special_tokens=True, + ) + input_ids = torch.tensor(tokenized["input_ids"], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + self.samples.append((input_ids, attention_mask)) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + input_ids, attention_mask = self.samples[idx] + labels = input_ids.clone() + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + + +class DPRandomSampler(Sampler[int]): + """Random sampler sharded by DP rank.""" + + def __init__(self, data_source: Dataset, dp_rank: int, dp_size: int, seed: int): + self.data_source = data_source + self.dp_rank = dp_rank + self.dp_size = dp_size + self.seed = seed + self.epoch = 0 + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def __iter__(self) -> Iterable[int]: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.data_source), generator=g).tolist() + return iter(indices[self.dp_rank :: self.dp_size]) + + def __len__(self) -> int: + return (len(self.data_source) + self.dp_size - 1) // self.dp_size + + +def collate_batch(samples: List[dict], pad_token_id: int) -> dict: + input_ids = [s["input_ids"] for s in samples] + labels = [s["labels"] for s in samples] + + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=pad_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) + attention_mask = input_ids.ne(pad_token_id) + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} + + +def parse_args(): + parser = argparse.ArgumentParser(description="AutoTP custom patterns example.") + parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-6.9b") + parser.add_argument("--tp_size", type=int, default=4) + parser.add_argument("--dp_size", type=int, default=2) + parser.add_argument("--zero_stage", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--seq_length", type=int, default=512) + parser.add_argument("--num_steps", type=int, default=20) + parser.add_argument("--learning_rate", type=float, default=2e-6) + parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Allow loading models with custom code from the Hub (auto-enabled for ChatGLM).", + ) + parser.add_argument( + "--local_rank", + type=int, + default=0, + help="Local rank passed by the launcher.", + ) + parser.add_argument("--seed", type=int, default=42) + return parser.parse_args() + + +def build_tp_dp_groups(rank: int, world_size: int, tp_size: int, dp_size: int): + if tp_size * dp_size != world_size: + raise ValueError(f"tp_size ({tp_size}) * dp_size ({dp_size}) must equal world_size ({world_size})") + + tp_rank = rank % tp_size + dp_rank = rank // tp_size + + tp_group = None + dp_group = None + + for dp_idx in range(dp_size): + tp_ranks = list(range(dp_idx * tp_size, (dp_idx + 1) * tp_size)) + group = dist.new_group(tp_ranks) + if rank in tp_ranks: + tp_group = group + + for tp_idx in range(tp_size): + dp_ranks = [tp_idx + dp_idx * tp_size for dp_idx in range(dp_size)] + group = dist.new_group(dp_ranks) + if rank in dp_ranks: + dp_group = group + + return tp_group, dp_group, tp_rank, dp_rank + + +def main(): + args = parse_args() + deepspeed.init_distributed() + + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank)) + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + + if args.precision == "bf16": + torch_dtype = torch.bfloat16 + elif args.precision == "fp16": + torch_dtype = torch.float16 + else: + torch_dtype = torch.float32 + + tp_group, dp_group, tp_rank, dp_rank = build_tp_dp_groups( + rank, world_size, args.tp_size, args.dp_size + ) + + trust_remote_code = args.trust_remote_code + try: + tokenizer = AutoTokenizer.from_pretrained( + args.model_name, use_fast=False, trust_remote_code=trust_remote_code + ) + except ValueError: + tokenizer = AutoTokenizer.from_pretrained( + args.model_name, use_fast=True, trust_remote_code=trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_name, + dtype=torch_dtype, + low_cpu_mem_usage=True, + trust_remote_code=trust_remote_code, + ) + model.config.pad_token_id = tokenizer.pad_token_id + model = model.to(device) + + num_heads = model.config.num_attention_heads + kv_heads = getattr(model.config, "num_kv_heads", None) + if kv_heads is None: + kv_heads = getattr(model.config, "num_key_value_heads", None) + if kv_heads is None: + kv_heads = num_heads + head_dim = getattr(model.config, "head_dim", None) + if head_dim is None: + head_dim = model.config.hidden_size // num_heads + uses_mqa = bool(getattr(model.config, "multi_query", False)) + if kv_heads % args.tp_size != 0: + uses_mqa = True + q_size = num_heads * head_dim + kv_size = kv_heads * head_dim + + if rank == 0 and uses_mqa: + print("Using row-parallel QKV for MQA (KV heads not shardable).") + + qkv_spec = { + "patterns": [".*(self_attention|attention)\\.query_key_value\\.weight$"], + "partition_type": "row" if uses_mqa else "column", + } + if not uses_mqa: + q_size = num_heads * head_dim + kv_size = kv_heads * head_dim + qkv_spec.update( + { + "shape": ((q_size, kv_size, kv_size), -1), + "partition_dim": 0, + } + ) + + # AutoTP is enabled via the DeepSpeed config. + ds_config = { + "train_batch_size": args.batch_size * args.dp_size, + "train_micro_batch_size_per_gpu": args.batch_size, + "zero_optimization": {"stage": args.zero_stage}, + "tensor_parallel": { + "autotp_size": args.tp_size, + "partition_config": { + "use_default_specs": False, + "layer_specs": [ + qkv_spec, + { + "patterns": [".*(self_attention|attention)\\.dense\\.weight$"], + "partition_type": "row", + }, + { + "patterns": [".*mlp\\.dense_h_to_4h\\.weight$"], + "partition_type": "column", + }, + { + "patterns": [".*mlp\\.dense_4h_to_h\\.weight$"], + "partition_type": "row", + }, + ], + }, + }, + "data_parallel_size": args.dp_size, + } + if args.precision == "bf16": + ds_config["bf16"] = {"enabled": True} + elif args.precision == "fp16": + ds_config["fp16"] = {"enabled": True} + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + mpu = ModelParallelUnit(tp_group, dp_group, args.tp_size, args.dp_size, tp_rank, dp_rank) + engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config, mpu=mpu) + + dataset = ToyTextDataset(tokenizer, args.seq_length) + sampler = DPRandomSampler(dataset, dp_rank=dp_rank, dp_size=args.dp_size, seed=args.seed) + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + collate_fn=lambda samples: collate_batch(samples, tokenizer.pad_token_id), + ) + + engine.train() + data_iter = iter(dataloader) + for step in range(args.num_steps): + try: + batch = next(data_iter) + except StopIteration: + sampler.set_epoch(step) + data_iter = iter(dataloader) + batch = next(data_iter) + + input_ids = batch["input_ids"].to(device, non_blocking=True) + attention_mask = batch["attention_mask"].to(device, non_blocking=True) + labels = batch["labels"].to(device, non_blocking=True) + + outputs = engine(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + engine.backward(outputs.loss) + engine.step() + + if rank == 0 and step % 5 == 0: + print(f"step={step} loss={outputs.loss.item():.4f}") + + if rank == 0: + print("AutoTP custom patterns example completed.") + + +if __name__ == "__main__": + main() diff --git a/training/tensor_parallel/hf_integration/README.md b/training/tensor_parallel/hf_integration/README.md new file mode 100644 index 000000000..c1b043929 --- /dev/null +++ b/training/tensor_parallel/hf_integration/README.md @@ -0,0 +1,9 @@ +# AutoTP (Tensor Parallel) HuggingFace Integration Example + +This project is adapted from https://github.com/tatsu-lab/stanford_alpaca. +It uses Hugging Face `Trainer` with a DeepSpeed config that enables AutoTP via `tensor_parallel.autotp_size`. +We only modified the DeepSpeed config and logging, as an example use case. + +**Script** + +``` bash run.sh ``` or ```bash run.sh MODE``` diff --git a/training/tensor_parallel/alpaca_data.json b/training/tensor_parallel/hf_integration/alpaca_data.json similarity index 100% rename from training/tensor_parallel/alpaca_data.json rename to training/tensor_parallel/hf_integration/alpaca_data.json diff --git a/training/tensor_parallel/configs/ds_config.json b/training/tensor_parallel/hf_integration/configs/ds_config.json similarity index 100% rename from training/tensor_parallel/configs/ds_config.json rename to training/tensor_parallel/hf_integration/configs/ds_config.json diff --git a/training/tensor_parallel/configs/ds_config_temp.json b/training/tensor_parallel/hf_integration/configs/ds_config_temp.json similarity index 100% rename from training/tensor_parallel/configs/ds_config_temp.json rename to training/tensor_parallel/hf_integration/configs/ds_config_temp.json diff --git a/training/tensor_parallel/requirements.txt b/training/tensor_parallel/hf_integration/requirements.txt similarity index 100% rename from training/tensor_parallel/requirements.txt rename to training/tensor_parallel/hf_integration/requirements.txt diff --git a/training/tensor_parallel/run.sh b/training/tensor_parallel/hf_integration/run.sh similarity index 100% rename from training/tensor_parallel/run.sh rename to training/tensor_parallel/hf_integration/run.sh diff --git a/training/tensor_parallel/train.py b/training/tensor_parallel/hf_integration/train.py similarity index 100% rename from training/tensor_parallel/train.py rename to training/tensor_parallel/hf_integration/train.py diff --git a/training/tensor_parallel/train_bench_length.py b/training/tensor_parallel/hf_integration/train_bench_length.py similarity index 100% rename from training/tensor_parallel/train_bench_length.py rename to training/tensor_parallel/hf_integration/train_bench_length.py diff --git a/training/tensor_parallel/utils.py b/training/tensor_parallel/hf_integration/utils.py similarity index 100% rename from training/tensor_parallel/utils.py rename to training/tensor_parallel/hf_integration/utils.py