fix(ci): fix ci train error (#199)

pull/200/head
huangting4201 2023-08-15 20:09:54 +08:00 committed by GitHub
parent ef851d16c6
commit db13bc46bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 8 additions and 3 deletions

View File

@ -547,6 +547,8 @@ def send_backward_and_recv_next_backward_async(
class AsynCommunicator:
"""AsynCommunicator for managing async communication."""
def __init__(
self,
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],

View File

@ -1,6 +1,5 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
from functools import wraps
from typing import List, Tuple, Union
import torch

View File

@ -139,7 +139,7 @@ class PipelineScheduler(BaseScheduler):
and gpc.is_initialized(ParallelMode.TENSOR)
and gpc.get_world_size(ParallelMode.TENSOR) > 1
)
if gpc.config.model.sequence_parallel:
self.scatter_gather_tensors = False

View File

@ -139,7 +139,7 @@ def args_sanity_check():
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot")
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot"))
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)

View File

@ -247,8 +247,12 @@ class ParameterStore(BaseStore):
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
if not last_bucket:
if group_id not in self._former_bucket_reduced_param:
return [], []
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
else:
if group_id not in self._last_bucket_reduced_param:
return [], []
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
def reset_reduced_data_for_compute_norm(self):