mirror of https://github.com/hpcaitech/ColossalAI
[zero] improve adaptability for not-shard parameters (#708)
* adapt post grad hooks for not-shard parameters * adapt optimizer for not-shard parameters * offload gradients for not-replicated parameterspull/712/head
parent
ab8c6b4a0e
commit
a9b8300d54
|
@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts
|
||||||
class ForceFP32Parameter(torch.nn.Parameter):
|
class ForceFP32Parameter(torch.nn.Parameter):
|
||||||
|
|
||||||
def half(self, memory_format=None):
|
def half(self, memory_format=None):
|
||||||
return self
|
return self.data
|
||||||
|
|
||||||
|
|
||||||
class NormalNoiseGenerator:
|
class NormalNoiseGenerator:
|
||||||
|
|
|
@ -142,6 +142,7 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
beta1, beta2 = group['betas']
|
beta1, beta2 = group['betas']
|
||||||
|
|
||||||
if target_device.type == 'cpu':
|
if target_device.type == 'cpu':
|
||||||
|
assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size"
|
||||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||||
|
@ -151,8 +152,8 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||||
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||||
|
|
||||||
bias_correction1 = 1 - beta1 ** state['step']
|
bias_correction1 = 1 - beta1**state['step']
|
||||||
bias_correction2 = 1 - beta2 ** state['step']
|
bias_correction2 = 1 - beta2**state['step']
|
||||||
|
|
||||||
# adam on cuda
|
# adam on cuda
|
||||||
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||||
|
|
|
@ -213,7 +213,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||||
for param in self.param_list:
|
for param in self.param_list:
|
||||||
assert hasattr(param, 'colo_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
if not param.colo_attr.param_is_sharded and param.is_replicated:
|
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
|
||||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||||
param.colo_attr.remove_torch_payload()
|
param.colo_attr.remove_torch_payload()
|
||||||
|
|
||||||
|
@ -239,9 +239,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
self.model_numel_tensor += param.numel()
|
self.model_numel_tensor += param.numel()
|
||||||
|
|
||||||
# mark whether the param is replicated
|
|
||||||
param.is_replicated = self.is_replicated
|
|
||||||
|
|
||||||
# convert parameters to half
|
# convert parameters to half
|
||||||
param_half = half_fn(param)
|
param_half = half_fn(param)
|
||||||
param.data = param_half
|
param.data = param_half
|
||||||
|
@ -261,6 +258,13 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
|
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
|
||||||
|
|
||||||
|
# mark whether the param is replicated
|
||||||
|
param.colo_attr.is_replicated = self.is_replicated
|
||||||
|
|
||||||
|
# mark whether the param should keep not sharded
|
||||||
|
# if True, the param is used as Zero stage 2
|
||||||
|
param.colo_attr.keep_not_shard = not self.shard_param
|
||||||
|
|
||||||
self.param_list.append(param)
|
self.param_list.append(param)
|
||||||
|
|
||||||
# We must cast buffers
|
# We must cast buffers
|
||||||
|
|
|
@ -123,7 +123,7 @@ class ShardedModelV2(nn.Module):
|
||||||
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
|
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
|
||||||
]
|
]
|
||||||
register_ophooks_recursively(self.module, self._ophook_list)
|
register_ophooks_recursively(self.module, self._ophook_list)
|
||||||
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
|
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||||
|
|
||||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||||
|
@ -177,8 +177,8 @@ class ShardedModelV2(nn.Module):
|
||||||
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
|
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
|
||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
with open(filename, 'w+') as f:
|
with open(filename, 'w+') as f:
|
||||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
|
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
|
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||||
f.write('CUDA model data (GB)\n')
|
f.write('CUDA model data (GB)\n')
|
||||||
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
|
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
|
||||||
f.write('\n')
|
f.write('\n')
|
||||||
|
@ -254,10 +254,6 @@ class ShardedModelV2(nn.Module):
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.cuda.current_stream().synchronize()
|
||||||
self.reducer.free()
|
self.reducer.free()
|
||||||
|
|
||||||
# all reduce gradients for unsharded parameters
|
|
||||||
reduce_list = [p for p in self.unshard_params if p.is_replicated]
|
|
||||||
bucket_allreduce(reduce_list, self.process_group)
|
|
||||||
|
|
||||||
# 3. shard tensors not dealed in the zero hook
|
# 3. shard tensors not dealed in the zero hook
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for p in self.sharded_params:
|
for p in self.sharded_params:
|
||||||
|
@ -279,15 +275,6 @@ class ShardedModelV2(nn.Module):
|
||||||
if not self._require_backward_grad_sync:
|
if not self._require_backward_grad_sync:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# move unsharded param grad to saved_grad
|
|
||||||
if not p.colo_attr.param_is_sharded:
|
|
||||||
if p.colo_attr.offload_grad:
|
|
||||||
colo_model_data_move_to_cpu(p.grad)
|
|
||||||
if p.colo_attr.saved_grad.is_null():
|
|
||||||
p.colo_attr.saved_grad.reset_payload(p.grad.data)
|
|
||||||
else:
|
|
||||||
p.colo_attr.saved_grad.payload.add_(p.grad.data)
|
|
||||||
|
|
||||||
p.grad = None
|
p.grad = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -316,6 +303,18 @@ class ShardedModelV2(nn.Module):
|
||||||
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
||||||
if not self._require_backward_grad_sync:
|
if not self._require_backward_grad_sync:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if param.colo_attr.is_replicated:
|
||||||
|
self._reduce_scatter_handler(param, grad)
|
||||||
|
else:
|
||||||
|
self._save_grad(param, grad)
|
||||||
|
|
||||||
|
# used to cheat Pytorch, since we can't return None
|
||||||
|
empty_grad = torch.empty_like(grad)
|
||||||
|
free_storage(empty_grad)
|
||||||
|
return empty_grad
|
||||||
|
|
||||||
|
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
|
||||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(self.comm_stream):
|
with torch.cuda.stream(self.comm_stream):
|
||||||
new_grad = grad.clone()
|
new_grad = grad.clone()
|
||||||
|
@ -334,9 +333,6 @@ class ShardedModelV2(nn.Module):
|
||||||
self._reduce_scatter_callback(param, new_grad)
|
self._reduce_scatter_callback(param, new_grad)
|
||||||
orig_grad_data.record_stream(self.comm_stream)
|
orig_grad_data.record_stream(self.comm_stream)
|
||||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||||
empty_grad = torch.empty_like(grad)
|
|
||||||
free_storage(empty_grad)
|
|
||||||
return empty_grad
|
|
||||||
|
|
||||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||||
assert isinstance(reduced_grad,
|
assert isinstance(reduced_grad,
|
||||||
|
@ -345,21 +341,35 @@ class ShardedModelV2(nn.Module):
|
||||||
if self.gradient_postdivide_factor > 1:
|
if self.gradient_postdivide_factor > 1:
|
||||||
# Average grad by world_size for consistency with PyTorch DDP.
|
# Average grad by world_size for consistency with PyTorch DDP.
|
||||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||||
# FIXME(ver217): remove the below line when impl eviction policy
|
self._save_grad(param, reduced_grad)
|
||||||
|
|
||||||
|
# FIXME(ver217): refactor the below line when impl eviction policy
|
||||||
|
def _save_grad(self, param: Parameter, grad: torch.Tensor):
|
||||||
|
# move gradient to cpu
|
||||||
if param.colo_attr.offload_grad:
|
if param.colo_attr.offload_grad:
|
||||||
colo_model_data_move_to_cpu(reduced_grad)
|
colo_model_data_move_to_cpu(grad)
|
||||||
|
|
||||||
if self.reuse_fp16_shard:
|
if self.reuse_fp16_shard:
|
||||||
|
# make parameters point to gradient
|
||||||
|
|
||||||
assert param.colo_attr.saved_grad.is_null(
|
assert param.colo_attr.saved_grad.is_null(
|
||||||
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||||
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad)
|
|
||||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
param.colo_attr.saved_grad.reset_payload(grad)
|
||||||
param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload)
|
param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param
|
||||||
|
|
||||||
|
if param.colo_attr.is_replicated:
|
||||||
|
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||||
else:
|
else:
|
||||||
reduced_grad = cast_tensor_to_fp32(reduced_grad)
|
|
||||||
|
fp32_grad = cast_tensor_to_fp32(grad)
|
||||||
|
|
||||||
if param.colo_attr.saved_grad.is_null():
|
if param.colo_attr.saved_grad.is_null():
|
||||||
param.colo_attr.saved_grad.reset_payload(reduced_grad)
|
param.colo_attr.saved_grad.reset_payload(fp32_grad)
|
||||||
else:
|
else:
|
||||||
param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload))
|
param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload))
|
||||||
|
|
||||||
|
# keep saved_grad in HOLD state
|
||||||
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||||
|
|
|
@ -68,9 +68,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
||||||
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
||||||
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
||||||
keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters.
|
|
||||||
In Zero-2, set keep_unsharded to False.
|
|
||||||
In Zero-3, set keep_unsharded to True.
|
|
||||||
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
||||||
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
|
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
|
||||||
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
|
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
|
||||||
|
@ -91,7 +88,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
growth_interval: float = 1000,
|
growth_interval: float = 1000,
|
||||||
hysteresis: float = 2,
|
hysteresis: float = 2,
|
||||||
max_scale: int = 2**32,
|
max_scale: int = 2**32,
|
||||||
keep_unsharded: bool = False,
|
|
||||||
dp_process_group: Optional[ProcessGroup] = None,
|
dp_process_group: Optional[ProcessGroup] = None,
|
||||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||||
|
@ -125,10 +121,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
|
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
|
||||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||||
|
|
||||||
assert not (keep_unsharded and self._should_move_fp32_shards_h2d), \
|
|
||||||
"Keeping unsharded parameters can't be used with hybrid OS placement right now."
|
|
||||||
self.keep_unshard = keep_unsharded
|
|
||||||
|
|
||||||
# Store fp32 param shards
|
# Store fp32 param shards
|
||||||
self._register_master_weight()
|
self._register_master_weight()
|
||||||
|
|
||||||
|
@ -139,6 +131,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
if self._use_memory_tracer:
|
if self._use_memory_tracer:
|
||||||
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
|
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_scale(self):
|
||||||
|
return self.grad_scaler.scale.item()
|
||||||
|
|
||||||
def get_memory_usage(self) -> Tuple[int, int]:
|
def get_memory_usage(self) -> Tuple[int, int]:
|
||||||
""" Get the memory usage of the optimizer. Including master_params (param fp32),
|
""" Get the memory usage of the optimizer. Including master_params (param fp32),
|
||||||
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
|
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
|
||||||
|
@ -166,6 +162,22 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
|
|
||||||
return cuda_use, cpu_use
|
return cuda_use, cpu_use
|
||||||
|
|
||||||
|
def zero_grad(self, *args, **kwargs):
|
||||||
|
self._zero_grad()
|
||||||
|
|
||||||
|
def backward(self, loss: Tensor) -> None:
|
||||||
|
loss = self.loss_scale * loss
|
||||||
|
self.optim_state = OptimState.SCALED
|
||||||
|
self.model.backward(loss)
|
||||||
|
|
||||||
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||||
|
self.model.backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
|
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||||
|
if self.optim_state == OptimState.SCALED:
|
||||||
|
self._unscale_grads()
|
||||||
|
return super().clip_grad_norm(model, max_norm)
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
self._prepare_grads()
|
self._prepare_grads()
|
||||||
self._maybe_move_fp32_shards()
|
self._maybe_move_fp32_shards()
|
||||||
|
@ -193,26 +205,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!",
|
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
self._copy_master_param_to_param_fp16()
|
self._copy_master_model_to_model_fp16()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def backward(self, loss: Tensor) -> None:
|
|
||||||
loss = self.loss_scale * loss
|
|
||||||
self.optim_state = OptimState.SCALED
|
|
||||||
self.model.backward(loss)
|
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
|
||||||
self.model.backward_by_grad(tensor, grad)
|
|
||||||
|
|
||||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
|
||||||
if self.optim_state == OptimState.SCALED:
|
|
||||||
self._unscale_grads()
|
|
||||||
return super().clip_grad_norm(model, max_norm)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def loss_scale(self):
|
|
||||||
return self.grad_scaler.scale.item()
|
|
||||||
|
|
||||||
def _check_overflow(self):
|
def _check_overflow(self):
|
||||||
# clear previous overflow record
|
# clear previous overflow record
|
||||||
self._found_overflow.fill_(0.0)
|
self._found_overflow.fill_(0.0)
|
||||||
|
@ -240,9 +235,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
p.grad.data.div_(self.loss_scale)
|
p.grad.data.div_(self.loss_scale)
|
||||||
self.optim_state = OptimState.UNSCALED
|
self.optim_state = OptimState.UNSCALED
|
||||||
|
|
||||||
def zero_grad(self, *args, **kwargs):
|
|
||||||
self._zero_grad()
|
|
||||||
|
|
||||||
def _zero_grad(self, recover_data: bool = False):
|
def _zero_grad(self, recover_data: bool = False):
|
||||||
"""zero grad and maybe recover fp16 params
|
"""zero grad and maybe recover fp16 params
|
||||||
When `reuse_fp16_shard` is enabled,
|
When `reuse_fp16_shard` is enabled,
|
||||||
|
@ -262,13 +254,11 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# p.colo_attr.sharded_data_tensor stores grad now
|
# p.colo_attr.sharded_data_tensor stores grad now
|
||||||
# we have to recover fp16 param
|
# we have to recover fp16 param
|
||||||
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
||||||
p.colo_attr.saved_grad.set_null()
|
|
||||||
if recover_data and reuse_fp16_shard:
|
if recover_data and reuse_fp16_shard:
|
||||||
# We should write like this to trigger ForceFP32Paramter's half method
|
self._copy_master_param_to_param_fp16(p)
|
||||||
p.data = self.master_params[p].payload
|
else:
|
||||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
# release saved gradient
|
||||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
p.colo_attr.saved_grad.set_null()
|
||||||
p.colo_attr.remove_torch_payload()
|
|
||||||
|
|
||||||
def sync_grad(self):
|
def sync_grad(self):
|
||||||
pass
|
pass
|
||||||
|
@ -278,14 +268,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
|
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated
|
||||||
if not is_param_sharded and not self.keep_unshard:
|
if shard_flag:
|
||||||
# Please use keep_unsharded to control whether shard unsharded paramters
|
# we always shard replicated paramters
|
||||||
# As we only store param shard, we shard it here
|
|
||||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
self.master_params[p] = StatefulTensor(
|
self.master_params[p] = StatefulTensor(
|
||||||
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
|
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
|
||||||
if not is_param_sharded and not self.keep_unshard:
|
if shard_flag:
|
||||||
# In this branch, there's no need to shard param
|
# In this branch, there's no need to shard param
|
||||||
# So we gather here
|
# So we gather here
|
||||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
|
@ -328,31 +317,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# Now p.data is sharded
|
# Now p.data is sharded
|
||||||
# So optimizer states are sharded naturally
|
# So optimizer states are sharded naturally
|
||||||
|
|
||||||
def _copy_master_param_to_param_fp16(self):
|
def _copy_master_model_to_model_fp16(self):
|
||||||
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
||||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||||
# a chunk.
|
# a chunk.
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
self._copy_master_param_to_param_fp16(p)
|
||||||
if not is_param_sharded and not self.keep_unshard:
|
|
||||||
# We use ZeRO-2 here
|
|
||||||
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
|
||||||
# But we only have updated fp32 param shard here
|
|
||||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
|
||||||
# Then we will gather them
|
|
||||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
|
||||||
# We have to use `copy_payload` instead of `reset_payload`
|
|
||||||
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
|
||||||
|
|
||||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
def _copy_master_param_to_param_fp16(self, p):
|
||||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
# flush gradient
|
||||||
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
|
p.colo_attr.saved_grad.set_null()
|
||||||
p.colo_attr.remove_torch_payload()
|
|
||||||
|
|
||||||
if not is_param_sharded and not self.keep_unshard:
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||||
# We gather full fp16 param here
|
p.data = self.master_params[p].payload
|
||||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||||
|
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
|
||||||
|
p.colo_attr.remove_torch_payload()
|
||||||
|
|
||||||
self.master_params[p].trans_state(TensorState.HOLD)
|
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||||
p.colo_attr.saved_grad.set_null()
|
# We gather full fp16 param here
|
||||||
|
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
|
||||||
|
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
|
|
||||||
|
self.master_params[p].trans_state(TensorState.HOLD)
|
||||||
|
|
|
@ -71,9 +71,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
||||||
|
|
||||||
# the parameters in moe experts is not replicated
|
# the parameters in moe experts is not replicated
|
||||||
if 'experts' in name:
|
if 'experts' in name:
|
||||||
assert not param.is_replicated
|
assert not param.colo_attr.is_replicated
|
||||||
else:
|
else:
|
||||||
assert param.is_replicated
|
assert param.colo_attr.is_replicated
|
||||||
|
|
||||||
if param.colo_attr.param_is_sharded:
|
if param.colo_attr.param_is_sharded:
|
||||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||||
|
|
|
@ -36,7 +36,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
||||||
|
|
||||||
# check whether parameters are identical in ddp
|
# check whether parameters are identical in ddp
|
||||||
for name, p in zero_model.named_parameters():
|
for name, p in zero_model.named_parameters():
|
||||||
if not p.colo_attr.param_is_sharded and p.is_replicated:
|
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||||
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)
|
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)
|
||||||
|
|
||||||
model = MoeModel().half()
|
model = MoeModel().half()
|
||||||
|
|
|
@ -48,8 +48,13 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler):
|
||||||
|
|
||||||
@parameterize("cpu_offload", [True])
|
@parameterize("cpu_offload", [True])
|
||||||
@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
|
@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
|
||||||
|
@parameterize("reuse_fp16_shard", [True, False])
|
||||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
|
def _run_test_sharded_optim_v2(cpu_offload,
|
||||||
|
shard_strategy_class,
|
||||||
|
use_cpuadam,
|
||||||
|
reuse_fp16_shard,
|
||||||
|
gpu_margin_mem_ratio=0.0):
|
||||||
shard_strategy = shard_strategy_class()
|
shard_strategy = shard_strategy_class()
|
||||||
if use_cpuadam and cpu_offload is False:
|
if use_cpuadam and cpu_offload is False:
|
||||||
return
|
return
|
||||||
|
@ -63,17 +68,15 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||||
shard_param=True):
|
shard_param=True):
|
||||||
zero_model = MoeModel()
|
zero_model = MoeModel()
|
||||||
|
|
||||||
zero_model = ShardedModelV2(
|
zero_model = ShardedModelV2(zero_model,
|
||||||
zero_model,
|
shard_strategy,
|
||||||
shard_strategy,
|
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
reuse_fp16_shard=reuse_fp16_shard)
|
||||||
reuse_fp16_shard=use_cpuadam,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check whether parameters are identical in ddp
|
# check whether parameters are identical in ddp
|
||||||
for name, p in zero_model.named_parameters():
|
for name, p in zero_model.named_parameters():
|
||||||
if not p.colo_attr.param_is_sharded and p.is_replicated:
|
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||||
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
|
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
|
||||||
|
|
||||||
model = MoeModel().half()
|
model = MoeModel().half()
|
||||||
|
@ -88,8 +91,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||||
sharded_optim,
|
sharded_optim,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
initial_scale=2**5,
|
initial_scale=2**5,
|
||||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||||
keep_unsharded=True)
|
|
||||||
|
|
||||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
|
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
|
||||||
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
||||||
|
|
|
@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||||
if zero_p.colo_attr.param_is_sharded:
|
if zero_p.colo_attr.is_replicated:
|
||||||
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
||||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||||
if rank >= len(chunks):
|
if rank >= len(chunks):
|
||||||
|
@ -102,8 +102,9 @@ def check_grads_padding(model, zero_model, loose=False):
|
||||||
if zero_grad.size(0) > grad.size(0):
|
if zero_grad.size(0) > grad.size(0):
|
||||||
zero_grad = zero_grad[:grad.size(0)]
|
zero_grad = zero_grad[:grad.size(0)]
|
||||||
else:
|
else:
|
||||||
grad = p.grad
|
|
||||||
zero_grad = zero_p.colo_attr.saved_grad.payload
|
zero_grad = zero_p.colo_attr.saved_grad.payload
|
||||||
|
grad = p.grad.to(zero_grad.dtype)
|
||||||
|
|
||||||
assert grad.dtype == zero_grad.dtype
|
assert grad.dtype == zero_grad.dtype
|
||||||
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
|
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
|
||||||
|
|
||||||
|
@ -134,7 +135,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
|
||||||
if zero_p.size(0) > p.size(0):
|
if zero_p.size(0) > p.size(0):
|
||||||
zero_p = zero_p[:p.size(0)]
|
zero_p = zero_p[:p.size(0)]
|
||||||
else:
|
else:
|
||||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload
|
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device)
|
||||||
|
|
||||||
assert p.dtype == zero_p.dtype
|
assert p.dtype == zero_p.dtype
|
||||||
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
|
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
|
||||||
|
|
Loading…
Reference in New Issue