mirror of https://github.com/hpcaitech/ColossalAI
commit
9470701110
|
@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
|
@ -1082,6 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
overlap_p2p=overlap_p2p,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
|
@ -1089,6 +1091,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.accelerator import get_accelerator
|
|||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
|
@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__(stage_manager)
|
||||
assert (
|
||||
|
@ -56,6 +58,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -191,8 +195,12 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
|
@ -210,10 +218,14 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
|
@ -224,6 +236,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
is_send = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_first_stage()
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
|
||||
output_tensor,
|
||||
is_send,
|
||||
|
@ -237,6 +251,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def send_backward_recv_backward(
|
||||
|
@ -246,6 +262,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
is_send = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_last_stage()
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
is_send,
|
||||
|
@ -258,6 +276,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.send_grad_metadata = not self.enable_metadata_cache and is_send
|
||||
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def forward_step(
|
||||
|
@ -379,6 +399,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait until current input is received
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
if not last_batch:
|
||||
|
@ -441,6 +463,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for input
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
@ -467,6 +491,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for input.
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
|
@ -511,6 +537,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
input_obj, fwd_wait_handles = send_forward_recv_forward()
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
if self.fp8_communication and output_obj_grad is not None:
|
||||
cast_from_fp8_pipeline(output_obj_grad)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
|
||||
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
|
||||
|
@ -532,6 +560,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
if self.fp8_communication and output_obj_grad is not None:
|
||||
cast_from_fp8_pipeline(output_obj_grad)
|
||||
# backward local grads
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
if not last_batch:
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
|
|||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
|
@ -32,6 +33,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
"""1F1B pipeline schedule.
|
||||
|
||||
|
@ -61,6 +63,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -129,6 +133,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
|
@ -143,6 +149,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor_grad)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
|
@ -157,9 +165,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
|
||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
@ -169,8 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
|
||||
|
||||
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
|
@ -183,6 +200,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_last_stage():
|
||||
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
|
||||
send_first = None
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
|
@ -192,6 +211,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
cast_from_fp8_pipeline(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
@ -206,6 +228,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_first_stage():
|
||||
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
|
||||
send_first = None # must not fallback
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
input_tensor, _ = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
|
@ -215,6 +239,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor)
|
||||
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
|
|
@ -0,0 +1,172 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
|
||||
r"""
|
||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||
Args:
|
||||
inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor.
|
||||
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
|
||||
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
Tuples: A tuple (fp8_tensor, scale)
|
||||
"""
|
||||
|
||||
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
raise TypeError("Only float16, bfloat16, and float32 are allowed.")
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
|
||||
if inp.dim() == 2:
|
||||
per_channel_max = inp.abs().max(dim=-1).values.float()
|
||||
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||
scale = fp8_max / per_channel_max[:, None]
|
||||
else:
|
||||
per_tensor_max = inp.abs().max().float()
|
||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||
scale = fp8_max / per_tensor_max
|
||||
|
||||
scale_inv = 1.0 / scale
|
||||
ret = (scale * inp.float()).to(fp8_type)
|
||||
return ret, scale_inv
|
||||
|
||||
|
||||
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
|
||||
r"""
|
||||
|
||||
Args:
|
||||
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
|
||||
scale: scaling factor returned by cast_to_fp8 function.
|
||||
ret_type: the datatype of the returned tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor
|
||||
"""
|
||||
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
||||
|
||||
if inp.dim() == 2:
|
||||
ret = scale_inv[:, None] * inp.float()
|
||||
else:
|
||||
ret = scale_inv * inp.float()
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
input_type = tensor.dtype
|
||||
input_shape = tensor.shape
|
||||
input_device = tensor.device
|
||||
input_size = tensor.numel()
|
||||
tensor = tensor.flatten()
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
||||
if dist.get_rank() == world_size - 1:
|
||||
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
||||
else:
|
||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_chunks, input_chunks)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale)
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
dist.all_gather(scale_list, scale)
|
||||
|
||||
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
|
||||
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8))
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
tensor_out = torch.cat(tensor_list, dim=0)
|
||||
tensor.data = tensor_out.view(input_shape).to(input_type)
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
"""
|
||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||
The activations tensor is indexed by 'hidden_states' in the inp dict.
|
||||
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
|
||||
Metadata such as fp8_scale is saved into inp dict for communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||
inp_tensor = inp["hidden_states"]
|
||||
|
||||
min_val, max_val = inp_tensor.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs())
|
||||
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
if amax > finfo.max:
|
||||
fp8_type = torch.float8_e5m2
|
||||
fp8_view_type = torch.float16
|
||||
else:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
fp8_view_type = torch.bfloat16
|
||||
|
||||
finfo = torch.finfo(fp8_type)
|
||||
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
|
||||
q_tensor = inp_tensor.data.float() * scale
|
||||
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
|
||||
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
|
||||
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
||||
|
||||
inp["fp8_scale"] = scale.float().reciprocal()
|
||||
|
||||
|
||||
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
"""
|
||||
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
|
||||
del_metadata = False is useful when this function is called before p2p communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||
inp_tensor = inp["hidden_states"]
|
||||
scale = inp["fp8_scale"]
|
||||
|
||||
fp8_view_type = inp_tensor.dtype
|
||||
if fp8_view_type == torch.float16:
|
||||
fp8_type = torch.float8_e5m2
|
||||
elif fp8_view_type == torch.bfloat16:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
else:
|
||||
raise TypeError("Only float16, bfloat16 are implemented.")
|
||||
|
||||
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
|
||||
|
||||
if del_metadata:
|
||||
del inp["fp8_scale"]
|
|
@ -190,6 +190,7 @@ def main():
|
|||
)
|
||||
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
||||
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "bert":
|
||||
|
@ -232,6 +233,7 @@ def main():
|
|||
zero_stage=1,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
|
|
@ -187,6 +187,7 @@ def main():
|
|||
)
|
||||
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
||||
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "gpt2":
|
||||
|
@ -225,6 +226,7 @@ def main():
|
|||
zero_stage=1,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
|
Loading…
Reference in New Issue