mirror of https://github.com/hpcaitech/ColossalAI
zero init ctx receives a dp process group (#471)
parent
7e30068a22
commit
3cb3fc275e
|
@ -1,11 +1,15 @@
|
|||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# Inserts _post_init_method at the end of init method
|
||||
|
||||
|
@ -103,8 +107,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
shard_strategy: BaseShardStrategy,
|
||||
shard_param: bool = False,
|
||||
shard_grad: bool = False,
|
||||
rm_torch_payload_on_the_fly=False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int)):
|
||||
rm_torch_payload_on_the_fly: bool = False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
|
||||
dp_process_group: Optional[ProcessGroup] = None):
|
||||
super().__init__()
|
||||
self.convert_fp16 = convert_fp16
|
||||
self.target_device = target_device
|
||||
|
@ -115,6 +120,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self.rm_torch_payload_on_the_fly = False
|
||||
self.initialized_param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
def _post_context_exec(self):
|
||||
"""The callback function when the context exits.
|
||||
|
@ -154,10 +160,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self.initialized_param_list.append(param)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||
self.shard_strategy.shard([param.col_attr._data_sharded_tensor], self.dp_process_group)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
# if param.col_attr.grad and self.shard_grad:
|
||||
# self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
|
|
Loading…
Reference in New Issue