mirror of https://github.com/hpcaitech/ColossalAI
162 lines
5.7 KiB
Python
162 lines
5.7 KiB
Python
|
from functools import partial
|
||
|
from time import time
|
||
|
|
||
|
import psutil
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from packaging import version
|
||
|
|
||
|
import colossalai
|
||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||
|
from colossalai.nn.optimizer import HybridAdam
|
||
|
from colossalai.nn.parallel import ZeroDDP
|
||
|
from colossalai.tensor import ProcessGroup
|
||
|
from colossalai.utils import get_current_device
|
||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||
|
from colossalai.zero import ZeroOptimizer
|
||
|
from transformers import GPT2Config, GPT2LMHeadModel
|
||
|
|
||
|
|
||
|
class GPTLMModel(nn.Module):
|
||
|
|
||
|
def __init__(self,
|
||
|
hidden_size=768,
|
||
|
num_layers=12,
|
||
|
num_attention_heads=12,
|
||
|
max_seq_len=1024,
|
||
|
vocab_size=50257,
|
||
|
checkpoint=False):
|
||
|
super().__init__()
|
||
|
self.checkpoint = checkpoint
|
||
|
self.model = GPT2LMHeadModel(
|
||
|
GPT2Config(n_embd=hidden_size,
|
||
|
n_layer=num_layers,
|
||
|
n_head=num_attention_heads,
|
||
|
n_positions=max_seq_len,
|
||
|
n_ctx=max_seq_len,
|
||
|
vocab_size=vocab_size))
|
||
|
if checkpoint:
|
||
|
self.model.gradient_checkpointing_enable()
|
||
|
|
||
|
def forward(self, input_ids, attention_mask):
|
||
|
# Only return lm_logits
|
||
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||
|
|
||
|
|
||
|
class GPTLMLoss(nn.Module):
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.loss_fn = nn.CrossEntropyLoss()
|
||
|
|
||
|
def forward(self, logits, labels):
|
||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||
|
shift_labels = labels[..., 1:].contiguous()
|
||
|
# Flatten the tokens
|
||
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||
|
|
||
|
|
||
|
def get_data(batch_size, seq_len, vocab_size):
|
||
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||
|
attention_mask = torch.ones_like(input_ids)
|
||
|
return input_ids, attention_mask
|
||
|
|
||
|
|
||
|
def gpt2_medium(checkpoint=False):
|
||
|
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||
|
|
||
|
|
||
|
def gpt2_xl(checkpoint=True):
|
||
|
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||
|
|
||
|
|
||
|
def gpt2_10b(checkpoint=True):
|
||
|
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||
|
|
||
|
|
||
|
def get_cpu_mem():
|
||
|
return psutil.Process().memory_info().rss / 1024**2
|
||
|
|
||
|
|
||
|
def get_gpu_mem():
|
||
|
return torch.cuda.memory_allocated() / 1024**2
|
||
|
|
||
|
|
||
|
def get_mem_info(prefix=''):
|
||
|
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||
|
|
||
|
|
||
|
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
BATCH_SIZE = 8
|
||
|
SEQ_LEN = 1024
|
||
|
VOCAB_SIZE = 50257
|
||
|
NUM_STEPS = 10
|
||
|
PLACEMENT_POLICY = 'auto'
|
||
|
disable_existing_loggers()
|
||
|
colossalai.launch_from_torch(config={})
|
||
|
pg = ProcessGroup()
|
||
|
logger = get_dist_logger()
|
||
|
|
||
|
logger.info(get_mem_info(), ranks=[0])
|
||
|
# build GPT model
|
||
|
with ColoInitContext(device=get_current_device()):
|
||
|
model = gpt2_medium(checkpoint=True)
|
||
|
numel = sum([p.numel() for p in model.parameters()])
|
||
|
logger.info(f'Model numel: {numel}', ranks=[0])
|
||
|
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||
|
|
||
|
cai_version = colossalai.__version__
|
||
|
logger.info(f'using Colossal-AI version {cai_version}')
|
||
|
if version.parse(cai_version) > version.parse("0.1.10"):
|
||
|
from colossalai.nn.parallel import GeminiDDP
|
||
|
model = GeminiDDP(model,
|
||
|
device=get_current_device(),
|
||
|
placement_policy=PLACEMENT_POLICY,
|
||
|
pin_memory=True,
|
||
|
search_range_mb=32)
|
||
|
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||
|
from colossalai.gemini import ChunkManager, GeminiManager
|
||
|
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||
|
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||
|
chunk_manager = ChunkManager(chunk_size,
|
||
|
pg,
|
||
|
enable_distributed_storage=True,
|
||
|
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
||
|
model = ZeroDDP(model, gemini_manager)
|
||
|
|
||
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||
|
|
||
|
# build criterion
|
||
|
criterion = GPTLMLoss()
|
||
|
|
||
|
# optimizer
|
||
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||
|
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||
|
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||
|
|
||
|
model.train()
|
||
|
for n in range(NUM_STEPS):
|
||
|
# we just use randomly generated data here
|
||
|
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||
|
optimizer.zero_grad()
|
||
|
start = time()
|
||
|
outputs = model(input_ids, attn_mask)
|
||
|
loss = criterion(outputs, input_ids)
|
||
|
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||
|
optimizer.backward(loss)
|
||
|
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||
|
optimizer.step()
|
||
|
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||
|
step_time = time() - start
|
||
|
logger.info(
|
||
|
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
|
||
|
ranks=[0])
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|