[example] update gpt example (#2225)

pull/2238/head
HELSON 2022-12-29 12:01:45 +08:00 committed by GitHub
parent 49c601da21
commit 7010e18134
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 8 deletions

View File

@ -14,13 +14,13 @@ class GPTLMModel(nn.Module):
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
self.config = GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size)
self.model = GPT2LMHeadModel(self.config)
if checkpoint:
self.model.gradient_checkpointing_enable()

View File

@ -120,6 +120,20 @@ def get_model_size(model: nn.Module):
return total_numel
def model_size_formatter(numel: int) -> str:
GB_SIZE = 10**9
MB_SIZE = 10**6
KB_SIZE = 10**3
if numel >= GB_SIZE:
return f'{numel / GB_SIZE:.1f}B'
elif numel >= MB_SIZE:
return f'{numel / MB_SIZE:.1f}M'
elif numel >= KB_SIZE:
return f'{numel / KB_SIZE:.1f}K'
else:
return str(numel)
def set_cpu_maximum_parallelism():
conf_str = torch.__config__.parallel_info()
inter_str = conf_str.split("hardware_concurrency() : ")[1]
@ -174,7 +188,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
hidden_dim=8192,
hidden_dim=model.config.n_embd,
search_range_mb=64)
if placememt_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
@ -261,6 +275,7 @@ def main():
# model is shared after TP
numel = get_model_size(model)
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu