diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index e66f90ef5..938826b55 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -1,7 +1,8 @@ import torch from colossalai.registry import OPHOOKS -from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.utils import get_current_device +from colossalai.zero.shard_utils import BaseShardStrategy + from ._base_ophook import BaseOpHook @@ -18,23 +19,32 @@ class ZeroHook(BaseOpHook): self.computing_device = torch.device(f'cuda:{get_current_device()}') def pre_fwd_exec(self, module: torch.nn.Module, *args): + tensor_list = [] for param in module.parameters(): assert hasattr(param, 'col_attr') - self.shard_strategy.gather([param.col_attr.data]) + tensor_list.append(param.col_attr.data) + self.shard_strategy.gather(tensor_list) + for param in module.parameters(): if param.col_attr.data.device != self.computing_device: param.col_attr.data.to(self.computing_device) param.data = param.col_attr.data.payload def post_fwd_exec(self, module: torch.nn.Module, *args): + tensor_list = [] for param in module.parameters(): assert hasattr(param, 'col_attr') - self.shard_strategy.shard([param.col_attr.data]) - param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device) + tensor_list.append(param.col_attr.data) + self.shard_strategy.shard(tensor_list) + for param in module.parameters(): + param.col_attr.remove_torch_payload() def pre_bwd_exec(self, module: torch.nn.Module, input, output): + tensor_list = [] for param in module.parameters(): assert hasattr(param, 'col_attr') - self.shard_strategy.gather([param.col_attr.data]) + tensor_list.append(param.col_attr.data) + self.shard_strategy.gather(tensor_list) + for param in module.parameters(): if param.col_attr.data.device != self.computing_device: param.col_attr.data.to(self.computing_device) param.data = param.col_attr.data.payload @@ -52,10 +62,13 @@ class ZeroHook(BaseOpHook): param.col_attr.bwd_count += 1 def post_bwd_exec(self, module: torch.nn.Module, input): + tensor_list = [] for param in module.parameters(): assert hasattr(param, 'col_attr') - self.shard_strategy.shard([param.col_attr.data]) - param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device) + tensor_list.append(param.col_attr.data) + self.shard_strategy.shard(tensor_list) + for param in module.parameters(): + param.col_attr.remove_torch_payload() def pre_iter(self): pass diff --git a/colossalai/zero/shard_utils/__init__.py b/colossalai/zero/shard_utils/__init__.py index 417e201e8..5e5d63a7e 100644 --- a/colossalai/zero/shard_utils/__init__.py +++ b/colossalai/zero/shard_utils/__init__.py @@ -1,4 +1,5 @@ -from colossalai.zero.shard_utils.base_shard_strategy import BaseShardStrategy -from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy +from .base_shard_strategy import BaseShardStrategy +from .bucket_tensor_shard_strategy import BucketTensorShardStrategy +from .tensor_shard_strategy import TensorShardStrategy -__all__ = ['BaseShardStrategy', 'TensorShardStrategy'] +__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py new file mode 100644 index 000000000..a2b9b0097 --- /dev/null +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -0,0 +1,38 @@ +from typing import List + +import torch +import torch.distributed as dist +from colossalai.utils import get_current_device +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from torch._utils import _flatten_dense_tensors as flatten + +from .tensor_shard_strategy import TensorShardStrategy + + +class BucketTensorShardStrategy(TensorShardStrategy): + + def gather(self, tensor_list: List[ShardedTensor]): + tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] + if len(tensor_list) == 0: + return + target_device = tensor_list[0].device + dtype = tensor_list[0].dtype + buffer_list: List[torch.Tensor] = [] + tensor_numels = [t.payload.numel() for t in tensor_list] + buffer_size = sum(tensor_numels) + for i in range(self.world_size): + if i == self.local_rank: + buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) + else: + buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) + dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group) + # Move to target device before splitting buffer + # Ensure we utilize maximum PCIE bandwidth + buffer_list = [buffer.to(target_device) for buffer in buffer_list] + offset = 0 + for i, t in enumerate(tensor_list): + gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list] + gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape) + t.reset_payload(gathered_payload) + t.is_sharded = False + offset += tensor_numels[i]