[Inference]Lazy Init Support (#5785)

* lazy init support

* lazy init llama support

* :lazy init support for baichuan

* aligh rpc

* add note for baichuan

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5874/head
Runyu Lu 2024-06-27 18:02:15 +08:00 committed by GitHub
parent d9d5e7ea1f
commit 3c7cda0c9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 205 additions and 105 deletions

View File

@ -24,8 +24,9 @@ from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
@ -122,16 +123,24 @@ class InferenceEngine:
model_inference_config: the configuration for modeling initialization when inference. model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
""" """
pretrained_path = None
if isinstance(model_or_path, str): if isinstance(model_or_path, str):
import colossalai.interface.pretrained as pretrained_utils
try: try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0] arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys(): if arch in _supported_models.keys():
# NOTE(lry89757) Currently we load the model using transformers-api, if arch is "BaichuanForCausalLM":
# but we will use lazy tensor and checkpoint io to accelerate self.logger.warning(
# the model load process in the future. "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) )
ctx = LazyInitContext(default_device="cuda")
with ctx:
model = _supported_models[arch].from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
pretrained_path = pretrained_utils.get_pretrained_path(model)
else: else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.") raise ValueError(f"Model {arch} is not supported.")
@ -189,14 +198,13 @@ class InferenceEngine:
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
) )
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor if pretrained_path:
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM): from colossalai.inference.core.plugin import InferCheckpoint_io
# from colossalai.inference.core.plugin import InferCheckpoint_io
# cpt_io = InferCheckpoint_io() cpt_io = InferCheckpoint_io()
# if_has_index_file, model_index_file = has_index_file(model_or_path) if_has_index_file, model_index_file = has_index_file(pretrained_path)
# assert if_has_index_file, "the model path is invalid" assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file) cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, _ = torch.cuda.mem_get_info() free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory peak_memory = init_gpu_memory - free_gpu_memory

View File

@ -73,7 +73,9 @@ class RPCInferenceEngine(InferenceEngine):
try: try:
if isinstance(model_or_path, str): if isinstance(model_or_path, str):
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) self.model_config = AutoConfig.from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
elif isinstance(model_or_path, nn.Module): elif isinstance(model_or_path, nn.Module):
self.logger.error( self.logger.error(
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n" f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"

View File

@ -18,8 +18,9 @@ from colossalai.inference.modeling.policy import (
model_policy_map, model_policy_map,
) )
from colossalai.inference.sampler import search_tokens from colossalai.inference.sampler import search_tokens
from colossalai.inference.utils import get_model_size from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
@ -178,20 +179,23 @@ class rpcWorkerService(rpyc.Service):
model_policy (Policy): the policy to replace the model model_policy (Policy): the policy to replace the model
""" """
pretrained_path = None
if isinstance(model_or_path, str): if isinstance(model_or_path, str):
# is_local = os.path.isdir(model_or_path) import colossalai.interface.pretrained as pretrained_utils
try: try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0] arch = getattr(hf_config, "architectures")[0]
# NOTE(lry89757) Currently we load the model using transformers-api, if arch is "BaichuanForCausalLM":
# but we will use lazy tensor and checkpoint io to accelerate self.logger.warning(
# the model load process in the future. "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) )
# if is_local: ctx = LazyInitContext(default_device="cuda")
# model = _SUPPORTED_MODELS[arch](hf_config) with ctx:
# else: model = _SUPPORTED_MODELS[arch].from_pretrained(
# # load the real checkpoint model_or_path, trust_remote_code=True, torch_dtype=self.dtype
# model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) )
pretrained_path = pretrained_utils.get_pretrained_path(model)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
@ -240,14 +244,13 @@ class rpcWorkerService(rpyc.Service):
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
) )
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor if pretrained_path:
# if isinstance(model_or_path, str) and is_local: from colossalai.inference.core.plugin import InferCheckpoint_io
# from colossalai.inference.core.plugin import InferCheckpoint_io
# cpt_io = InferCheckpoint_io() cpt_io = InferCheckpoint_io()
# if_has_index_file, model_index_file = has_index_file(model_or_path) if_has_index_file, model_index_file = has_index_file(pretrained_path)
# assert if_has_index_file, "the model path is invalid" assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file) cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory peak_memory = init_gpu_memory - free_gpu_memory

View File

@ -1,8 +1,10 @@
from typing import List, Union from typing import List, Union
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Linear1D_Col from colossalai.shardformer.layer import Linear1D_Col
from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.shardformer.layer.parallel_module import ParallelModule
@ -12,17 +14,51 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
def from_native_module( def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule: ) -> ParallelModule:
LazyInitContext.materialize(module)
module.in_features = module.weight.size(1) module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0) module.out_features = module.weight.size(0)
module.bias = None module.bias = None
module.weight.data = nn.functional.normalize( module.weight.data = nn.functional.normalize(
module.weight module.weight
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight. ) # NOTE(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue. # So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
return Linear1D_Col.from_native_module( # get the attributes
module, in_features = module.in_features
process_group, out_features = module.out_features
*args, bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
lmhead_1d = BaichuanLMHeadLinear1D_Col(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs, **kwargs,
) )
return lmhead_1d
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
state_dict[prefix + "weight"] = nn.functional.normalize(state_dict[prefix + "weight"])
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

View File

@ -70,7 +70,6 @@ class NopadBaichuanAttention(ParallelModule):
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None. attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
""" """
ParallelModule.__init__(self) ParallelModule.__init__(self)
self.o_proj = attn_oproj
self.config = config self.config = config
self.num_heads = num_heads self.num_heads = num_heads
@ -78,6 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.process_group = process_group self.process_group = process_group
self.W_pack = W_pack self.W_pack = W_pack
self.o_proj = attn_oproj
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
self.attention_backend = get_attention_backend(model_shard_infer_config) self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)

View File

@ -284,6 +284,10 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
self.gate_up_weight = nn.Parameter( self.gate_up_weight = nn.Parameter(
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0) torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
) )
self.gate_up_dict = {
"gate_proj.weight": None,
"up_proj.weight": None,
} # used and delattr in load/shard of gate/up weight
self.down_proj = mlp_dproj self.down_proj = mlp_dproj
self.process_group = process_group self.process_group = process_group
@ -321,44 +325,47 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
): ):
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight) # NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
for hook in self._load_state_dict_pre_hooks.values(): if hasattr(self, "gate_up_dict"):
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None} local_state = {k: v for k, v in local_name_params if v is not None}
key = "gate_up_weight" device_mesh = self.helper_layout.device_mesh
k1 = "gate_proj.weight" sharding_spec = self.helper_layout.sharding_spec
k2 = "up_proj.weight" for weight_name in self.gate_up_dict:
prefix_weight_name = prefix + weight_name
if prefix_weight_name in state_dict.keys():
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
self.gate_up_dict[weight_name] = w.T
gate_w = state_dict[prefix + k1] if None not in self.gate_up_dict.values():
up_w = state_dict[prefix + k2] # we've got all the weights of gate/up
gate_up_w = torch.stack(list(self.gate_up_dict.values()), dim=0)
device_mesh = self.helper_layout.device_mesh input_param = nn.Parameter(
sharding_spec = self.helper_layout.sharding_spec gate_up_w
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec) ) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0) key = "gate_up_weight"
param = local_state.get(key, None)
input_param = nn.Parameter( try:
gate_up_w with torch.no_grad():
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) param.copy_(input_param)
param = local_state[key] except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
try: del self.gate_up_dict
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
strict = False # to avoid unexpected_keys strict = False # to avoid unexpected_keys
super()._load_from_state_dict( super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
) )
@ -429,7 +436,15 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
self.helper_layout = ( self.helper_layout = (
attn_qproj_w.dist_layout attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
self.qkv_dict = {
"q_proj.weight": None,
"k_proj.weight": None,
"v_proj.weight": None,
} # used and delattr in load/shard of qkv weight
else: else:
self.helper_layout = (
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous()) self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous()) self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous()) self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
@ -577,49 +592,83 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
def _load_from_state_dict( def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
): ):
if self.num_heads == self.num_key_value_heads: for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
if self.num_heads == self.num_key_value_heads and hasattr(self, "qkv_dict"):
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight) # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
key = "qkv_weight" key = "qkv_weight"
k1 = "q_proj.weight"
k2 = "k_proj.weight"
k3 = "v_proj.weight"
q_w = state_dict[prefix + k1]
k_w = state_dict[prefix + k2]
v_w = state_dict[prefix + k3]
device_mesh = self.helper_layout.device_mesh # NOTE(@lry89757) We will load the sharded checkpoint file according to the weight map from *.index.json
sharding_spec = self.helper_layout.sharding_spec # Here we need the weight of q,k,v to stack the weights of q,k,v into one qkv weight.
q_w = distribute_tensor(q_w, device_mesh, sharding_spec) # Unfortunately, it is highly like that all weights of q,k,v are not in the same sharded checkpoint file(like meta-llama/llama3-70B)
k_w = distribute_tensor(k_w, device_mesh, sharding_spec) # so here we will stack them when we really collect all the three weights.
v_w = distribute_tensor(v_w, device_mesh, sharding_spec) for weight_name in self.qkv_dict:
prefix_weight_name = prefix + weight_name
if prefix_weight_name in state_dict.keys():
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
self.qkv_dict[weight_name] = w.T
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0) if None not in self.qkv_dict.values():
# we've got all the weights of q, k, v
qkv_w = torch.stack(list(self.qkv_dict.values()), dim=0)
input_param = nn.Parameter( input_param = nn.Parameter(
qkv_w qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key] param = local_state[key]
try: try:
with torch.no_grad(): with torch.no_grad():
param.copy_(input_param) param.copy_(input_param)
except Exception as ex: except Exception as ex:
error_msgs.append( error_msgs.append(
'While copying the parameter named "{}", ' 'While copying the parameter named "{}", '
"whose dimensions in the model are {} and " "whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, " "whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
) )
strict = False # to avoid unexpected_keys del self.qkv_dict
else:
def _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight"):
if prefix + origin_weight_name in state_dict.keys():
attn_qproj_w = state_dict[prefix + origin_weight_name]
w = distribute_tensor(attn_qproj_w, device_mesh, sharding_spec)
input_param = nn.Parameter(w.T)
param = local_state[local_weight_name]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
key = local_weight_name
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
if prefix + "q_proj.weight" in state_dict.keys():
_load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight")
if prefix + "k_proj.weight" in state_dict.keys():
_load(origin_weight_name="k_proj.weight", local_weight_name="k_proj_weight")
if prefix + "v_proj.weight" in state_dict.keys():
_load(origin_weight_name="v_proj.weight", local_weight_name="v_proj_weight")
strict = False # to avoid unexpected_keys
super()._load_from_state_dict( super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
) )

View File

@ -674,6 +674,8 @@ class FusedLinear1D_Col(ParallelModule):
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight. n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
""" """
LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.in_features in_features = module.in_features
out_features = module.out_features out_features = module.out_features