|
|
|
@ -19,7 +19,7 @@ class LoRAConfig:
|
|
|
|
|
class LoRAConstructor: |
|
|
|
|
''' |
|
|
|
|
Tools for reconstructing a model from a remote LoRA model. |
|
|
|
|
(Transfering only LoRA data costs much less!) |
|
|
|
|
(Transferring only LoRA data costs much less!) |
|
|
|
|
Usage: |
|
|
|
|
Step 1 (Sender): |
|
|
|
|
filter_state_dict_lora() |
|
|
|
@ -52,7 +52,7 @@ class LoRAConstructor:
|
|
|
|
|
if lora_config_dict is not None: |
|
|
|
|
self.register_lora_config(lora_config_dict) |
|
|
|
|
|
|
|
|
|
state_dict_increasae = OrderedDict() |
|
|
|
|
state_dict_increase = OrderedDict() |
|
|
|
|
config_iter = iter(self.lora_config_dict.items()) |
|
|
|
|
lora_A, lora_B, layer_prefix = None, None, None |
|
|
|
|
for k, v in state_dict_lora.items(): |
|
|
|
@ -65,11 +65,11 @@ class LoRAConstructor:
|
|
|
|
|
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair" |
|
|
|
|
lora_B = v |
|
|
|
|
weight_data_increase = self._compute(lora_A, lora_B, config) |
|
|
|
|
state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase |
|
|
|
|
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase |
|
|
|
|
lora_A, lora_B, layer_prefix = None, None, None |
|
|
|
|
else: |
|
|
|
|
raise ValueError('unexpected key') |
|
|
|
|
return state_dict_increasae |
|
|
|
|
return state_dict_increase |
|
|
|
|
|
|
|
|
|
def _compute(self, lora_A, lora_B, config=LoRAConfig()): |
|
|
|
|
def T(w): |
|
|
|
@ -80,12 +80,12 @@ class LoRAConstructor:
|
|
|
|
|
return weight_data_increase |
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
|
def load_state_dict_increase(self, model: nn.Module, state_dict_increasae: Dict[str, Any]): |
|
|
|
|
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]): |
|
|
|
|
''' |
|
|
|
|
The final reconstruction step |
|
|
|
|
''' |
|
|
|
|
# naive approach |
|
|
|
|
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increasae.items()}, strict=False) |
|
|
|
|
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False): |
|
|
|
|