Merge pull request #5961 from ver217/feature/zeor-fp8

[fp8] add fp8 comm for low level zero
pull/5963/head
Hanks 2024-08-02 20:38:58 +08:00 committed by GitHub
commit c297e21bea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 50 additions and 14 deletions

View File

@ -293,6 +293,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
@ -315,6 +316,7 @@ class LowLevelZeroPlugin(DPPluginBase):
partition_grad=(stage == 2),
cpu_offload=cpu_offload,
master_weights=master_weights,
fp8_communication=fp8_communication,
)
self.lora_enabled = False
self.verbose = verbose

View File

@ -4,6 +4,8 @@ import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
class TensorBucket:
def __init__(self, size):
@ -61,11 +63,14 @@ class TensorBucket:
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
def all_gather(self, group=None):
def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
dist.all_gather(buffers, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if fp8_communication:
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):

View File

@ -20,6 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket
@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
fp8_communication: bool = False,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -123,6 +125,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
self._fp8_communication = fp8_communication
# gradient clipping
self._clip_grad_norm = clip_grad_norm
@ -323,7 +326,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if self._fp8_communication:
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
else:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
@ -333,7 +339,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
if self._fp8_communication:
reduce_scatter_fp8(
recieved_grad,
flat_grads_list,
group=bucket_store.torch_pg,
)
else:
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
@ -553,18 +566,21 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
if self._fp8_communication:
all_gather_into_tensor_flat_fp8(buffer_tensor, param_to_gather, pg, fp8_format="e4m3")
else:
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""

View File

@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
return splited_grad
def exam_zero_1_2():
@parameterize("fp8_communication", [True, False])
def exam_zero_1_2(fp8_communication: bool):
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
@ -73,10 +74,18 @@ def exam_zero_1_2():
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True
zero1_optimizer,
overlap_communication=True,
initial_scale=128,
verbose=True,
fp8_communication=fp8_communication,
)
zero2_optimizer = LowLevelZeroOptimizer(
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128
zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=128,
fp8_communication=fp8_communication,
)
# create data
seed_all(2001 + local_rank)
@ -97,7 +106,10 @@ def exam_zero_1_2():
if g1 is None or g2 is None:
assert g1 is None and g2 is None
continue
assert torch.allclose(g1, g2)
if fp8_communication:
loose_close(g1, g2, dtype=torch.float16)
else:
assert torch.allclose(g1, g2)
# step
zero1_optimizer.step()
@ -105,7 +117,8 @@ def exam_zero_1_2():
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.allclose(z1p, z2p)
if not fp8_communication:
assert torch.allclose(z1p, z2p)
@parameterize("dtype", [torch.float16, torch.bfloat16])