2023-09-04 09:52:23 +00:00
|
|
|
from collections import OrderedDict
|
|
|
|
from typing import Any, List, Optional, Tuple
|
2023-06-29 05:35:39 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.cuda
|
|
|
|
from torch.nn import Module
|
2023-09-04 09:52:23 +00:00
|
|
|
from torch.utils._pytree import (
|
|
|
|
SUPPORTED_NODES,
|
|
|
|
LeafSpec,
|
|
|
|
TreeSpec,
|
|
|
|
_is_leaf,
|
|
|
|
_register_pytree_node,
|
|
|
|
tree_flatten,
|
|
|
|
tree_map,
|
|
|
|
tree_unflatten,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# this register are for torch under version 1.13.1, maybe removed in the future
|
|
|
|
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
|
|
|
|
return list(d.values()), list(d.keys())
|
|
|
|
|
|
|
|
|
|
|
|
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
|
|
|
|
return OrderedDict((key, value) for key, value in zip(context, values))
|
|
|
|
|
|
|
|
|
|
|
|
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
|
|
|
|
|
|
|
|
|
|
|
def tree_map_hf(fn: Any, pytree: Any):
|
|
|
|
flat_args, spec = tree_flatten_hf(pytree)
|
|
|
|
return tree_unflatten([fn(i) for i in flat_args], spec)
|
|
|
|
|
|
|
|
|
|
|
|
# use this flatten function to handle the ModelingOutput Class instance.
|
|
|
|
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
|
|
|
|
"""Flattens a pytree into a list of values an a TreeSpec that can be used
|
|
|
|
to reconstruct the pytree.
|
|
|
|
"""
|
|
|
|
if isinstance(pytree, OrderedDict):
|
|
|
|
node_type = OrderedDict
|
|
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
|
|
|
child_pytrees, context = flatten_fn(pytree)
|
|
|
|
|
|
|
|
# Recursively flatten the children
|
|
|
|
result: List[Any] = []
|
|
|
|
children_specs: List['TreeSpec'] = []
|
|
|
|
for child in child_pytrees:
|
|
|
|
flat, child_spec = tree_flatten_hf(child)
|
|
|
|
result += flat
|
|
|
|
children_specs.append(child_spec)
|
|
|
|
return result, TreeSpec(node_type, context, children_specs)
|
|
|
|
else:
|
|
|
|
result, tree_spec = tree_flatten(pytree)
|
|
|
|
return result, tree_spec
|
2023-06-29 05:35:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
|
|
|
|
"""Move object to device if it is a tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Any): Object to be moved.
|
|
|
|
device (Optional[torch.device], optional): Target device. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: Moved object.
|
|
|
|
"""
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
return x.to(device)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def get_batch_size(batch: Any) -> int:
|
|
|
|
"""Get the batch size (size of dimension-0) of the first tensor in the batch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch (Any): Batch to be inspected.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
RuntimeError: If no tensor is found in the batch.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: Batch size.
|
|
|
|
"""
|
|
|
|
data_list, _ = tree_flatten(batch)
|
|
|
|
for data in data_list:
|
|
|
|
if isinstance(data, torch.Tensor):
|
|
|
|
return data.size(0)
|
|
|
|
raise RuntimeError('No tensor found in the batch')
|
|
|
|
|
|
|
|
|
|
|
|
def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any:
|
|
|
|
"""Get a micro batch of the original batch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch (Any): Batch to be sliced.
|
|
|
|
start (int): Start index of the micro batch.
|
|
|
|
micro_batch_size (int): Size of the micro batch.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: Target micro batch.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _get_tensor_slice(x: Any):
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
return x[start:start + micro_batch_size]
|
|
|
|
return x
|
|
|
|
|
|
|
|
return tree_map(_get_tensor_slice, batch)
|
|
|
|
|
|
|
|
|
|
|
|
def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any:
|
|
|
|
"""Call model forward function with data and internal inputs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (Module): Model to be called.
|
|
|
|
data (Any): Data loaded from data iterator.
|
|
|
|
internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: Outputs of the model.
|
|
|
|
"""
|
|
|
|
if internal_inputs is None:
|
|
|
|
internal_inputs = {}
|
|
|
|
if isinstance(data, (list, tuple)):
|
|
|
|
return model(*data, **internal_inputs)
|
|
|
|
elif isinstance(data, dict):
|
|
|
|
return model(**data, **internal_inputs)
|
|
|
|
return model(data, **internal_inputs)
|
|
|
|
|
|
|
|
|
|
|
|
def retain_grad(x: Any) -> None:
|
|
|
|
"""Call retain_grad() on a tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Any): Object to be called.
|
|
|
|
"""
|
2023-08-08 09:46:44 +00:00
|
|
|
if isinstance(x, torch.Tensor) and x.requires_grad:
|
2023-06-29 05:35:39 +00:00
|
|
|
x.retain_grad()
|
|
|
|
|
|
|
|
|
|
|
|
def detach(x: Any) -> Any:
|
|
|
|
"""Call detach() on a tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Any): Object to be called.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: The detached object.
|
|
|
|
"""
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
return x.detach()
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2023-09-04 09:52:23 +00:00
|
|
|
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
2023-06-29 05:35:39 +00:00
|
|
|
"""Merge micro batches into a batch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
data (List[Any]): A list of micro batches.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: Merge batch.
|
|
|
|
"""
|
|
|
|
if len(data) == 0:
|
|
|
|
return
|
|
|
|
flattened_data = []
|
|
|
|
tree_spec = None
|
|
|
|
for d in data:
|
2023-09-04 09:52:23 +00:00
|
|
|
# elems should be an instance of OrderedDict
|
|
|
|
elems, tree_spec = tree_flatten_hf(d)
|
2023-06-29 05:35:39 +00:00
|
|
|
flattened_data.append(elems)
|
|
|
|
merged_data = []
|
2023-09-04 09:52:23 +00:00
|
|
|
|
2023-06-29 05:35:39 +00:00
|
|
|
for elem_batch in zip(*flattened_data):
|
|
|
|
if isinstance(elem_batch[0], torch.Tensor):
|
2023-08-29 03:25:05 +00:00
|
|
|
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
|
|
|
|
merged_data.append(None)
|
|
|
|
else:
|
2023-09-04 09:52:23 +00:00
|
|
|
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
|
2023-06-29 05:35:39 +00:00
|
|
|
else:
|
|
|
|
merged_data.append(list(elem_batch))
|
|
|
|
return tree_unflatten(merged_data, tree_spec)
|