mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
2.5 KiB
64 lines
2.5 KiB
from typing import List, Union |
|
|
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
from torch.distributed import ProcessGroup |
|
|
|
from colossalai.lazy import LazyInitContext |
|
from colossalai.shardformer.layer import Linear1D_Col |
|
from colossalai.shardformer.layer.parallel_module import ParallelModule |
|
|
|
|
|
class BaichuanLMHeadLinear1D_Col(Linear1D_Col): |
|
@staticmethod |
|
def from_native_module( |
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs |
|
) -> ParallelModule: |
|
LazyInitContext.materialize(module) |
|
module.in_features = module.weight.size(1) |
|
module.out_features = module.weight.size(0) |
|
module.bias = None |
|
module.weight.data = nn.functional.normalize( |
|
module.weight |
|
) # NOTE(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight. |
|
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue. |
|
|
|
# get the attributes |
|
in_features = module.in_features |
|
out_features = module.out_features |
|
bias = module.bias is not None |
|
device = module.weight.device |
|
# ensure only one process group is passed |
|
if isinstance(process_group, (list, tuple)): |
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." |
|
process_group = process_group[0] |
|
|
|
tp_size = dist.get_world_size(process_group) |
|
if out_features < tp_size: |
|
return module |
|
|
|
if out_features % tp_size != 0: |
|
raise ValueError( |
|
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" |
|
) |
|
|
|
lmhead_1d = BaichuanLMHeadLinear1D_Col( |
|
in_features=in_features, |
|
out_features=out_features, |
|
bias=bias, |
|
device=device, |
|
process_group=process_group, |
|
weight=module.weight, |
|
bias_=module.bias, |
|
**kwargs, |
|
) |
|
|
|
return lmhead_1d |
|
|
|
def _load_from_state_dict( |
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
): |
|
state_dict[prefix + "weight"] = nn.functional.normalize(state_dict[prefix + "weight"]) |
|
super()._load_from_state_dict( |
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
)
|
|
|