mirror of https://github.com/InternLM/InternLM
fix(ci): fix ci train error (#199)
parent
ef851d16c6
commit
db13bc46bc
|
@ -547,6 +547,8 @@ def send_backward_and_recv_next_backward_async(
|
|||
|
||||
|
||||
class AsynCommunicator:
|
||||
"""AsynCommunicator for managing async communication."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
||||
|
||||
from functools import wraps
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
|
|
@ -139,7 +139,7 @@ def args_sanity_check():
|
|||
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
||||
|
||||
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot")
|
||||
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot"))
|
||||
|
||||
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
|
||||
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
|
||||
|
|
|
@ -247,8 +247,12 @@ class ParameterStore(BaseStore):
|
|||
|
||||
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
|
||||
if not last_bucket:
|
||||
if group_id not in self._former_bucket_reduced_param:
|
||||
return [], []
|
||||
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
|
||||
else:
|
||||
if group_id not in self._last_bucket_reduced_param:
|
||||
return [], []
|
||||
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
|
||||
|
||||
def reset_reduced_data_for_compute_norm(self):
|
||||
|
|
Loading…
Reference in New Issue