mirror of https://github.com/hpcaitech/ColossalAI
[example] update gpt example (#2225)
parent
49c601da21
commit
7010e18134
|
@ -14,13 +14,13 @@ class GPTLMModel(nn.Module):
|
|||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
self.config = GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
vocab_size=vocab_size)
|
||||
self.model = GPT2LMHeadModel(self.config)
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
|
|
|
@ -120,6 +120,20 @@ def get_model_size(model: nn.Module):
|
|||
return total_numel
|
||||
|
||||
|
||||
def model_size_formatter(numel: int) -> str:
|
||||
GB_SIZE = 10**9
|
||||
MB_SIZE = 10**6
|
||||
KB_SIZE = 10**3
|
||||
if numel >= GB_SIZE:
|
||||
return f'{numel / GB_SIZE:.1f}B'
|
||||
elif numel >= MB_SIZE:
|
||||
return f'{numel / MB_SIZE:.1f}M'
|
||||
elif numel >= KB_SIZE:
|
||||
return f'{numel / KB_SIZE:.1f}K'
|
||||
else:
|
||||
return str(numel)
|
||||
|
||||
|
||||
def set_cpu_maximum_parallelism():
|
||||
conf_str = torch.__config__.parallel_info()
|
||||
inter_str = conf_str.split("hardware_concurrency() : ")[1]
|
||||
|
@ -174,7 +188,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
|||
device=get_current_device(),
|
||||
placement_policy=placememt_policy,
|
||||
pin_memory=True,
|
||||
hidden_dim=8192,
|
||||
hidden_dim=model.config.n_embd,
|
||||
search_range_mb=64)
|
||||
if placememt_policy == 'const':
|
||||
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
|
||||
|
@ -261,6 +275,7 @@ def main():
|
|||
|
||||
# model is shared after TP
|
||||
numel = get_model_size(model)
|
||||
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
|
||||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||
|
|
Loading…
Reference in New Issue