mirror of https://github.com/hpcaitech/ColossalAI
140 lines
4.5 KiB
Python
140 lines
4.5 KiB
Python
import argparse
|
|
import functools
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import tqdm
|
|
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from transformers.models.llama import LlamaConfig
|
|
from utils import PerformanceEvaluator, get_model_numel
|
|
|
|
from colossalai.moe.manager import MOE_MANAGER
|
|
|
|
|
|
class RandomDataset(Dataset):
|
|
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
|
self.num_samples = num_samples
|
|
self.max_length = max_length
|
|
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
|
|
self.attention_mask = torch.ones_like(self.input_ids)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
return {
|
|
"input_ids": self.input_ids[idx],
|
|
"attention_mask": self.attention_mask[idx],
|
|
"labels": self.input_ids[idx],
|
|
}
|
|
|
|
|
|
def fsdp_main(rank, world_size, args):
|
|
# initialize the process group
|
|
|
|
# initialize the process group
|
|
dist.init_process_group("nccl")
|
|
|
|
MOE_MANAGER.setup(parallel=None)
|
|
|
|
dp_size = dist.get_world_size()
|
|
dataset = RandomDataset(
|
|
max_length=args.seq_length,
|
|
num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
|
|
)
|
|
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
|
|
train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
|
|
train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
|
|
torch.cuda.set_device(rank)
|
|
|
|
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-%s" % args.model_name)
|
|
set_openmoe_args(
|
|
config,
|
|
num_experts=config.num_experts,
|
|
moe_layer_interval=config.moe_layer_interval,
|
|
enable_load_balance=False,
|
|
enable_kernel=False,
|
|
enable_comm_overlap=False,
|
|
)
|
|
torch.set_default_dtype(torch.float16)
|
|
model = OpenMoeForCausalLM(config)
|
|
torch.set_default_dtype(torch.float32)
|
|
auto_wrap_policy = functools.partial(
|
|
transformer_auto_wrap_policy,
|
|
transformer_layer_cls={
|
|
OpenMoeDecoderLayer,
|
|
},
|
|
)
|
|
model = FSDP(
|
|
model,
|
|
mixed_precision=MixedPrecision(
|
|
param_dtype=torch.bfloat16,
|
|
reduce_dtype=torch.bfloat16,
|
|
buffer_dtype=torch.bfloat16,
|
|
),
|
|
auto_wrap_policy=auto_wrap_policy,
|
|
device_id=torch.cuda.current_device(),
|
|
)
|
|
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
|
|
model.train()
|
|
|
|
model_numel = get_model_numel(model)
|
|
performance_evaluator = PerformanceEvaluator(
|
|
model_numel,
|
|
enable_grad_checkpoint=True,
|
|
ignore_steps=args.warmup,
|
|
dp_world_size=dist.get_world_size(),
|
|
)
|
|
|
|
for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
|
|
performance_evaluator.on_step_start(step)
|
|
input_ids, attention_mask, labels = (
|
|
data["input_ids"].cuda(),
|
|
data["attention_mask"].cuda(),
|
|
data["labels"].cuda(),
|
|
)
|
|
|
|
optimizer.zero_grad()
|
|
output = model(
|
|
input_ids=input_ids,
|
|
labels=labels,
|
|
attention_mask=attention_mask,
|
|
chunk_head=False,
|
|
)
|
|
loss = output["loss"]
|
|
loss.backward()
|
|
optimizer.step()
|
|
performance_evaluator.on_step_end(input_ids)
|
|
|
|
performance_evaluator.on_fit_end()
|
|
if dist.get_rank() == 0:
|
|
print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--model_name",
|
|
type=str,
|
|
default="base",
|
|
choices=["base", "8b"],
|
|
help="base or 8b",
|
|
)
|
|
parser.add_argument("--batch_size", type=int, default=1)
|
|
parser.add_argument("--seq_length", type=int, default=2048)
|
|
parser.add_argument("--warmup", type=int, default=20)
|
|
parser.add_argument("--active", type=int, default=20)
|
|
args = parser.parse_args()
|
|
|
|
torch.manual_seed(42)
|
|
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
fsdp_main(local_rank, world_size, args)
|