Browse Source

[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5978/head
Hanks 4 months ago committed by GitHub
parent
commit
b480eec738
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      colossalai/booster/plugin/gemini_plugin.py
  2. 7
      colossalai/booster/plugin/torch_ddp_plugin.py
  3. 15
      colossalai/booster/plugin/torch_fsdp_plugin.py
  4. 207
      colossalai/quantization/fp8.py
  5. 112
      colossalai/quantization/utils.py
  6. 15
      colossalai/zero/gemini/chunk/chunk.py
  7. 4
      colossalai/zero/gemini/chunk/manager.py
  8. 3
      colossalai/zero/gemini/gemini_ddp.py
  9. 17
      examples/language/bert/finetune.py
  10. 2
      examples/language/gpt/hybridparallelism/finetune.py
  11. 12
      examples/language/llama/benchmark.py
  12. 26
      tests/test_fp8/test_fp8_cast.py
  13. 87
      tests/test_fp8/test_fp8_ddp_comm_hook.py
  14. 107
      tests/test_fp8/test_fp8_fsdp_comm_hook.py

2
colossalai/booster/plugin/gemini_plugin.py

@ -364,6 +364,7 @@ class GeminiPlugin(DPPluginBase):
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
@ -395,6 +396,7 @@ class GeminiPlugin(DPPluginBase):
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,

7
colossalai/booster/plugin/torch_ddp_plugin.py

@ -177,6 +177,7 @@ class TorchDDPPlugin(DPPluginBase):
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
self.ddp_kwargs = dict(
@ -187,6 +188,7 @@ class TorchDDPPlugin(DPPluginBase):
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.fp8_communication = fp8_communication
def support_no_sync(self) -> bool:
return True
@ -226,6 +228,11 @@ class TorchDDPPlugin(DPPluginBase):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
if self.fp8_communication:
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async
model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:

15
colossalai/booster/plugin/torch_fsdp_plugin.py

@ -298,6 +298,7 @@ class TorchFSDPPlugin(DPPluginBase):
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
fp8_communication: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(
@ -311,6 +312,7 @@ class TorchFSDPPlugin(DPPluginBase):
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
self.fp8_communication = fp8_communication
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@ -347,6 +349,19 @@ class TorchFSDPPlugin(DPPluginBase):
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if self.fp8_communication:
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
patch_fsdp_params_comm_hook()
from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook
fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook
fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(

207
colossalai/quantization/fp8.py

@ -15,6 +15,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
fp8_format: e4m3 or e5m2
Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
@ -29,12 +30,13 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
else:
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale
scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv
@ -185,7 +187,11 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
return
assert "hidden_states" in inp, "required by pipeline parallelism."
assert (
inp["hidden_states"].size(-1) % 2 == 0
), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16"
inp_tensor = inp["hidden_states"]
inp_dtype = inp_tensor.dtype
min_val, max_val = inp_tensor.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs())
@ -206,6 +212,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
inp["fp8_scale"] = scale.float().reciprocal()
inp["dtype"] = torch.zeros_like(scale).to(inp_dtype)
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
@ -230,10 +237,11 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
else:
raise TypeError("Only float16, bfloat16 are implemented.")
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale
if del_metadata:
del inp["fp8_scale"]
del inp["dtype"]
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
@ -273,6 +281,199 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
output.data = summed_out
def fp8_compress_ddp_grad_comm_hook_async(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
fp8_format: str = "e5m2",
) -> torch.futures.Future[torch.Tensor]:
"""
Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size.
This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it
by the process group size.
Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back
to the input data type (such as ``float32``).
Example::
>>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
input_tensor = bucket.buffer()
world_size = dist.get_world_size()
input_type = input_tensor.dtype
input_device = input_tensor.device
flat_padded_x = input_tensor.flatten()
if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
inp = ret.view(torch.uint8)
output_chunks_single = torch.empty_like(inp)
split_sizes = [inp.numel() // world_size for _ in range(world_size)]
fut0 = dist.all_to_all_single(
output_chunks_single,
inp,
output_split_sizes=split_sizes,
input_split_sizes=split_sizes,
group=group_to_use,
async_op=True,
).get_future()
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
fut1 = dist.all_gather_into_tensor(
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
).get_future()
all_to_all_fut = torch.futures.collect_all([fut0, fut1])
def sum_and_allgather(fut):
output_chunks_single = fut.value()[0].wait()[0]
scale_list_single = fut.value()[1].wait()[0]
output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0))
scale_list = scale_list_single.chunk(world_size, dim=0)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
summed_out.div_(world_size)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8)
fut2 = dist.all_gather_into_tensor(
tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True
).get_future()
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
fut3 = dist.all_gather_into_tensor(
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
).get_future()
fut_combined2 = torch.futures.collect_all([fut2, fut3])
return fut_combined2
def decompress(fut):
tensor_list_single = fut.value().wait()[0].value()[0]
scale_list_single = fut.value().wait()[1].value()[0]
tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0))
scale_list = scale_list_single.chunk(world_size, dim=0)
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
out = torch.cat(tensor_list, dim=0)
input_tensor_size = input_tensor.numel()
input_shape = input_tensor.shape
out = out[:input_tensor_size]
input_tensor.copy_(out.view(input_shape).to(input_type))
return input_tensor
return all_to_all_fut.then(sum_and_allgather).then(decompress)
def fp8_compress_ddp_grad_comm_hook_sync(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
fp8_format="e5m2",
) -> torch.futures.Future[torch.Tensor]:
"""
Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized.
This breaks the overlapping between allreduce communication and backward compuation.
This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization.
For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead.
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync)
"""
buffer = bucket.buffer()
all_reduce_fp8(buffer, fp8_format=fp8_format)
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
fut.set_result(bucket.buffer())
return fut
def fp8_compress_fsdp_grad_comm_hook(
state: object,
unsharded_gradient_flattened: torch.Tensor,
sharded_gradient: torch.Tensor,
group=None,
fp8_format="e5m2",
) -> None:
"""
This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic
by using all_to_all and all_gather among the process group.
Example::
>>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
"""
grad = unsharded_gradient_flattened
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
input_type = grad.dtype
input_device = grad.device
world_size = dist.get_world_size(group=group)
grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format)
uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8)
dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0))
sharded_gradient.zero_()
for tensor, scale in zip(buffer_list, scale_list):
sharded_gradient += cast_from_fp8(tensor, scale, input_type)
def fp8_compress_fsdp_params_comm_hook(
state: object,
padded_unsharded_flat_param: torch.Tensor,
sharded_flat_param: torch.Tensor,
group=None,
fp8_format="e5m2",
) -> None:
"""
This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook.
Example::
>>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
"""
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
inp = sharded_flat_param
out = padded_unsharded_flat_param
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group)
scale = fp8_max / per_tensor_max
fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8)
fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device)
dist.all_gather_into_tensor(
fp8_out,
fp8_sharded_flat_param,
group=group,
)
padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype))
def split_chunk_by_channel(
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
):
@ -342,7 +543,7 @@ def all_gather_into_tensor_flat_fp8(
scale_inv = 1.0 / scale
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
numel = np.prod(output_shape)
numel = output_shape.numel()
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1))

112
colossalai/quantization/utils.py

@ -0,0 +1,112 @@
import torch
import torch.distributed as dist
from packaging import version
from torch import Tensor
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream
from torch.distributed.utils import _p_assert
def _all_gather_flat_param(
self,
padded_unsharded_flat_param: Tensor,
) -> Tensor:
"""
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
Then switch to use the all-gathered tensor.
"""
_p_assert(
hasattr(self, "process_group") and hasattr(self, "world_size"),
"Expects a process group and world size to have been set via `shard()`",
)
sharded_flat_param = self.flat_param.data
expected_numel = sharded_flat_param.numel() * self.world_size
_p_assert(
padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
)
pg = self._fake_process_group if self._use_fake_all_gather else self.process_group
# HACK this should be handled by C10D
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg)))
work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
else:
if self._comm_hook is None:
dist.all_gather_into_tensor(
padded_unsharded_flat_param,
sharded_flat_param,
pg,
)
else:
self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg)
if self._offload_params:
# In case of offloading, `flat_param.data` (i.e. sharded param) is
# created on the pre-unshard stream. We need to hand it over to the
# unshard stream for all-gather
_no_dispatch_record_stream(
sharded_flat_param,
self._device_handle.current_stream(), # unshard_stream
)
return padded_unsharded_flat_param
def register_params_comm_hook(self, state: object, hook: callable):
"""Register a communication hook for FlatParamHandle.
This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards
parameters across multiple workers.
.. warning ::
FSDP communication hook should be registered before running an initial forward pass
and only once.
Args:
state (object): Passed to the hook to maintain any state information during the training process.
hook (Callable): Callable, which has one of the following signatures:
1) ``hook: Callable[torch.Tensor] -> None``:
This function takes in a Python tensor, which represents
the full, flattened, unsharded gradient with respect to all variables
corresponding to the model this FSDP unit is wrapping
(that are not wrapped by other FSDP sub-units).
It then performs all necessary processing and returns ``None``;
2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
This function takes in two Python tensors, the first one represents
the full, flattened, unsharded gradient with respect to all variables
corresponding to the model this FSDP unit is wrapping
(that are not wrapped by other FSDP sub-units). The latter
represents a pre-sized tensor to store a chunk of a sharded gradient after
reduction.
In both cases, callable performs all necessary processing and returns ``None``.
Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
Callables with signature 2 are expected to handle gradient communication for sharded cases.
"""
if not self.check_is_root():
raise AssertionError("register_comm_hook can only be called on a root instance.")
# if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
# raise AssertionError(
# f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
# )
if self._handle._comm_hook is not None:
raise AssertionError("A communication hook is already registered")
if not callable(hook):
raise ValueError(f"The communication hook must be callable but got {hook}")
self._handle._comm_hook = hook
self._handle._comm_hook_state = state
def patch_fsdp_params_comm_hook():
if version.parse(torch.__version__) >= version.parse("2.2.0"):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._flat_param import FlatParamHandle
FlatParamHandle._comm_hook = None
FlatParamHandle._comm_hook_state = None
FlatParamHandle._all_gather_flat_param = _all_gather_flat_param
FSDP.register_params_comm_hook = register_params_comm_hook
else:
raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.")

15
colossalai/zero/gemini/chunk/chunk.py

@ -166,6 +166,7 @@ class Chunk:
self.grad_chunk = None
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
self.grad_reduce_work = None
self.fp8_communication = False
@property
def memory_usage(self) -> Dict[str, int]:
@ -521,9 +522,17 @@ class Chunk:
alloc_storage(self.cuda_global_chunk)
assert self.cuda_global_chunk.is_contiguous()
work = dist.all_gather_into_tensor(
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
)
if self.fp8_communication:
assert async_op == False, "fp8 all-gather does not support async_op!"
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
work = all_gather_into_tensor_flat_fp8(
self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg
)
else:
work = dist.all_gather_into_tensor(
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
)
self.cuda_shard = None
self.is_gathered = True

4
colossalai/zero/gemini/chunk/manager.py

@ -26,6 +26,7 @@ class ChunkManager:
init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True,
max_prefetch: int = 0,
fp8_communication: bool = False,
) -> None:
self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
@ -44,6 +45,7 @@ class ChunkManager:
self.accumulating_grads = False
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
self.fp8_communication = fp8_communication
def register_tensor(
self,
@ -101,6 +103,8 @@ class ChunkManager:
extra_dp_group=extra_dp_group,
**chunk_kwargs,
)
if self.fp8_communication:
chunk.fp8_communication = True
chunk_group.append(chunk)
chunk.append_tensor(tensor)

3
colossalai/zero/gemini/gemini_ddp.py

@ -98,6 +98,7 @@ class GeminiDDP(ModelWrapper):
extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False,
enable_async_reduce: bool = True,
fp8_communication: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@ -122,6 +123,8 @@ class GeminiDDP(ModelWrapper):
verbose=verbose,
max_prefetch=max_prefetch,
)
if fp8_communication:
self.chunk_manager.fp8_communication = True
self.gemini_manager = GeminiManager(
placement_policy,
self.chunk_manager,

17
examples/language/bert/finetune.py

@ -179,7 +179,7 @@ def main():
"--plugin",
type=str,
default="torch_ddp",
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"],
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel", "torch_fsdp"],
help="plugin to use",
)
parser.add_argument(
@ -215,9 +215,9 @@ def main():
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
elif args.plugin == "gemini":
plugin = GeminiPlugin(initial_scale=2**5)
plugin = GeminiPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm)
elif args.plugin == "low_level_zero":
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == "hybrid_parallel":
@ -235,6 +235,17 @@ def main():
initial_scale=1,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "torch_fsdp":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from colossalai.booster.plugin import TorchFSDPPlugin
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
),
fp8_communication=args.use_fp8_comm,
)
booster = Booster(plugin=plugin, **booster_kwargs)

2
examples/language/gpt/hybridparallelism/finetune.py

@ -212,7 +212,7 @@ def main():
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
elif args.plugin == "gemini":
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == "low_level_zero":

12
examples/language/llama/benchmark.py

@ -98,7 +98,7 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
args = parser.parse_args()
colossalai.launch_from_torch()
@ -158,6 +158,7 @@ def main():
buffer_dtype=torch.float16,
),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@ -165,7 +166,8 @@ def main():
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)
),
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "fsdp_cpu":
if use_empty_init:
@ -177,6 +179,7 @@ def main():
),
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@ -186,6 +189,7 @@ def main():
buffer_dtype=torch.float16,
),
cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
@ -200,9 +204,9 @@ def main():
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
@ -293,7 +297,7 @@ def main():
with get_profile_context(
args.profile,
args.ignore_steps,
1, # avoid creating massive log files
len(dataloader) - 1,
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:

26
tests/test_fp8/test_fp8_cast.py

@ -0,0 +1,26 @@
import torch
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
from colossalai.testing import parameterize
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def test_fp8_cast(shape, dtype, fp8_format):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
out = cast_from_fp8(ret, scale_inv, x.dtype)
assert_close(out, x, rtol=0.1, atol=0.1)
if x.size(-1) % 2 == 0:
inp_dict = {"hidden_states": x.clone()}
cast_to_fp8_pipeline(inp_dict)
cast_from_fp8_pipeline(inp_dict)
assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)
if __name__ == "__main__":
test_fp8_cast()

87
tests/test_fp8/test_fp8_ddp_comm_hook.py

@ -0,0 +1,87 @@
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
def get_grads_after_one_iteration(hook=None):
torch.manual_seed(0)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
if hook is not None:
ddp_model.register_comm_hook(None, hook)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
torch.distributed.barrier()
grad_dict = {}
for name, params in ddp_model.named_parameters():
grad_dict[name] = params.grad
return grad_dict
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync
grad_dict = get_grads_after_one_iteration()
for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]:
grad_dict_w_hook = get_grads_after_one_iteration(hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)

107
tests/test_fp8/test_fp8_fsdp_comm_hook.py

@ -0,0 +1,107 @@
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing import assert_close
from colossalai import launch
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(100, 100)
self.relu = nn.ReLU()
self.net2 = nn.Linear(100, 50)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
@parameterize("mode", ["grad", "params"])
def run_model(mode):
rank = dist.get_rank()
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
patch_fsdp_params_comm_hook()
def get_grads_after_one_iteration(grad_hook=None, params_hook=None):
torch.manual_seed(0)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
fsdp_model = FSDP(model)
if grad_hook is not None:
fsdp_model.register_comm_hook(None, grad_hook)
if params_hook is not None:
fsdp_model.register_params_comm_hook(None, params_hook)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(fsdp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = fsdp_model(torch.randn(20, 100))
labels = torch.randn(20, 50).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
torch.distributed.barrier()
grad_dict = {}
for name, params in fsdp_model.named_parameters():
grad_dict[name] = params.grad
return grad_dict
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook, fp8_compress_fsdp_params_comm_hook
if mode == "grad":
grad_dict = get_grads_after_one_iteration()
for hook in [
fp8_compress_fsdp_grad_comm_hook,
]:
grad_dict_w_hook = get_grads_after_one_iteration(grad_hook=hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
elif mode == "params":
grad_dict = get_grads_after_one_iteration()
for hook in [
fp8_compress_fsdp_params_comm_hook,
]:
grad_dict_w_hook = get_grads_after_one_iteration(params_hook=hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
else:
raise NotImplementedError
def demo_basic(rank, world_size, port):
print(f"Running basic FSDP example on rank {rank}.")
launch(rank=rank, world_size=world_size, port=port, host="localhost")
run_model()
cleanup()
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("2.2.0"), reason="torch version < 2.2.0.")
@rerun_if_address_is_in_use()
def test_fsdp():
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
spawn(demo_basic, n_gpus)
if __name__ == "__main__":
test_fsdp()
Loading…
Cancel
Save