[npu] add npu support for hybrid plugin and llama (#5090)

* llama 3d

* update

* fix autocast
pull/5099/head
Xuanlei Zhao 2023-11-22 19:23:21 +08:00 committed by GitHub
parent aae496631c
commit 3acbf6d496
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 61 additions and 40 deletions

View File

@ -6,6 +6,7 @@ from torch import Tensor
from torch.optim import Optimizer
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils.device import autocast
from .mixed_precision_base import MixedPrecision
@ -88,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
super().__init__(module)
def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast():
with autocast():
return self.module(*args, **kwargs)

View File

@ -29,6 +29,7 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.utils.device import get_current_device
from .pp_plugin_base import PipelinePluginBase
@ -81,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
module = module.cuda()
module = module.to(get_current_device())
# setting input type cast when using mixed precision
self.convert_fn = None
@ -345,7 +346,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if self.pp_size > 1:
@ -384,7 +385,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
@ -542,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@ -585,7 +586,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
@ -797,7 +798,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@ -836,7 +837,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
if dp_size > 1:
# compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
@ -1027,7 +1028,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return self.pp_size > 1
def supported_devices(self) -> List[str]:
return ["cuda"]
return ["cuda", "npu"]
def supported_precisions(self) -> List[str]:
return ["fp16", "bf16", "fp32"]

View File

@ -38,7 +38,7 @@ class DeviceMesh:
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
"""
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo", "npu": "hccl"}
def __init__(
self,

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.cuda.amp as torch_amp
from colossalai.utils.device import autocast
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.loss import _Loss
@ -70,7 +71,7 @@ class TorchAMPModel(nn.Module):
super().__init__()
self.model = model
@torch_amp.autocast()
@autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context
@ -89,7 +90,7 @@ class TorchAMPLoss(nn.Module):
super().__init__()
self.loss = loss
@torch_amp.autocast()
@autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context

View File

@ -7,7 +7,7 @@ import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
from colossalai.utils import get_current_device
from colossalai.utils.device import autocast, get_current_device
def copy_to_device(obj, device):
@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd:
with torch.enable_grad(), torch.cuda.amp.autocast():
with torch.enable_grad(), autocast():
outputs = ctx.run_function(*detached_inputs)
else:
with torch.enable_grad():
@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd:
with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks(
inner_pack, inner_unpack
):
_unused = function(*args)

View File

@ -6,6 +6,7 @@ import torch.distributed as dist
from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size
from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed
class SeqParallelUtils:
@ -104,14 +105,14 @@ class Randomizer:
def __init__(self, seed: int):
self.seed = seed
# Handle CUDA rng state
# Handle device rng state
# 1. get the current rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
cuda_original_rng_state = torch.cuda.get_rng_state()
torch.cuda.manual_seed(seed)
self.cuda_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(cuda_original_rng_state)
device_original_rng_state = get_rng_state()
manual_seed(seed)
self.device_rng_state = get_rng_state()
set_rng_state(device_original_rng_state)
# to the same for cpu rng state
cpu_original_rng_state = torch.get_rng_state()
@ -119,11 +120,11 @@ class Randomizer:
self.cpu_rng_state = torch.get_rng_state()
torch.set_rng_state(cpu_original_rng_state)
def _set_cuda_rng_state(self, rng_state):
torch.cuda.set_rng_state(rng_state)
def _set_device_rng_state(self, rng_state):
set_rng_state(rng_state)
def _get_cuda_rng_state(self):
current_state = torch.cuda.get_rng_state()
def _get_device_rng_state(self):
current_state = get_rng_state()
return current_state
def _set_cpu_rng_state(self, rng_state):
@ -144,16 +145,16 @@ class Randomizer:
>>> input = super().forward(input)
"""
try:
current_cuda_rng_state = self._get_cuda_rng_state()
self._set_cuda_rng_state(self.cuda_rng_state)
current_device_rng_state = self._get_device_rng_state()
self._set_device_rng_state(self.device_rng_state)
if enable_cpu:
current_cpu_rng_state = self._get_cpu_rng_state()
self._set_cpu_rng_state(self.cpu_rng_state)
yield
finally:
self.cuda_rng_state = self._get_cuda_rng_state()
self._set_cuda_rng_state(current_cuda_rng_state)
self.device_rng_state = self._get_device_rng_state()
self._set_device_rng_state(current_device_rng_state)
if enable_cpu:
self.cpu_rng_state = self._get_cpu_rng_state()
@ -208,7 +209,7 @@ class Randomizer:
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
@ -230,7 +231,7 @@ class Randomizer:
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, List
import torch
import torch.multiprocessing as mp
from packaging import version
from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count
def parameterize(argument: str, values: List[Any]) -> Callable:
@ -198,7 +199,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
def _wrap_func(f):
def _execute_by_gpu_num(*args, **kwargs):
num_avail_gpu = torch.cuda.device_count()
num_avail_gpu = device_count()
if num_avail_gpu >= min_gpus:
f(*args, **kwargs)
@ -262,11 +263,11 @@ def clear_cache_before_run():
def _wrap_func(f):
def _clear_cache(*args, **kwargs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
torch.cuda.synchronize()
empty_cache()
reset_peak_memory_stats()
reset_max_memory_allocated()
reset_max_memory_cached()
synchronize()
gc.collect()
f(*args, **kwargs)

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Callable
import torch
import torch.distributed as dist
@ -191,6 +191,10 @@ def reset_max_memory_allocated(device=None) -> None:
return _dispatch_device_func("reset_max_memory_allocated", device)
def reset_max_memory_cached(device=None) -> None:
return _dispatch_device_func("reset_max_memory_cached", device)
def memory_reserved(device=None) -> int:
return _dispatch_device_func("memory_reserved", device)
@ -205,3 +209,15 @@ def set_per_process_memory_fraction(fraction: float, device=None) -> None:
def reset_peak_memory_stats(device=None) -> None:
return _dispatch_device_func("reset_peak_memory_stats", device)
# amp
def autocast() -> Callable:
if torch.cuda.is_available():
return torch.cuda.amp.autocast()
elif IS_NPU_AVAILABLE:
return torch.npu.amp.autocast()
else:
raise RuntimeError("No device available")

View File

@ -131,7 +131,7 @@ def main():
tp_size=args.tp,
pp_size=args.pp,
zero_stage=args.zero,
enable_fused_normalization=True,
enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs,
precision="bf16",
)
@ -141,7 +141,7 @@ def main():
pp_size=args.pp,
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=True,
enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs,
initial_scale=2**8,
precision="bf16",