mirror of https://github.com/hpcaitech/ColossalAI
[Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>pull/5894/head
parent
3420921101
commit
8ec24b6a4d
|
@ -3,6 +3,12 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation,
|
||||||
|
# the order of of kernel launches on GPUs are the same as on the CPU so that comm is launched first.
|
||||||
|
# see https://github.com/NVIDIA/Megatron-LM/issues/533
|
||||||
|
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
|
||||||
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
|
|
|
@ -81,7 +81,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||||
# Delay the start of weight gradient computation shortly (3us) to have
|
# Delay the start of weight gradient computation shortly (3us) to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated
|
# all-reduce scheduled first and have GPU resources allocated
|
||||||
_ = torch.empty(1, device=grad_output.device) + 1
|
|
||||||
|
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -11,9 +10,6 @@ from ..policies.base_policy import Policy
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
from .sharder import ModelSharder
|
from .sharder import ModelSharder
|
||||||
|
|
||||||
# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct
|
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
|
||||||
|
|
||||||
|
|
||||||
class ShardFormer:
|
class ShardFormer:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -292,7 +292,7 @@ def main():
|
||||||
with get_profile_context(
|
with get_profile_context(
|
||||||
args.profile,
|
args.profile,
|
||||||
args.ignore_steps,
|
args.ignore_steps,
|
||||||
len(dataloader) - 1,
|
1, # avoid creating massive log files
|
||||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||||
) as prof:
|
) as prof:
|
||||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||||
|
|
Loading…
Reference in New Issue