[fp8] add fp8 comm for low level zero

pull/5961/head
ver217 2024-08-02 11:12:12 +08:00
parent 5fd0592767
commit ae486ce005
3 changed files with 32 additions and 9 deletions

View File

@ -293,6 +293,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False, cpu_offload: bool = False,
master_weights: bool = True, master_weights: bool = True,
verbose: bool = False, verbose: bool = False,
fp8_communication: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
@ -315,6 +316,7 @@ class LowLevelZeroPlugin(DPPluginBase):
partition_grad=(stage == 2), partition_grad=(stage == 2),
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
master_weights=master_weights, master_weights=master_weights,
fp8_communication=fp8_communication,
) )
self.lora_enabled = False self.lora_enabled = False
self.verbose = verbose self.verbose = verbose

View File

@ -4,6 +4,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
class TensorBucket: class TensorBucket:
def __init__(self, size): def __init__(self, size):
@ -61,11 +63,14 @@ class TensorBucket:
for old, new in zip(self._bucket, unflattened_tensor_list): for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new) old.copy_(new)
def all_gather(self, group=None): def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten() flat = self.flatten()
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))] buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
dist.all_gather(buffers, flat, group=group) if fp8_communication:
unflat_buffers = [self.unflatten(buffer) for buffer in buffers] all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# transpose the list of list # transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers))) unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket): for unflat_shards, tensor in zip(unflat_buffers, self._bucket):

View File

@ -20,6 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
) )
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket from .bookkeeping import BucketStore, GradientStore, TensorBucket
@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights master_weights: bool = True, # master weights
fp8_communication: bool = False,
): ):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -123,6 +125,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._overlap_communication = overlap_communication self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype self._communication_dtype = communication_dtype
self._fp8_communication = fp8_communication
# gradient clipping # gradient clipping
self._clip_grad_norm = clip_grad_norm self._clip_grad_norm = clip_grad_norm
@ -323,6 +326,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
flat_grads = flat_grads.to(self._communication_dtype) flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads: if not self._partition_grads:
if self._fp8_communication:
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
else:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg) dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if flat_grads.dtype != grad_dtype: if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype) flat_grads = flat_grads.to(grad_dtype)
@ -333,6 +339,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
else: else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0]) recieved_grad = torch.zeros_like(flat_grads_list[0])
if self._fp8_communication:
reduce_scatter_fp8(
recieved_grad,
flat_grads_list,
group=bucket_store.torch_pg,
)
else:
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
if recieved_grad.dtype != grad_dtype: if recieved_grad.dtype != grad_dtype:
@ -553,18 +566,21 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
buffer_tensor = torch.empty_like( buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
) )
if self._fp8_communication:
all_gather_into_tensor_flat_fp8(buffer_tensor, param_to_gather, pg, fp8_format="e4m3")
else:
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue continue
try: try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError: except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg) self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty(): if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg) tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r""" r"""