[zero] sharded model support the reuse of fp16 shard (#495)

* sharded model supports reuse fp16 shard

* rename variable

* polish code

* polish code

* polish code
pull/504/head
ver217 2022-03-23 14:59:59 +08:00 committed by GitHub
parent f24b5ed201
commit 9ec1ce6ab1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 62 additions and 42 deletions

View File

@ -56,6 +56,8 @@ class CPUAdam(torch.optim.Optimizer):
bias_correction2, bias_correction2,
loss_scale, loss_scale,
use_adamw=False): use_adamw=False):
# FIXME(ver217): remove the below line when replace torch adam with fused adam
grad = grad.float()
if loss_scale is not None: if loss_scale is not None:
grad.div_(loss_scale) grad.div_(loss_scale)

View File

@ -29,24 +29,22 @@ class ShardedModelV2(nn.Module):
compared to classic data parallelism while the computational granularity and communication efficiency are retained. compared to classic data parallelism while the computational granularity and communication efficiency are retained.
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`. Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
:param module: A sharded module, which must be initialized by `ZeroInitContext`. Args:
:type module: nn.Module module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
:param shard_strategy: A shard strategy to manage shard behavior. shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
:type shard_strategy: BaseShardStrategy process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
:param process_group: Data parallel process group, defaults to None reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
:type process_group: Optional[ProcessGroup], optional Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
:param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`. reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
:type reduce_scatter_process_group: Optional[ProcessGroup], optional fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
:param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25 offload_config (Optional[dict], optional): We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload. Defaults to None.
:type reduce_scatter_bucket_size_mb: int, optional gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
:param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False use_memory_tracer (bool, optional): Whether to use memoty tracer. Defaults to False.
:type fp32_reduce_scatter: bool, optional reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
:param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
:type offload_config: Optional[dict], optional In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
:param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0 We find that PyTorch's optimizers don't support mixed precision,
:type gradient_predivide_factor: Optional[float], optional so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
:param use_memory_tracer: Whether to use memoty tracer, defaults to False
:type use_memory_tracer: bool, optional
""" """
def __init__(self, def __init__(self,
@ -58,7 +56,8 @@ class ShardedModelV2(nn.Module):
fp32_reduce_scatter: bool = False, fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None, offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0, gradient_predivide_factor: Optional[float] = 1.0,
use_memory_tracer: bool = False): use_memory_tracer: bool = False,
reuse_fp16_shard: bool = False):
super().__init__() super().__init__()
self.logger = get_dist_logger() self.logger = get_dist_logger()
@ -97,8 +96,8 @@ class ShardedModelV2(nn.Module):
self.fp32_reduce_scatter = fp32_reduce_scatter self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
for param in module.parameters(): for param in module.parameters():
# Init `offload_fp32_grad` # Init `offload_grad`
param.col_attr.offload_fp32_grad = self._cpu_offload param.col_attr.offload_grad = self._cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor # So we use 1.0 as the default gradient_predivide_factor
@ -114,6 +113,7 @@ class ShardedModelV2(nn.Module):
self._require_backward_grad_sync: bool = True self._require_backward_grad_sync: bool = True
self._cuda_margin_space = 0 self._cuda_margin_space = 0
self.reuse_fp16_shard = reuse_fp16_shard
@property @property
def cuda_margin_space(self): def cuda_margin_space(self):
@ -143,11 +143,7 @@ class ShardedModelV2(nn.Module):
for ophook in self._ophook_list: for ophook in self._ophook_list:
ophook.post_iter() ophook.post_iter()
@torch.no_grad() def _update_memstats(self):
def _post_backward_operations(self) -> None:
"""
The method includes operations required to be processed after backward
"""
if self._iter_cnter == 0 and self._memstats_collector: if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection() self._memstats_collector.finish_collection()
if self._memstats_collector: if self._memstats_collector:
@ -160,6 +156,13 @@ class ShardedModelV2(nn.Module):
self._iter_cnter += 1 self._iter_cnter += 1
@torch.no_grad()
def _post_backward_operations(self) -> None:
"""
The method includes operations required to be processed after backward
"""
self._update_memstats()
if self._require_backward_grad_sync: if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream. # Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream): with torch.cuda.stream(self.comm_stream):
@ -171,9 +174,11 @@ class ShardedModelV2(nn.Module):
self.reducer.free() self.reducer.free()
# In case some post bwd hook is not fired # In case some post bwd hook is not fired
if self.shard_param: if self.shard_param:
tensor_list = []
for p in self.module.parameters(): for p in self.module.parameters():
if not p.col_attr.param_is_sharded: if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group) tensor_list.append(p.col_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
for p in self.module.parameters(): for p in self.module.parameters():
p.col_attr.bwd_count = 0 p.col_attr.bwd_count = 0
if not p.requires_grad: if not p.requires_grad:
@ -191,13 +196,17 @@ class ShardedModelV2(nn.Module):
# If world size == 1 and sharded param, # If world size == 1 and sharded param,
# the shape `grad` is the same as unsharded param # the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard # So we can just use `view(-1)` to ensure grad is a flat tensor shard
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad) if self.reuse_fp16_shard:
if p.col_attr.offload_fp32_grad: grad = p.col_attr.sharded_data_tensor.payload
else:
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if p.col_attr.offload_grad:
col_move_to_cpu(grad) col_move_to_cpu(grad)
if p.col_attr.fp32_grad is not None: if p.col_attr.fp32_grad is not None:
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad)) p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
grad = p.col_attr.fp32_grad grad = p.col_attr.fp32_grad
p.grad.data = grad.view(-1) p.grad.data = grad
p.col_attr.fp16_grad = None p.col_attr.fp16_grad = None
p.col_attr.fp32_grad = None p.col_attr.fp32_grad = None
@ -250,11 +259,15 @@ class ShardedModelV2(nn.Module):
return empty_grad return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
reduced_grad = reduced_grad.view(-1)
if self.gradient_postdivide_factor > 1: if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP. # Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor) reduced_grad.data.div_(self.gradient_postdivide_factor)
if self.reuse_fp16_shard:
param.col_attr.fp16_grad = reduced_grad.data param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
param.col_attr.sharded_data_tensor.is_sharded = True
else:
param.col_attr.fp16_grad = reduced_grad.data
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()], self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],

View File

@ -224,5 +224,5 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device()) self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device()) p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.col_attr.offload_fp32_grad = False p.col_attr.offload_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem fp32_shards_used_cuda_margin_mem += shard_mem

View File

@ -14,7 +14,7 @@ class ShardedParamV2(object):
self.fp16_grad: Optional[torch.Tensor] = None self.fp16_grad: Optional[torch.Tensor] = None
self.fp32_grad: Optional[torch.Tensor] = None self.fp32_grad: Optional[torch.Tensor] = None
# This attribute must be initialized in ShardedModel # This attribute must be initialized in ShardedModel
self.offload_fp32_grad: bool = False self.offload_grad: bool = False
# make sure the shared param is the only owner of payload # make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model. # The param.data maybe used to init the other part of the model.

View File

@ -16,7 +16,8 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
offload_config=None, offload_config=None,
gradient_predivide_factor=1.0, gradient_predivide_factor=1.0,
use_memory_tracer=False, use_memory_tracer=False,
shard_strategy=TensorShardStrategy()) shard_strategy=TensorShardStrategy(),
reuse_fp16_shard=False)
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False, _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
initial_scale=2**5, initial_scale=2**5,
@ -116,10 +117,13 @@ def check_params_padding(model, zero_model, loose=False):
assert allclose(p, zero_p, loose=loose) assert allclose(p, zero_p, loose=loose)
def check_sharded_params_padding(model, zero_model, loose=False): def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
rank = dist.get_rank() rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()): for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float() if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float()
else:
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size()) chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks): if rank >= len(chunks):
continue continue

View File

@ -18,7 +18,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_sharded_params_padding from common import CONFIG, check_sharded_model_params
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@ -65,7 +65,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
zero_model = ShardedModelV2(zero_model, zero_model = ShardedModelV2(zero_model,
shard_strategy, shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None, offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0) use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=use_cpuadam)
model = model_builder(checkpoint=True).half() model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model) col_model_deepcopy(zero_model, model)
@ -92,7 +93,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
data, label = data.cuda(), label.cuda() data, label = data.cuda(), label.cuda()
_run_step(apex_model, apex_optimizer, data, label, criterion, False) _run_step(apex_model, apex_optimizer, data, label, criterion, False)
_run_step(zero_model, sharded_optim, data, label, criterion, False) _run_step(zero_model, sharded_optim, data, label, criterion, False)
check_sharded_params_padding(model, zero_model, loose=True) check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
for param in model.parameters(): for param in model.parameters():
assert not has_inf_or_nan(param) assert not has_inf_or_nan(param)

View File

@ -16,7 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding) from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
def run_dist(rank, world_size, port, parallel_config): def run_dist(rank, world_size, port, parallel_config):
@ -87,7 +87,7 @@ def run_dist(rank, world_size, port, parallel_config):
if parallel_config == MP_PARALLEL_CONFIG: if parallel_config == MP_PARALLEL_CONFIG:
check_params(torch_model, colo_model, loose=True) check_params(torch_model, colo_model, loose=True)
elif parallel_config == ZERO_PARALLEL_CONFIG: elif parallel_config == ZERO_PARALLEL_CONFIG:
check_sharded_params_padding(torch_model, colo_model, loose=True) check_sharded_model_params(torch_model, colo_model, loose=True)
# FIXME: enable this test in next PR # FIXME: enable this test in next PR