[zero] able to place params on cpu after zero init context (#365)

* place params on cpu after zero init context

* polish code
pull/394/head
Jiarui Fang 2022-03-10 14:08:58 +08:00 committed by Frank Lee
parent b66f3b994c
commit 44e4891f57
6 changed files with 58 additions and 20 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.utils import get_current_device
from ._base_ophook import BaseOpHook from ._base_ophook import BaseOpHook
@ -14,11 +14,15 @@ class ZeroHook(BaseOpHook):
def __init__(self, shard_strategy: BaseShardStrategy): def __init__(self, shard_strategy: BaseShardStrategy):
super().__init__() super().__init__()
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = torch.device(f'cuda:{get_current_device()}')
def pre_fwd_exec(self, module: torch.nn.Module, *args): def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data]) self.shard_strategy.gather([param.col_attr.data])
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
param.data = param.col_attr.data.payload param.data = param.col_attr.data.payload
def post_fwd_exec(self, module: torch.nn.Module, *args): def post_fwd_exec(self, module: torch.nn.Module, *args):
@ -31,6 +35,8 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data]) self.shard_strategy.gather([param.col_attr.data])
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
param.data = param.col_attr.data.payload param.data = param.col_attr.data.payload
# Store local accumulated grad shard # Store local accumulated grad shard
if param.grad is not None: if param.grad is not None:

View File

@ -1,7 +1,6 @@
import functools import functools
import torch import torch
from colossalai.utils.cuda import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
@ -82,6 +81,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
1. Convert the model to fp16. 1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter. 2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags. 3. Shard the param and grad according to flags.
target_device: the device where param data after exiting the context
shard_strategy: shard strategy instance
shard_param: is param sharded after exiting the context
shard_grad: is param sharded after exiting the context
rm_torch_payload_on_the_fly: rm_torch_payload_on_the_fly:
True: remove tensor payload on param.data after module init finished. True: remove tensor payload on param.data after module init finished.
False: remove tensor payload on param.data afther the context exist. False: remove tensor payload on param.data afther the context exist.
@ -91,18 +96,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, def __init__(self,
convert_fp16: bool, convert_fp16: bool,
convert_cuda: bool, target_device: torch.device,
shard_strategy: BaseShardStrategy, shard_strategy: BaseShardStrategy,
shard_param: bool = False, shard_param: bool = False,
shard_grad: bool = False, shard_grad: bool = False,
rm_torch_payload_on_the_fly=False): rm_torch_payload_on_the_fly=False):
super().__init__() super().__init__()
self.convert_fp16 = convert_fp16 self.convert_fp16 = convert_fp16
self.convert_cuda = convert_cuda self.target_device = target_device
self.shard_param = shard_param self.shard_param = shard_param
self.shard_grad = shard_grad self.shard_grad = shard_grad
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly # FIXME(jiaruifang) now setting it to True is invalid.
self.rm_torch_payload_on_the_fly = False
self.initialized_param_list = [] self.initialized_param_list = []
def _post_context_exec(self): def _post_context_exec(self):
@ -123,17 +129,19 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if hasattr(param, 'col_attr'): if hasattr(param, 'col_attr'):
continue continue
if self.convert_cuda: target_device = self.target_device
target_device = get_current_device()
else:
target_device = param.data.device
# convert to fp16 and cuda if necessary # convert to fp16 if necessary
if self.convert_fp16: if self.convert_fp16:
param.data = param.data.to(torch.half).to(target_device) param.data = param.data.to(torch.half)
if param.grad is not None: if param.grad is not None:
param.grad = param.grad.to(torch.half).to(target_device) param.grad = param.grad.to(torch.half).to(target_device)
# move torch parameters to the target device
param.data = param.data.to(target_device)
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly) param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
self.initialized_param_list.append(param) self.initialized_param_list.append(param)

View File

@ -30,7 +30,7 @@ class TensorShardStrategy(BaseShardStrategy):
def _gather_tensor(self, t: ShardedTensor): def _gather_tensor(self, t: ShardedTensor):
if not t.is_sharded: if not t.is_sharded:
return return
target_device = t.device
buffer_list = [] buffer_list = []
payload_numel = t.payload.numel() payload_numel = t.payload.numel()
for i in range(self.world_size): for i in range(self.world_size):
@ -45,4 +45,5 @@ class TensorShardStrategy(BaseShardStrategy):
async_op=False) async_op=False)
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape) gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload) t.reset_payload(gathered_payload)
t.to(target_device)
t.is_sharded = False t.is_sharded = False

View File

@ -47,11 +47,18 @@ class ShardedTensor(object):
del self._payload del self._payload
self._payload = tensor self._payload = tensor
@property
def device(self):
return self._payload.device
@property @property
def dtype(self): def dtype(self):
assert self._payload.dtype == self._origin_dtype assert self._payload.dtype == self._origin_dtype
return self._origin_dtype return self._origin_dtype
def to(self, device: torch.device):
self._payload = self._payload.to(device)
@property @property
def shape(self): def shape(self):
return self._payload.shape return self._payload.shape

View File

@ -4,6 +4,7 @@
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.utils.cuda import get_current_device
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -17,13 +18,13 @@ from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, init_device):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for get_components_func in non_distributed_component_funcs: for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func() model_builder, _, _, _, _ = get_components_func()
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
convert_cuda=True, target_device=init_device,
shard_strategy=TensorShardStrategy(), shard_strategy=TensorShardStrategy(),
shard_param=True): shard_param=True):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
@ -32,18 +33,26 @@ def run_dist(rank, world_size, port):
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
assert param.col_attr.data.dtype == torch.half assert param.col_attr.data.dtype == torch.half
assert param.col_attr.data.is_sharded assert param.col_attr.data.is_sharded
assert param.col_attr.data.payload.device.type == 'cuda' assert param.col_attr.data.payload.device.type == init_device.type, \
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}')
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}') print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
elif init_device.type == 'cpu':
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4]) @pytest.mark.parametrize("world_size", [1, 4])
def test_zero_init_context(world_size): @pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
run_func = partial(run_dist, world_size=world_size, port=free_port()) def test_zero_init_context(world_size, init_device):
run_func = partial(run_dist, world_size=world_size, port=free_port(), init_device=init_device)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_init_context(2) test_zero_init_context(2, torch.device('cpu'))
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}'))

View File

@ -5,6 +5,7 @@ import copy
from functools import partial from functools import partial
import pytest import pytest
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -30,8 +31,14 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
rm_torch_payload_on_the_fly = False
if use_zero_init_ctx: if use_zero_init_ctx:
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True): with ZeroInitContext(convert_fp16=True,
target_device=torch.device('cpu'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
zero_model = model_builder(checkpoint=True) zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy) zero_model = ShardedModelV2(zero_model, shard_strategy)