|
|
|
@ -3,6 +3,7 @@ import functools
|
|
|
|
|
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str |
|
|
|
|
from colossalai.builder.pipeline import partition_uniform, partition_balanced |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.tensor import ColoTensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PipelinableContext(InsertPostInitMethodToModuleSubClasses): |
|
|
|
@ -64,8 +65,15 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
layer_spec = LayerSpec(module.__class__, *modified_args, **kwargs) |
|
|
|
|
layer_spec.set_children(module.children()) |
|
|
|
|
self._layer_spec_dict[module_id] = layer_spec |
|
|
|
|
for param in module.parameters(recurse=False): |
|
|
|
|
param.data = torch.rand(1, 1) |
|
|
|
|
name_list = [] |
|
|
|
|
for name, param in module.named_parameters(): |
|
|
|
|
if isinstance(param, ColoTensor): |
|
|
|
|
continue |
|
|
|
|
name_list.append((name, param)) |
|
|
|
|
|
|
|
|
|
for name, param in name_list: |
|
|
|
|
delattr(module, name) |
|
|
|
|
setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=False)) |
|
|
|
|
|
|
|
|
|
def to_layer_list(self, exec_seq=None): |
|
|
|
|
""" |
|
|
|
|