mirror of https://github.com/InternLM/InternLM
merge reduce-scatter
commit
1804d01bb3
|
|
@ -152,19 +152,19 @@ zero1 parallel (dict):
|
||||||
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
|
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
|
||||||
tensor parallel (dict):
|
tensor parallel (dict):
|
||||||
1. size: int, the size of tensor parallel.
|
1. size: int, the size of tensor parallel.
|
||||||
2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp',
|
2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
|
||||||
the sequence_parallel should be True.
|
defaults to 'none', means the sequence parallel will be disabled.
|
||||||
|
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
|
||||||
|
defaults to False.
|
||||||
pipeline parallel (dict):
|
pipeline parallel (dict):
|
||||||
1. size: int, the size of pipeline parallel.
|
1. size: int, the size of pipeline parallel.
|
||||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
|
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
|
||||||
defaults to False.
|
defaults to False.
|
||||||
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=-1, fsdp=False),
|
zero1=dict(size=-1, fsdp=False),
|
||||||
tensor=dict(size=8, sp="intern", intern_overlap=True),
|
tensor=dict(size=8, sp="megatron", intern_overlap=True),
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cudnn_deterministic = False
|
cudnn_deterministic = False
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,6 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
_output, _loss, _moe_loss = self._train_one_batch(
|
_output, _loss, _moe_loss = self._train_one_batch(
|
||||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||||
)
|
)
|
||||||
engine.optimizer.reset_reduce_bucket()
|
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss += _loss
|
loss += _loss
|
||||||
|
|
|
||||||
|
|
@ -306,15 +306,20 @@ def args_sanity_check():
|
||||||
), "sequence parallel does not support use_flash_attn=False"
|
), "sequence parallel does not support use_flash_attn=False"
|
||||||
|
|
||||||
if isinstance(gpc.config.parallel["tensor"], int):
|
if isinstance(gpc.config.parallel["tensor"], int):
|
||||||
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp")
|
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False)
|
||||||
|
if gpc.config.parallel["tensor"].get("sp", None) is None:
|
||||||
if gpc.config.parallel["tensor"].get("mode", None) is None:
|
gpc.config.parallel["tensor"]["sp"] = "none"
|
||||||
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
|
if gpc.config.parallel["tensor"].get("intern_overlap", None) is None:
|
||||||
|
gpc.config.parallel["tensor"]["intern_overlap"] = False
|
||||||
if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
|
assert gpc.config.parallel["tensor"].get("sp", None) in [
|
||||||
assert (
|
"none",
|
||||||
gpc.config.parallel.sequence_parallel is True
|
"megatron",
|
||||||
), "when the tp_mode is fstp, the sequence_parallel should be True."
|
"flash-attn",
|
||||||
|
"intern",
|
||||||
|
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
|
||||||
|
# adapt to old version's sequence parallel config
|
||||||
|
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
|
||||||
|
gpc.config.parallel.sequence_parallel = True
|
||||||
|
|
||||||
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
||||||
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
|
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
|
||||||
|
|
|
||||||
|
|
@ -456,10 +456,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||||
if child.bias is not None:
|
if child.bias is not None:
|
||||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||||
# _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
|
||||||
# setattr(child.weight, "_fstp_all_reduce_str", f"{_full_name}.weight")
|
|
||||||
# if child.bias is not None:
|
|
||||||
# setattr(child.bias, "_fstp_all_reduce_str", f"{_full_name}.bias")
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
elif isinstance(children, ScaleColumnParallelLinear):
|
elif isinstance(children, ScaleColumnParallelLinear):
|
||||||
|
|
@ -473,27 +469,15 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
|
|
||||||
return self.zero_const_pool[size]
|
return self.zero_const_pool[size]
|
||||||
|
|
||||||
|
|
||||||
def _all_gather_block_weight(self, block_index: int):
|
|
||||||
#block = self.index_to_block[block_index]
|
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
|
||||||
# self.block_handles[block] = []
|
|
||||||
for module in fsdp_modules:
|
|
||||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
|
||||||
self.FSTP_global_weights[module] = total_weight
|
|
||||||
self.FSTP_global_handle[module] = weight_handle
|
|
||||||
# self.block_handles[block].append(weight_handle)
|
|
||||||
|
|
||||||
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
# self.block_handles[block] = []
|
|
||||||
for module in fsdp_modules:
|
for module in fsdp_modules:
|
||||||
module_index = self.module_name_index[module]
|
module_index = self.module_name_index[module]
|
||||||
name = self.module_name[module_index]
|
name = self.module_name[module_index]
|
||||||
weight_handle = all_gather_raw_memory_pool(module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name)
|
weight_handle = all_gather_raw_memory_pool(
|
||||||
# self.FSTP_global_weights[module] = total_weight
|
module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
|
||||||
|
)
|
||||||
self.FSTP_global_handle[module] = weight_handle
|
self.FSTP_global_handle[module] = weight_handle
|
||||||
# self.block_handles[block].append(weight_handle)
|
|
||||||
|
|
||||||
def _register_sync_parameters_hook(self) -> None:
|
def _register_sync_parameters_hook(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -510,41 +494,14 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
# start the all-gather for next block
|
# start the all-gather for next block
|
||||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||||
# self._all_gather_block_weight(block_index + 1)
|
|
||||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||||
|
|
||||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
|
||||||
block_index = self.block_to_index[block]
|
|
||||||
if block_index == 0:
|
|
||||||
# all gather weight for block 0
|
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
|
||||||
for module in fsdp_modules:
|
|
||||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
|
||||||
weight_handle.wait()
|
|
||||||
self.FSTP_global_weights[module] = total_weight
|
|
||||||
else:
|
|
||||||
# wait handle for current block
|
|
||||||
handles = self.block_handles[block]
|
|
||||||
for handle in handles:
|
|
||||||
handle.wait()
|
|
||||||
|
|
||||||
def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
|
|
||||||
# self._all_gather_block_weight(0)
|
|
||||||
self._all_gather_block_weight_memory_pool(0)
|
self._all_gather_block_weight_memory_pool(0)
|
||||||
|
|
||||||
|
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||||
def _post_forward_hook_for_block(block: nn.Module, input, output):
|
handle = self.FSTP_global_handle[module]
|
||||||
block_index = self.block_to_index[block]
|
handle.wait()
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
|
||||||
if block in self.block_handles:
|
|
||||||
del self.block_handles[block]
|
|
||||||
for module in fsdp_modules:
|
|
||||||
del self.FSTP_global_weights[module]
|
|
||||||
|
|
||||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,):
|
|
||||||
block_index = self.module_to_index[module]
|
|
||||||
handler = self.FSTP_global_handle[module]
|
|
||||||
handler.wait()
|
|
||||||
|
|
||||||
def _post_forward_hook_for_module(module: nn.Module, input, output):
|
def _post_forward_hook_for_module(module: nn.Module, input, output):
|
||||||
if module in self.FSTP_global_weights:
|
if module in self.FSTP_global_weights:
|
||||||
|
|
@ -552,56 +509,29 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
if module in self.FSTP_global_handle:
|
if module in self.FSTP_global_handle:
|
||||||
del self.FSTP_global_handle[module]
|
del self.FSTP_global_handle[module]
|
||||||
|
|
||||||
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
block_index = self.block_to_index[block]
|
|
||||||
# if block_index == gpc.config.NUM_LAYER - 1:
|
|
||||||
# # all gather weight for the last block
|
|
||||||
# fsdp_modules = self.index_to_fsdp_modules[block_index]
|
|
||||||
# for module in fsdp_modules:
|
|
||||||
# total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
|
||||||
# weight_handle.wait()
|
|
||||||
# self.FSTP_global_weights[module] = total_weight
|
|
||||||
# else:
|
|
||||||
# # wait handle for current block
|
|
||||||
# handles = self.block_handles[block]
|
|
||||||
# for handle in handles:
|
|
||||||
# handle.wait()
|
|
||||||
# if block_index == gpc.config.NUM_LAYER - 1:
|
|
||||||
# self._all_gather_block_weight(block_index)
|
|
||||||
# start the all-gather for next block
|
|
||||||
if block_index - 1 >= 0:
|
|
||||||
self._all_gather_block_weight(block_index - 1)
|
|
||||||
|
|
||||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||||
first_module = self.block_module[gpc.config.NUM_LAYER - 1][4]
|
first_module = self.block_module[gpc.config.NUM_LAYER - 1][4]
|
||||||
total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True)
|
total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True)
|
||||||
self.FSTP_global_handle[first_module] = weight_handler
|
self.FSTP_global_handle[first_module] = weight_handler
|
||||||
self.FSTP_global_weights[first_module] = total_weight
|
self.FSTP_global_weights[first_module] = total_weight
|
||||||
|
|
||||||
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
|
||||||
block_index = self.block_to_index[block]
|
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
|
||||||
if block in self.block_handles:
|
|
||||||
del self.block_handles[block]
|
|
||||||
for module in fsdp_modules:
|
|
||||||
del self.FSTP_global_weights[module]
|
|
||||||
|
|
||||||
def _pre_backward_hook_for_module_memory_pool(module: nn.Module, grad_output):
|
def _pre_backward_hook_for_module_memory_pool(module: nn.Module, grad_output):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
name_index = self.module_name_index[module]
|
name_index = self.module_name_index[module]
|
||||||
|
|
||||||
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
||||||
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
|
||||||
weight_handler = self.FSTP_global_handle[module]
|
weight_handler = self.FSTP_global_handle[module]
|
||||||
weight_handler.wait()
|
weight_handler.wait()
|
||||||
# self.FSTP_global_weights[module] = total_weight
|
|
||||||
|
|
||||||
# start the all-gather for next module
|
# start the all-gather for next module
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
next_name = self.module_name[name_index - 1]
|
next_name = self.module_name[name_index - 1]
|
||||||
weights_handler = all_gather_raw_memory_pool(
|
weights_handler = all_gather_raw_memory_pool(
|
||||||
next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=next_name
|
next_module.weight,
|
||||||
|
self.process_group,
|
||||||
|
async_op=True,
|
||||||
|
block_index=block_index,
|
||||||
|
module_name=next_name,
|
||||||
)
|
)
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
elif name_index == 0:
|
elif name_index == 0:
|
||||||
|
|
@ -612,7 +542,11 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
next_module = self.block_module[block_index - 1][4]
|
next_module = self.block_module[block_index - 1][4]
|
||||||
name = self.module_name[4]
|
name = self.module_name[4]
|
||||||
weights_handler = all_gather_raw_memory_pool(
|
weights_handler = all_gather_raw_memory_pool(
|
||||||
next_module.weight, self.process_group, async_op=True, block_index=block_index - 1, module_name=name,
|
next_module.weight,
|
||||||
|
self.process_group,
|
||||||
|
async_op=True,
|
||||||
|
block_index=block_index - 1,
|
||||||
|
module_name=name,
|
||||||
)
|
)
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
else:
|
else:
|
||||||
|
|
@ -625,48 +559,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
|
next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
|
||||||
)
|
)
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
# if module in self.FSTP_global_handle:
|
|
||||||
# handler = self.FSTP_global_handle[module]
|
|
||||||
# handler.wait()
|
|
||||||
|
|
||||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
|
||||||
block_index = self.module_to_index[module]
|
|
||||||
name_index = self.module_name_index[module]
|
|
||||||
|
|
||||||
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
|
||||||
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
|
||||||
weight_handler = self.FSTP_global_handle[module]
|
|
||||||
weight_handler.wait()
|
|
||||||
# self.FSTP_global_weights[module] = total_weight
|
|
||||||
|
|
||||||
# start the all-gather for next module
|
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
|
||||||
elif name_index == 0:
|
|
||||||
handler = self.FSTP_global_handle[module]
|
|
||||||
handler.wait()
|
|
||||||
|
|
||||||
if block_index - 1 >= 0:
|
|
||||||
next_module = self.block_module[block_index - 1][4]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
|
||||||
else:
|
|
||||||
handler = self.FSTP_global_handle[module]
|
|
||||||
handler.wait()
|
|
||||||
if name_index != 0:
|
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
|
||||||
# if module in self.FSTP_global_handle:
|
|
||||||
# handler = self.FSTP_global_handle[module]
|
|
||||||
# handler.wait()
|
|
||||||
|
|
||||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||||
if module in self.FSTP_global_weights:
|
if module in self.FSTP_global_weights:
|
||||||
|
|
@ -675,26 +567,16 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
del self.FSTP_global_handle[module]
|
del self.FSTP_global_handle[module]
|
||||||
|
|
||||||
for embedding in self.embedding:
|
for embedding in self.embedding:
|
||||||
embedding.register_forward_hook(_pre_forward_hook_for_embedding)
|
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
||||||
|
|
||||||
for head in self.head:
|
for head in self.head:
|
||||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||||
|
|
||||||
# for block in self.FSTP_blocks:
|
|
||||||
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
|
||||||
# block.register_forward_hook(_post_forward_hook_for_block)
|
|
||||||
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
|
||||||
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
|
||||||
|
|
||||||
for out_proj in self.FSTP_outs:
|
for out_proj in self.FSTP_outs:
|
||||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||||
|
|
||||||
# for wqkv in self.FSTP_wqkvs:
|
|
||||||
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
|
||||||
|
|
||||||
for module in self.FSTP_modules:
|
for module in self.FSTP_modules:
|
||||||
module.register_forward_pre_hook(_pre_forward_hook_for_module)
|
module.register_forward_pre_hook(_pre_forward_hook_for_module)
|
||||||
module.register_forward_hook(_post_forward_hook_for_module)
|
module.register_forward_hook(_post_forward_hook_for_module)
|
||||||
# module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
|
|
||||||
module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool)
|
module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool)
|
||||||
module.register_full_backward_hook(_post_backward_hook_for_module)
|
module.register_full_backward_hook(_post_backward_hook_for_module)
|
||||||
|
|
|
||||||
|
|
@ -396,7 +396,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
assert len(indexes) == 1
|
assert len(indexes) == 1
|
||||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||||
indexes = indexes[0]
|
indexes = indexes[0]
|
||||||
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
|
# if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension.
|
||||||
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
|
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
|
||||||
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import fused_dense_lib as fused_dense_cuda
|
import fused_dense_lib as fused_dense_cuda
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from flash_attn.utils.distributed import all_reduce_raw #, reduce_scatter_raw
|
from flash_attn.utils.distributed import all_reduce_raw
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.logger import get_logger
|
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
@ -125,9 +125,20 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
|
||||||
)
|
)
|
||||||
return output, handle
|
return output, handle
|
||||||
|
|
||||||
def all_gather_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, gather_dim: int = 0, block_index: int = None, module_name: str = None):
|
|
||||||
|
def all_gather_raw_memory_pool(
|
||||||
|
input_: Tensor,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
async_op: bool = False,
|
||||||
|
gather_dim: int = 0,
|
||||||
|
block_index: int = None,
|
||||||
|
module_name: str = None,
|
||||||
|
):
|
||||||
handle = torch.distributed.all_gather_into_tensor(
|
handle = torch.distributed.all_gather_into_tensor(
|
||||||
gpc.config.block_memory[block_index % 2][module_name], input_.contiguous(), group=process_group, async_op=async_op
|
gpc.config.block_memory[block_index % 2][module_name],
|
||||||
|
input_.contiguous(),
|
||||||
|
group=process_group,
|
||||||
|
async_op=async_op,
|
||||||
)
|
)
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
|
@ -142,23 +153,25 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
|
||||||
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||||
world_size = torch.distributed.get_world_size(process_group)
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
assert input_.shape[0] % world_size == 0
|
assert input_.shape[0] % world_size == 0
|
||||||
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:],
|
output = torch.empty(
|
||||||
dtype=input_.dtype, device=input_.device).contiguous()
|
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
||||||
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
).contiguous()
|
||||||
group=process_group,
|
handle = torch.distributed.reduce_scatter_tensor(
|
||||||
async_op=async_op)
|
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||||
|
)
|
||||||
return output, handle
|
return output, handle
|
||||||
|
|
||||||
|
|
||||||
def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||||
world_size = torch.distributed.get_world_size(process_group)
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
assert input_.shape[0] % world_size == 0
|
assert input_.shape[0] % world_size == 0
|
||||||
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||||
index = check_reduce_scatter_memory_pool(size)
|
index = check_reduce_scatter_memory_pool(size)
|
||||||
output = gpc.config.reduce_scatter_memory[size]['data'][index]
|
output = gpc.config.reduce_scatter_memory[size]["data"][index]
|
||||||
setattr(output, "index", index)
|
setattr(output, "index", index)
|
||||||
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
handle = torch.distributed.reduce_scatter_tensor(
|
||||||
group=process_group,
|
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||||
async_op=async_op)
|
)
|
||||||
return output, handle
|
return output, handle
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -444,7 +457,18 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, overlap_handler=None, block_index=None, module_name=None):
|
def forward(
|
||||||
|
ctx,
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
return_residual=False,
|
||||||
|
process_group=None,
|
||||||
|
module=None,
|
||||||
|
overlap_handler=None,
|
||||||
|
block_index=None,
|
||||||
|
module_name=None,
|
||||||
|
):
|
||||||
ctx.compute_weight_gradient = weight.requires_grad
|
ctx.compute_weight_gradient = weight.requires_grad
|
||||||
ctx.return_residual = return_residual
|
ctx.return_residual = return_residual
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
|
|
@ -540,7 +564,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
||||||
grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
|
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(
|
||||||
|
grad_bias, process_group, async_op=True
|
||||||
|
)
|
||||||
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
||||||
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
||||||
grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||||
|
|
@ -619,7 +645,9 @@ def fstp_fused_dense_func(
|
||||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||||
)
|
)
|
||||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||||
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler, block_index, module_name)
|
return FSTPFusedDenseFunc.apply(
|
||||||
|
x, weight, bias, return_residual, process_group, module, handler, block_index, module_name
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert process_group is None
|
assert process_group is None
|
||||||
out = F.linear(x, weight, bias)
|
out = F.linear(x, weight, bias)
|
||||||
|
|
@ -666,36 +694,37 @@ def Silu(w1_o, w2_o):
|
||||||
|
|
||||||
Silu = torch.jit.script(Silu)
|
Silu = torch.jit.script(Silu)
|
||||||
|
|
||||||
def check_reduce_scatter_memory_pool(key):
|
|
||||||
|
|
||||||
|
def check_reduce_scatter_memory_pool(key):
|
||||||
return_idx = 0
|
return_idx = 0
|
||||||
|
|
||||||
# if key not in dict
|
# if key not in dict
|
||||||
if key not in gpc.config.reduce_scatter_memory:
|
if key not in gpc.config.reduce_scatter_memory:
|
||||||
gpc.config.reduce_scatter_memory[key] = {'data': [], 'used': []}
|
gpc.config.reduce_scatter_memory[key] = {"data": [], "used": []}
|
||||||
|
|
||||||
# if the data is empty
|
# if the data is empty
|
||||||
if len(gpc.config.reduce_scatter_memory[key]['data']) == 0:
|
if len(gpc.config.reduce_scatter_memory[key]["data"]) == 0:
|
||||||
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
|
gpc.config.reduce_scatter_memory[key]["data"].append(
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
|
||||||
device=get_current_device()).contiguous())
|
)
|
||||||
gpc.config.reduce_scatter_memory[key]['used'].append(True)
|
gpc.config.reduce_scatter_memory[key]["used"].append(True)
|
||||||
return_idx = 0
|
return_idx = 0
|
||||||
return return_idx
|
return return_idx
|
||||||
else: # if not empty
|
else: # if not empty
|
||||||
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]['used']):
|
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]["used"]):
|
||||||
if used == False:
|
if used is False:
|
||||||
gpc.config.reduce_scatter_memory[key]['used'][index] = True
|
gpc.config.reduce_scatter_memory[key]["used"][index] = True
|
||||||
return_idx = index
|
return_idx = index
|
||||||
return return_idx
|
return return_idx
|
||||||
# if the memory pool is all used
|
# if the memory pool is all used
|
||||||
length = len(gpc.config.reduce_scatter_memory[key]['data'])
|
length = len(gpc.config.reduce_scatter_memory[key]["data"])
|
||||||
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
|
gpc.config.reduce_scatter_memory[key]["data"].append(
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
|
||||||
device=get_current_device()).contiguous())
|
)
|
||||||
gpc.config.reduce_scatter_memory[key]['used'].append(True)
|
gpc.config.reduce_scatter_memory[key]["used"].append(True)
|
||||||
return_idx = length
|
return_idx = length
|
||||||
return return_idx
|
return return_idx
|
||||||
|
|
||||||
|
|
||||||
def release_reduce_scatter_memory_pool(size, index):
|
def release_reduce_scatter_memory_pool(size, index):
|
||||||
gpc.config.reduce_scatter_memory[size]['used'][index] = False
|
gpc.config.reduce_scatter_memory[size]["used"][index] = False
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
@ -10,7 +11,7 @@ from torch.optim import Optimizer
|
||||||
|
|
||||||
from internlm.core.context import Config, ParallelMode
|
from internlm.core.context import Config, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
|
from internlm.model.utils import release_reduce_scatter_memory_pool
|
||||||
from internlm.monitor import send_alert_message
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer.store import (
|
from internlm.solver.optimizer.store import (
|
||||||
BucketStore,
|
BucketStore,
|
||||||
|
|
@ -65,7 +66,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
hysteresis = grad_scal_cfg.hysteresis
|
hysteresis = grad_scal_cfg.hysteresis
|
||||||
max_scale = grad_scal_cfg.max_scale
|
max_scale = grad_scal_cfg.max_scale
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
|
self._fstp_handler = None
|
||||||
|
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||||
self._fstp_handler = gpc.config.fstp_handler
|
self._fstp_handler = gpc.config.fstp_handler
|
||||||
|
|
||||||
# Zero related args
|
# Zero related args
|
||||||
|
|
@ -85,8 +87,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# it will not manage the tensors used by mixed precision training
|
# it will not manage the tensors used by mixed precision training
|
||||||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||||
self._bucket_store = []
|
self._bucket_store: List[BucketStore] = []
|
||||||
self._bucket_store_2 = []
|
self._accum_grad_buckets: List[BucketStore] = []
|
||||||
self._bucket_in_progress = []
|
self._bucket_in_progress = []
|
||||||
|
|
||||||
# fp16 and fp32 params for mixed precision training
|
# fp16 and fp32 params for mixed precision training
|
||||||
|
|
@ -155,7 +157,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
|
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
|
||||||
self._broadcast_parallel_mode.append(zero_mode)
|
self._broadcast_parallel_mode.append(zero_mode)
|
||||||
self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"]))
|
self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||||
self._bucket_store_2.append(BucketStore(group_id, param_group["dp_mode"]))
|
self._accum_grad_buckets.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||||
|
|
||||||
# assign parameters to ranks the params in the list are sorted
|
# assign parameters to ranks the params in the list are sorted
|
||||||
params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group)
|
params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group)
|
||||||
|
|
@ -303,7 +305,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
)
|
)
|
||||||
|
|
||||||
reduce_scatter_checker = partial(
|
reduce_scatter_checker = partial(
|
||||||
self._wait_reduce_scatter_and_accumulate_grad,
|
self._wait_reduce_scatter_and_accumulate_grads,
|
||||||
param=param,
|
param=param,
|
||||||
reduce_rank=reduce_rank,
|
reduce_rank=reduce_rank,
|
||||||
)
|
)
|
||||||
|
|
@ -312,7 +314,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||||
def reduce_grad_hook(*args): # pylint: disable=W0613
|
def reduce_grad_hook(*args): # pylint: disable=W0613
|
||||||
if gpc.config.fstp_handler is not None:
|
if self._fstp_handler is not None:
|
||||||
reduce_scatter_checker()
|
reduce_scatter_checker()
|
||||||
|
|
||||||
if self.skip_grad_reduce is False:
|
if self.skip_grad_reduce is False:
|
||||||
|
|
@ -336,56 +338,36 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
group_id = getattr(param, "group_id")
|
group_id = getattr(param, "group_id")
|
||||||
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
|
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
|
||||||
|
|
||||||
def reset_reduce_bucket(self) -> None:
|
def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optional[int] = None) -> None:
|
||||||
for bucket in self._bucket_store_2:
|
for _param in bucket.get_param(reduce_rank):
|
||||||
for rank, params in bucket._params.items():
|
|
||||||
for _param in params:
|
|
||||||
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
key = getattr(_param, "_fstp_reduce_scatter_str")
|
# wait and accumulate gardient.
|
||||||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
_key = getattr(_param, "_fstp_reduce_scatter_str")
|
||||||
comm_handle.wait()
|
_comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[_key]
|
||||||
|
_comm_handle.wait()
|
||||||
_param.grad.add_(_grad)
|
_param.grad.add_(_grad)
|
||||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
|
||||||
# del _grad
|
# release cuda memory.
|
||||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
|
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
|
||||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
||||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
|
||||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
|
||||||
|
|
||||||
|
bucket.reset_by_rank(reduce_rank)
|
||||||
|
|
||||||
bucket.reset_by_rank(rank)
|
def _wait_reduce_scatter_and_accumulate_grads(self, param, reduce_rank: Optional[int] = None):
|
||||||
|
|
||||||
def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None):
|
|
||||||
param_size = param.numel()
|
param_size = param.numel()
|
||||||
|
|
||||||
|
group_id = getattr(param, "group_id")
|
||||||
|
current_bucket = self._accum_grad_buckets[group_id]
|
||||||
|
|
||||||
# check if the bucket is full
|
# check if the bucket is full
|
||||||
# if full, will reduce the grads already in the bucket
|
# if full, will reduce the grads already in the bucket
|
||||||
# after reduction, the bucket will be empty
|
# after reduction, the bucket will be empty
|
||||||
group_id = getattr(param, "group_id")
|
if current_bucket.num_elements_in_bucket(reduce_rank) >= self._reduce_bucket_size:
|
||||||
current_bucket = self._bucket_store_2[group_id]
|
self._accum_grads_store_in_bucket(current_bucket, reduce_rank)
|
||||||
|
|
||||||
if current_bucket.num_elements_in_bucket(reduce_rank) >= 512 * 1024 * 1024:
|
|
||||||
# wait reduce scatter communication
|
|
||||||
params = current_bucket.get_param(reduce_rank)
|
|
||||||
for _param in params:
|
|
||||||
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
key = getattr(_param, "_fstp_reduce_scatter_str")
|
|
||||||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
|
||||||
comm_handle.wait()
|
|
||||||
_param.grad.add_(_grad)
|
|
||||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
|
||||||
# del _grad
|
|
||||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
|
|
||||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
|
||||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
|
||||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
|
||||||
|
|
||||||
current_bucket.reset_by_rank(reduce_rank)
|
|
||||||
|
|
||||||
|
# otherwise, add the parameter into bucket.
|
||||||
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
|
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||||
current_bucket.add_param(param, reduce_rank)
|
current_bucket.add_param(param, reduce_rank)
|
||||||
|
|
||||||
|
|
@ -612,6 +594,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
||||||
|
|
||||||
|
# we need to accumulate gradients left in the accumulate gardient bucket
|
||||||
|
for group_id in range(self.num_param_groups):
|
||||||
|
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None)
|
||||||
|
|
||||||
# compute norm for gradients in the before bucket
|
# compute norm for gradients in the before bucket
|
||||||
groups_norms = []
|
groups_norms = []
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
|
|
|
||||||
|
|
@ -110,9 +110,8 @@ def initialize_model():
|
||||||
|
|
||||||
gpc.config.fstp_handler = None
|
gpc.config.fstp_handler = None
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
|
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
|
||||||
handler._register_sync_parameters_hook()
|
handler._register_sync_parameters_hook()
|
||||||
gpc.config.fstp_handler = handler
|
gpc.config.fstp_handler = handler
|
||||||
|
|
||||||
|
|
@ -123,31 +122,44 @@ def initialize_model():
|
||||||
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||||||
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
||||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
size_key = [(3 * hidden_size // world_size, hidden_size), (mlp_hidden_size // world_size, hidden_size), (hidden_size // world_size, mlp_hidden_size), (hidden_size // world_size, hidden_size)]
|
size_key = [
|
||||||
module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3']
|
(3 * hidden_size // world_size, hidden_size),
|
||||||
|
(mlp_hidden_size // world_size, hidden_size),
|
||||||
|
(hidden_size // world_size, mlp_hidden_size),
|
||||||
|
(hidden_size // world_size, hidden_size),
|
||||||
|
]
|
||||||
|
module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
weight = {}
|
weight = {}
|
||||||
for name in module_name:
|
for name in module_name:
|
||||||
if name == 'Wqkv':
|
if name == "Wqkv":
|
||||||
weight[name] = torch.zeros((3 * hidden_size, hidden_size),
|
weight[name] = torch.zeros(
|
||||||
|
(3 * hidden_size, hidden_size),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device=get_current_device()).contiguous()
|
device=get_current_device(),
|
||||||
elif name == 'out_proj':
|
).contiguous()
|
||||||
weight[name] = torch.zeros((hidden_size, hidden_size),
|
elif name == "out_proj":
|
||||||
|
weight[name] = torch.zeros(
|
||||||
|
(hidden_size, hidden_size),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device=get_current_device()).contiguous()
|
device=get_current_device(),
|
||||||
elif name == 'w1' or name == 'w2':
|
).contiguous()
|
||||||
weight[name] = torch.zeros((mlp_hidden_size, hidden_size),
|
elif name == "w1" or name == "w2":
|
||||||
|
weight[name] = torch.zeros(
|
||||||
|
(mlp_hidden_size, hidden_size),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device=get_current_device()).contiguous()
|
device=get_current_device(),
|
||||||
|
).contiguous()
|
||||||
else:
|
else:
|
||||||
weight[name] = torch.zeros((hidden_size, mlp_hidden_size),
|
weight[name] = torch.zeros(
|
||||||
|
(hidden_size, mlp_hidden_size),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device=get_current_device()).contiguous()
|
device=get_current_device(),
|
||||||
|
).contiguous()
|
||||||
block_memory[i] = weight
|
block_memory[i] = weight
|
||||||
reduce_scatter_memory = {}
|
reduce_scatter_memory = {}
|
||||||
for key in size_key:
|
for key in size_key:
|
||||||
reduce_scatter_memory[key] = {'data': [], 'used': []}
|
reduce_scatter_memory[key] = {"data": [], "used": []}
|
||||||
|
|
||||||
gpc.config.block_memory = block_memory
|
gpc.config.block_memory = block_memory
|
||||||
gpc.config.reduce_scatter_memory = reduce_scatter_memory
|
gpc.config.reduce_scatter_memory = reduce_scatter_memory
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
||||||
def switch_sequence_parallel_mode():
|
def switch_sequence_parallel_mode():
|
||||||
prev_mode = gpc.config.parallel.sequence_parallel
|
prev_mode = gpc.config.parallel.sequence_parallel
|
||||||
try:
|
try:
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
if gpc.config.parallel["tensor"]["sp"] == "intern":
|
||||||
gpc.config.parallel.sequence_parallel = True
|
gpc.config.parallel.sequence_parallel = True
|
||||||
else:
|
else:
|
||||||
gpc.config.parallel.sequence_parallel = False
|
gpc.config.parallel.sequence_parallel = False
|
||||||
|
|
@ -106,7 +106,7 @@ def evaluate_on_val_dls(
|
||||||
total_val_bsz = len(batch[1])
|
total_val_bsz = len(batch[1])
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
if gpc.config.parallel["tensor"]["sp"] == "intern":
|
||||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
tensor_shape = torch.Size(
|
tensor_shape = torch.Size(
|
||||||
[
|
[
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue