mirror of https://github.com/hpcaitech/ColossalAI
[zero] sharded model support the reuse of fp16 shard (#495)
* sharded model supports reuse fp16 shard * rename variable * polish code * polish code * polish codepull/504/head
parent
f24b5ed201
commit
9ec1ce6ab1
|
@ -56,6 +56,8 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
bias_correction2,
|
bias_correction2,
|
||||||
loss_scale,
|
loss_scale,
|
||||||
use_adamw=False):
|
use_adamw=False):
|
||||||
|
# FIXME(ver217): remove the below line when replace torch adam with fused adam
|
||||||
|
grad = grad.float()
|
||||||
if loss_scale is not None:
|
if loss_scale is not None:
|
||||||
grad.div_(loss_scale)
|
grad.div_(loss_scale)
|
||||||
|
|
||||||
|
|
|
@ -29,24 +29,22 @@ class ShardedModelV2(nn.Module):
|
||||||
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
|
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
|
||||||
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
|
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
|
||||||
|
|
||||||
:param module: A sharded module, which must be initialized by `ZeroInitContext`.
|
Args:
|
||||||
:type module: nn.Module
|
module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
|
||||||
:param shard_strategy: A shard strategy to manage shard behavior.
|
shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
|
||||||
:type shard_strategy: BaseShardStrategy
|
process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
|
||||||
:param process_group: Data parallel process group, defaults to None
|
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
|
||||||
:type process_group: Optional[ProcessGroup], optional
|
Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
|
||||||
:param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`.
|
reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
|
||||||
:type reduce_scatter_process_group: Optional[ProcessGroup], optional
|
fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
|
||||||
:param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25
|
offload_config (Optional[dict], optional): We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload. Defaults to None.
|
||||||
:type reduce_scatter_bucket_size_mb: int, optional
|
gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
|
||||||
:param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False
|
use_memory_tracer (bool, optional): Whether to use memoty tracer. Defaults to False.
|
||||||
:type fp32_reduce_scatter: bool, optional
|
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
|
||||||
:param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None
|
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
|
||||||
:type offload_config: Optional[dict], optional
|
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
||||||
:param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0
|
We find that PyTorch's optimizers don't support mixed precision,
|
||||||
:type gradient_predivide_factor: Optional[float], optional
|
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
|
||||||
:param use_memory_tracer: Whether to use memoty tracer, defaults to False
|
|
||||||
:type use_memory_tracer: bool, optional
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -58,7 +56,8 @@ class ShardedModelV2(nn.Module):
|
||||||
fp32_reduce_scatter: bool = False,
|
fp32_reduce_scatter: bool = False,
|
||||||
offload_config: Optional[dict] = None,
|
offload_config: Optional[dict] = None,
|
||||||
gradient_predivide_factor: Optional[float] = 1.0,
|
gradient_predivide_factor: Optional[float] = 1.0,
|
||||||
use_memory_tracer: bool = False):
|
use_memory_tracer: bool = False,
|
||||||
|
reuse_fp16_shard: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
|
@ -97,8 +96,8 @@ class ShardedModelV2(nn.Module):
|
||||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
# Init `offload_fp32_grad`
|
# Init `offload_grad`
|
||||||
param.col_attr.offload_fp32_grad = self._cpu_offload
|
param.col_attr.offload_grad = self._cpu_offload
|
||||||
|
|
||||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||||
# So we use 1.0 as the default gradient_predivide_factor
|
# So we use 1.0 as the default gradient_predivide_factor
|
||||||
|
@ -114,6 +113,7 @@ class ShardedModelV2(nn.Module):
|
||||||
self._require_backward_grad_sync: bool = True
|
self._require_backward_grad_sync: bool = True
|
||||||
|
|
||||||
self._cuda_margin_space = 0
|
self._cuda_margin_space = 0
|
||||||
|
self.reuse_fp16_shard = reuse_fp16_shard
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cuda_margin_space(self):
|
def cuda_margin_space(self):
|
||||||
|
@ -143,11 +143,7 @@ class ShardedModelV2(nn.Module):
|
||||||
for ophook in self._ophook_list:
|
for ophook in self._ophook_list:
|
||||||
ophook.post_iter()
|
ophook.post_iter()
|
||||||
|
|
||||||
@torch.no_grad()
|
def _update_memstats(self):
|
||||||
def _post_backward_operations(self) -> None:
|
|
||||||
"""
|
|
||||||
The method includes operations required to be processed after backward
|
|
||||||
"""
|
|
||||||
if self._iter_cnter == 0 and self._memstats_collector:
|
if self._iter_cnter == 0 and self._memstats_collector:
|
||||||
self._memstats_collector.finish_collection()
|
self._memstats_collector.finish_collection()
|
||||||
if self._memstats_collector:
|
if self._memstats_collector:
|
||||||
|
@ -160,6 +156,13 @@ class ShardedModelV2(nn.Module):
|
||||||
|
|
||||||
self._iter_cnter += 1
|
self._iter_cnter += 1
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _post_backward_operations(self) -> None:
|
||||||
|
"""
|
||||||
|
The method includes operations required to be processed after backward
|
||||||
|
"""
|
||||||
|
self._update_memstats()
|
||||||
|
|
||||||
if self._require_backward_grad_sync:
|
if self._require_backward_grad_sync:
|
||||||
# Flush any unreduced buckets in the post_backward stream.
|
# Flush any unreduced buckets in the post_backward stream.
|
||||||
with torch.cuda.stream(self.comm_stream):
|
with torch.cuda.stream(self.comm_stream):
|
||||||
|
@ -171,9 +174,11 @@ class ShardedModelV2(nn.Module):
|
||||||
self.reducer.free()
|
self.reducer.free()
|
||||||
# In case some post bwd hook is not fired
|
# In case some post bwd hook is not fired
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
|
tensor_list = []
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if not p.col_attr.param_is_sharded:
|
if not p.col_attr.param_is_sharded:
|
||||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group)
|
tensor_list.append(p.col_attr.sharded_data_tensor)
|
||||||
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
p.col_attr.bwd_count = 0
|
p.col_attr.bwd_count = 0
|
||||||
if not p.requires_grad:
|
if not p.requires_grad:
|
||||||
|
@ -191,13 +196,17 @@ class ShardedModelV2(nn.Module):
|
||||||
# If world size == 1 and sharded param,
|
# If world size == 1 and sharded param,
|
||||||
# the shape `grad` is the same as unsharded param
|
# the shape `grad` is the same as unsharded param
|
||||||
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
||||||
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
|
if self.reuse_fp16_shard:
|
||||||
if p.col_attr.offload_fp32_grad:
|
grad = p.col_attr.sharded_data_tensor.payload
|
||||||
|
else:
|
||||||
|
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
|
||||||
|
if p.col_attr.offload_grad:
|
||||||
col_move_to_cpu(grad)
|
col_move_to_cpu(grad)
|
||||||
if p.col_attr.fp32_grad is not None:
|
if p.col_attr.fp32_grad is not None:
|
||||||
|
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||||
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
|
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
|
||||||
grad = p.col_attr.fp32_grad
|
grad = p.col_attr.fp32_grad
|
||||||
p.grad.data = grad.view(-1)
|
p.grad.data = grad
|
||||||
p.col_attr.fp16_grad = None
|
p.col_attr.fp16_grad = None
|
||||||
p.col_attr.fp32_grad = None
|
p.col_attr.fp32_grad = None
|
||||||
|
|
||||||
|
@ -250,11 +259,15 @@ class ShardedModelV2(nn.Module):
|
||||||
return empty_grad
|
return empty_grad
|
||||||
|
|
||||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||||
|
reduced_grad = reduced_grad.view(-1)
|
||||||
if self.gradient_postdivide_factor > 1:
|
if self.gradient_postdivide_factor > 1:
|
||||||
# Average grad by world_size for consistency with PyTorch DDP.
|
# Average grad by world_size for consistency with PyTorch DDP.
|
||||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||||
|
if self.reuse_fp16_shard:
|
||||||
param.col_attr.fp16_grad = reduced_grad.data
|
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
||||||
|
param.col_attr.sharded_data_tensor.is_sharded = True
|
||||||
|
else:
|
||||||
|
param.col_attr.fp16_grad = reduced_grad.data
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||||
|
|
|
@ -224,5 +224,5 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||||
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
||||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||||
p.col_attr.offload_fp32_grad = False
|
p.col_attr.offload_grad = False
|
||||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||||
|
|
|
@ -14,7 +14,7 @@ class ShardedParamV2(object):
|
||||||
self.fp16_grad: Optional[torch.Tensor] = None
|
self.fp16_grad: Optional[torch.Tensor] = None
|
||||||
self.fp32_grad: Optional[torch.Tensor] = None
|
self.fp32_grad: Optional[torch.Tensor] = None
|
||||||
# This attribute must be initialized in ShardedModel
|
# This attribute must be initialized in ShardedModel
|
||||||
self.offload_fp32_grad: bool = False
|
self.offload_grad: bool = False
|
||||||
|
|
||||||
# make sure the shared param is the only owner of payload
|
# make sure the shared param is the only owner of payload
|
||||||
# The param.data maybe used to init the other part of the model.
|
# The param.data maybe used to init the other part of the model.
|
||||||
|
|
|
@ -16,7 +16,8 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
|
||||||
offload_config=None,
|
offload_config=None,
|
||||||
gradient_predivide_factor=1.0,
|
gradient_predivide_factor=1.0,
|
||||||
use_memory_tracer=False,
|
use_memory_tracer=False,
|
||||||
shard_strategy=TensorShardStrategy())
|
shard_strategy=TensorShardStrategy(),
|
||||||
|
reuse_fp16_shard=False)
|
||||||
|
|
||||||
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
|
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
|
||||||
initial_scale=2**5,
|
initial_scale=2**5,
|
||||||
|
@ -116,10 +117,13 @@ def check_params_padding(model, zero_model, loose=False):
|
||||||
assert allclose(p, zero_p, loose=loose)
|
assert allclose(p, zero_p, loose=loose)
|
||||||
|
|
||||||
|
|
||||||
def check_sharded_params_padding(model, zero_model, loose=False):
|
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||||
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
if reuse_fp16_shard:
|
||||||
|
zero_p = zero_p.data.to(p.device).float()
|
||||||
|
else:
|
||||||
|
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||||
if rank >= len(chunks):
|
if rank >= len(chunks):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -18,7 +18,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from common import CONFIG, check_sharded_params_padding
|
from common import CONFIG, check_sharded_model_params
|
||||||
|
|
||||||
|
|
||||||
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||||
|
@ -65,7 +65,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||||
zero_model = ShardedModelV2(zero_model,
|
zero_model = ShardedModelV2(zero_model,
|
||||||
shard_strategy,
|
shard_strategy,
|
||||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0)
|
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||||
|
reuse_fp16_shard=use_cpuadam)
|
||||||
|
|
||||||
model = model_builder(checkpoint=True).half()
|
model = model_builder(checkpoint=True).half()
|
||||||
col_model_deepcopy(zero_model, model)
|
col_model_deepcopy(zero_model, model)
|
||||||
|
@ -92,7 +93,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||||
data, label = data.cuda(), label.cuda()
|
data, label = data.cuda(), label.cuda()
|
||||||
_run_step(apex_model, apex_optimizer, data, label, criterion, False)
|
_run_step(apex_model, apex_optimizer, data, label, criterion, False)
|
||||||
_run_step(zero_model, sharded_optim, data, label, criterion, False)
|
_run_step(zero_model, sharded_optim, data, label, criterion, False)
|
||||||
check_sharded_params_padding(model, zero_model, loose=True)
|
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
assert not has_inf_or_nan(param)
|
assert not has_inf_or_nan(param)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding)
|
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, parallel_config):
|
def run_dist(rank, world_size, port, parallel_config):
|
||||||
|
@ -87,7 +87,7 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||||
if parallel_config == MP_PARALLEL_CONFIG:
|
if parallel_config == MP_PARALLEL_CONFIG:
|
||||||
check_params(torch_model, colo_model, loose=True)
|
check_params(torch_model, colo_model, loose=True)
|
||||||
elif parallel_config == ZERO_PARALLEL_CONFIG:
|
elif parallel_config == ZERO_PARALLEL_CONFIG:
|
||||||
check_sharded_params_padding(torch_model, colo_model, loose=True)
|
check_sharded_model_params(torch_model, colo_model, loose=True)
|
||||||
|
|
||||||
|
|
||||||
# FIXME: enable this test in next PR
|
# FIXME: enable this test in next PR
|
||||||
|
|
Loading…
Reference in New Issue