mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] support fp32 for HybridPlugin/merge shardformer test and pipeline test into one file (#4354)
* add naive optimizer for 3DPlugin/refactor gpt2 shardformer test * merge tests of PP/DP/TP combinations into one test file * fix bug when sync grad for dp in HybridPlugin * update supported precisions for 3DPlugin/fix bug when shifting tp_degree * improve the passing of lazy_init * modify lazy_init/use sync_shared_paramspull/4445/head
parent
f13954cd58
commit
0ceec8f9a9
|
@ -134,7 +134,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||
working_param = self.master_to_working_map[p]
|
||||
if p is working_param:
|
||||
continue
|
||||
if working_param.grad is None:
|
||||
if working_param.grad is not None:
|
||||
p.grad = working_param.grad.data.float()
|
||||
working_param.grad = None
|
||||
total_norm = self._compute_grad_norm()
|
||||
|
|
|
@ -42,6 +42,8 @@ class HybridParallelModule(ModelWrapper):
|
|||
module = module.half().cuda()
|
||||
elif precision == 'bf16':
|
||||
module = module.to(dtype=torch.bfloat16).cuda()
|
||||
else:
|
||||
module = module.cuda() # train without AMP
|
||||
# TODO(ver217): support TP+DP
|
||||
super().__init__(module)
|
||||
|
||||
|
@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper):
|
|||
for p in self.module.parameters():
|
||||
if p.grad is not None:
|
||||
dist.all_reduce(p.grad, group=self.dp_group)
|
||||
p.grad.div_(self.dp_group.size())
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
|
@ -72,7 +75,15 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
|||
optim.__setstate__({'param_groups': new_param_groups})
|
||||
|
||||
|
||||
class HybridParallelOptimizer(MixedPrecisionOptimizer):
|
||||
class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(optim)
|
||||
|
||||
|
||||
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
|
@ -192,7 +203,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return ['cuda']
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16', 'bf16']
|
||||
return ['fp16', 'bf16', 'fp32']
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
|
@ -218,12 +229,17 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
optimizer = HybridParallelOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config)
|
||||
if self.precision in ['fp16', 'bf16']:
|
||||
optimizer = HybridParallelAMPOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism)
|
||||
else:
|
||||
optimizer = HybridParallelZeroOptimizer(optimizer,
|
||||
model,
|
||||
|
@ -241,7 +257,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
data_iter: Iterator,
|
||||
model: HybridParallelModule,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer],
|
||||
optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
|
||||
HybridParallelZeroOptimizer],
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
|
||||
|
@ -250,7 +267,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
with ctx:
|
||||
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
|
||||
return_outputs)
|
||||
# model.sync_shared_params()
|
||||
model.sync_shared_params()
|
||||
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
||||
optimizer.sync_grad()
|
||||
else:
|
||||
|
|
|
@ -456,12 +456,12 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[0], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_.shape, self.weight.shape, self.weight.shape[0])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
|
||||
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
|
|
|
@ -16,6 +16,11 @@ from .sharding_spec import ShardingSpec
|
|||
layout_converter = LayoutConverter()
|
||||
|
||||
|
||||
def clear_layout_converter():
|
||||
global layout_converter
|
||||
layout_converter.cached_solution.clear()
|
||||
|
||||
|
||||
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check whether the given tensor is a distributed tensor.
|
||||
|
|
|
@ -70,7 +70,8 @@ config = transformers.GPT2Config(n_layer=2,
|
|||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification")
|
||||
problem_type="single_label_classification",
|
||||
pad_token_id=50256)
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_gpt',
|
||||
|
|
|
@ -160,7 +160,6 @@ def check_llama(rank, world_size, port):
|
|||
run_llama_test()
|
||||
|
||||
|
||||
@pytest.mark.skip('This test will fail')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
|
|
@ -1,85 +1,180 @@
|
|||
import copy
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
from colossalai.tensor.d_tensor.api import (
|
||||
clear_layout_converter,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
|
||||
# do backward
|
||||
use_lazy_init = False
|
||||
if 'use_lazy_init' in test_config:
|
||||
use_lazy_init = test_config.pop('use_lazy_init')
|
||||
|
||||
if use_lazy_init:
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
|
||||
# prepare booster
|
||||
plugin = HybridParallelPlugin(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
stage_manager = plugin.stage_manager
|
||||
|
||||
# prepare models and optimizers
|
||||
with ctx:
|
||||
org_model = model_fn().cuda()
|
||||
sharded_model = copy.deepcopy(org_model)
|
||||
|
||||
if use_lazy_init:
|
||||
org_model = ctx.materialize(org_model)
|
||||
|
||||
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
||||
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
||||
criterion = loss_fn
|
||||
|
||||
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
outputs = output_transform_fn(outputs)
|
||||
loss = criterion(outputs)
|
||||
return loss
|
||||
|
||||
# do forward and backward
|
||||
data = data_gen_fn()
|
||||
sharded_model.train()
|
||||
if stage_manager:
|
||||
data = {
|
||||
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
data_iter = iter([data])
|
||||
sharded_output = booster.execute_pipeline(data_iter,
|
||||
sharded_model,
|
||||
_criterion,
|
||||
sharded_optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
sharded_loss = sharded_output['loss']
|
||||
else:
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
sharded_output = sharded_model(**data)
|
||||
sharded_loss = criterion(sharded_output)
|
||||
sharded_loss.backward()
|
||||
|
||||
org_model.train()
|
||||
org_output = org_model(**data)
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
|
||||
# check last hidden state
|
||||
if org_model.__class__.__name__ == 'GPT2Model':
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
if stage_manager is None:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
||||
if stage_manager and stage_manager.is_last_stage():
|
||||
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']],
|
||||
dim=0)
|
||||
|
||||
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \
|
||||
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
|
||||
|
||||
# check loss
|
||||
assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \
|
||||
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'GPT2Model':
|
||||
org_model = org_model
|
||||
sharded_model = sharded_model
|
||||
sharded_model = sharded_model.unwrap()
|
||||
else:
|
||||
org_model = org_model.transformer
|
||||
sharded_model = sharded_model.transformer
|
||||
sharded_model = sharded_model.unwrap().transformer
|
||||
|
||||
# check mlp grad
|
||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_weight = sharded_model.h[0].mlp.c_fc.weight
|
||||
# check weights and gradients
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
shard_weight = sharded_model.h[0].mlp.c_fc.weight
|
||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||
|
||||
# check embedding weights
|
||||
org_grad = org_model.wte.weight.grad
|
||||
shard_grad = sharded_model.wte.weight.grad
|
||||
shard_weight = sharded_model.wte.weight
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(plugin.tp_size)]
|
||||
dist.all_gather(shard_grad_list, shard_grad, plugin.tp_group)
|
||||
shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
|
||||
assert torch.allclose(org_grad, shard_grad, atol=1e-5, rtol=1e-3), \
|
||||
f"shard model grad is not equal to origin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
|
||||
org_weight = org_model.h[0].mlp.c_fc.weight
|
||||
shard_weight = sharded_model.h[0].mlp.c_fc.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)]
|
||||
dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group)
|
||||
shard_weight = torch.cat(shard_weight_list, dim=1)
|
||||
|
||||
assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \
|
||||
f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}"
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': True
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': False,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}])
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
def run_gpt2_test(test_config):
|
||||
|
||||
# TODO: add plugin_config for TP+DP after supporting & debugging it
|
||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
@ -93,7 +188,7 @@ def check_gpt2(rank, world_size, port):
|
|||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gpt2():
|
||||
spawn(check_gpt2, 2)
|
||||
spawn(check_gpt2, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,72 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# TODO: add tests for forward/backward later
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_gpt2
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
inputs = data_gen_fn()
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
input_ids = inputs['input_ids']
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = sharded_model.config.n_embd
|
||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
||||
|
||||
if not stage_manager.is_first_stage():
|
||||
# change inputs if not the first stage
|
||||
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
||||
inputs['input_ids'] = None
|
||||
inputs['hidden_states'] = hidden_states
|
||||
|
||||
sharded_model.train()
|
||||
output = sharded_model(**inputs)
|
||||
if stage_manager.is_last_stage():
|
||||
if name == 'transformers_gpt':
|
||||
assert output[0].shape == hidden_state_shape
|
||||
else:
|
||||
assert output.loss is not None
|
||||
else:
|
||||
assert output['hidden_states'].shape == hidden_state_shape
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_gpt2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_gpt2_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gpt2():
|
||||
spawn(check_gpt2, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt2()
|
Loading…
Reference in New Issue