mirror of https://github.com/hpcaitech/ColossalAI
[gemini] fixes for benchmarking (#5847)
* [gemini] fix missing return * [gemini] fix missing arg pass * [gemini] use gather tensor instead of list * [test] enable flash attention for benchmark by default * [test] enable flash attention for benchmark by default --------- Co-authored-by: genghaozhe <939857490@qq.com>pull/5864/head
parent
2a25a2aff7
commit
8e718a1421
|
@ -369,9 +369,9 @@ class GeminiPlugin(DPPluginBase):
|
||||||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||||
if get_accelerator().name == "npu":
|
if get_accelerator().name == "npu":
|
||||||
assert placement_policy == "static", "NPU only supports static placement policy"
|
assert placement_policy == "static", "NPU only supports static placement policy"
|
||||||
if placement_policy == "auto" and enable_async_reduce:
|
if enable_async_reduce and not pin_memory:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set."
|
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set."
|
||||||
)
|
)
|
||||||
pin_memory = True
|
pin_memory = True
|
||||||
self.gemini_config = dict(
|
self.gemini_config = dict(
|
||||||
|
|
|
@ -403,9 +403,9 @@ class Chunk:
|
||||||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
assert self.cuda_global_chunk.is_contiguous()
|
||||||
self.grad_reduce_work = dist.reduce_scatter(
|
self.grad_reduce_work = dist.reduce_scatter_tensor(
|
||||||
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
|
self.cuda_shard, self.cuda_global_chunk, group=self.torch_pg, async_op=async_op
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.extra_dp_group is not None:
|
if self.extra_dp_group is not None:
|
||||||
|
@ -520,8 +520,10 @@ class Chunk:
|
||||||
assert self.cuda_shard is not None
|
assert self.cuda_shard is not None
|
||||||
|
|
||||||
alloc_storage(self.cuda_global_chunk)
|
alloc_storage(self.cuda_global_chunk)
|
||||||
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
assert self.cuda_global_chunk.is_contiguous()
|
||||||
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
|
work = dist.all_gather_into_tensor(
|
||||||
|
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
|
||||||
|
)
|
||||||
|
|
||||||
self.cuda_shard = None
|
self.cuda_shard = None
|
||||||
self.is_gathered = True
|
self.is_gathered = True
|
||||||
|
|
|
@ -133,12 +133,12 @@ class ChunkManager:
|
||||||
self.__sub_accessed_chunk(chunk)
|
self.__sub_accessed_chunk(chunk)
|
||||||
self.__add_memory_usage(chunk.memory_usage)
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
|
||||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False, async_move=False) -> None:
|
||||||
"""Move the shard of the chunk to the target device."""
|
"""Move the shard of the chunk to the target device."""
|
||||||
if not chunk.can_move or chunk.device_type == device.type:
|
if not chunk.can_move or chunk.device_type == device.type:
|
||||||
return
|
return
|
||||||
self.__sub_memory_usage(chunk.memory_usage)
|
self.__sub_memory_usage(chunk.memory_usage)
|
||||||
chunk.shard_move(device, force_copy)
|
chunk.shard_move(device, force_copy, non_blocking=async_move)
|
||||||
self.__add_memory_usage(chunk.memory_usage)
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
|
||||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||||
|
|
|
@ -387,6 +387,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
p: nn.Parameter,
|
p: nn.Parameter,
|
||||||
async_reduce_stream: Optional[torch.cuda.Stream] = None,
|
async_reduce_stream: Optional[torch.cuda.Stream] = None,
|
||||||
):
|
):
|
||||||
|
async_reduce_scatter = async_reduce_stream is not None
|
||||||
setattr(p, "_gemini_reduced", True)
|
setattr(p, "_gemini_reduced", True)
|
||||||
empty_grad = torch.empty_like(grad)
|
empty_grad = torch.empty_like(grad)
|
||||||
free_storage(empty_grad)
|
free_storage(empty_grad)
|
||||||
|
@ -426,7 +427,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
async_reduce_stream.wait_stream(torch.cuda.current_stream())
|
async_reduce_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
|
||||||
with torch.cuda.stream(async_reduce_stream):
|
with torch.cuda.stream(async_reduce_stream):
|
||||||
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None))
|
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
|
||||||
if reduced:
|
if reduced:
|
||||||
grad_chunk.wait_async_reduce()
|
grad_chunk.wait_async_reduce()
|
||||||
if not chunk_manager.reuse_fp16_chunk:
|
if not chunk_manager.reuse_fp16_chunk:
|
||||||
|
@ -447,9 +448,13 @@ class GeminiDDP(ModelWrapper):
|
||||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||||
if chunk.l2_norm_flag:
|
if chunk.l2_norm_flag:
|
||||||
grad_chunk.set_l2_norm()
|
grad_chunk.set_l2_norm()
|
||||||
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
chunk_manager.move_chunk(
|
||||||
|
grad_chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
|
||||||
|
)
|
||||||
if not (master_weights) or (enable_gradient_accumulation):
|
if not (master_weights) or (enable_gradient_accumulation):
|
||||||
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
chunk_manager.move_chunk(
|
||||||
|
chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
|
||||||
|
)
|
||||||
return empty_grad
|
return empty_grad
|
||||||
|
|
||||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||||
|
|
|
@ -253,8 +253,13 @@ def main():
|
||||||
init_kwargs["empty_init"] = False
|
init_kwargs["empty_init"] = False
|
||||||
|
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config,
|
||||||
|
trust_remote_code=True,
|
||||||
|
**init_kwargs,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
if config.model_type == "chatglm":
|
if config.model_type == "chatglm":
|
||||||
|
@ -286,7 +291,7 @@ def main():
|
||||||
|
|
||||||
with get_profile_context(
|
with get_profile_context(
|
||||||
args.profile,
|
args.profile,
|
||||||
1,
|
args.ignore_steps,
|
||||||
len(dataloader) - 1,
|
len(dataloader) - 1,
|
||||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||||
) as prof:
|
) as prof:
|
||||||
|
|
Loading…
Reference in New Issue