mirror of https://github.com/hpcaitech/ColossalAI
fix typo applications/Chat/coati/ (#3947)
parent
e8ad3c88f5
commit
d4fb7bfda7
|
@ -205,15 +205,15 @@ class ExperienceMakerHolder:
|
|||
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
|
||||
else:
|
||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increasae = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
|
||||
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increasae)
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
|
||||
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase)
|
||||
if new_critic_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||
else:
|
||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increasae = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
|
||||
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increasae)
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
|
||||
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase)
|
||||
|
||||
# the lock must be released after both actor and critic being updated
|
||||
if chunk_end:
|
||||
|
|
|
@ -19,7 +19,7 @@ class LoRAConfig:
|
|||
class LoRAConstructor:
|
||||
'''
|
||||
Tools for reconstructing a model from a remote LoRA model.
|
||||
(Transfering only LoRA data costs much less!)
|
||||
(Transferring only LoRA data costs much less!)
|
||||
Usage:
|
||||
Step 1 (Sender):
|
||||
filter_state_dict_lora()
|
||||
|
@ -52,7 +52,7 @@ class LoRAConstructor:
|
|||
if lora_config_dict is not None:
|
||||
self.register_lora_config(lora_config_dict)
|
||||
|
||||
state_dict_increasae = OrderedDict()
|
||||
state_dict_increase = OrderedDict()
|
||||
config_iter = iter(self.lora_config_dict.items())
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
for k, v in state_dict_lora.items():
|
||||
|
@ -65,11 +65,11 @@ class LoRAConstructor:
|
|||
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
||||
lora_B = v
|
||||
weight_data_increase = self._compute(lora_A, lora_B, config)
|
||||
state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase
|
||||
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
else:
|
||||
raise ValueError('unexpected key')
|
||||
return state_dict_increasae
|
||||
return state_dict_increase
|
||||
|
||||
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
||||
def T(w):
|
||||
|
@ -80,12 +80,12 @@ class LoRAConstructor:
|
|||
return weight_data_increase
|
||||
return 0
|
||||
|
||||
def load_state_dict_increase(self, model: nn.Module, state_dict_increasae: Dict[str, Any]):
|
||||
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
|
||||
'''
|
||||
The final reconstruction step
|
||||
'''
|
||||
# naive approach
|
||||
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increasae.items()}, strict=False)
|
||||
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
|
||||
|
||||
@staticmethod
|
||||
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
||||
|
|
|
@ -29,7 +29,7 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
|
||||
seed(int): The seed for the random number generator.
|
||||
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
||||
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
||||
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
||||
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
|
||||
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
|
||||
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
||||
|
@ -39,7 +39,7 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
||||
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
|
||||
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
||||
reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
||||
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
||||
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
|
||||
initial_scale(float): The initial scale for the optimizer.
|
||||
growth_factor(float): The growth factor for the optimizer.
|
||||
|
|
Loading…
Reference in New Issue