[zero] sharded model support the reuse of fp16 shard (#495)

* sharded model supports reuse fp16 shard

* rename variable

* polish code

* polish code

* polish code
pull/504/head
ver217 2022-03-23 14:59:59 +08:00 committed by GitHub
parent f24b5ed201
commit 9ec1ce6ab1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 62 additions and 42 deletions

View File

@ -56,6 +56,8 @@ class CPUAdam(torch.optim.Optimizer):
bias_correction2,
loss_scale,
use_adamw=False):
# FIXME(ver217): remove the below line when replace torch adam with fused adam
grad = grad.float()
if loss_scale is not None:
grad.div_(loss_scale)

View File

@ -29,24 +29,22 @@ class ShardedModelV2(nn.Module):
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
:param module: A sharded module, which must be initialized by `ZeroInitContext`.
:type module: nn.Module
:param shard_strategy: A shard strategy to manage shard behavior.
:type shard_strategy: BaseShardStrategy
:param process_group: Data parallel process group, defaults to None
:type process_group: Optional[ProcessGroup], optional
:param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`.
:type reduce_scatter_process_group: Optional[ProcessGroup], optional
:param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25
:type reduce_scatter_bucket_size_mb: int, optional
:param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False
:type fp32_reduce_scatter: bool, optional
:param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None
:type offload_config: Optional[dict], optional
:param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0
:type gradient_predivide_factor: Optional[float], optional
:param use_memory_tracer: Whether to use memoty tracer, defaults to False
:type use_memory_tracer: bool, optional
Args:
module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
offload_config (Optional[dict], optional): We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload. Defaults to None.
gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
use_memory_tracer (bool, optional): Whether to use memoty tracer. Defaults to False.
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
"""
def __init__(self,
@ -58,7 +56,8 @@ class ShardedModelV2(nn.Module):
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
use_memory_tracer: bool = False):
use_memory_tracer: bool = False,
reuse_fp16_shard: bool = False):
super().__init__()
self.logger = get_dist_logger()
@ -97,8 +96,8 @@ class ShardedModelV2(nn.Module):
self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
for param in module.parameters():
# Init `offload_fp32_grad`
param.col_attr.offload_fp32_grad = self._cpu_offload
# Init `offload_grad`
param.col_attr.offload_grad = self._cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
@ -114,6 +113,7 @@ class ShardedModelV2(nn.Module):
self._require_backward_grad_sync: bool = True
self._cuda_margin_space = 0
self.reuse_fp16_shard = reuse_fp16_shard
@property
def cuda_margin_space(self):
@ -143,11 +143,7 @@ class ShardedModelV2(nn.Module):
for ophook in self._ophook_list:
ophook.post_iter()
@torch.no_grad()
def _post_backward_operations(self) -> None:
"""
The method includes operations required to be processed after backward
"""
def _update_memstats(self):
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector:
@ -160,6 +156,13 @@ class ShardedModelV2(nn.Module):
self._iter_cnter += 1
@torch.no_grad()
def _post_backward_operations(self) -> None:
"""
The method includes operations required to be processed after backward
"""
self._update_memstats()
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream):
@ -171,9 +174,11 @@ class ShardedModelV2(nn.Module):
self.reducer.free()
# In case some post bwd hook is not fired
if self.shard_param:
tensor_list = []
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group)
tensor_list.append(p.col_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
@ -191,13 +196,17 @@ class ShardedModelV2(nn.Module):
# If world size == 1 and sharded param,
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if p.col_attr.offload_fp32_grad:
if self.reuse_fp16_shard:
grad = p.col_attr.sharded_data_tensor.payload
else:
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if p.col_attr.offload_grad:
col_move_to_cpu(grad)
if p.col_attr.fp32_grad is not None:
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
grad = p.col_attr.fp32_grad
p.grad.data = grad.view(-1)
p.grad.data = grad
p.col_attr.fp16_grad = None
p.col_attr.fp32_grad = None
@ -250,11 +259,15 @@ class ShardedModelV2(nn.Module):
return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
reduced_grad = reduced_grad.view(-1)
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
param.col_attr.fp16_grad = reduced_grad.data
if self.reuse_fp16_shard:
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
param.col_attr.sharded_data_tensor.is_sharded = True
else:
param.col_attr.fp16_grad = reduced_grad.data
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],

View File

@ -224,5 +224,5 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.col_attr.offload_fp32_grad = False
p.col_attr.offload_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem

View File

@ -14,7 +14,7 @@ class ShardedParamV2(object):
self.fp16_grad: Optional[torch.Tensor] = None
self.fp32_grad: Optional[torch.Tensor] = None
# This attribute must be initialized in ShardedModel
self.offload_fp32_grad: bool = False
self.offload_grad: bool = False
# make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model.

View File

@ -16,7 +16,8 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
offload_config=None,
gradient_predivide_factor=1.0,
use_memory_tracer=False,
shard_strategy=TensorShardStrategy())
shard_strategy=TensorShardStrategy(),
reuse_fp16_shard=False)
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
initial_scale=2**5,
@ -116,10 +117,13 @@ def check_params_padding(model, zero_model, loose=False):
assert allclose(p, zero_p, loose=loose)
def check_sharded_params_padding(model, zero_model, loose=False):
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float()
else:
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue

View File

@ -18,7 +18,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_sharded_params_padding
from common import CONFIG, check_sharded_model_params
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@ -65,7 +65,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
zero_model = ShardedModelV2(zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0)
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=use_cpuadam)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
@ -92,7 +93,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
data, label = data.cuda(), label.cuda()
_run_step(apex_model, apex_optimizer, data, label, criterion, False)
_run_step(zero_model, sharded_optim, data, label, criterion, False)
check_sharded_params_padding(model, zero_model, loose=True)
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
for param in model.parameters():
assert not has_inf_or_nan(param)

View File

@ -16,7 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding)
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
def run_dist(rank, world_size, port, parallel_config):
@ -87,7 +87,7 @@ def run_dist(rank, world_size, port, parallel_config):
if parallel_config == MP_PARALLEL_CONFIG:
check_params(torch_model, colo_model, loose=True)
elif parallel_config == ZERO_PARALLEL_CONFIG:
check_sharded_params_padding(torch_model, colo_model, loose=True)
check_sharded_model_params(torch_model, colo_model, loose=True)
# FIXME: enable this test in next PR