mirror of https://github.com/hpcaitech/ColossalAI
fix typo applications/Chat/coati/ (#3947)
parent
e8ad3c88f5
commit
d4fb7bfda7
|
@ -205,15 +205,15 @@ class ExperienceMakerHolder:
|
||||||
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
|
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||||
state_dict_increasae = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
|
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
|
||||||
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increasae)
|
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase)
|
||||||
if new_critic_state_dict is not None:
|
if new_critic_state_dict is not None:
|
||||||
if not self._update_lora_weights or fully_update:
|
if not self._update_lora_weights or fully_update:
|
||||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||||
state_dict_increasae = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
|
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
|
||||||
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increasae)
|
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase)
|
||||||
|
|
||||||
# the lock must be released after both actor and critic being updated
|
# the lock must be released after both actor and critic being updated
|
||||||
if chunk_end:
|
if chunk_end:
|
||||||
|
|
|
@ -19,7 +19,7 @@ class LoRAConfig:
|
||||||
class LoRAConstructor:
|
class LoRAConstructor:
|
||||||
'''
|
'''
|
||||||
Tools for reconstructing a model from a remote LoRA model.
|
Tools for reconstructing a model from a remote LoRA model.
|
||||||
(Transfering only LoRA data costs much less!)
|
(Transferring only LoRA data costs much less!)
|
||||||
Usage:
|
Usage:
|
||||||
Step 1 (Sender):
|
Step 1 (Sender):
|
||||||
filter_state_dict_lora()
|
filter_state_dict_lora()
|
||||||
|
@ -52,7 +52,7 @@ class LoRAConstructor:
|
||||||
if lora_config_dict is not None:
|
if lora_config_dict is not None:
|
||||||
self.register_lora_config(lora_config_dict)
|
self.register_lora_config(lora_config_dict)
|
||||||
|
|
||||||
state_dict_increasae = OrderedDict()
|
state_dict_increase = OrderedDict()
|
||||||
config_iter = iter(self.lora_config_dict.items())
|
config_iter = iter(self.lora_config_dict.items())
|
||||||
lora_A, lora_B, layer_prefix = None, None, None
|
lora_A, lora_B, layer_prefix = None, None, None
|
||||||
for k, v in state_dict_lora.items():
|
for k, v in state_dict_lora.items():
|
||||||
|
@ -65,11 +65,11 @@ class LoRAConstructor:
|
||||||
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
||||||
lora_B = v
|
lora_B = v
|
||||||
weight_data_increase = self._compute(lora_A, lora_B, config)
|
weight_data_increase = self._compute(lora_A, lora_B, config)
|
||||||
state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase
|
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
|
||||||
lora_A, lora_B, layer_prefix = None, None, None
|
lora_A, lora_B, layer_prefix = None, None, None
|
||||||
else:
|
else:
|
||||||
raise ValueError('unexpected key')
|
raise ValueError('unexpected key')
|
||||||
return state_dict_increasae
|
return state_dict_increase
|
||||||
|
|
||||||
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
||||||
def T(w):
|
def T(w):
|
||||||
|
@ -80,12 +80,12 @@ class LoRAConstructor:
|
||||||
return weight_data_increase
|
return weight_data_increase
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def load_state_dict_increase(self, model: nn.Module, state_dict_increasae: Dict[str, Any]):
|
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
|
||||||
'''
|
'''
|
||||||
The final reconstruction step
|
The final reconstruction step
|
||||||
'''
|
'''
|
||||||
# naive approach
|
# naive approach
|
||||||
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increasae.items()}, strict=False)
|
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
||||||
|
|
|
@ -29,7 +29,7 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
|
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
|
||||||
seed(int): The seed for the random number generator.
|
seed(int): The seed for the random number generator.
|
||||||
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
||||||
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
||||||
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
|
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
|
||||||
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
|
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
|
||||||
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
||||||
|
@ -39,7 +39,7 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
||||||
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
|
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
|
||||||
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
||||||
reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
||||||
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
|
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
|
||||||
initial_scale(float): The initial scale for the optimizer.
|
initial_scale(float): The initial scale for the optimizer.
|
||||||
growth_factor(float): The growth factor for the optimizer.
|
growth_factor(float): The growth factor for the optimizer.
|
||||||
|
|
Loading…
Reference in New Issue