|
|
|
@ -66,11 +66,11 @@ class GPTLMLoss(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cpu_mem():
|
|
|
|
|
return psutil.Process().memory_info().rss / 1024**2 # MB unit
|
|
|
|
|
return psutil.Process().memory_info().rss / 1024**2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gpu_mem():
|
|
|
|
|
return torch.cuda.memory_allocated() / 1024**2 # MB unit
|
|
|
|
|
return torch.cuda.memory_allocated() / 1024**2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_mem_info(prefix=""):
|
|
|
|
@ -78,7 +78,6 @@ def get_mem_info(prefix=""):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_size(model: nn.Module):
|
|
|
|
|
# get the number of parameter of the model
|
|
|
|
|
total_numel = 0
|
|
|
|
|
for module in model.modules():
|
|
|
|
|
for p in module.parameters(recurse=False):
|
|
|
|
@ -130,7 +129,7 @@ def main():
|
|
|
|
|
WARMUP_STEPS = 1
|
|
|
|
|
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
|
|
|
|
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
|
|
|
|
|
PROF_FLAG = True # The flag of profiling, False by default
|
|
|
|
|
PROF_FLAG = False # The flag of profiling, False by default
|
|
|
|
|
|
|
|
|
|
disable_existing_loggers()
|
|
|
|
|
colossalai.launch_from_torch()
|
|
|
|
@ -167,7 +166,7 @@ def main():
|
|
|
|
|
stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True
|
|
|
|
|
)
|
|
|
|
|
elif args.distplan == "CAI_Gemini":
|
|
|
|
|
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1)
|
|
|
|
|
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError
|
|
|
|
|
|
|
|
|
@ -249,7 +248,7 @@ def main():
|
|
|
|
|
prof.step()
|
|
|
|
|
|
|
|
|
|
tflops_list.sort()
|
|
|
|
|
median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1)
|
|
|
|
|
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
|
|
|
|
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|