mirror of https://github.com/InternLM/InternLM
fix(ci): fix ci train error (#199)
parent
ef851d16c6
commit
db13bc46bc
|
@ -547,6 +547,8 @@ def send_backward_and_recv_next_backward_async(
|
||||||
|
|
||||||
|
|
||||||
class AsynCommunicator:
|
class AsynCommunicator:
|
||||||
|
"""AsynCommunicator for managing async communication."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
|
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
||||||
|
|
||||||
from functools import wraps
|
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -139,7 +139,7 @@ def args_sanity_check():
|
||||||
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
||||||
|
|
||||||
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
|
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
|
||||||
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot")
|
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot"))
|
||||||
|
|
||||||
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
|
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
|
||||||
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
|
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
|
||||||
|
|
|
@ -247,8 +247,12 @@ class ParameterStore(BaseStore):
|
||||||
|
|
||||||
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
|
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
|
||||||
if not last_bucket:
|
if not last_bucket:
|
||||||
|
if group_id not in self._former_bucket_reduced_param:
|
||||||
|
return [], []
|
||||||
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
|
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
|
||||||
else:
|
else:
|
||||||
|
if group_id not in self._last_bucket_reduced_param:
|
||||||
|
return [], []
|
||||||
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
|
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
|
||||||
|
|
||||||
def reset_reduced_data_for_compute_norm(self):
|
def reset_reduced_data_for_compute_norm(self):
|
||||||
|
|
Loading…
Reference in New Issue