diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 474b78aa2..ad131fbe7 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -369,9 +369,9 @@ class GeminiPlugin(DPPluginBase): assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" - if placement_policy == "auto" and enable_async_reduce: + if enable_async_reduce and not pin_memory: logging.warning( - f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set." + f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." ) pin_memory = True self.gemini_config = dict( diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 18fbf8fc3..969df9621 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -403,9 +403,9 @@ class Chunk: self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() ) - input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - self.grad_reduce_work = dist.reduce_scatter( - self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op + assert self.cuda_global_chunk.is_contiguous() + self.grad_reduce_work = dist.reduce_scatter_tensor( + self.cuda_shard, self.cuda_global_chunk, group=self.torch_pg, async_op=async_op ) if self.extra_dp_group is not None: @@ -520,8 +520,10 @@ class Chunk: assert self.cuda_shard is not None alloc_storage(self.cuda_global_chunk) - gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op) + assert self.cuda_global_chunk.is_contiguous() + work = dist.all_gather_into_tensor( + self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op + ) self.cuda_shard = None self.is_gathered = True diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 3a5f0a5aa..d0e1755f4 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -133,12 +133,12 @@ class ChunkManager: self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) - def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: + def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False, async_move=False) -> None: """Move the shard of the chunk to the target device.""" if not chunk.can_move or chunk.device_type == device.type: return self.__sub_memory_usage(chunk.memory_usage) - chunk.shard_move(device, force_copy) + chunk.shard_move(device, force_copy, non_blocking=async_move) self.__add_memory_usage(chunk.memory_usage) def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index ebdde83b4..80b2c7961 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -387,6 +387,7 @@ class GeminiDDP(ModelWrapper): p: nn.Parameter, async_reduce_stream: Optional[torch.cuda.Stream] = None, ): + async_reduce_scatter = async_reduce_stream is not None setattr(p, "_gemini_reduced", True) empty_grad = torch.empty_like(grad) free_storage(empty_grad) @@ -426,7 +427,7 @@ class GeminiDDP(ModelWrapper): async_reduce_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(async_reduce_stream): - reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None)) + reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter) if reduced: grad_chunk.wait_async_reduce() if not chunk_manager.reuse_fp16_chunk: @@ -447,9 +448,13 @@ class GeminiDDP(ModelWrapper): # record l2 norm for gradient clipping. flag is bound to fp16 chunk if chunk.l2_norm_flag: grad_chunk.set_l2_norm() - chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + chunk_manager.move_chunk( + grad_chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter + ) if not (master_weights) or (enable_gradient_accumulation): - chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + chunk_manager.move_chunk( + chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter + ) return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 4b897770e..8a35db1f7 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -253,8 +253,13 @@ def main(): init_kwargs["empty_init"] = False with init_ctx: - model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs) - + model = AutoModelForCausalLM.from_config( + config, + trust_remote_code=True, + **init_kwargs, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + ) if args.grad_checkpoint: model.gradient_checkpointing_enable() if config.model_type == "chatglm": @@ -286,7 +291,7 @@ def main(): with get_profile_context( args.profile, - 1, + args.ignore_steps, len(dataloader) - 1, save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", ) as prof: