mirror of https://github.com/hpcaitech/ColossalAI
[polish] rename col_attr -> colo_attr (#558)
parent
2c45efc398
commit
7675366fce
|
@ -35,58 +35,58 @@ class ZeroHook(BaseOpHook):
|
||||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||||
param.data = param.col_attr.sharded_data_tensor.payload
|
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||||
|
|
||||||
if self._memstarts_collector:
|
if self._memstarts_collector:
|
||||||
self._memstarts_collector.sample_memstats()
|
self._memstarts_collector.sample_memstats()
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
|
||||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||||
|
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.col_attr.remove_torch_payload()
|
param.colo_attr.remove_torch_payload()
|
||||||
|
|
||||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||||
param.data = param.col_attr.sharded_data_tensor.payload
|
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||||
if self._memstarts_collector:
|
if self._memstarts_collector:
|
||||||
self._memstarts_collector.sample_memstats()
|
self._memstarts_collector.sample_memstats()
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
|
||||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||||
|
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.col_attr.remove_torch_payload()
|
param.colo_attr.remove_torch_payload()
|
||||||
|
|
||||||
def pre_iter(self):
|
def pre_iter(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -45,8 +45,8 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
cuda_mem_usage = 0
|
cuda_mem_usage = 0
|
||||||
cpu_mem_usage = 0
|
cpu_mem_usage = 0
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
if hasattr(param, 'col_attr'):
|
if hasattr(param, 'colo_attr'):
|
||||||
t_cuda, t_cpu = param.col_attr.get_memory_usage()
|
t_cuda, t_cpu = param.colo_attr.get_memory_usage()
|
||||||
cuda_mem_usage += t_cuda
|
cuda_mem_usage += t_cuda
|
||||||
cpu_mem_usage += t_cpu
|
cpu_mem_usage += t_cpu
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -162,8 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
"""
|
"""
|
||||||
if not self.rm_torch_payload_on_the_fly:
|
if not self.rm_torch_payload_on_the_fly:
|
||||||
for param in self.initialized_param_list:
|
for param in self.initialized_param_list:
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
param.col_attr.remove_torch_payload()
|
param.colo_attr.remove_torch_payload()
|
||||||
|
|
||||||
del self.initialized_param_list
|
del self.initialized_param_list
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
# avoid adapting a param to ShardedParam twice
|
# avoid adapting a param to ShardedParam twice
|
||||||
if hasattr(param, 'col_attr'):
|
if hasattr(param, 'colo_attr'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.model_numel_tensor += param.numel()
|
self.model_numel_tensor += param.numel()
|
||||||
|
@ -196,10 +196,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
param.grad = param.grad.to(target_device)
|
param.grad = param.grad.to(target_device)
|
||||||
|
|
||||||
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
param.colo_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
self.initialized_param_list.append(param)
|
self.initialized_param_list.append(param)
|
||||||
|
|
||||||
# We must cast buffers
|
# We must cast buffers
|
||||||
|
|
|
@ -70,9 +70,9 @@ class ShardedModelV2(nn.Module):
|
||||||
sharded = []
|
sharded = []
|
||||||
unsharded = []
|
unsharded = []
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
assert hasattr(param, 'col_attr'), 'You must use ZeroInitContext to init your module first.'
|
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
|
||||||
sharded.append(param.col_attr.param_is_sharded)
|
sharded.append(param.colo_attr.param_is_sharded)
|
||||||
unsharded.append(not param.col_attr.param_is_sharded)
|
unsharded.append(not param.colo_attr.param_is_sharded)
|
||||||
assert all(sharded) or all(
|
assert all(sharded) or all(
|
||||||
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
|
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
|
||||||
self.shard_param = all(sharded)
|
self.shard_param = all(sharded)
|
||||||
|
@ -103,7 +103,7 @@ class ShardedModelV2(nn.Module):
|
||||||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
# Init `offload_grad`
|
# Init `offload_grad`
|
||||||
param.col_attr.offload_grad = self._cpu_offload
|
param.colo_attr.offload_grad = self._cpu_offload
|
||||||
|
|
||||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||||
# So we use 1.0 as the default gradient_predivide_factor
|
# So we use 1.0 as the default gradient_predivide_factor
|
||||||
|
@ -162,13 +162,13 @@ class ShardedModelV2(nn.Module):
|
||||||
self._memstats_collector.start_collection()
|
self._memstats_collector.start_collection()
|
||||||
|
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if hasattr(p, 'col_attr'):
|
if hasattr(p, 'colo_attr'):
|
||||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
def _post_forward_operations(self):
|
def _post_forward_operations(self):
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if hasattr(p, 'col_attr'):
|
if hasattr(p, 'colo_attr'):
|
||||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
self._pre_forward_operations()
|
self._pre_forward_operations()
|
||||||
|
@ -228,10 +228,10 @@ class ShardedModelV2(nn.Module):
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if not p.col_attr.param_is_sharded:
|
if not p.colo_attr.param_is_sharded:
|
||||||
tensor_list.append(p.col_attr.sharded_data_tensor)
|
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
||||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||||
p.col_attr.remove_torch_payload()
|
p.colo_attr.remove_torch_payload()
|
||||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||||
|
|
||||||
# 4. move sharded param grad payload to param.grad
|
# 4. move sharded param grad payload to param.grad
|
||||||
|
@ -245,27 +245,27 @@ class ShardedModelV2(nn.Module):
|
||||||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||||
if not self._require_backward_grad_sync:
|
if not self._require_backward_grad_sync:
|
||||||
continue
|
continue
|
||||||
# Reduced grad is saved in `p.col_attr.saved_grad`
|
# Reduced grad is saved in `p.colo_attr.saved_grad`
|
||||||
# It can be on CPU or CUDA
|
# It can be on CPU or CUDA
|
||||||
# It can be fp16 or fp32
|
# It can be fp16 or fp32
|
||||||
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
|
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
|
||||||
if self.reuse_fp16_shard:
|
if self.reuse_fp16_shard:
|
||||||
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
|
grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload
|
||||||
else:
|
else:
|
||||||
grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
|
grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload)
|
||||||
assert isinstance(grad_fp16_payload, torch.Tensor)
|
assert isinstance(grad_fp16_payload, torch.Tensor)
|
||||||
if p.col_attr.offload_grad:
|
if p.colo_attr.offload_grad:
|
||||||
colo_model_data_move_to_cpu(grad_fp16_payload)
|
colo_model_data_move_to_cpu(grad_fp16_payload)
|
||||||
if not p.col_attr.saved_grad.is_null():
|
if not p.colo_attr.saved_grad.is_null():
|
||||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||||
# Accumulate grad, saved grad must be fp32
|
# Accumulate grad, saved grad must be fp32
|
||||||
p.col_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.col_attr.saved_grad.payload))
|
p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload))
|
||||||
p.col_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.saved_grad.payload))
|
p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload))
|
||||||
else:
|
else:
|
||||||
p.col_attr.saved_grad.reset_payload(grad_fp16_payload)
|
p.colo_attr.saved_grad.reset_payload(grad_fp16_payload)
|
||||||
|
|
||||||
p.grad = None
|
p.grad = None
|
||||||
p.col_attr.fp16_grad.set_null()
|
p.colo_attr.fp16_grad.set_null()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
@ -273,7 +273,7 @@ class ShardedModelV2(nn.Module):
|
||||||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||||
full gradient for the local batch. The reduce-scatter op will save
|
full gradient for the local batch. The reduce-scatter op will save
|
||||||
a single shard of the summed gradient across all
|
a single shard of the summed gradient across all
|
||||||
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example::
|
GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example::
|
||||||
|
|
||||||
before reduce_scatter:
|
before reduce_scatter:
|
||||||
param.grad (GPU #0): [1, 2, 3, 4]
|
param.grad (GPU #0): [1, 2, 3, 4]
|
||||||
|
@ -285,7 +285,7 @@ class ShardedModelV2(nn.Module):
|
||||||
|
|
||||||
The local GPU's ``optim.step`` is responsible for updating a single
|
The local GPU's ``optim.step`` is responsible for updating a single
|
||||||
shard of params, also corresponding to the current GPU's rank. This
|
shard of params, also corresponding to the current GPU's rank. This
|
||||||
alignment is created by `param.col_attr.grad`, which ensures that
|
alignment is created by `param.colo_attr.grad`, which ensures that
|
||||||
the local optimizer only sees the relevant parameter shard.
|
the local optimizer only sees the relevant parameter shard.
|
||||||
"""
|
"""
|
||||||
if grad is None:
|
if grad is None:
|
||||||
|
@ -323,20 +323,20 @@ class ShardedModelV2(nn.Module):
|
||||||
# Average grad by world_size for consistency with PyTorch DDP.
|
# Average grad by world_size for consistency with PyTorch DDP.
|
||||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||||
if self.reuse_fp16_shard:
|
if self.reuse_fp16_shard:
|
||||||
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
||||||
param.col_attr.sharded_data_tensor.is_sharded = True
|
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||||
else:
|
else:
|
||||||
param.col_attr.fp16_grad = StatefulTensor(reduced_grad.data)
|
param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data)
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||||
self.process_group)
|
self.process_group)
|
||||||
prev_params = {}
|
prev_params = {}
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
prev_params[p] = p.data
|
prev_params[p] = p.data
|
||||||
p.data = p.col_attr.sharded_data_tensor.payload
|
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||||
self.process_group)
|
self.process_group)
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
p.data = prev_params[p]
|
p.data = prev_params[p]
|
||||||
|
|
|
@ -10,10 +10,10 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
|
||||||
Note the other_model has to be the same as self.
|
Note the other_model has to be the same as self.
|
||||||
"""
|
"""
|
||||||
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
|
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
|
||||||
assert hasattr(zero_param, 'col_attr')
|
assert hasattr(zero_param, 'colo_attr')
|
||||||
shard_flag = zero_param.col_attr.sharded_data_tensor.is_sharded
|
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
|
||||||
if shard_flag:
|
if shard_flag:
|
||||||
sharded_model.shard_strategy.gather([zero_param.col_attr.sharded_data_tensor])
|
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
|
||||||
param.data = copy.deepcopy(zero_param.col_attr.sharded_data_tensor.payload)
|
param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload)
|
||||||
if shard_flag:
|
if shard_flag:
|
||||||
sharded_model.shard_strategy.shard([zero_param.col_attr.sharded_data_tensor])
|
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])
|
||||||
|
|
|
@ -116,18 +116,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
|
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
|
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||||
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||||
if not is_param_sharded:
|
if not is_param_sharded:
|
||||||
# TODO (ver217): we may not use shard / gather here
|
# TODO (ver217): we may not use shard / gather here
|
||||||
# Param is no sharded, which means we use ZeRO-2 here
|
# Param is no sharded, which means we use ZeRO-2 here
|
||||||
# As we only store param shard, we shard it here
|
# As we only store param shard, we shard it here
|
||||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device)
|
self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)
|
||||||
if not is_param_sharded:
|
if not is_param_sharded:
|
||||||
# In this branch, there's no need to shard param
|
# In this branch, there's no need to shard param
|
||||||
# So we gather here
|
# So we gather here
|
||||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
|
|
||||||
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
|
@ -201,30 +201,30 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
# Copy master param data (fp32) to payload of col_attr (fp16)
|
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
||||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||||
# a chunk.
|
# a chunk.
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||||
if not is_param_sharded:
|
if not is_param_sharded:
|
||||||
# We use ZeRO-2 here
|
# We use ZeRO-2 here
|
||||||
# The `p.col_attr.sharded_data_tensor` saves full fp16 param
|
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
||||||
# But we only have updated fp32 param shard here
|
# But we only have updated fp32 param shard here
|
||||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
# So we first shard full fp16 param and copy fp32 param shard to it
|
||||||
# Then we will gather them
|
# Then we will gather them
|
||||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
# We have to use `copy_payload` instead of `reset_payload`
|
# We have to use `copy_payload` instead of `reset_payload`
|
||||||
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
|
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
||||||
|
|
||||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||||
p.col_attr.sharded_data_tensor.reset_payload(
|
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||||
|
|
||||||
if not is_param_sharded:
|
if not is_param_sharded:
|
||||||
# We gather full fp16 param here
|
# We gather full fp16 param here
|
||||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
p.data = p.col_attr.sharded_data_tensor.payload
|
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def backward(self, loss: Tensor) -> None:
|
def backward(self, loss: Tensor) -> None:
|
||||||
|
@ -292,7 +292,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||||
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
||||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||||
p.col_attr.offload_grad = False
|
p.colo_attr.offload_grad = False
|
||||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||||
|
|
||||||
def _prepare_grads(self):
|
def _prepare_grads(self):
|
||||||
|
@ -301,7 +301,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
||||||
# If we change p.grad directly
|
# If we change p.grad directly
|
||||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||||
# We just set p.data = p.col_attr.saved_grad.payload here
|
# We just set p.data = p.colo_attr.saved_grad.payload here
|
||||||
p.data = p.col_attr.saved_grad.payload
|
p.data = p.colo_attr.saved_grad.payload
|
||||||
p.grad = p.col_attr.saved_grad.payload
|
p.grad = p.colo_attr.saved_grad.payload
|
||||||
p.col_attr.saved_grad.set_null()
|
p.colo_attr.saved_grad.set_null()
|
||||||
|
|
|
@ -61,22 +61,22 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
||||||
model = MoeModel()
|
model = MoeModel()
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
|
|
||||||
# the weights in the gate should be fp32
|
# the weights in the gate should be fp32
|
||||||
if 'gate' in name:
|
if 'gate' in name:
|
||||||
assert param.col_attr.sharded_data_tensor.dtype == torch.float32
|
assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
|
||||||
else:
|
else:
|
||||||
assert param.col_attr.sharded_data_tensor.dtype == torch.half
|
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||||
|
|
||||||
# the parameters in moe experts and its gate should not be sharded
|
# the parameters in moe experts and its gate should not be sharded
|
||||||
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
|
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
|
||||||
assert not param.col_attr.sharded_data_tensor.is_sharded
|
assert not param.colo_attr.sharded_data_tensor.is_sharded
|
||||||
else:
|
else:
|
||||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||||
|
|
||||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||||
|
|
||||||
|
|
||||||
def _run_dist(rank, world_size, port):
|
def _run_dist(rank, world_size, port):
|
||||||
|
|
|
@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||||
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
|
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
||||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||||
if rank >= len(chunks):
|
if rank >= len(chunks):
|
||||||
continue
|
continue
|
||||||
|
@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
|
||||||
if reuse_fp16_shard:
|
if reuse_fp16_shard:
|
||||||
zero_p = zero_p.data.to(p.device).float()
|
zero_p = zero_p.data.to(p.device).float()
|
||||||
else:
|
else:
|
||||||
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||||
if rank >= len(chunks):
|
if rank >= len(chunks):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -45,11 +45,11 @@ def run_model_test(init_device_type, shard_strategy_class):
|
||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
assert param.col_attr.sharded_data_tensor.dtype == torch.half
|
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||||
|
|
||||||
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
|
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
|
||||||
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
||||||
|
|
Loading…
Reference in New Issue