mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
175 lines
5.1 KiB
175 lines
5.1 KiB
from collections import OrderedDict |
|
from typing import Any, List, Optional, Tuple |
|
|
|
import torch |
|
import torch.cuda |
|
from torch.nn import Module |
|
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten |
|
|
|
|
|
# this register are for torch under version 1.13.1, maybe removed in the future |
|
def _odict_flatten(d: "OrderedDict[Any, Any]") -> Tuple[List[Any], Any]: |
|
return list(d.values()), list(d.keys()) |
|
|
|
|
|
def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]": |
|
return OrderedDict((key, value) for key, value in zip(context, values)) |
|
|
|
|
|
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) |
|
|
|
|
|
def tree_map_hf(fn: Any, pytree: Any): |
|
flat_args, spec = tree_flatten_hf(pytree) |
|
return tree_unflatten([fn(i) for i in flat_args], spec) |
|
|
|
|
|
# use this flatten function to handle the ModelingOutput Class instance. |
|
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: |
|
"""Flattens a pytree into a list of values an a TreeSpec that can be used |
|
to reconstruct the pytree. |
|
""" |
|
if isinstance(pytree, OrderedDict): |
|
node_type = OrderedDict |
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn |
|
child_pytrees, context = flatten_fn(pytree) |
|
|
|
# Recursively flatten the children |
|
result: List[Any] = [] |
|
children_specs: List["TreeSpec"] = [] |
|
for child in child_pytrees: |
|
flat, child_spec = tree_flatten_hf(child) |
|
result += flat |
|
children_specs.append(child_spec) |
|
return result, TreeSpec(node_type, context, children_specs) |
|
else: |
|
result, tree_spec = tree_flatten(pytree) |
|
return result, tree_spec |
|
|
|
|
|
def to_device(x: Any, device: Optional[torch.device] = None) -> Any: |
|
"""Move object to device if it is a tensor. |
|
|
|
Args: |
|
x (Any): Object to be moved. |
|
device (Optional[torch.device], optional): Target device. Defaults to None. |
|
|
|
Returns: |
|
Any: Moved object. |
|
""" |
|
if isinstance(x, torch.Tensor): |
|
return x.to(device) |
|
return x |
|
|
|
|
|
def get_batch_size(batch: Any) -> int: |
|
"""Get the batch size (size of dimension-0) of the first tensor in the batch. |
|
|
|
Args: |
|
batch (Any): Batch to be inspected. |
|
|
|
Raises: |
|
RuntimeError: If no tensor is found in the batch. |
|
|
|
Returns: |
|
int: Batch size. |
|
""" |
|
data_list, _ = tree_flatten(batch) |
|
for data in data_list: |
|
if isinstance(data, torch.Tensor): |
|
return data.size(0) |
|
raise RuntimeError("No tensor found in the batch") |
|
|
|
|
|
def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any: |
|
"""Get a micro batch of the original batch. |
|
|
|
Args: |
|
batch (Any): Batch to be sliced. |
|
start (int): Start index of the micro batch. |
|
micro_batch_size (int): Size of the micro batch. |
|
|
|
Returns: |
|
Any: Target micro batch. |
|
""" |
|
|
|
def _get_tensor_slice(x: Any): |
|
if isinstance(x, torch.Tensor): |
|
return x[start : start + micro_batch_size] |
|
return x |
|
|
|
return tree_map(_get_tensor_slice, batch) |
|
|
|
|
|
def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any: |
|
"""Call model forward function with data and internal inputs. |
|
|
|
Args: |
|
model (Module): Model to be called. |
|
data (Any): Data loaded from data iterator. |
|
internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage. |
|
|
|
Returns: |
|
Any: Outputs of the model. |
|
""" |
|
if internal_inputs is None: |
|
internal_inputs = {} |
|
if isinstance(data, (list, tuple)): |
|
return model(*data, **internal_inputs) |
|
elif isinstance(data, dict): |
|
return model(**data, **internal_inputs) |
|
return model(data, **internal_inputs) |
|
|
|
|
|
def retain_grad(x: Any) -> None: |
|
"""Call retain_grad() on a tensor. |
|
|
|
Args: |
|
x (Any): Object to be called. |
|
""" |
|
if isinstance(x, torch.Tensor) and x.requires_grad: |
|
x.retain_grad() |
|
|
|
|
|
def detach(x: Any) -> Any: |
|
"""Call detach() on a tensor. |
|
|
|
Args: |
|
x (Any): Object to be called. |
|
|
|
Returns: |
|
Any: The detached object. |
|
""" |
|
if isinstance(x, torch.Tensor): |
|
return x.detach() |
|
return x |
|
|
|
|
|
def merge_batch(data: List[Any], batch_size_dim=0) -> Any: |
|
"""Merge micro batches into a batch. |
|
|
|
Args: |
|
data (List[Any]): A list of micro batches. |
|
|
|
Returns: |
|
Any: Merge batch. |
|
""" |
|
if len(data) == 0: |
|
return |
|
flattened_data = [] |
|
tree_spec = None |
|
for d in data: |
|
# elems should be an instance of OrderedDict |
|
elems, tree_spec = tree_flatten_hf(d) |
|
flattened_data.append(elems) |
|
merged_data = [] |
|
|
|
for elem_batch in zip(*flattened_data): |
|
if isinstance(elem_batch[0], torch.Tensor): |
|
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs |
|
merged_data.append(None) |
|
else: |
|
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) |
|
else: |
|
merged_data.append(list(elem_batch)) |
|
return tree_unflatten(merged_data, tree_spec)
|
|
|