|
|
@ -1,13 +1,14 @@ |
|
|
|
#!/usr/bin/env python |
|
|
|
#!/usr/bin/env python |
|
|
|
# coding: utf-8 |
|
|
|
# coding: utf-8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
|
|
|
|
|
import types |
|
|
|
|
|
|
|
from typing import Callable, List |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn as nn |
|
|
|
from colossalai.tensor import ColoParameter, ColoTensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import types |
|
|
|
from colossalai.tensor import ColoParameter, ColoTensor |
|
|
|
import inspect |
|
|
|
|
|
|
|
from typing import List, Callable |
|
|
|
|
|
|
|
from colossalai.utils.model.utils import substitute_init_recursively |
|
|
|
from colossalai.utils.model.utils import substitute_init_recursively |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -35,14 +36,15 @@ class LazyInitContext(): |
|
|
|
assert not model.weight.is_meta and torch.all(model.weight == 0) |
|
|
|
assert not model.weight.is_meta and torch.all(model.weight == 0) |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
Args: |
|
|
|
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False. |
|
|
|
to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This |
|
|
|
|
|
|
|
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet. |
|
|
|
extra_torch_tensor_func (List[str]): extra torch tensor functions related |
|
|
|
extra_torch_tensor_func (List[str]): extra torch tensor functions related |
|
|
|
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. |
|
|
|
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
tensor_set_value_func = ['zero_', 'fill_'] |
|
|
|
tensor_set_value_func = ['zero_', 'fill_'] |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None): |
|
|
|
def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None): |
|
|
|
# TODO: hijack the torch constructor functions as well |
|
|
|
# TODO: hijack the torch constructor functions as well |
|
|
|
self._to_meta = to_meta |
|
|
|
self._to_meta = to_meta |
|
|
|
self._intercepted_nn_init_func_cache = {} |
|
|
|
self._intercepted_nn_init_func_cache = {} |
|
|
@ -212,18 +214,19 @@ class LazyInitContext(): |
|
|
|
materialized_tensor = torch.empty_like(tensor, device=device) |
|
|
|
materialized_tensor = torch.empty_like(tensor, device=device) |
|
|
|
# if this tensor is a meta tensor, it must have an init function |
|
|
|
# if this tensor is a meta tensor, it must have an init function |
|
|
|
assert tensor in self._intercepted_nn_init_func_cache |
|
|
|
assert tensor in self._intercepted_nn_init_func_cache |
|
|
|
tensor = materialized_tensor |
|
|
|
else: |
|
|
|
|
|
|
|
materialized_tensor = tensor |
|
|
|
|
|
|
|
|
|
|
|
# apply init function |
|
|
|
# apply init function |
|
|
|
if tensor in self._intercepted_nn_init_func_cache: |
|
|
|
if tensor in self._intercepted_nn_init_func_cache: |
|
|
|
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] |
|
|
|
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] |
|
|
|
init_func(tensor, *args, **kwargs) |
|
|
|
init_func(materialized_tensor, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
# convert it to ColoTensor or ColoParameter |
|
|
|
# convert it to ColoTensor or ColoParameter |
|
|
|
if is_param: |
|
|
|
if is_param: |
|
|
|
tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad) |
|
|
|
tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad) |
|
|
|
else: |
|
|
|
else: |
|
|
|
tensor = ColoTensor.from_torch_tensor(tensor) |
|
|
|
tensor = ColoTensor.from_torch_tensor(materialized_tensor) |
|
|
|
|
|
|
|
|
|
|
|
# override the original tensor |
|
|
|
# override the original tensor |
|
|
|
with torch.no_grad(): |
|
|
|
with torch.no_grad(): |
|
|
|