|
|
@ -1,5 +1,6 @@
|
|
|
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
|
|
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
|
|
|
import copy
|
|
|
|
import copy
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
from contextlib import contextmanager
|
|
|
|
from contextlib import contextmanager
|
|
|
|
from functools import partial
|
|
|
|
from functools import partial
|
|
|
|
from typing import Dict, Iterator, List, Optional, Tuple
|
|
|
|
from typing import Dict, Iterator, List, Optional, Tuple
|
|
|
@ -87,6 +88,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
self._partition_grads = partition_grad
|
|
|
|
self._partition_grads = partition_grad
|
|
|
|
|
|
|
|
|
|
|
|
self._cpu_offload = cpu_offload
|
|
|
|
self._cpu_offload = cpu_offload
|
|
|
|
|
|
|
|
if cpu_offload:
|
|
|
|
|
|
|
|
warnings.warn(
|
|
|
|
|
|
|
|
"Please enlarge OMP_NUM_THREADS to speed up the CPU computation. "
|
|
|
|
|
|
|
|
"Use default OMP_NUM_THREADS=1 would significantly slow down the training."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# grad accumulation
|
|
|
|
# grad accumulation
|
|
|
|
self.require_grad_sync = True
|
|
|
|
self.require_grad_sync = True
|
|
|
|