From 077a5cdde409cc89b726240b4788717fba1e62c4 Mon Sep 17 00:00:00 2001 From: HELSON Date: Sun, 29 Jan 2023 15:09:57 +0800 Subject: [PATCH] [zero] fix gradient clipping in hybrid parallelism (#2521) * [zero] fix gradient clipping in hybrid parallelism * [testing] change model name to avoid pytest warning * [hotfix] fix unit testing --- .../grad_scaler/dynamic_grad_scaler.py | 20 +++++++++++++++++-- colossalai/zero/sharded_optim/_utils.py | 11 +++++----- .../test_zero/low_level_zero/test_grad_acc.py | 8 ++++---- .../test_zero/low_level_zero/test_zero1_2.py | 8 ++++---- .../low_level_zero/test_zero_init.py | 8 ++++---- .../test_zero/low_level_zero/test_zero_tp.py | 16 ++++++++------- 6 files changed, 45 insertions(+), 26 deletions(-) diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index 6d6f2f287..e899b9ca4 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -58,10 +58,12 @@ class DynamicGradScaler(BaseGradScaler): if self._min_scale: assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' + assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale' if self._max_scale: - assert self._min_scale > 0, 'The maximum gradient scale cannot be zero or negative' + assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative' + assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale' assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' - assert self._backoff_factor < 1 and self._backoff_factor > 0, 'The backoff factor must be between 0 and 1' + assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1' assert self._hysteresis >= 0, 'The hysteresis cannot be negative' def update(self, overflow: bool) -> None: @@ -103,3 +105,17 @@ class DynamicGradScaler(BaseGradScaler): self._scale = self._scale * self._growth_factor if self._max_scale: self._scale = torch.min(self._scale, self._max_scale) + + def state_dict(self): + state_dict = dict() + state_dict['scale'] = self._scale + state_dict['growth_factor'] = self._growth_factor + state_dict['backoff_factor'] = self._backoff_factor + state_dict['hysteresis'] = self._hysteresis + return state_dict + + def load_state_dict(self, state_dict): + self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) + self._growth_factor = state_dict['growth_factor'] + self._backoff_factor = state_dict['backoff_factor'] + self._hysteresis = state_dict['hysteresis'] diff --git a/colossalai/zero/sharded_optim/_utils.py b/colossalai/zero/sharded_optim/_utils.py index 70d9c040c..e67434401 100644 --- a/colossalai/zero/sharded_optim/_utils.py +++ b/colossalai/zero/sharded_optim/_utils.py @@ -6,9 +6,7 @@ import torch.distributed as dist from torch._six import inf from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.tensor import ProcessGroup +from colossalai.tensor import ColoParameter from colossalai.utils import is_model_parallel_parameter @@ -225,7 +223,10 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): for g, p in zip(gradients, params): # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if is_model_parallel_parameter(p) or mp_rank == 0: + tp_param_flag = False + if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()): + tp_param_flag = True + if tp_param_flag or mp_rank == 0: param_norm = g.data.double().norm(2) total_norm += param_norm.item()**2 @@ -234,7 +235,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group) if mp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM) + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group) total_norm = total_norm_cuda[0].item()**(1. / norm_type) diff --git a/tests/test_zero/low_level_zero/test_grad_acc.py b/tests/test_zero/low_level_zero/test_grad_acc.py index 69795ed6a..1e157c70a 100644 --- a/tests/test_zero/low_level_zero/test_grad_acc.py +++ b/tests/test_zero/low_level_zero/test_grad_acc.py @@ -15,10 +15,10 @@ from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer -class TestModel(nn.Module): +class MlpModel(nn.Module): def __init__(self): - super(TestModel, self).__init__() + super(MlpModel, self).__init__() self.linear1 = nn.Linear(128, 256) self.linear2 = nn.Linear(256, 512) @@ -33,7 +33,7 @@ def exam_zero_1_2_grad_acc(): seed_all(2009) # create model - zero1_model = TestModel().cuda() + zero1_model = MlpModel().cuda() zero2_model = copy.deepcopy(zero1_model) # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) @@ -89,7 +89,7 @@ def exam_zero_1_grad_acc(): seed_all(2008) # create models - zero_model = TestModel() + zero_model = MlpModel() torch_model = copy.deepcopy(zero_model) seed_all(2008) diff --git a/tests/test_zero/low_level_zero/test_zero1_2.py b/tests/test_zero/low_level_zero/test_zero1_2.py index 8771bfbe6..494963072 100644 --- a/tests/test_zero/low_level_zero/test_zero1_2.py +++ b/tests/test_zero/low_level_zero/test_zero1_2.py @@ -14,10 +14,10 @@ from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer -class TestModel(nn.Module): +class MlpModel(nn.Module): def __init__(self): - super(TestModel, self).__init__() + super(MlpModel, self).__init__() self.linear1 = nn.Linear(128, 256) self.linear2 = nn.Linear(256, 512) @@ -55,7 +55,7 @@ def exam_zero_1_2(): seed_all(2001) # create model - zero1_model = TestModel().cuda() + zero1_model = MlpModel().cuda() zero2_model = copy.deepcopy(zero1_model) # create optimizer @@ -111,7 +111,7 @@ def exam_zero_1_torch_ddp(): seed_all(1453) # create models - zero_model = TestModel() + zero_model = MlpModel() torch_model = copy.deepcopy(zero_model) zero_model = zero_model.cuda().half() diff --git a/tests/test_zero/low_level_zero/test_zero_init.py b/tests/test_zero/low_level_zero/test_zero_init.py index 84d7b8c51..1305da5df 100644 --- a/tests/test_zero/low_level_zero/test_zero_init.py +++ b/tests/test_zero/low_level_zero/test_zero_init.py @@ -13,10 +13,10 @@ from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import LowLevelZeroOptimizer -class TestModel(nn.Module): +class MlpModel(nn.Module): def __init__(self): - super(TestModel, self).__init__() + super(MlpModel, self).__init__() self.linear1 = nn.Linear(128, 256) self.linear2 = nn.Linear(256, 512) @@ -28,9 +28,9 @@ class TestModel(nn.Module): def exam_zero_init(): dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2) - model1 = TestModel().cuda() + model1 = MlpModel().cuda() with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg): - model2 = TestModel() + model2 = MlpModel() optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1)) optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1)) diff --git a/tests/test_zero/low_level_zero/test_zero_tp.py b/tests/test_zero/low_level_zero/test_zero_tp.py index 8ba6e3cb6..ea8e3a0a3 100644 --- a/tests/test_zero/low_level_zero/test_zero_tp.py +++ b/tests/test_zero/low_level_zero/test_zero_tp.py @@ -20,10 +20,10 @@ def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4): return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol) -class TestModel(nn.Module): +class MlpModel(nn.Module): def __init__(self): - super(TestModel, self).__init__() + super(MlpModel, self).__init__() self.linear1 = nn.Linear(32, 128) self.act = nn.GELU() self.linear2 = nn.Linear(128, 32) @@ -42,8 +42,8 @@ def exam_zero_with_tp(overlap_flag, partition_flag): tp_pg = ProcessGroup(tp_degree=2) with ColoInitContext(device=get_current_device(), default_pg=tp_pg): - hybrid_model = TestModel() - torch_model = TestModel().cuda() + hybrid_model = MlpModel() + torch_model = MlpModel().cuda() for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()): pt.data.copy_(ph.data) @@ -55,10 +55,11 @@ def exam_zero_with_tp(overlap_flag, partition_flag): split_param_col_tp1d(param, tp_pg) torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group()) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1) - hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11 + hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2) hybrid_optim = LowLevelZeroOptimizer(hybrid_optim, - initial_scale=1, + initial_scale=2, + clip_grad_norm=1.0, overlap_communication=overlap_flag, partition_grad=partition_flag) @@ -71,6 +72,7 @@ def exam_zero_with_tp(overlap_flag, partition_flag): assert_close(torch_loss, hybrid_loss) torch_loss.backward() + torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) hybrid_optim.backward(hybrid_loss) hybrid_optim.sync_grad()