mirror of https://github.com/hpcaitech/ColossalAI
fix some typo with colossalai/device colossalai/tensor/ etc. (#4171)
Co-authored-by: flybird11111 <1829166702@qq.com>pull/4127/head^2
parent
d8ceeac14e
commit
9c2feb2f0b
|
@ -59,7 +59,7 @@ class DeviceMesh:
|
||||||
# 2. directly supply the logical mesh id
|
# 2. directly supply the logical mesh id
|
||||||
assert mesh_shape is None or logical_mesh_id is None, \
|
assert mesh_shape is None or logical_mesh_id is None, \
|
||||||
"Only one of mesh_shape and logical_mesh_id can be specified." \
|
"Only one of mesh_shape and logical_mesh_id can be specified." \
|
||||||
"Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
|
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
|
||||||
|
|
||||||
if logical_mesh_id is None:
|
if logical_mesh_id is None:
|
||||||
self._mesh_shape = mesh_shape
|
self._mesh_shape = mesh_shape
|
||||||
|
@ -74,7 +74,7 @@ class DeviceMesh:
|
||||||
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
|
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
|
||||||
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
|
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
|
||||||
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
|
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
|
||||||
"Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
|
"Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
|
||||||
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
|
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
|
||||||
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
|
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ class DeviceMesh:
|
||||||
self._global_rank_of_current_process = None
|
self._global_rank_of_current_process = None
|
||||||
self._is_initialized = False
|
self._is_initialized = False
|
||||||
|
|
||||||
# attribute used to inidicate whether this objectd
|
# attribute used to indicate whether this object
|
||||||
# is created using DeviceMesh.from_process_group
|
# is created using DeviceMesh.from_process_group
|
||||||
# this attribute can be used to do some check in methods
|
# this attribute can be used to do some check in methods
|
||||||
# such get_process_group as no global rank information
|
# such get_process_group as no global rank information
|
||||||
|
@ -395,7 +395,7 @@ class DeviceMesh:
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
sphysical_mesh_id = torch.arange(0, 16)
|
physical_mesh_id = torch.arange(0, 16)
|
||||||
mesh_shape = (4, 4)
|
mesh_shape = (4, 4)
|
||||||
|
|
||||||
# logical mesh will look like
|
# logical mesh will look like
|
||||||
|
@ -438,7 +438,7 @@ class DeviceMesh:
|
||||||
# the _local_rank refers to the local rank of the current process
|
# the _local_rank refers to the local rank of the current process
|
||||||
for _local_rank in range(self.logical_mesh_id.shape[dim]):
|
for _local_rank in range(self.logical_mesh_id.shape[dim]):
|
||||||
|
|
||||||
# if this dimension is not initailized yet,
|
# if this dimension is not initialized yet,
|
||||||
# initialize it with an empty array
|
# initialize it with an empty array
|
||||||
if dim not in processes_in_the_same_process_group:
|
if dim not in processes_in_the_same_process_group:
|
||||||
processes_in_the_same_process_group[dim] = []
|
processes_in_the_same_process_group[dim] = []
|
||||||
|
@ -447,7 +447,7 @@ class DeviceMesh:
|
||||||
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
|
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
|
||||||
|
|
||||||
# replace the local rank in the given dimension with the
|
# replace the local rank in the given dimension with the
|
||||||
# lcoal rank of the current process iterated
|
# local rank of the current process iterated
|
||||||
process_coordinates[dim] = _local_rank
|
process_coordinates[dim] = _local_rank
|
||||||
processes_in_the_same_process_group[dim].append(process_coordinates)
|
processes_in_the_same_process_group[dim].append(process_coordinates)
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ class CommSpec:
|
||||||
to determine the buffer shape, and logical_process_axis
|
to determine the buffer shape, and logical_process_axis
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
|
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
|
||||||
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||||
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
||||||
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
||||||
|
|
|
@ -339,7 +339,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
RS01 -> RR
|
RS01 -> RR
|
||||||
'''
|
'''
|
||||||
valid_spec_dict = {}
|
valid_spec_dict = {}
|
||||||
comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
|
comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
|
||||||
tensor_dims = len(source_spec.entire_shape)
|
tensor_dims = len(source_spec.entire_shape)
|
||||||
for f_index in range(tensor_dims - 1):
|
for f_index in range(tensor_dims - 1):
|
||||||
for b_index in range(f_index + 1, tensor_dims):
|
for b_index in range(f_index + 1, tensor_dims):
|
||||||
|
@ -362,7 +362,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
b_target_pair = (b_index, [])
|
b_target_pair = (b_index, [])
|
||||||
|
|
||||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||||
comm_spec = CommSpec(comm_pathern,
|
comm_spec = CommSpec(comm_pattern,
|
||||||
sharding_spec=source_spec,
|
sharding_spec=source_spec,
|
||||||
gather_dim=gather_dim,
|
gather_dim=gather_dim,
|
||||||
logical_process_axis=logical_process_axes,
|
logical_process_axis=logical_process_axes,
|
||||||
|
|
|
@ -43,7 +43,7 @@ def data_gen_for_t5_model():
|
||||||
# output transform function
|
# output transform function
|
||||||
output_transform_fn = lambda x: x
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
# define loss funciton
|
# define loss function
|
||||||
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
|
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
|
||||||
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
|
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
|
||||||
loss_fn_for_conditional_generation = lambda x: x.loss
|
loss_fn_for_conditional_generation = lambda x: x.loss
|
||||||
|
|
|
@ -64,7 +64,7 @@ def check_torch_ddp_no_sync():
|
||||||
model = DummyModel()
|
model = DummyModel()
|
||||||
criterion = lambda x: x.mean()
|
criterion = lambda x: x.mean()
|
||||||
optimizer = SGD(model.parameters(), lr=1e-3)
|
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||||
# create a custom dasetset with 0 to 10
|
# create a custom dataset with 0 to 10
|
||||||
dataset = torch.arange(0, 10)
|
dataset = torch.arange(0, 10)
|
||||||
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
|
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
|
||||||
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
|
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
|
||||||
|
|
|
@ -15,7 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
# test baisc fsdp function
|
# test basic fsdp function
|
||||||
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||||
plugin = TorchFSDPPlugin()
|
plugin = TorchFSDPPlugin()
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
Loading…
Reference in New Issue