mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5961 from ver217/feature/zeor-fp8
[fp8] add fp8 comm for low level zeropull/5963/head
commit
c297e21bea
|
@ -293,6 +293,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
|
@ -315,6 +316,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
partition_grad=(stage == 2),
|
||||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
|
|
|
@ -4,6 +4,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
def __init__(self, size):
|
||||
|
@ -61,11 +63,14 @@ class TensorBucket:
|
|||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
||||
|
||||
def all_gather(self, group=None):
|
||||
def all_gather(self, group=None, fp8_communication: bool = False):
|
||||
flat = self.flatten()
|
||||
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
|
||||
dist.all_gather(buffers, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
|
||||
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
|
||||
if fp8_communication:
|
||||
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
|
||||
else:
|
||||
dist.all_gather_into_tensor(buffer, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
|
||||
# transpose the list of list
|
||||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||
|
|
|
@ -20,6 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
|||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
||||
|
@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
|
||||
|
@ -123,6 +125,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
self._fp8_communication = fp8_communication
|
||||
|
||||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
@ -323,7 +326,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
flat_grads = flat_grads.to(self._communication_dtype)
|
||||
|
||||
if not self._partition_grads:
|
||||
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
|
||||
if self._fp8_communication:
|
||||
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
|
||||
else:
|
||||
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
|
||||
|
@ -333,7 +339,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
|
||||
if self._fp8_communication:
|
||||
reduce_scatter_fp8(
|
||||
recieved_grad,
|
||||
flat_grads_list,
|
||||
group=bucket_store.torch_pg,
|
||||
)
|
||||
else:
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
|
||||
|
||||
if recieved_grad.dtype != grad_dtype:
|
||||
recieved_grad = recieved_grad.to(grad_dtype)
|
||||
|
@ -553,18 +566,21 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
buffer_tensor = torch.empty_like(
|
||||
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
|
||||
)
|
||||
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
|
||||
if self._fp8_communication:
|
||||
all_gather_into_tensor_flat_fp8(buffer_tensor, param_to_gather, pg, fp8_format="e4m3")
|
||||
else:
|
||||
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
|
||||
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
self.pg_to_tensor_bucket[pg].all_gather(pg)
|
||||
self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication)
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg)
|
||||
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
||||
|
||||
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
|
|
|
@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
|
|||
return splited_grad
|
||||
|
||||
|
||||
def exam_zero_1_2():
|
||||
@parameterize("fp8_communication", [True, False])
|
||||
def exam_zero_1_2(fp8_communication: bool):
|
||||
"""
|
||||
In this test, we want to test whether zero stage 1 and 2
|
||||
deliver the same numerical results despite different communication
|
||||
|
@ -73,10 +74,18 @@ def exam_zero_1_2():
|
|||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
||||
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
||||
zero1_optimizer = LowLevelZeroOptimizer(
|
||||
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True
|
||||
zero1_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=128,
|
||||
verbose=True,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(
|
||||
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128
|
||||
zero2_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=128,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
# create data
|
||||
seed_all(2001 + local_rank)
|
||||
|
@ -97,7 +106,10 @@ def exam_zero_1_2():
|
|||
if g1 is None or g2 is None:
|
||||
assert g1 is None and g2 is None
|
||||
continue
|
||||
assert torch.allclose(g1, g2)
|
||||
if fp8_communication:
|
||||
loose_close(g1, g2, dtype=torch.float16)
|
||||
else:
|
||||
assert torch.allclose(g1, g2)
|
||||
|
||||
# step
|
||||
zero1_optimizer.step()
|
||||
|
@ -105,7 +117,8 @@ def exam_zero_1_2():
|
|||
|
||||
# check updated param
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
assert torch.allclose(z1p, z2p)
|
||||
if not fp8_communication:
|
||||
assert torch.allclose(z1p, z2p)
|
||||
|
||||
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
|
|
Loading…
Reference in New Issue