mirror of https://github.com/hpcaitech/ColossalAI
[pipelinable]use pipelinable to support GPT model. (#903)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [pipelinable]use pipelinable to support GPT model.
* fix a bug caused by ShardedModel
* polish
* fix front func list
pull/933/head
parent
b61d64685f
commit
32a45cd7ef
|
@ -119,9 +119,12 @@ class PipelineSchedule(BaseSchedule):
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
# TODO: remove this after testing new zero with pipeline parallelism
|
# TODO: remove this after testing new zero with pipeline parallelism
|
||||||
model = engine.model
|
model = engine.model
|
||||||
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
|
if isinstance(model, NaiveAMPModel):
|
||||||
self.dtype = torch.half
|
self.dtype = torch.half
|
||||||
model = model.model
|
model = model.model
|
||||||
|
if isinstance(model, ShardedModelV2):
|
||||||
|
self.dtype = torch.half
|
||||||
|
model = model.module
|
||||||
sig = inspect.signature(model.forward)
|
sig = inspect.signature(model.forward)
|
||||||
for p in sig.parameters.values():
|
for p in sig.parameters.values():
|
||||||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||||
|
@ -135,6 +138,12 @@ class PipelineSchedule(BaseSchedule):
|
||||||
else:
|
else:
|
||||||
sig = inspect.signature(model.forward)
|
sig = inspect.signature(model.forward)
|
||||||
if isinstance(batch_data, torch.Tensor):
|
if isinstance(batch_data, torch.Tensor):
|
||||||
|
for p in sig.parameters.values():
|
||||||
|
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
||||||
|
if input_tensor is None:
|
||||||
|
return model(batch_data)
|
||||||
|
else:
|
||||||
|
return model(input_tensor)
|
||||||
if input_tensor is None:
|
if input_tensor is None:
|
||||||
return model(batch_data)
|
return model(batch_data)
|
||||||
elif len(sig.parameters) > 1:
|
elif len(sig.parameters) > 1:
|
||||||
|
@ -148,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
filter_batch = False
|
filter_batch = False
|
||||||
if filter_batch:
|
if filter_batch:
|
||||||
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
|
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
|
||||||
if input_tensor is None:
|
if input_tensor is None and filter_batch:
|
||||||
return model(**batch_data)
|
return model(**batch_data)
|
||||||
else:
|
else:
|
||||||
return model(input_tensor, **batch_data)
|
return model(input_tensor, **batch_data)
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||||
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
|
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
|
||||||
from colossalai.builder.pipeline import partition_uniform, partition_balanced
|
from colossalai.builder.pipeline import partition_uniform, partition_balanced
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.nn.layer.utils import CheckpointModule
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,11 +61,18 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
if issubclass(obj.__class__, torch.nn.modules.module.Module):
|
if issubclass(obj.__class__, torch.nn.modules.module.Module):
|
||||||
obj = self._layer_spec_dict[id(obj)]
|
obj = self._layer_spec_dict[id(obj)]
|
||||||
modified_args.append(obj)
|
modified_args.append(obj)
|
||||||
# (lyl)TODO: analyse kwargs as well
|
|
||||||
|
modified_kwargs = {}
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if issubclass(v.__class__, torch.nn.modules.module.Module):
|
||||||
|
v = self._layer_spec_dict[id(v)]
|
||||||
|
# (lyl)TODO: analyse ColoTensor as well
|
||||||
|
modified_kwargs[k] = v
|
||||||
|
|
||||||
modified_args = tuple(modified_args)
|
modified_args = tuple(modified_args)
|
||||||
self._root_children = list(module.children())
|
self._root_children = list(module.children())
|
||||||
self._model = module
|
self._model = module
|
||||||
layer_spec = LayerSpec(module.__class__, *modified_args, **kwargs)
|
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
|
||||||
layer_spec.set_children(module.children())
|
layer_spec.set_children(module.children())
|
||||||
self._layer_spec_dict[module_id] = layer_spec
|
self._layer_spec_dict[module_id] = layer_spec
|
||||||
name_list = []
|
name_list = []
|
||||||
|
@ -82,27 +92,48 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
"""
|
"""
|
||||||
if exec_seq is None:
|
if exec_seq is None:
|
||||||
#if user do not provide the model executing sequence, we use the initialization order as the executing order.
|
#if user do not provide the model executing sequence, we use the initialization order as the executing order.
|
||||||
|
children_name = []
|
||||||
for child in self._root_children:
|
for child in self._root_children:
|
||||||
layer_spec = self._layer_spec_dict[id(child)]
|
layer_spec = self._layer_spec_dict[id(child)]
|
||||||
if layer_spec.typename in (torch.nn.modules.container.ModuleList,
|
if layer_spec.typename in (torch.nn.modules.container.ModuleList,
|
||||||
torch.nn.modules.container.Sequential):
|
torch.nn.modules.container.Sequential):
|
||||||
for child_in_container in layer_spec.children:
|
for child_in_container in layer_spec.children:
|
||||||
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
|
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
|
||||||
|
for name, module in self._model.named_modules():
|
||||||
|
if id(module) == id(child_in_container):
|
||||||
|
children_name.append(name)
|
||||||
|
break
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self._layer_spec_list.append(layer_spec)
|
self._layer_spec_list.append(layer_spec)
|
||||||
|
for name, module in self._model.named_modules():
|
||||||
|
if id(module) == id(child):
|
||||||
|
children_name.append(name)
|
||||||
|
break
|
||||||
|
|
||||||
else:
|
else:
|
||||||
func_key = "first"
|
front_funcs_list = []
|
||||||
for index, element in enumerate(exec_seq):
|
for index, element in enumerate(exec_seq):
|
||||||
if isinstance(element, str):
|
if isinstance(element, str):
|
||||||
module = dict(self._model.named_modules())[element]
|
module = dict(self._model.named_modules())[element]
|
||||||
layer_spec = self._layer_spec_dict[id(module)]
|
layer_spec = self._layer_spec_dict[id(module)]
|
||||||
func_key = layer_spec
|
if len(front_funcs_list) != 0:
|
||||||
|
func_key = (layer_spec, "front")
|
||||||
|
if func_key not in self._func_dict:
|
||||||
|
self._func_dict[func_key] = []
|
||||||
|
for f in front_funcs_list:
|
||||||
|
self._func_dict[func_key].append(f)
|
||||||
|
front_funcs_list = []
|
||||||
|
func_key = (layer_spec, "behind")
|
||||||
self._layer_spec_list.append(layer_spec)
|
self._layer_spec_list.append(layer_spec)
|
||||||
|
elif isinstance(element, tuple) and element[1] == "front":
|
||||||
|
front_funcs_list.append(element[0])
|
||||||
else:
|
else:
|
||||||
if func_key not in self._func_dict:
|
if func_key not in self._func_dict:
|
||||||
self._func_dict[func_key] = []
|
self._func_dict[func_key] = []
|
||||||
|
if isinstance(element, tuple):
|
||||||
|
self._func_dict[func_key].append(element[0])
|
||||||
|
else:
|
||||||
self._func_dict[func_key].append(element)
|
self._func_dict[func_key].append(element)
|
||||||
|
|
||||||
def partition(self, num_chunks, pipeline_size, rank):
|
def partition(self, num_chunks, pipeline_size, rank):
|
||||||
|
@ -128,17 +159,19 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
layers_to_build = []
|
layers_to_build = []
|
||||||
for start, end in parts:
|
for start, end in parts:
|
||||||
layers_to_build += self._layer_spec_list[start:end]
|
layers_to_build += self._layer_spec_list[start:end]
|
||||||
func_dict_in_partition = {}
|
behind_func_dict_in_partition = {}
|
||||||
|
front_func_dict_in_partition = {}
|
||||||
module_list_in_partition = []
|
module_list_in_partition = []
|
||||||
if rank == 0 and "first" in self._func_dict:
|
|
||||||
func_dict_in_partition["first"] = self._func_dict["first"]
|
|
||||||
for layer in layers_to_build:
|
for layer in layers_to_build:
|
||||||
module = layer.build()
|
module = layer.build()
|
||||||
module_list_in_partition.append(module)
|
module_list_in_partition.append(module)
|
||||||
if layer in self._func_dict:
|
if (layer, "front") in self._func_dict:
|
||||||
func_dict_in_partition[id(module)] = self._func_dict[layer]
|
front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")]
|
||||||
|
elif (layer, "behind") in self._func_dict:
|
||||||
|
behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")]
|
||||||
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
|
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
|
||||||
pipeline_model = PipelinableModel(module_list_in_partition, func_dict_in_partition)
|
pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition,
|
||||||
|
behind_func_dict_in_partition)
|
||||||
|
|
||||||
return pipeline_model
|
return pipeline_model
|
||||||
|
|
||||||
|
@ -146,31 +179,119 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
self._policy = policy
|
self._policy = policy
|
||||||
|
|
||||||
|
|
||||||
|
def _build_kwargs_for_module(function, kw_dict):
|
||||||
|
"""
|
||||||
|
Generally, the first argument of module.forward is an input tensor come from the previous layer.
|
||||||
|
Therefore, we just filter the kwargs from second element of the dictionary.
|
||||||
|
"""
|
||||||
|
sig = inspect.signature(function)
|
||||||
|
if len(sig.parameters) <= 1:
|
||||||
|
return None
|
||||||
|
args_name_list = list(sig.parameters.keys())
|
||||||
|
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
|
||||||
|
return kw_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _build_kwargs_for_function(function, kw_dict):
|
||||||
|
sig = inspect.signature(function)
|
||||||
|
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
|
||||||
|
if len(kw_dict) == 0:
|
||||||
|
return None
|
||||||
|
return kw_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
|
||||||
|
"""
|
||||||
|
We suppose the callable object passed to to_layer_list method in two purpose:
|
||||||
|
a. use the callable object to modify input tensor, such as \
|
||||||
|
lambda x: torch.flatten(x, 1)
|
||||||
|
b. use the callable object to modify kwargs value, such as \
|
||||||
|
def foo(attention_mask=None):
|
||||||
|
if attention_mask is not None:
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
|
return attention_mask
|
||||||
|
"""
|
||||||
|
|
||||||
|
if kw_dict is not None:
|
||||||
|
rst = func(**kw_dict)
|
||||||
|
if isinstance(rst, tuple):
|
||||||
|
for i, k in enumerate(kw_dict.keys()):
|
||||||
|
kwargs[k] = rst[i]
|
||||||
|
else:
|
||||||
|
for k in kw_dict.keys():
|
||||||
|
kwargs[k] = rst
|
||||||
|
return input_tensor
|
||||||
|
return func(input_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def _exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
|
||||||
|
|
||||||
|
assert func_key in func_dict, f"{func_key} is not in the function_dict."
|
||||||
|
funcs_to_exec = func_dict[func_key]
|
||||||
|
if isinstance(funcs_to_exec, list):
|
||||||
|
for f in funcs_to_exec:
|
||||||
|
f_kwargs = _build_kwargs_for_function(f, kwargs)
|
||||||
|
input_tensor = _exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
|
||||||
|
else:
|
||||||
|
f_kwargs = _build_kwargs_for_function(funcs_to_exec, kwargs)
|
||||||
|
input_tensor = _exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
|
||||||
|
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
class PipelinableModel(torch.nn.Module):
|
class PipelinableModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, module_list, func_dict):
|
def __init__(self, module_list, front_func_dict, behind_func_dict):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._module_list = module_list
|
self._module_list = module_list
|
||||||
self._func_dict = func_dict
|
self._front_func_dict = front_func_dict
|
||||||
|
self._behind_func_dict = behind_func_dict
|
||||||
|
|
||||||
def forward(self, input_tensor):
|
def forward(self, input_tensor, **kwargs):
|
||||||
if "first" in self._func_dict:
|
|
||||||
funcs = self._func_dict["first"]
|
|
||||||
if isinstance(funcs, list):
|
|
||||||
for f in funcs:
|
|
||||||
input_tensor = f(input_tensor)
|
|
||||||
else:
|
|
||||||
input_tensor = funcs(input_tensor)
|
|
||||||
|
|
||||||
for module in self._module_list:
|
for module in self._module_list:
|
||||||
input_tensor = module(input_tensor)
|
|
||||||
if id(module) in self._func_dict:
|
if id(module) in self._front_func_dict:
|
||||||
funcs = self._func_dict[id(module)]
|
input_tensor = _exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
|
||||||
if isinstance(funcs, list):
|
|
||||||
for f in funcs:
|
if isinstance(module, CheckpointModule):
|
||||||
input_tensor = f(input_tensor)
|
forward_func = module._forward
|
||||||
else:
|
else:
|
||||||
input_tensor = funcs(input_tensor)
|
forward_func = module.forward
|
||||||
|
if input_tensor is None:
|
||||||
|
module_kwargs = _build_kwargs_for_function(forward_func, kwargs)
|
||||||
|
else:
|
||||||
|
module_kwargs = _build_kwargs_for_module(forward_func, kwargs)
|
||||||
|
if module_kwargs is not None and input_tensor is not None:
|
||||||
|
if isinstance(module, CheckpointModule):
|
||||||
|
convert_kwargs_to_args = []
|
||||||
|
for v in module_kwargs.values():
|
||||||
|
convert_kwargs_to_args.append(v)
|
||||||
|
rst = module(input_tensor, *convert_kwargs_to_args)
|
||||||
|
else:
|
||||||
|
rst = module(input_tensor, **module_kwargs)
|
||||||
|
if isinstance(rst, tuple):
|
||||||
|
input_tensor = rst[0]
|
||||||
|
else:
|
||||||
|
input_tensor = rst
|
||||||
|
elif module_kwargs is not None and input_tensor is None:
|
||||||
|
if isinstance(module, CheckpointModule):
|
||||||
|
convert_kwargs_to_args = []
|
||||||
|
for v in module_kwargs.values():
|
||||||
|
convert_kwargs_to_args.append(v)
|
||||||
|
rst = module(input_tensor, *convert_kwargs_to_args)
|
||||||
|
else:
|
||||||
|
rst = module(**module_kwargs)
|
||||||
|
if isinstance(rst, tuple):
|
||||||
|
input_tensor = rst[0]
|
||||||
|
else:
|
||||||
|
input_tensor = rst
|
||||||
|
else:
|
||||||
|
input_tensor = module(input_tensor)
|
||||||
|
|
||||||
|
if id(module) in self._behind_func_dict:
|
||||||
|
input_tensor = _exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||||||
|
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
@ -203,7 +324,14 @@ class LayerSpec:
|
||||||
obj = obj.build()
|
obj = obj.build()
|
||||||
recovered_args.append(obj)
|
recovered_args.append(obj)
|
||||||
recovered_args = tuple(recovered_args)
|
recovered_args = tuple(recovered_args)
|
||||||
return self.typename(*recovered_args, **self.module_kwargs)
|
|
||||||
|
recovered_kwargs = {}
|
||||||
|
for k, v in self.module_kwargs.items():
|
||||||
|
if isinstance(v, LayerSpec):
|
||||||
|
v = v.build()
|
||||||
|
recovered_kwargs[k] = v
|
||||||
|
|
||||||
|
return self.typename(*recovered_args, **recovered_kwargs)
|
||||||
|
|
||||||
def set_children(self, children):
|
def set_children(self, children):
|
||||||
self.children = children
|
self.children = children
|
||||||
|
|
Loading…
Reference in New Issue