diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index 69390c512..913f82ed7 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -34,13 +34,13 @@ class ZeroHook(BaseOpHook): tensor_list = [] for param in module.parameters(): assert hasattr(param, 'col_attr') - tensor_list.append(param.col_attr.data) + tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) for param in module.parameters(): - if param.col_attr.data.device != self.computing_device: - param.col_attr.data.to(self.computing_device) - GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload) - param.data = param.col_attr.data.payload + if param.col_attr.sharded_data_tensor.device != self.computing_device: + param.col_attr.sharded_data_tensor.to(self.computing_device) + GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) + param.data = param.col_attr.sharded_data_tensor.payload if self._memstarts_collector: self._memstarts_collector.sample_memstats() @@ -49,7 +49,7 @@ class ZeroHook(BaseOpHook): tensor_list = [] for param in module.parameters(): assert hasattr(param, 'col_attr') - tensor_list.append(param.col_attr.data) + tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) for param in module.parameters(): param.col_attr.remove_torch_payload() @@ -58,13 +58,13 @@ class ZeroHook(BaseOpHook): tensor_list = [] for param in module.parameters(): assert hasattr(param, 'col_attr') - tensor_list.append(param.col_attr.data) + tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) for param in module.parameters(): - if param.col_attr.data.device != self.computing_device: - param.col_attr.data.to(self.computing_device) - GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload) - param.data = param.col_attr.data.payload + if param.col_attr.sharded_data_tensor.device != self.computing_device: + param.col_attr.sharded_data_tensor.to(self.computing_device) + GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) + param.data = param.col_attr.sharded_data_tensor.payload # Store local accumulated grad shard if param.grad is not None: if param.col_attr.bwd_count == 0: @@ -75,7 +75,7 @@ class ZeroHook(BaseOpHook): else: # We have stored local accumulated grad # The grad here must be locally computed full grad in this backward pass - assert param.grad.shape == param.col_attr.data.origin_shape + assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape param.col_attr.bwd_count += 1 if self._memstarts_collector: self._memstarts_collector.sample_memstats() @@ -84,7 +84,7 @@ class ZeroHook(BaseOpHook): tensor_list = [] for param in module.parameters(): assert hasattr(param, 'col_attr') - tensor_list.append(param.col_attr.data) + tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) for param in module.parameters(): param.col_attr.remove_torch_payload() diff --git a/colossalai/utils/commons/bucket_tensor_copy.py b/colossalai/utils/commons/bucket_tensor_copy.py index 6febb9705..f65a75a81 100644 --- a/colossalai/utils/commons/bucket_tensor_copy.py +++ b/colossalai/utils/commons/bucket_tensor_copy.py @@ -50,7 +50,7 @@ class BucketizedTensorCopy(object): self._cuda_buffer.copy_(self._cpu_buffer) flush_offset = 0 for sparam, numel in zip(self._buffered_param_list, self._numel_list): - sparam.data.copy_payload(self._cpu_buffer.narrow(0, flush_offset, numel)) + sparam.sharded_data_tensor.copy_payload(self._cpu_buffer.narrow(0, flush_offset, numel)) flush_offset += numel self.reset() diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 2a43d240d..53889db6d 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -160,8 +160,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.initialized_param_list.append(param) if self.shard_param: - self.shard_strategy.shard([param.col_attr._data_sharded_tensor], self.dp_process_group) - GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload) + self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) + GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) # if param.col_attr.grad and self.shard_grad: # self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group) # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index d8e309d66..7ca598e82 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -165,7 +165,7 @@ class ShardedModelV2(nn.Module): if self.shard_param: for p in self.module.parameters(): if not p.col_attr.param_is_sharded: - self.shard_strategy.shard([p.col_attr.data], self.process_group) + self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group) for p in self.module.parameters(): p.col_attr.bwd_count = 0 if not p.requires_grad: @@ -249,13 +249,15 @@ class ShardedModelV2(nn.Module): param.col_attr.fp16_grad = reduced_grad.data def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': - self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()], self.process_group) + self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()], + self.process_group) prev_params = {} for p in self.module.parameters(): prev_params[p] = p.data - p.data = p.col_attr.data.payload + p.data = p.col_attr.sharded_data_tensor.payload gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) - self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()], self.process_group) + self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()], + self.process_group) for p in self.module.parameters(): p.data = prev_params[p] return gathered_state_dict diff --git a/colossalai/zero/sharded_model/utils.py b/colossalai/zero/sharded_model/utils.py index 7b7c634d3..4489afdc9 100644 --- a/colossalai/zero/sharded_model/utils.py +++ b/colossalai/zero/sharded_model/utils.py @@ -11,9 +11,9 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu """ for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): assert hasattr(zero_param, 'col_attr') - shard_flag = zero_param.col_attr.data.is_sharded + shard_flag = zero_param.col_attr.sharded_data_tensor.is_sharded if shard_flag: - sharded_model.shard_strategy.gather([zero_param.col_attr.data]) - param.data = copy.deepcopy(zero_param.col_attr.data.payload) + sharded_model.shard_strategy.gather([zero_param.col_attr.sharded_data_tensor]) + param.data = copy.deepcopy(zero_param.col_attr.sharded_data_tensor.payload) if shard_flag: - sharded_model.shard_strategy.shard([zero_param.col_attr.data]) + sharded_model.shard_strategy.shard([zero_param.col_attr.sharded_data_tensor]) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 4f111921d..5ec57c083 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -109,17 +109,17 @@ class ShardedOptimizerV2(ColossalaiOptimizer): for group in self.optim.param_groups: for p in group['params']: assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' - is_param_sharded = p.col_attr.data.is_sharded + is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded if not is_param_sharded: # TODO (ver217): we may not use shard / gather here # Param is no sharded, which means we use ZeRO-2 here # As we only store param shard, we shard it here - self.shard_strategy.shard([p.col_attr.data], self.dp_process_group) - self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device) + self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group) + self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device) if not is_param_sharded: # In this branch, there's no need to shard param # So we gather here - self.shard_strategy.gather([p.col_attr.data], self.dp_process_group) + self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) def step(self, *args, **kwargs): # unscale grads if scaled @@ -149,24 +149,24 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # a chunk. for group in self.optim.param_groups: for p in group['params']: - is_param_sharded = p.col_attr.data.is_sharded + is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded if not is_param_sharded: # We use ZeRO-2 here - # The `p.col_attr.data` saves full fp16 param + # The `p.col_attr.sharded_data_tensor` saves full fp16 param # But we only have updated fp32 param shard here # So we first shard full fp16 param and copy fp32 param shard to it # Then we will gather them - self.shard_strategy.shard([p.col_attr.data], self.dp_process_group) + self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group) # We have to use `copy_payload` instead of `reset_payload` - # Since p.data is fp32 and p.col_attr.data is fp16 + # Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16 # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.col_attr.data.copy_payload(p.data) + p.col_attr.sharded_data_tensor.copy_payload(p.data) if not is_param_sharded: # We gather full fp16 param here - self.shard_strategy.gather([p.col_attr.data], self.dp_process_group) - p.data = p.col_attr.data.payload + self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) + p.data = p.col_attr.sharded_data_tensor.payload return ret def backward(self, loss: Tensor) -> None: diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py index 95c9b0471..5642a504a 100644 --- a/colossalai/zero/sharded_param/__init__.py +++ b/colossalai/zero/sharded_param/__init__.py @@ -1,4 +1,4 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.zero.sharded_param.sharded_param import ShardedParam, ShardedParamV2 +from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 -__all__ = ['ShardedParam', 'ShardedTensor', 'ShardedParamV2'] +__all__ = ['ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 01987c1e2..de5f8f849 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -1,12 +1,7 @@ -from typing import Optional, Tuple, Union - -import numpy import torch import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_param import ShardedTensor +from typing import Optional class ShardedParamV2(object): @@ -15,7 +10,7 @@ class ShardedParamV2(object): param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None, rm_torch_payload=False) -> None: - self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group) + self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group) self.fp16_grad: Optional[torch.Tensor] = None self.fp32_grad: Optional[torch.Tensor] = None @@ -37,105 +32,9 @@ class ShardedParamV2(object): self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device) @property - def data(self): - return self._data_sharded_tensor + def sharded_data_tensor(self): + return self._sharded_data_tensor @property def param_is_sharded(self): - return self._data_sharded_tensor.is_sharded - - -class ShardedParam(object): - r""" - A wrapper to torch.nn.Parameter. Shard a param - on memory space of different processes. - """ - - def __init__(self, - other: Union[torch.nn.Parameter, Tuple[int, ...]], - process_group: Optional[dist.ProcessGroup] = None, - is_sharded: bool = False, - device: Optional[torch.device] = None) -> None: - r""" - other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape. - process_group: the process group storing the shared data. - is_sharded: is shared the param during __init__. - device: the device to place param data payload on - """ - self.process_group = process_group or gpc.get_group(ParallelMode.DATA) - self.world_size = dist.get_world_size(self.process_group) - self.local_rank = dist.get_rank(self.process_group) - self.is_sharded = False - self.device = device - - # Hijack the data payload of param - if isinstance(other, torch.nn.Parameter): - self._param_payload = other.data.to(device) - self._origin_shape = other.shape - self._origin_numel = other.numel() - if is_sharded: - self.shard() - elif isinstance(other, tuple): - self._origin_shape = other - self._origin_numel = numpy.prod(other) - - # TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape. - assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple" - self._param_payload = torch.empty(self._origin_shape, device=device) - if is_sharded: - self.shard() - else: - raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}") - - self._payload_numel = None - - def payload(self, target_device: Optional[torch.device] = None): - r""" - get the payload and move it to target device - """ - if target_device is not None: - return self._param_payload.to(target_device) - return self._param_payload - - def set_payload(self, data: torch.Tensor): - r""" - set payload as data - """ - assert self._param_payload.shape == data.shape - self._param_payload.copy_(data) - - def shard(self): - r""" - Distributed the payload of param to all processes. - """ - if self.is_sharded: - return - self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size) - self.is_sharded = True - - def gather(self): - r""" - Collect the payload of param from different processes to process of local rank. - The payload has to be moved to cuda memory before communication. - """ - if not self.is_sharded: - return - - buffer_list = [] - payload_numel = self._param_payload.numel() - for i in range(self.world_size): - if i == self.local_rank: - buffer_list.append(self._param_payload.cuda()) - else: - buffer_list.append(torch.zeros(payload_numel).cuda()) - - torch.distributed.all_gather(buffer_list, - buffer_list[self.local_rank], - group=self.process_group, - async_op=False) - self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape) - self.is_sharded = False - - @property - def origin_dtype(self): - return self._origin_dtype + return self._sharded_data_tensor.is_sharded diff --git a/tests/test_utils/test_bucket_tensor_copy.py b/tests/test_utils/test_bucket_tensor_copy.py index 198d7b691..31d534b78 100644 --- a/tests/test_utils/test_bucket_tensor_copy.py +++ b/tests/test_utils/test_bucket_tensor_copy.py @@ -17,7 +17,6 @@ def test_bucket_copy(): for shape in shape_list: # on CPU src_param = torch.nn.Parameter(torch.randn(shape, dtype=torch.float, device=torch.device('cpu'))) - print(src_param) # on GPU tgt_param = ShardedParamV2(torch.nn.Parameter(torch.ones(shape, dtype=torch.half, device=torch.device('cuda')))) @@ -29,9 +28,10 @@ def test_bucket_copy(): copyer.flush() for src_param, tgt_param in zip(src_param_list, tgt_param_list): - print(tgt_param.data.payload) - diff = src_param.cpu().float() - tgt_param.data.payload.cpu().float() - assert torch.allclose(src_param.cpu().float(), tgt_param.data.payload.cpu().float(), rtol=1e-03, + diff = src_param.cpu().float() - tgt_param.sharded_data_tensor.payload.cpu().float() + assert torch.allclose(src_param.cpu().float(), + tgt_param.sharded_data_tensor.payload.cpu().float(), + rtol=1e-03, atol=1e-03), f"diff {diff}" diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index fc95d59b4..7e6f881dc 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -119,7 +119,7 @@ def check_params_padding(model, zero_model, loose=False): def check_sharded_params_padding(model, zero_model, loose=False): rank = dist.get_rank() for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_p = zero_p.col_attr.data.payload.to(p.device).float() + zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float() chunks = torch.flatten(p).chunk(dist.get_world_size()) if rank >= len(chunks): continue diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index 0fe0cd19c..f1b41ee09 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -34,10 +34,10 @@ def run_model_test(init_device, shard_strategy_class): for param in model.parameters(): assert hasattr(param, 'col_attr') - assert param.col_attr.data.dtype == torch.half - assert param.col_attr.data.is_sharded - assert param.col_attr.data.payload.device.type == init_device.type, \ - f'{param.col_attr.data.payload.device.type} vs. {init_device.type}' + assert param.col_attr.sharded_data_tensor.dtype == torch.half + assert param.col_attr.sharded_data_tensor.is_sharded + assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \ + f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}') print(f'numel {model_numel_tensor}') diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 16cac2c15..b3ca1bda1 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - from copy import deepcopy from functools import partial @@ -8,13 +5,11 @@ import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.testing import parameterize from colossalai.utils import free_port from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.zero.sharded_param import ShardedParam, ShardedTensor +from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 -from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_zero_data_parallel.common import CONFIG, allclose @@ -52,7 +47,7 @@ def _run_shard_param_v2(rank, world_size, port): param_ref = deepcopy(param) sparam = ShardedParamV2(param=param, process_group=None) - allclose(sparam.data.payload, param_ref.data) + allclose(sparam.sharded_data_tensor.payload, param_ref.data) sparam.remove_torch_payload() assert (param.data.numel() == 1) @@ -65,69 +60,6 @@ def test_shard_param_v2(world_size): mp.spawn(run_func, nprocs=world_size) -def _run_test_shard_param(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - param = torch.nn.Parameter(torch.randn(2, 3)) - param_ref = deepcopy(param) - sparam = ShardedParamV2(param=param, process_group=None) - print(sparam.data) - print(param_ref.data) - - logger = get_dist_logger() - for get_components_func in non_distributed_component_funcs: - model_builder, *_ = get_components_func() - model = model_builder(checkpoint=True) - # add an attribute as col_attr to hijack the access to param.data - for _, param in model.named_parameters(): - numel_ref = (param.numel() + world_size - 1) // world_size - param.col_attr = ShardedParam(param) - param.col_attr.shard() - param_data = param.col_attr.payload(torch.device('cpu')) - assert (numel_ref == param_data.numel()) - - for _, param in model.named_parameters(): - param.col_attr.gather() - param_data = param.col_attr.payload(torch.device('cpu')) - - disable_existing_loggers([logger]) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -def test_shard_param(world_size): - run_func = partial(_run_test_shard_param, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -def _run_init_shard_param(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - param = torch.nn.Parameter(data=torch.rand(world_size, 3)) - sparam = ShardedParam(param, None, True) - payload = sparam.payload(torch.device('cuda')) - assert (list(payload.shape) == [3]) - del sparam - - param_shape = (world_size, 3) - sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu')) - payload = sparam.payload(torch.device('cuda')) - assert (list(payload.shape) == [3]) - - param_shape = (world_size, 3) - sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu')) - payload = sparam.payload(torch.device('cuda')) - assert (list(payload.shape) == [world_size, 3]) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 4]) -def test_init_shard_param(world_size): - run_func = partial(_run_init_shard_param, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - if __name__ == '__main__': test_shard_tensor(2) - test_shard_param(2) test_shard_param_v2(2) - test_init_shard_param(4)