mirror of https://github.com/hpcaitech/ColossalAI
[npu] add npu support for hybrid plugin and llama (#5090)
* llama 3d * update * fix autocastpull/5099/head
parent
aae496631c
commit
3acbf6d496
|
@ -6,6 +6,7 @@ from torch import Tensor
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils.device import autocast
|
||||
|
||||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
|
@ -88,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
|
|||
super().__init__(module)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with torch.cuda.amp.autocast():
|
||||
with autocast():
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
|
|||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
|
@ -81,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
|
|||
self.mixed_precision = torch.bfloat16
|
||||
if self.mixed_precision is not None:
|
||||
module = module.to(self.mixed_precision)
|
||||
module = module.cuda()
|
||||
module = module.to(get_current_device())
|
||||
|
||||
# setting input type cast when using mixed precision
|
||||
self.convert_fn = None
|
||||
|
@ -345,7 +346,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if self.pp_size > 1:
|
||||
|
@ -384,7 +385,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
@ -542,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
# so we need to calculate the norm of 'tp' and 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
|
@ -585,7 +586,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
@ -797,7 +798,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# so we only need to calculate the norm 'tp' of 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(gradients, norm_type)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
|
@ -836,7 +837,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||
if dp_size > 1:
|
||||
# compute norm in dp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||
|
@ -1027,7 +1028,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return self.pp_size > 1
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ["cuda"]
|
||||
return ["cuda", "npu"]
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ["fp16", "bf16", "fp32"]
|
||||
|
|
|
@ -38,7 +38,7 @@ class DeviceMesh:
|
|||
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
|
||||
"""
|
||||
|
||||
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
|
||||
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo", "npu": "hccl"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.cuda.amp as torch_amp
|
||||
from colossalai.utils.device import autocast
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
@ -70,7 +71,7 @@ class TorchAMPModel(nn.Module):
|
|||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
@torch_amp.autocast()
|
||||
@autocast()
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Execute forward under the torch amp context
|
||||
|
@ -89,7 +90,7 @@ class TorchAMPLoss(nn.Module):
|
|||
super().__init__()
|
||||
self.loss = loss
|
||||
|
||||
@torch_amp.autocast()
|
||||
@autocast()
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Execute forward under the torch amp context
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
||||
|
||||
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.device import autocast, get_current_device
|
||||
|
||||
|
||||
def copy_to_device(obj, device):
|
||||
|
@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
inputs[idx] = tensors[i]
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
if ctx.had_autocast_in_fwd:
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast():
|
||||
with torch.enable_grad(), autocast():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
else:
|
||||
with torch.enable_grad():
|
||||
|
@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
|||
|
||||
# rerun forward, the inner_pack will store all the activations in storage
|
||||
if has_autocast_in_fwd:
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
|
||||
with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks(
|
||||
inner_pack, inner_unpack
|
||||
):
|
||||
_unused = function(*args)
|
||||
|
|
|
@ -6,6 +6,7 @@ import torch.distributed as dist
|
|||
from torch import nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup, get_world_size
|
||||
from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed
|
||||
|
||||
|
||||
class SeqParallelUtils:
|
||||
|
@ -104,14 +105,14 @@ class Randomizer:
|
|||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
# Handle CUDA rng state
|
||||
# Handle device rng state
|
||||
# 1. get the current rng state
|
||||
# 2. set the seed and store the rng state
|
||||
# 3. recover the original rng state
|
||||
cuda_original_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.manual_seed(seed)
|
||||
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(cuda_original_rng_state)
|
||||
device_original_rng_state = get_rng_state()
|
||||
manual_seed(seed)
|
||||
self.device_rng_state = get_rng_state()
|
||||
set_rng_state(device_original_rng_state)
|
||||
|
||||
# to the same for cpu rng state
|
||||
cpu_original_rng_state = torch.get_rng_state()
|
||||
|
@ -119,11 +120,11 @@ class Randomizer:
|
|||
self.cpu_rng_state = torch.get_rng_state()
|
||||
torch.set_rng_state(cpu_original_rng_state)
|
||||
|
||||
def _set_cuda_rng_state(self, rng_state):
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
def _set_device_rng_state(self, rng_state):
|
||||
set_rng_state(rng_state)
|
||||
|
||||
def _get_cuda_rng_state(self):
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
def _get_device_rng_state(self):
|
||||
current_state = get_rng_state()
|
||||
return current_state
|
||||
|
||||
def _set_cpu_rng_state(self, rng_state):
|
||||
|
@ -144,16 +145,16 @@ class Randomizer:
|
|||
>>> input = super().forward(input)
|
||||
"""
|
||||
try:
|
||||
current_cuda_rng_state = self._get_cuda_rng_state()
|
||||
self._set_cuda_rng_state(self.cuda_rng_state)
|
||||
current_device_rng_state = self._get_device_rng_state()
|
||||
self._set_device_rng_state(self.device_rng_state)
|
||||
|
||||
if enable_cpu:
|
||||
current_cpu_rng_state = self._get_cpu_rng_state()
|
||||
self._set_cpu_rng_state(self.cpu_rng_state)
|
||||
yield
|
||||
finally:
|
||||
self.cuda_rng_state = self._get_cuda_rng_state()
|
||||
self._set_cuda_rng_state(current_cuda_rng_state)
|
||||
self.device_rng_state = self._get_device_rng_state()
|
||||
self._set_device_rng_state(current_device_rng_state)
|
||||
|
||||
if enable_cpu:
|
||||
self.cpu_rng_state = self._get_cpu_rng_state()
|
||||
|
@ -208,7 +209,7 @@ class Randomizer:
|
|||
index = Randomizer.index()
|
||||
if dist.is_initialized():
|
||||
# convert the index to tensor
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
|
||||
|
||||
# all gather the index
|
||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||
|
@ -230,7 +231,7 @@ class Randomizer:
|
|||
|
||||
if dist.is_initialized():
|
||||
# convert the index to tensor
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
|
||||
|
||||
# all gather the index
|
||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, List
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from packaging import version
|
||||
from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count
|
||||
|
||||
|
||||
def parameterize(argument: str, values: List[Any]) -> Callable:
|
||||
|
@ -198,7 +199,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
|
|||
|
||||
def _wrap_func(f):
|
||||
def _execute_by_gpu_num(*args, **kwargs):
|
||||
num_avail_gpu = torch.cuda.device_count()
|
||||
num_avail_gpu = device_count()
|
||||
if num_avail_gpu >= min_gpus:
|
||||
f(*args, **kwargs)
|
||||
|
||||
|
@ -262,11 +263,11 @@ def clear_cache_before_run():
|
|||
|
||||
def _wrap_func(f):
|
||||
def _clear_cache(*args, **kwargs):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_max_memory_cached()
|
||||
torch.cuda.synchronize()
|
||||
empty_cache()
|
||||
reset_peak_memory_stats()
|
||||
reset_max_memory_allocated()
|
||||
reset_max_memory_cached()
|
||||
synchronize()
|
||||
gc.collect()
|
||||
f(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -191,6 +191,10 @@ def reset_max_memory_allocated(device=None) -> None:
|
|||
return _dispatch_device_func("reset_max_memory_allocated", device)
|
||||
|
||||
|
||||
def reset_max_memory_cached(device=None) -> None:
|
||||
return _dispatch_device_func("reset_max_memory_cached", device)
|
||||
|
||||
|
||||
def memory_reserved(device=None) -> int:
|
||||
return _dispatch_device_func("memory_reserved", device)
|
||||
|
||||
|
@ -205,3 +209,15 @@ def set_per_process_memory_fraction(fraction: float, device=None) -> None:
|
|||
|
||||
def reset_peak_memory_stats(device=None) -> None:
|
||||
return _dispatch_device_func("reset_peak_memory_stats", device)
|
||||
|
||||
|
||||
# amp
|
||||
|
||||
|
||||
def autocast() -> Callable:
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.amp.autocast()
|
||||
elif IS_NPU_AVAILABLE:
|
||||
return torch.npu.amp.autocast()
|
||||
else:
|
||||
raise RuntimeError("No device available")
|
||||
|
|
|
@ -131,7 +131,7 @@ def main():
|
|||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
enable_fused_normalization=True,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
num_microbatches=args.mbs,
|
||||
precision="bf16",
|
||||
)
|
||||
|
@ -141,7 +141,7 @@ def main():
|
|||
pp_size=args.pp,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=True,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
num_microbatches=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
|
|
Loading…
Reference in New Issue