zero init ctx receives a dp process group (#471)

pull/472/head
ver217 2022-03-21 11:18:55 +08:00 committed by GitHub
parent 7e30068a22
commit 3cb3fc275e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 4 deletions

View File

@ -1,11 +1,15 @@
import functools
from typing import Optional
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
# Inserts _post_init_method at the end of init method
@ -103,8 +107,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
shard_strategy: BaseShardStrategy,
shard_param: bool = False,
shard_grad: bool = False,
rm_torch_payload_on_the_fly=False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int)):
rm_torch_payload_on_the_fly: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
dp_process_group: Optional[ProcessGroup] = None):
super().__init__()
self.convert_fp16 = convert_fp16
self.target_device = target_device
@ -115,6 +120,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.rm_torch_payload_on_the_fly = False
self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
def _post_context_exec(self):
"""The callback function when the context exits.
@ -154,10 +160,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.initialized_param_list.append(param)
if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
self.shard_strategy.shard([param.col_attr._data_sharded_tensor], self.dp_process_group)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
# if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float