ColossalAI/tests/test_shardformer/test_model/_utils.py

99 lines
4.2 KiB
Python

import copy
from contextlib import nullcontext
import torch
import torch.distributed as dist
from torch.nn import Module
from colossalai.lazy import LazyInitContext
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()
def build_pipeline_model(model_fn,
stage_manager=None,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
use_lazy_init: bool = False):
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# switch to train mode
original_model.train()
sharded_model.train()
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
org_loss = loss_fn(org_output)
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
org_sd = org_model.state_dict()
shard_sd = sharded_model.state_dict()
for k, v in org_sd.items():
assert k in shard_sd, f'{name} {k} not in sharded model'
shard_v = shard_sd[k]
assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False):
for suffix in layer_suffix:
org_grad = getattr_(original_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=dim)
else:
all_shard_grad = shard_grad
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}")
assert torch.allclose(
org_grad, all_shard_grad, rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}"