Browse Source

[zero] able to place params on cpu after zero init context (#365)

* place params on cpu after zero init context

* polish code
pull/394/head
Jiarui Fang 3 years ago committed by Frank Lee
parent
commit
44e4891f57
  1. 8
      colossalai/engine/ophooks/zero_hook.py
  2. 28
      colossalai/zero/init_ctx/init_context.py
  3. 3
      colossalai/zero/shard_utils/tensor_shard_strategy.py
  4. 7
      colossalai/zero/sharded_param/sharded_tensor.py
  5. 23
      tests/test_zero_data_parallel/test_init_context.py
  6. 9
      tests/test_zero_data_parallel/test_shard_model_v2.py

8
colossalai/engine/ophooks/zero_hook.py

@ -1,7 +1,7 @@
import torch
from colossalai.registry import OPHOOKS
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.utils import get_current_device
from ._base_ophook import BaseOpHook
@ -14,11 +14,15 @@ class ZeroHook(BaseOpHook):
def __init__(self, shard_strategy: BaseShardStrategy):
super().__init__()
self.shard_strategy = shard_strategy
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = torch.device(f'cuda:{get_current_device()}')
def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data])
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
param.data = param.col_attr.data.payload
def post_fwd_exec(self, module: torch.nn.Module, *args):
@ -31,6 +35,8 @@ class ZeroHook(BaseOpHook):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data])
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
param.data = param.col_attr.data.payload
# Store local accumulated grad shard
if param.grad is not None:

28
colossalai/zero/init_ctx/init_context.py

@ -1,7 +1,6 @@
import functools
import torch
from colossalai.utils.cuda import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
@ -82,6 +81,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
target_device: the device where param data after exiting the context
shard_strategy: shard strategy instance
shard_param: is param sharded after exiting the context
shard_grad: is param sharded after exiting the context
rm_torch_payload_on_the_fly:
True: remove tensor payload on param.data after module init finished.
False: remove tensor payload on param.data afther the context exist.
@ -91,18 +96,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self,
convert_fp16: bool,
convert_cuda: bool,
target_device: torch.device,
shard_strategy: BaseShardStrategy,
shard_param: bool = False,
shard_grad: bool = False,
rm_torch_payload_on_the_fly=False):
super().__init__()
self.convert_fp16 = convert_fp16
self.convert_cuda = convert_cuda
self.target_device = target_device
self.shard_param = shard_param
self.shard_grad = shard_grad
self.shard_strategy = shard_strategy
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
# FIXME(jiaruifang) now setting it to True is invalid.
self.rm_torch_payload_on_the_fly = False
self.initialized_param_list = []
def _post_context_exec(self):
@ -123,17 +129,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if hasattr(param, 'col_attr'):
continue
if self.convert_cuda:
target_device = get_current_device()
else:
target_device = param.data.device
target_device = self.target_device
# convert to fp16 and cuda if necessary
# convert to fp16 if necessary
if self.convert_fp16:
param.data = param.data.to(torch.half).to(target_device)
param.data = param.data.to(torch.half)
if param.grad is not None:
param.grad = param.grad.to(torch.half).to(target_device)
# move torch parameters to the target device
param.data = param.data.to(target_device)
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
self.initialized_param_list.append(param)

3
colossalai/zero/shard_utils/tensor_shard_strategy.py

@ -30,7 +30,7 @@ class TensorShardStrategy(BaseShardStrategy):
def _gather_tensor(self, t: ShardedTensor):
if not t.is_sharded:
return
target_device = t.device
buffer_list = []
payload_numel = t.payload.numel()
for i in range(self.world_size):
@ -45,4 +45,5 @@ class TensorShardStrategy(BaseShardStrategy):
async_op=False)
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload)
t.to(target_device)
t.is_sharded = False

7
colossalai/zero/sharded_param/sharded_tensor.py

@ -47,11 +47,18 @@ class ShardedTensor(object):
del self._payload
self._payload = tensor
@property
def device(self):
return self._payload.device
@property
def dtype(self):
assert self._payload.dtype == self._origin_dtype
return self._origin_dtype
def to(self, device: torch.device):
self._payload = self._payload.to(device)
@property
def shape(self):
return self._payload.shape

23
tests/test_zero_data_parallel/test_init_context.py

@ -4,6 +4,7 @@
from functools import partial
import colossalai
from colossalai.utils.cuda import get_current_device
import pytest
import torch
import torch.multiprocessing as mp
@ -17,13 +18,13 @@ from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, init_device):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func()
with ZeroInitContext(convert_fp16=True,
convert_cuda=True,
target_device=init_device,
shard_strategy=TensorShardStrategy(),
shard_param=True):
model = model_builder(checkpoint=True)
@ -32,18 +33,26 @@ def run_dist(rank, world_size, port):
assert hasattr(param, 'col_attr')
assert param.col_attr.data.dtype == torch.half
assert param.col_attr.data.is_sharded
assert param.col_attr.data.payload.device.type == 'cuda'
assert param.col_attr.data.payload.device.type == init_device.type, \
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}')
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
elif init_device.type == 'cpu':
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4])
def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
def test_zero_init_context(world_size, init_device):
run_func = partial(run_dist, world_size=world_size, port=free_port(), init_device=init_device)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_init_context(2)
test_zero_init_context(2, torch.device('cpu'))
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}'))

9
tests/test_zero_data_parallel/test_shard_model_v2.py

@ -5,6 +5,7 @@ import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
@ -30,8 +31,14 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
rm_torch_payload_on_the_fly = False
if use_zero_init_ctx:
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True):
with ZeroInitContext(convert_fp16=True,
target_device=torch.device('cpu'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy)

Loading…
Cancel
Save