mirror of https://github.com/hpcaitech/ColossalAI
[example] update gpt example (#2225)
parent
49c601da21
commit
7010e18134
|
@ -14,13 +14,13 @@ class GPTLMModel(nn.Module):
|
||||||
checkpoint=False):
|
checkpoint=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.model = GPT2LMHeadModel(
|
self.config = GPT2Config(n_embd=hidden_size,
|
||||||
GPT2Config(n_embd=hidden_size,
|
|
||||||
n_layer=num_layers,
|
n_layer=num_layers,
|
||||||
n_head=num_attention_heads,
|
n_head=num_attention_heads,
|
||||||
n_positions=max_seq_len,
|
n_positions=max_seq_len,
|
||||||
n_ctx=max_seq_len,
|
n_ctx=max_seq_len,
|
||||||
vocab_size=vocab_size))
|
vocab_size=vocab_size)
|
||||||
|
self.model = GPT2LMHeadModel(self.config)
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
self.model.gradient_checkpointing_enable()
|
self.model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
|
|
@ -120,6 +120,20 @@ def get_model_size(model: nn.Module):
|
||||||
return total_numel
|
return total_numel
|
||||||
|
|
||||||
|
|
||||||
|
def model_size_formatter(numel: int) -> str:
|
||||||
|
GB_SIZE = 10**9
|
||||||
|
MB_SIZE = 10**6
|
||||||
|
KB_SIZE = 10**3
|
||||||
|
if numel >= GB_SIZE:
|
||||||
|
return f'{numel / GB_SIZE:.1f}B'
|
||||||
|
elif numel >= MB_SIZE:
|
||||||
|
return f'{numel / MB_SIZE:.1f}M'
|
||||||
|
elif numel >= KB_SIZE:
|
||||||
|
return f'{numel / KB_SIZE:.1f}K'
|
||||||
|
else:
|
||||||
|
return str(numel)
|
||||||
|
|
||||||
|
|
||||||
def set_cpu_maximum_parallelism():
|
def set_cpu_maximum_parallelism():
|
||||||
conf_str = torch.__config__.parallel_info()
|
conf_str = torch.__config__.parallel_info()
|
||||||
inter_str = conf_str.split("hardware_concurrency() : ")[1]
|
inter_str = conf_str.split("hardware_concurrency() : ")[1]
|
||||||
|
@ -174,7 +188,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
placement_policy=placememt_policy,
|
placement_policy=placememt_policy,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
hidden_dim=8192,
|
hidden_dim=model.config.n_embd,
|
||||||
search_range_mb=64)
|
search_range_mb=64)
|
||||||
if placememt_policy == 'const':
|
if placememt_policy == 'const':
|
||||||
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
|
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
|
||||||
|
@ -261,6 +275,7 @@ def main():
|
||||||
|
|
||||||
# model is shared after TP
|
# model is shared after TP
|
||||||
numel = get_model_size(model)
|
numel = get_model_size(model)
|
||||||
|
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
|
||||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||||
|
|
||||||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||||
|
|
Loading…
Reference in New Issue