mirror of https://github.com/hpcaitech/ColossalAI
[npu] add npu support for hybrid plugin and llama (#5090)
* llama 3d * update * fix autocastpull/5099/head
parent
aae496631c
commit
3acbf6d496
|
@ -6,6 +6,7 @@ from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.utils.device import autocast
|
||||||
|
|
||||||
from .mixed_precision_base import MixedPrecision
|
from .mixed_precision_base import MixedPrecision
|
||||||
|
|
||||||
|
@ -88,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
with torch.cuda.amp.autocast():
|
with autocast():
|
||||||
return self.module(*args, **kwargs)
|
return self.module(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||||
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from .pp_plugin_base import PipelinePluginBase
|
from .pp_plugin_base import PipelinePluginBase
|
||||||
|
|
||||||
|
@ -81,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
|
||||||
self.mixed_precision = torch.bfloat16
|
self.mixed_precision = torch.bfloat16
|
||||||
if self.mixed_precision is not None:
|
if self.mixed_precision is not None:
|
||||||
module = module.to(self.mixed_precision)
|
module = module.to(self.mixed_precision)
|
||||||
module = module.cuda()
|
module = module.to(get_current_device())
|
||||||
|
|
||||||
# setting input type cast when using mixed precision
|
# setting input type cast when using mixed precision
|
||||||
self.convert_fn = None
|
self.convert_fn = None
|
||||||
|
@ -345,7 +346,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
if norm_type == inf:
|
if norm_type == inf:
|
||||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
|
@ -384,7 +385,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
total_norm_exponentiated += grad_norm_exponentiated
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
# compute norm in tp process group
|
# compute norm in tp process group
|
||||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
@ -542,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
# so we need to calculate the norm of 'tp' and 'pp' gradients.
|
# so we need to calculate the norm of 'tp' and 'pp' gradients.
|
||||||
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
|
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
|
||||||
|
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||||
|
@ -585,7 +586,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
|
|
||||||
total_norm_exponentiated += grad_norm_exponentiated
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
# compute norm in tp process group
|
# compute norm in tp process group
|
||||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
@ -797,7 +798,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
# so we only need to calculate the norm 'tp' of 'pp' gradients.
|
# so we only need to calculate the norm 'tp' of 'pp' gradients.
|
||||||
total_norm = super()._compute_grad_norm(gradients, norm_type)
|
total_norm = super()._compute_grad_norm(gradients, norm_type)
|
||||||
|
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||||
|
|
||||||
if tp_size > 1:
|
if tp_size > 1:
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||||
|
@ -836,7 +837,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
|
|
||||||
total_norm_exponentiated += grad_norm_exponentiated
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||||
if dp_size > 1:
|
if dp_size > 1:
|
||||||
# compute norm in dp process group
|
# compute norm in dp process group
|
||||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
@ -1027,7 +1028,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
return self.pp_size > 1
|
return self.pp_size > 1
|
||||||
|
|
||||||
def supported_devices(self) -> List[str]:
|
def supported_devices(self) -> List[str]:
|
||||||
return ["cuda"]
|
return ["cuda", "npu"]
|
||||||
|
|
||||||
def supported_precisions(self) -> List[str]:
|
def supported_precisions(self) -> List[str]:
|
||||||
return ["fp16", "bf16", "fp32"]
|
return ["fp16", "bf16", "fp32"]
|
||||||
|
|
|
@ -38,7 +38,7 @@ class DeviceMesh:
|
||||||
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
|
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
|
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo", "npu": "hccl"}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import torch.cuda.amp as torch_amp
|
from colossalai.utils.device import autocast
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
|
@ -70,7 +71,7 @@ class TorchAMPModel(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
@torch_amp.autocast()
|
@autocast()
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Execute forward under the torch amp context
|
Execute forward under the torch amp context
|
||||||
|
@ -89,7 +90,7 @@ class TorchAMPLoss(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
|
|
||||||
@torch_amp.autocast()
|
@autocast()
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Execute forward under the torch amp context
|
Execute forward under the torch amp context
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
||||||
|
|
||||||
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
|
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils.device import autocast, get_current_device
|
||||||
|
|
||||||
|
|
||||||
def copy_to_device(obj, device):
|
def copy_to_device(obj, device):
|
||||||
|
@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
inputs[idx] = tensors[i]
|
inputs[idx] = tensors[i]
|
||||||
detached_inputs = detach_variable(tuple(inputs))
|
detached_inputs = detach_variable(tuple(inputs))
|
||||||
if ctx.had_autocast_in_fwd:
|
if ctx.had_autocast_in_fwd:
|
||||||
with torch.enable_grad(), torch.cuda.amp.autocast():
|
with torch.enable_grad(), autocast():
|
||||||
outputs = ctx.run_function(*detached_inputs)
|
outputs = ctx.run_function(*detached_inputs)
|
||||||
else:
|
else:
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
|
@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||||
|
|
||||||
# rerun forward, the inner_pack will store all the activations in storage
|
# rerun forward, the inner_pack will store all the activations in storage
|
||||||
if has_autocast_in_fwd:
|
if has_autocast_in_fwd:
|
||||||
with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
|
with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks(
|
||||||
inner_pack, inner_unpack
|
inner_pack, inner_unpack
|
||||||
):
|
):
|
||||||
_unused = function(*args)
|
_unused = function(*args)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
from torch.distributed import ProcessGroup, get_world_size
|
from torch.distributed import ProcessGroup, get_world_size
|
||||||
|
from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed
|
||||||
|
|
||||||
|
|
||||||
class SeqParallelUtils:
|
class SeqParallelUtils:
|
||||||
|
@ -104,14 +105,14 @@ class Randomizer:
|
||||||
def __init__(self, seed: int):
|
def __init__(self, seed: int):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
# Handle CUDA rng state
|
# Handle device rng state
|
||||||
# 1. get the current rng state
|
# 1. get the current rng state
|
||||||
# 2. set the seed and store the rng state
|
# 2. set the seed and store the rng state
|
||||||
# 3. recover the original rng state
|
# 3. recover the original rng state
|
||||||
cuda_original_rng_state = torch.cuda.get_rng_state()
|
device_original_rng_state = get_rng_state()
|
||||||
torch.cuda.manual_seed(seed)
|
manual_seed(seed)
|
||||||
self.cuda_rng_state = torch.cuda.get_rng_state()
|
self.device_rng_state = get_rng_state()
|
||||||
torch.cuda.set_rng_state(cuda_original_rng_state)
|
set_rng_state(device_original_rng_state)
|
||||||
|
|
||||||
# to the same for cpu rng state
|
# to the same for cpu rng state
|
||||||
cpu_original_rng_state = torch.get_rng_state()
|
cpu_original_rng_state = torch.get_rng_state()
|
||||||
|
@ -119,11 +120,11 @@ class Randomizer:
|
||||||
self.cpu_rng_state = torch.get_rng_state()
|
self.cpu_rng_state = torch.get_rng_state()
|
||||||
torch.set_rng_state(cpu_original_rng_state)
|
torch.set_rng_state(cpu_original_rng_state)
|
||||||
|
|
||||||
def _set_cuda_rng_state(self, rng_state):
|
def _set_device_rng_state(self, rng_state):
|
||||||
torch.cuda.set_rng_state(rng_state)
|
set_rng_state(rng_state)
|
||||||
|
|
||||||
def _get_cuda_rng_state(self):
|
def _get_device_rng_state(self):
|
||||||
current_state = torch.cuda.get_rng_state()
|
current_state = get_rng_state()
|
||||||
return current_state
|
return current_state
|
||||||
|
|
||||||
def _set_cpu_rng_state(self, rng_state):
|
def _set_cpu_rng_state(self, rng_state):
|
||||||
|
@ -144,16 +145,16 @@ class Randomizer:
|
||||||
>>> input = super().forward(input)
|
>>> input = super().forward(input)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
current_cuda_rng_state = self._get_cuda_rng_state()
|
current_device_rng_state = self._get_device_rng_state()
|
||||||
self._set_cuda_rng_state(self.cuda_rng_state)
|
self._set_device_rng_state(self.device_rng_state)
|
||||||
|
|
||||||
if enable_cpu:
|
if enable_cpu:
|
||||||
current_cpu_rng_state = self._get_cpu_rng_state()
|
current_cpu_rng_state = self._get_cpu_rng_state()
|
||||||
self._set_cpu_rng_state(self.cpu_rng_state)
|
self._set_cpu_rng_state(self.cpu_rng_state)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
self.cuda_rng_state = self._get_cuda_rng_state()
|
self.device_rng_state = self._get_device_rng_state()
|
||||||
self._set_cuda_rng_state(current_cuda_rng_state)
|
self._set_device_rng_state(current_device_rng_state)
|
||||||
|
|
||||||
if enable_cpu:
|
if enable_cpu:
|
||||||
self.cpu_rng_state = self._get_cpu_rng_state()
|
self.cpu_rng_state = self._get_cpu_rng_state()
|
||||||
|
@ -208,7 +209,7 @@ class Randomizer:
|
||||||
index = Randomizer.index()
|
index = Randomizer.index()
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# convert the index to tensor
|
# convert the index to tensor
|
||||||
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
|
||||||
|
|
||||||
# all gather the index
|
# all gather the index
|
||||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||||
|
@ -230,7 +231,7 @@ class Randomizer:
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
# convert the index to tensor
|
# convert the index to tensor
|
||||||
index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
|
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
|
||||||
|
|
||||||
# all gather the index
|
# all gather the index
|
||||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, List
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count
|
||||||
|
|
||||||
|
|
||||||
def parameterize(argument: str, values: List[Any]) -> Callable:
|
def parameterize(argument: str, values: List[Any]) -> Callable:
|
||||||
|
@ -198,7 +199,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
|
||||||
|
|
||||||
def _wrap_func(f):
|
def _wrap_func(f):
|
||||||
def _execute_by_gpu_num(*args, **kwargs):
|
def _execute_by_gpu_num(*args, **kwargs):
|
||||||
num_avail_gpu = torch.cuda.device_count()
|
num_avail_gpu = device_count()
|
||||||
if num_avail_gpu >= min_gpus:
|
if num_avail_gpu >= min_gpus:
|
||||||
f(*args, **kwargs)
|
f(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -262,11 +263,11 @@ def clear_cache_before_run():
|
||||||
|
|
||||||
def _wrap_func(f):
|
def _wrap_func(f):
|
||||||
def _clear_cache(*args, **kwargs):
|
def _clear_cache(*args, **kwargs):
|
||||||
torch.cuda.empty_cache()
|
empty_cache()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
reset_peak_memory_stats()
|
||||||
torch.cuda.reset_max_memory_allocated()
|
reset_max_memory_allocated()
|
||||||
torch.cuda.reset_max_memory_cached()
|
reset_max_memory_cached()
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
f(*args, **kwargs)
|
f(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -191,6 +191,10 @@ def reset_max_memory_allocated(device=None) -> None:
|
||||||
return _dispatch_device_func("reset_max_memory_allocated", device)
|
return _dispatch_device_func("reset_max_memory_allocated", device)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_max_memory_cached(device=None) -> None:
|
||||||
|
return _dispatch_device_func("reset_max_memory_cached", device)
|
||||||
|
|
||||||
|
|
||||||
def memory_reserved(device=None) -> int:
|
def memory_reserved(device=None) -> int:
|
||||||
return _dispatch_device_func("memory_reserved", device)
|
return _dispatch_device_func("memory_reserved", device)
|
||||||
|
|
||||||
|
@ -205,3 +209,15 @@ def set_per_process_memory_fraction(fraction: float, device=None) -> None:
|
||||||
|
|
||||||
def reset_peak_memory_stats(device=None) -> None:
|
def reset_peak_memory_stats(device=None) -> None:
|
||||||
return _dispatch_device_func("reset_peak_memory_stats", device)
|
return _dispatch_device_func("reset_peak_memory_stats", device)
|
||||||
|
|
||||||
|
|
||||||
|
# amp
|
||||||
|
|
||||||
|
|
||||||
|
def autocast() -> Callable:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.cuda.amp.autocast()
|
||||||
|
elif IS_NPU_AVAILABLE:
|
||||||
|
return torch.npu.amp.autocast()
|
||||||
|
else:
|
||||||
|
raise RuntimeError("No device available")
|
||||||
|
|
|
@ -131,7 +131,7 @@ def main():
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
pp_size=args.pp,
|
pp_size=args.pp,
|
||||||
zero_stage=args.zero,
|
zero_stage=args.zero,
|
||||||
enable_fused_normalization=True,
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
num_microbatches=args.mbs,
|
num_microbatches=args.mbs,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
)
|
)
|
||||||
|
@ -141,7 +141,7 @@ def main():
|
||||||
pp_size=args.pp,
|
pp_size=args.pp,
|
||||||
zero_stage=args.zero,
|
zero_stage=args.zero,
|
||||||
cpu_offload=True,
|
cpu_offload=True,
|
||||||
enable_fused_normalization=True,
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
num_microbatches=args.mbs,
|
num_microbatches=args.mbs,
|
||||||
initial_scale=2**8,
|
initial_scale=2**8,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
|
|
Loading…
Reference in New Issue