diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index ab65c4e22..e66f90ef5 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -1,7 +1,7 @@ import torch from colossalai.registry import OPHOOKS from colossalai.zero.shard_utils import BaseShardStrategy - +from colossalai.utils import get_current_device from ._base_ophook import BaseOpHook @@ -14,11 +14,15 @@ class ZeroHook(BaseOpHook): def __init__(self, shard_strategy: BaseShardStrategy): super().__init__() self.shard_strategy = shard_strategy + # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU + self.computing_device = torch.device(f'cuda:{get_current_device()}') def pre_fwd_exec(self, module: torch.nn.Module, *args): for param in module.parameters(): assert hasattr(param, 'col_attr') self.shard_strategy.gather([param.col_attr.data]) + if param.col_attr.data.device != self.computing_device: + param.col_attr.data.to(self.computing_device) param.data = param.col_attr.data.payload def post_fwd_exec(self, module: torch.nn.Module, *args): @@ -31,6 +35,8 @@ class ZeroHook(BaseOpHook): for param in module.parameters(): assert hasattr(param, 'col_attr') self.shard_strategy.gather([param.col_attr.data]) + if param.col_attr.data.device != self.computing_device: + param.col_attr.data.to(self.computing_device) param.data = param.col_attr.data.payload # Store local accumulated grad shard if param.grad is not None: diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index f045e144f..340206661 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -1,7 +1,6 @@ import functools import torch -from colossalai.utils.cuda import get_current_device from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER @@ -82,6 +81,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): 1. Convert the model to fp16. 2. The paramaters of the module are adapted to type ShardedParameter. 3. Shard the param and grad according to flags. + + target_device: the device where param data after exiting the context + shard_strategy: shard strategy instance + shard_param: is param sharded after exiting the context + shard_grad: is param sharded after exiting the context + rm_torch_payload_on_the_fly: True: remove tensor payload on param.data after module init finished. False: remove tensor payload on param.data afther the context exist. @@ -91,18 +96,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): def __init__(self, convert_fp16: bool, - convert_cuda: bool, + target_device: torch.device, shard_strategy: BaseShardStrategy, shard_param: bool = False, shard_grad: bool = False, rm_torch_payload_on_the_fly=False): super().__init__() self.convert_fp16 = convert_fp16 - self.convert_cuda = convert_cuda + self.target_device = target_device self.shard_param = shard_param self.shard_grad = shard_grad self.shard_strategy = shard_strategy - self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly + # FIXME(jiaruifang) now setting it to True is invalid. + self.rm_torch_payload_on_the_fly = False self.initialized_param_list = [] def _post_context_exec(self): @@ -123,17 +129,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if hasattr(param, 'col_attr'): continue - if self.convert_cuda: - target_device = get_current_device() - else: - target_device = param.data.device + target_device = self.target_device - # convert to fp16 and cuda if necessary + # convert to fp16 if necessary if self.convert_fp16: - param.data = param.data.to(torch.half).to(target_device) + param.data = param.data.to(torch.half) if param.grad is not None: param.grad = param.grad.to(torch.half).to(target_device) + # move torch parameters to the target device + param.data = param.data.to(target_device) + if param.grad is not None: + param.grad = param.grad.to(target_device) + param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly) self.initialized_param_list.append(param) diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index ae58bb6aa..94d40e9fb 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -30,7 +30,7 @@ class TensorShardStrategy(BaseShardStrategy): def _gather_tensor(self, t: ShardedTensor): if not t.is_sharded: return - + target_device = t.device buffer_list = [] payload_numel = t.payload.numel() for i in range(self.world_size): @@ -45,4 +45,5 @@ class TensorShardStrategy(BaseShardStrategy): async_op=False) gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape) t.reset_payload(gathered_payload) + t.to(target_device) t.is_sharded = False diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index 093889b4b..eaaa8ab99 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -47,11 +47,18 @@ class ShardedTensor(object): del self._payload self._payload = tensor + @property + def device(self): + return self._payload.device + @property def dtype(self): assert self._payload.dtype == self._origin_dtype return self._origin_dtype + def to(self, device: torch.device): + self._payload = self._payload.to(device) + @property def shape(self): return self._payload.shape diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index e29a266eb..0a2f0d960 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -4,6 +4,7 @@ from functools import partial import colossalai +from colossalai.utils.cuda import get_current_device import pytest import torch import torch.multiprocessing as mp @@ -17,13 +18,13 @@ from common import CONFIG from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, init_device): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') for get_components_func in non_distributed_component_funcs: model_builder, _, _, _, _ = get_components_func() with ZeroInitContext(convert_fp16=True, - convert_cuda=True, + target_device=init_device, shard_strategy=TensorShardStrategy(), shard_param=True): model = model_builder(checkpoint=True) @@ -32,18 +33,26 @@ def run_dist(rank, world_size, port): assert hasattr(param, 'col_attr') assert param.col_attr.data.dtype == torch.half assert param.col_attr.data.is_sharded - assert param.col_attr.data.payload.device.type == 'cuda' + assert param.col_attr.data.payload.device.type == init_device.type, \ + f'{param.col_attr.data.payload.device.type} vs. {init_device.type}' + print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}') print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}') - assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) + + if init_device.type == 'cuda': + assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) + elif init_device.type == 'cpu': + assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0) @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 4]) -def test_zero_init_context(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) +@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')]) +def test_zero_init_context(world_size, init_device): + run_func = partial(run_dist, world_size=world_size, port=free_port(), init_device=init_device) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_zero_init_context(2) + test_zero_init_context(2, torch.device('cpu')) + test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}')) diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index a1885e8f0..23a75cfcd 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -5,6 +5,7 @@ import copy from functools import partial import pytest +import torch import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP @@ -30,8 +31,14 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, _, criterion = get_components_func() + rm_torch_payload_on_the_fly = False + if use_zero_init_ctx: - with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True): + with ZeroInitContext(convert_fp16=True, + target_device=torch.device('cpu'), + shard_strategy=shard_strategy, + shard_param=True, + rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly): zero_model = model_builder(checkpoint=True) zero_model = ShardedModelV2(zero_model, shard_strategy)