mirror of https://github.com/hpcaitech/ColossalAI
[nfc]fix typo colossalai/pipeline tensor nn (#3899)
* fix typo colossalai/autochunk auto_parallel amp * fix typo colossalai/auto_parallel nn utils etc. * fix typo colossalai/auto_parallel autochunk fx/passes etc. * fix typo docs/ * change placememt_policy to placement_policy in docs/ and examples/ * fix typo colossalai/ applications/ * fix typo colossalai/cli fx kernel * fix typo colossalai/nn * revert change warmuped * fix typo colossalai/pipeline tensor nnpull/3908/head
parent
c1535ccbba
commit
0e484e6201
|
@ -13,7 +13,7 @@ from .nvme_optimizer import NVMeOptimizer
|
||||||
class CPUAdam(NVMeOptimizer):
|
class CPUAdam(NVMeOptimizer):
|
||||||
"""Implements Adam algorithm.
|
"""Implements Adam algorithm.
|
||||||
|
|
||||||
Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
|
Supports parameters updating on both GPU and CPU, depending on the device of parameters.
|
||||||
But the parameters and gradients should on the same device:
|
But the parameters and gradients should on the same device:
|
||||||
* Parameters on CPU and gradients on CPU is allowed.
|
* Parameters on CPU and gradients on CPU is allowed.
|
||||||
* Parameters on GPU and gradients on GPU is allowed.
|
* Parameters on GPU and gradients on GPU is allowed.
|
||||||
|
|
|
@ -14,7 +14,7 @@ from .cpu_adam import CPUAdam
|
||||||
class HybridAdam(CPUAdam):
|
class HybridAdam(CPUAdam):
|
||||||
"""Implements Adam algorithm.
|
"""Implements Adam algorithm.
|
||||||
|
|
||||||
Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
|
Supports parameters updating on both GPU and CPU, depending on the device of parameters.
|
||||||
But the parameters and gradients should on the same device:
|
But the parameters and gradients should on the same device:
|
||||||
* Parameters on CPU and gradients on CPU is allowed.
|
* Parameters on CPU and gradients on CPU is allowed.
|
||||||
* Parameters on GPU and gradients on GPU is allowed.
|
* Parameters on GPU and gradients on GPU is allowed.
|
||||||
|
|
|
@ -83,7 +83,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if isinstance(v, torch.nn.Module):
|
if isinstance(v, torch.nn.Module):
|
||||||
v = self._layer_spec_dict[id(v)]
|
v = self._layer_spec_dict[id(v)]
|
||||||
# (lyl)TODO: analyse ColoTensor as well
|
# (lyl)TODO: analyze ColoTensor as well
|
||||||
modified_kwargs[k] = v
|
modified_kwargs[k] = v
|
||||||
|
|
||||||
# keep track of the module children
|
# keep track of the module children
|
||||||
|
@ -117,7 +117,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
def to_layer_list(self, exec_seq=None):
|
def to_layer_list(self, exec_seq=None):
|
||||||
"""
|
"""
|
||||||
Create a layer spec list and func list with execution sequence given by user.
|
Create a layer spec list and func list with execution sequence given by user.
|
||||||
If exec_seq is None, we will take the module initizing order as execution order.
|
If exec_seq is None, we will take the module initializing order as execution order.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._exec_seq = exec_seq
|
self._exec_seq = exec_seq
|
||||||
|
@ -177,7 +177,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
def partition(self, num_chunks, pipeline_size, rank):
|
def partition(self, num_chunks, pipeline_size, rank):
|
||||||
"""
|
"""
|
||||||
Partitioned model will be built respect to partion policy.
|
Partitioned model will be built respect to partition policy.
|
||||||
The real module instance will be built in this method.
|
The real module instance will be built in this method.
|
||||||
"""
|
"""
|
||||||
if isinstance(self._policy, str):
|
if isinstance(self._policy, str):
|
||||||
|
@ -193,7 +193,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
self.customized_parts = customized_partition(self._exec_seq)
|
self.customized_parts = customized_partition(self._exec_seq)
|
||||||
assert len(self.customized_parts) == gpc.get_world_size(
|
assert len(self.customized_parts) == gpc.get_world_size(
|
||||||
ParallelMode.PIPELINE
|
ParallelMode.PIPELINE
|
||||||
), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partions is {len(self.customized_parts)}'
|
), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}'
|
||||||
parts = self.customized_parts[rank]
|
parts = self.customized_parts[rank]
|
||||||
else:
|
else:
|
||||||
raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].")
|
raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].")
|
||||||
|
|
|
@ -123,7 +123,7 @@ class WorkerBase(ABC):
|
||||||
self.device = device
|
self.device = device
|
||||||
self._initialize_outstanding_range()
|
self._initialize_outstanding_range()
|
||||||
|
|
||||||
# variable and const for context managment
|
# variable and const for context management
|
||||||
self.outstanding = 0
|
self.outstanding = 0
|
||||||
self.forward_times = 0
|
self.forward_times = 0
|
||||||
self.backward_times = 0
|
self.backward_times = 0
|
||||||
|
@ -226,7 +226,7 @@ class WorkerBase(ABC):
|
||||||
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
|
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
|
||||||
|
|
||||||
# for some schedule need the other worker's info to initialise partition (like Chimera)
|
# for some schedule need the other worker's info to initialise partition (like Chimera)
|
||||||
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
# construction of partition is executed after the registration of pp_rank_to_worker_rref
|
||||||
self._initialize_partition()
|
self._initialize_partition()
|
||||||
|
|
||||||
# res_use works for lifecycle counter,
|
# res_use works for lifecycle counter,
|
||||||
|
@ -418,7 +418,7 @@ class WorkerBase(ABC):
|
||||||
# On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer
|
# On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer
|
||||||
# can only be executed once for every producer-consumer stage pair, which is necessary
|
# can only be executed once for every producer-consumer stage pair, which is necessary
|
||||||
# to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same
|
# to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same
|
||||||
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
|
# lock of work_item queue operation guarantees the consistency of lifecycle counter.
|
||||||
work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only)
|
work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only)
|
||||||
self.work_list[key] = work_item_from_producer
|
self.work_list[key] = work_item_from_producer
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
@ -460,7 +460,7 @@ class WorkerBase(ABC):
|
||||||
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
|
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
|
||||||
# can only be executed once for every producer-consumer stage pair, which is necessary
|
# can only be executed once for every producer-consumer stage pair, which is necessary
|
||||||
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
|
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
|
||||||
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
|
# lock of work_item queue operation guarantees the consistency of lifecycle counter.
|
||||||
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
|
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
|
||||||
self.work_list[key] = work_item_from_consumer
|
self.work_list[key] = work_item_from_consumer
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
@ -508,7 +508,7 @@ class WorkerBase(ABC):
|
||||||
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
||||||
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
||||||
|
|
||||||
# should be aranged in order, the order of the input of current forward
|
# should be arranged in order, the order of the input of current forward
|
||||||
self.producer_stage_ids = self.get_producer_stage_ids()
|
self.producer_stage_ids = self.get_producer_stage_ids()
|
||||||
self.consumer_stage_ids = self.get_consumer_stage_ids()
|
self.consumer_stage_ids = self.get_consumer_stage_ids()
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,7 @@ class ChimeraWorker(WorkerBase):
|
||||||
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
||||||
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
||||||
|
|
||||||
# should be aranged in order, the order of the input of current forward
|
# should be arranged in order, the order of the input of current forward
|
||||||
self.producer_stage_ids = []
|
self.producer_stage_ids = []
|
||||||
self.consumer_stage_ids = []
|
self.consumer_stage_ids = []
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ class ChimeraWorker(WorkerBase):
|
||||||
else:
|
else:
|
||||||
# if it is down pipeline, create partition by origin method
|
# if it is down pipeline, create partition by origin method
|
||||||
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
|
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
|
||||||
# get the coresponding model state dict and wait for its init
|
# get the corresponding model state dict and wait for its init
|
||||||
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
|
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
|
||||||
super()._initialize_partition()
|
super()._initialize_partition()
|
||||||
self.module_partition.load_state_dict(state_dict)
|
self.module_partition.load_state_dict(state_dict)
|
||||||
|
@ -228,7 +228,7 @@ class ChimeraWorker(WorkerBase):
|
||||||
stage_num = self.actual_stage_num
|
stage_num = self.actual_stage_num
|
||||||
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
|
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
|
||||||
|
|
||||||
# if currrent pp_rank is not the first to do step
|
# if current pp_rank is not the first to do step
|
||||||
# wait its previous pp_rank finish step
|
# wait its previous pp_rank finish step
|
||||||
grads = self.get_parameter_gradients()
|
grads = self.get_parameter_gradients()
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,7 @@ def _binary_search(weights, num):
|
||||||
|
|
||||||
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||||
assert num_items % num_chunks == 0, \
|
assert num_items % num_chunks == 0, \
|
||||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
"Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||||
|
|
|
@ -28,7 +28,7 @@ class CommSpec:
|
||||||
to determine the buffer shape, and logical_process_axis
|
to determine the buffer shape, and logical_process_axis
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
|
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
|
||||||
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||||
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
||||||
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
||||||
|
|
|
@ -41,7 +41,7 @@ class DimSpec:
|
||||||
|
|
||||||
def _convert_str_to_shard_list(self, str_spec):
|
def _convert_str_to_shard_list(self, str_spec):
|
||||||
'''
|
'''
|
||||||
Conver str_spec into shard_list.
|
Convert str_spec into shard_list.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
str_spec(str): dim spec in str type.
|
str_spec(str): dim spec in str type.
|
||||||
|
@ -58,7 +58,7 @@ class DimSpec:
|
||||||
|
|
||||||
def build_difference_2d_dict(self):
|
def build_difference_2d_dict(self):
|
||||||
'''
|
'''
|
||||||
Build a difference maping for 2D device mesh case. It will be used to
|
Build a difference mapping for 2D device mesh case. It will be used to
|
||||||
compute the difference between DimSpec pairs.
|
compute the difference between DimSpec pairs.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
|
@ -164,7 +164,7 @@ def _get_grad_args(*args):
|
||||||
for obj in args:
|
for obj in args:
|
||||||
if _is_grad_tensor(obj):
|
if _is_grad_tensor(obj):
|
||||||
return args, None
|
return args, None
|
||||||
# otherwise, the first arguement should be a tuple of grad tensors
|
# otherwise, the first argument should be a tuple of grad tensors
|
||||||
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
|
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
|
||||||
arg_zero = args[0]
|
arg_zero = args[0]
|
||||||
if not isinstance(arg_zero, tuple):
|
if not isinstance(arg_zero, tuple):
|
||||||
|
|
|
@ -130,7 +130,7 @@ class ProcessGroup:
|
||||||
@property
|
@property
|
||||||
def has_cpu_groups(self) -> bool:
|
def has_cpu_groups(self) -> bool:
|
||||||
"""has_cpu_groups
|
"""has_cpu_groups
|
||||||
If cpu groups have been initailized.
|
If cpu groups have been initialized.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: cpu process groups have been initialized or not.
|
bool: cpu process groups have been initialized or not.
|
||||||
|
|
|
@ -252,7 +252,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
|
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
|
||||||
'''
|
'''
|
||||||
Get all valid sharding specs from source_spec with single shard operation, and
|
Get all valid sharding specs from source_spec with single shard operation, and
|
||||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
||||||
For the sharding operation, we just care about legal sharding dimensions.
|
For the sharding operation, we just care about legal sharding dimensions.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
|
@ -386,7 +386,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
|
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
|
||||||
'''
|
'''
|
||||||
Get all valid sharding specs from source_spec with one step transform, and
|
Get all valid sharding specs from source_spec with one step transform, and
|
||||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
||||||
Note:
|
Note:
|
||||||
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
|
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
|
||||||
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
|
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
|
||||||
|
@ -577,7 +577,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
Step3:
|
Step3:
|
||||||
Repeat above steps until the source spec transform to target spec.
|
Repeat above steps until the source spec transform to target spec.
|
||||||
|
|
||||||
During finding the transform path, commucation cost will be accumulated, and it
|
During finding the transform path, communication cost will be accumulated, and it
|
||||||
will be finally used in auto parallel solver.
|
will be finally used in auto parallel solver.
|
||||||
|
|
||||||
Additionally, to avoid repeating the path search in runtime, we cached all solved path
|
Additionally, to avoid repeating the path search in runtime, we cached all solved path
|
||||||
|
|
|
@ -45,7 +45,7 @@ class _DimSpec:
|
||||||
|
|
||||||
def _convert_str_to_shard_list(self, str_spec):
|
def _convert_str_to_shard_list(self, str_spec):
|
||||||
'''
|
'''
|
||||||
Conver str_spec into shard_list.
|
Convert str_spec into shard_list.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
str_spec(str): dim spec in str type.
|
str_spec(str): dim spec in str type.
|
||||||
|
@ -62,7 +62,7 @@ class _DimSpec:
|
||||||
|
|
||||||
def build_difference_2d_dict(self):
|
def build_difference_2d_dict(self):
|
||||||
'''
|
'''
|
||||||
Build a difference maping for 2D device mesh case. It will be used to
|
Build a difference mapping for 2D device mesh case. It will be used to
|
||||||
compute the difference between DimSpec pairs.
|
compute the difference between DimSpec pairs.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ class ShardingSpec:
|
||||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||||
entire_shape(torch.Size): The entire shape of tensor before sharded.
|
entire_shape(torch.Size): The entire shape of tensor before sharded.
|
||||||
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
|
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
|
||||||
and the value of the key decribe which logical axis will be sharded in that dimension.
|
and the value of the key describe which logical axis will be sharded in that dimension.
|
||||||
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
|
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||||
and the second element decribes which logical axis will be sharded in that dimension.
|
and the second element describes which logical axis will be sharded in that dimension.
|
||||||
'''
|
'''
|
||||||
_, shard_list = target_pair
|
_, shard_list = target_pair
|
||||||
shard_list_list = []
|
shard_list_list = []
|
||||||
|
|
Loading…
Reference in New Issue