[pipeline] support fp32 for HybridPlugin/merge shardformer test and pipeline test into one file (#4354)

* add naive optimizer for 3DPlugin/refactor gpt2 shardformer test

* merge tests of PP/DP/TP combinations into one test file

* fix bug when sync grad for dp in HybridPlugin

* update supported precisions for 3DPlugin/fix bug when shifting tp_degree

* improve the passing of lazy_init

* modify lazy_init/use sync_shared_params
pull/4445/head
Baizhou Zhang 2023-08-01 17:29:09 +08:00 committed by Hongxin Liu
parent f13954cd58
commit 0ceec8f9a9
8 changed files with 187 additions and 142 deletions

View File

@ -134,7 +134,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
working_param = self.master_to_working_map[p]
if p is working_param:
continue
if working_param.grad is None:
if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()

View File

@ -42,6 +42,8 @@ class HybridParallelModule(ModelWrapper):
module = module.half().cuda()
elif precision == 'bf16':
module = module.to(dtype=torch.bfloat16).cuda()
else:
module = module.cuda() # train without AMP
# TODO(ver217): support TP+DP
super().__init__(module)
@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper):
for p in self.module.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, group=self.dp_group)
p.grad.div_(self.dp_group.size())
def init_pipeline_optimizer(optim: Optimizer, model: Module):
@ -72,7 +75,15 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
optim.__setstate__({'param_groups': new_param_groups})
class HybridParallelOptimizer(MixedPrecisionOptimizer):
class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim)
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(self,
optim: Optimizer,
@ -192,7 +203,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return ['cuda']
def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16']
return ['fp16', 'bf16', 'fp32']
def control_device(self) -> bool:
return True
@ -218,12 +229,17 @@ class HybridParallelPlugin(PipelinePluginBase):
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
optimizer = HybridParallelOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
if self.precision in ['fp16', 'bf16']:
optimizer = HybridParallelAMPOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
else:
optimizer = HybridParallelNaiveOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism)
else:
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
@ -241,7 +257,8 @@ class HybridParallelPlugin(PipelinePluginBase):
data_iter: Iterator,
model: HybridParallelModule,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer],
optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
HybridParallelZeroOptimizer],
return_loss: bool = True,
return_outputs: bool = False) -> dict:
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
@ -250,7 +267,7 @@ class HybridParallelPlugin(PipelinePluginBase):
with ctx:
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
return_outputs)
# model.sync_shared_params()
model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad()
else:

View File

@ -456,12 +456,12 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_.shape, self.weight.shape, self.weight.shape[0])
input_ = input_
else:
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
if self.stream_chunk_num > 1:

View File

@ -16,6 +16,11 @@ from .sharding_spec import ShardingSpec
layout_converter = LayoutConverter()
def clear_layout_converter():
global layout_converter
layout_converter.cached_solution.clear()
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a distributed tensor.

View File

@ -70,7 +70,8 @@ config = transformers.GPT2Config(n_layer=2,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification")
problem_type="single_label_classification",
pad_token_id=50256)
# register the following models
model_zoo.register(name='transformers_gpt',

View File

@ -160,7 +160,6 @@ def check_llama(rank, world_size, port):
run_llama_test()
@pytest.mark.skip('This test will fail')
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()

View File

@ -1,85 +1,180 @@
import copy
from contextlib import nullcontext
import pytest
import torch
from torch import distributed as dist
from torch.optim import Adam
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
from colossalai.tensor.d_tensor.api import (
clear_layout_converter,
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# do backward
use_lazy_init = False
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')
if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()
# prepare booster
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
stage_manager = plugin.stage_manager
# prepare models and optimizers
with ctx:
org_model = model_fn().cuda()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
org_model = ctx.materialize(org_model)
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
# do forward and backward
data = data_gen_fn()
sharded_model.train()
if stage_manager:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True)
sharded_loss = sharded_output['loss']
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_loss.backward()
org_model.train()
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
shard_loss.backward()
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
if stage_manager is None or stage_manager.is_last_stage():
# check last hidden state
if org_model.__class__.__name__ == 'GPT2Model':
org_hidden_state = org_output.last_hidden_state
if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']],
dim=0)
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
# check loss
assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
# unwrap model
if org_model.__class__.__name__ == 'GPT2Model':
org_model = org_model
sharded_model = sharded_model
sharded_model = sharded_model.unwrap()
else:
org_model = org_model.transformer
sharded_model = sharded_model.transformer
sharded_model = sharded_model.unwrap().transformer
# check mlp grad
org_grad = org_model.h[0].mlp.c_fc.weight.grad
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
shard_weight = sharded_model.h[0].mlp.c_fc.weight
# check weights and gradients
if stage_manager is None or stage_manager.is_first_stage():
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=1)
else:
all_shard_grad = shard_grad
assert torch.allclose(
org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
shard_weight = sharded_model.h[0].mlp.c_fc.weight
org_grad = org_model.h[0].mlp.c_fc.weight.grad
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
# check embedding weights
org_grad = org_model.wte.weight.grad
shard_grad = sharded_model.wte.weight.grad
shard_weight = sharded_model.wte.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(plugin.tp_size)]
dist.all_gather(shard_grad_list, shard_grad, plugin.tp_group)
shard_grad = torch.cat(shard_grad_list, dim=1)
assert torch.allclose(org_grad, shard_grad, atol=1e-5, rtol=1e-3), \
f"shard model grad is not equal to origin model grad\n{org_grad}\n{shard_grad}"
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
org_weight = org_model.h[0].mlp.c_fc.weight
shard_weight = sharded_model.h[0].mlp.c_fc.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)]
dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group)
shard_weight = torch.cat(shard_weight_list, dim=1)
assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}"
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(
org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('use_lazy_init', [False, True])
@parameterize('test_config', [{
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': True
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
@clear_cache_before_run()
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
def run_gpt2_test(test_config):
# TODO: add plugin_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@ -93,7 +188,7 @@ def check_gpt2(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2():
spawn(check_gpt2, 2)
spawn(check_gpt2, 4)
if __name__ == "__main__":

View File

@ -1,72 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_gpt2
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
input_ids = inputs['input_ids']
batch_size, seq_len = input_ids.shape
hidden_size = sharded_model.config.n_embd
hidden_state_shape = (batch_size, seq_len, hidden_size)
if not stage_manager.is_first_stage():
# change inputs if not the first stage
hidden_states = torch.zeros(*hidden_state_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
if name == 'transformers_gpt':
assert output[0].shape == hidden_state_shape
else:
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache()
def check_gpt2(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt2_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2():
spawn(check_gpt2, 4)
if __name__ == "__main__":
test_gpt2()