[zero] fix gradient clipping in hybrid parallelism (#2521)

* [zero] fix gradient clipping in hybrid parallelism

* [testing] change model name to avoid pytest warning

* [hotfix] fix unit testing
pull/2523/head
HELSON 2023-01-29 15:09:57 +08:00 committed by GitHub
parent fd8d19a6e7
commit 077a5cdde4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 26 deletions

View File

@ -58,10 +58,12 @@ class DynamicGradScaler(BaseGradScaler):
if self._min_scale:
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale'
if self._max_scale:
assert self._min_scale > 0, 'The maximum gradient scale cannot be zero or negative'
assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative'
assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale'
assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
assert self._backoff_factor < 1 and self._backoff_factor > 0, 'The backoff factor must be between 0 and 1'
assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1'
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
def update(self, overflow: bool) -> None:
@ -103,3 +105,17 @@ class DynamicGradScaler(BaseGradScaler):
self._scale = self._scale * self._growth_factor
if self._max_scale:
self._scale = torch.min(self._scale, self._max_scale)
def state_dict(self):
state_dict = dict()
state_dict['scale'] = self._scale
state_dict['growth_factor'] = self._growth_factor
state_dict['backoff_factor'] = self._backoff_factor
state_dict['hysteresis'] = self._hysteresis
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_factor = state_dict['growth_factor']
self._backoff_factor = state_dict['backoff_factor']
self._hysteresis = state_dict['hysteresis']

View File

@ -6,9 +6,7 @@ import torch.distributed as dist
from torch._six import inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.tensor import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils import is_model_parallel_parameter
@ -225,7 +223,10 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if is_model_parallel_parameter(p) or mp_rank == 0:
tp_param_flag = False
if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()):
tp_param_flag = True
if tp_param_flag or mp_rank == 0:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
@ -234,7 +235,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM)
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)

View File

@ -15,10 +15,10 @@ from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
class TestModel(nn.Module):
class MlpModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
@ -33,7 +33,7 @@ def exam_zero_1_2_grad_acc():
seed_all(2009)
# create model
zero1_model = TestModel().cuda()
zero1_model = MlpModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
@ -89,7 +89,7 @@ def exam_zero_1_grad_acc():
seed_all(2008)
# create models
zero_model = TestModel()
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
seed_all(2008)

View File

@ -14,10 +14,10 @@ from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
class TestModel(nn.Module):
class MlpModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
@ -55,7 +55,7 @@ def exam_zero_1_2():
seed_all(2001)
# create model
zero1_model = TestModel().cuda()
zero1_model = MlpModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
@ -111,7 +111,7 @@ def exam_zero_1_torch_ddp():
seed_all(1453)
# create models
zero_model = TestModel()
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half()

View File

@ -13,10 +13,10 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import LowLevelZeroOptimizer
class TestModel(nn.Module):
class MlpModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
@ -28,9 +28,9 @@ class TestModel(nn.Module):
def exam_zero_init():
dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2)
model1 = TestModel().cuda()
model1 = MlpModel().cuda()
with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg):
model2 = TestModel()
model2 = MlpModel()
optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1))
optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1))

View File

@ -20,10 +20,10 @@ def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4):
return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol)
class TestModel(nn.Module):
class MlpModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(32, 128)
self.act = nn.GELU()
self.linear2 = nn.Linear(128, 32)
@ -42,8 +42,8 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
tp_pg = ProcessGroup(tp_degree=2)
with ColoInitContext(device=get_current_device(), default_pg=tp_pg):
hybrid_model = TestModel()
torch_model = TestModel().cuda()
hybrid_model = MlpModel()
torch_model = MlpModel().cuda()
for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()):
pt.data.copy_(ph.data)
@ -55,10 +55,11 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
split_param_col_tp1d(param, tp_pg)
torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group())
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1)
hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11
hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2)
hybrid_optim = LowLevelZeroOptimizer(hybrid_optim,
initial_scale=1,
initial_scale=2,
clip_grad_norm=1.0,
overlap_communication=overlap_flag,
partition_grad=partition_flag)
@ -71,6 +72,7 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
assert_close(torch_loss, hybrid_loss)
torch_loss.backward()
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
hybrid_optim.backward(hybrid_loss)
hybrid_optim.sync_grad()