|
|
|
@ -1,11 +1,12 @@
|
|
|
|
|
import argparse |
|
|
|
|
import resource |
|
|
|
|
import time |
|
|
|
|
from contextlib import nullcontext |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
from data_utils import RandomDataset |
|
|
|
|
from model_utils import format_numel_str, get_model_numel |
|
|
|
|
from performance_evaluator import PerformanceEvaluator |
|
|
|
|
from performance_evaluator import PerformanceEvaluator, get_profile_context |
|
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision |
|
|
|
|
from tqdm import tqdm |
|
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
@ -76,6 +77,7 @@ def main():
|
|
|
|
|
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") |
|
|
|
|
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") |
|
|
|
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) |
|
|
|
|
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) |
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
colossalai.launch_from_torch() |
|
|
|
@ -110,6 +112,7 @@ def main():
|
|
|
|
|
extra_dp_size=args.extra_dp, |
|
|
|
|
enable_fused_normalization=torch.cuda.is_available(), |
|
|
|
|
enable_flash_attention=args.xformers, |
|
|
|
|
max_prefetch=10, |
|
|
|
|
) |
|
|
|
|
elif args.plugin == "gemini_auto": |
|
|
|
|
plugin = GeminiPlugin( |
|
|
|
@ -246,25 +249,37 @@ def main():
|
|
|
|
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: |
|
|
|
|
data_iter = iter(dataloader) |
|
|
|
|
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): |
|
|
|
|
performance_evaluator.on_step_start(step) |
|
|
|
|
booster.execute_pipeline( |
|
|
|
|
data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False |
|
|
|
|
) |
|
|
|
|
optimizer.step() |
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) |
|
|
|
|
else: |
|
|
|
|
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): |
|
|
|
|
performance_evaluator.on_step_start(step) |
|
|
|
|
outputs = model(**batch) |
|
|
|
|
loss = outputs[0] |
|
|
|
|
booster.backward(loss, optimizer) |
|
|
|
|
optimizer.step() |
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
performance_evaluator.on_step_end(**batch) |
|
|
|
|
with get_profile_context( |
|
|
|
|
args.profile, |
|
|
|
|
1, |
|
|
|
|
len(dataloader) - 1, |
|
|
|
|
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", |
|
|
|
|
) as prof: |
|
|
|
|
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: |
|
|
|
|
data_iter = iter(dataloader) |
|
|
|
|
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): |
|
|
|
|
performance_evaluator.on_step_start(step) |
|
|
|
|
booster.execute_pipeline( |
|
|
|
|
data_iter, |
|
|
|
|
model, |
|
|
|
|
criterion=lambda outputs, inputs: outputs[0], |
|
|
|
|
optimizer=optimizer, |
|
|
|
|
return_loss=False, |
|
|
|
|
) |
|
|
|
|
optimizer.step() |
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) |
|
|
|
|
prof.step() |
|
|
|
|
else: |
|
|
|
|
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): |
|
|
|
|
performance_evaluator.on_step_start(step) |
|
|
|
|
outputs = model(**batch) |
|
|
|
|
loss = outputs[0] |
|
|
|
|
booster.backward(loss, optimizer) |
|
|
|
|
optimizer.step() |
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
performance_evaluator.on_step_end(**batch) |
|
|
|
|
prof.step() |
|
|
|
|
|
|
|
|
|
performance_evaluator.on_fit_end() |
|
|
|
|
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") |
|
|
|
|