mirror of https://github.com/hpcaitech/ColossalAI
107 lines
4.5 KiB
Python
107 lines
4.5 KiB
Python
import queue
|
|
|
|
# from megatron import get_args
|
|
# from megatron.core import parallel_state
|
|
# from megatron.core.distributed.finalize_model_grads import _allreduce_embedding_grads
|
|
# from megatron.core.utils import get_model_config, get_attr_wrapped_model
|
|
|
|
|
|
class WeightGradStore:
|
|
|
|
cache = []
|
|
weight_grad_queue = [queue.Queue(), queue.Queue()]
|
|
|
|
@classmethod
|
|
def put(cls, total_input, grad_output, weight, func):
|
|
# func(total_input, grad_output, weight.main_grad)
|
|
cls.cache.append((total_input, grad_output, weight, func))
|
|
|
|
@classmethod
|
|
def flush(cls, chunk=0):
|
|
cls.weight_grad_queue[chunk].put(cls.cache)
|
|
cls.cache = []
|
|
|
|
@classmethod
|
|
def pop(cls, chunk=0):
|
|
if cls.weight_grad_queue[chunk].qsize() > 0:
|
|
stored_grads = cls.weight_grad_queue[chunk].get()
|
|
for total_input, grad_output, weight, func in stored_grads:
|
|
if weight.grad is not None:
|
|
func(total_input, grad_output, weight.grad)
|
|
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
|
else:
|
|
grad_weight = func(total_input, grad_output)
|
|
weight.grad = grad_weight
|
|
else:
|
|
raise Exception("Pop empty queue.")
|
|
|
|
# @classmethod
|
|
# def clear(cls, model, chunk=0):
|
|
# weight_grad_tasks = []
|
|
# while cls.weight_grad_queue[chunk].qsize() > 0:
|
|
# stored_grads = cls.weight_grad_queue[chunk].get()
|
|
# if len(weight_grad_tasks) == 0:
|
|
# for _ in stored_grads:
|
|
# weight_grad_tasks.append([])
|
|
# else:
|
|
# assert len(weight_grad_tasks) == len(stored_grads)
|
|
# for i, task in enumerate(stored_grads):
|
|
# weight_grad_tasks[i].append(task)
|
|
# weight_params = []
|
|
# handles = []
|
|
# if get_args().overlap_grad_reduce:
|
|
# handles += model.async_reduce_grad()
|
|
|
|
# output_layer_weight = None
|
|
# if parallel_state.is_pipeline_last_stage():
|
|
# assert len(weight_grad_tasks) > 0
|
|
# output_layer_grads = weight_grad_tasks[0]
|
|
# for j in range(len(output_layer_grads)):
|
|
# total_input, grad_output, weight, func = output_layer_grads[j]
|
|
# if output_layer_weight is None:
|
|
# output_layer_weight = weight
|
|
# assert output_layer_weight is weight
|
|
# func(total_input, grad_output, weight.main_grad)
|
|
# output_layer_grads[j] = None # release memory
|
|
# weight_grad_tasks = weight_grad_tasks[1:]
|
|
# if get_args().overlap_grad_reduce:
|
|
# handles += model.async_reduce_grad(output_layer_weight)
|
|
|
|
# if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
|
|
# model_module = get_attr_wrapped_model(model, 'pre_process', return_model_obj=True)
|
|
# if model_module.share_embeddings_and_output_weights:
|
|
# # if share_embeddings_and_output_weights, wait all-reduce for embeddings
|
|
# for handle in handles:
|
|
# if handle is not None:
|
|
# handle.wait()
|
|
# handles = []
|
|
|
|
# config = get_model_config(model)
|
|
# # Do async all-reduce for embedding grads firstly, so that the rank 0 won't
|
|
# # be blocked
|
|
# embedding_handles = _allreduce_embedding_grads([model], config, async_op=True)
|
|
# handles += embedding_handles
|
|
|
|
# for i in range(len(weight_grad_tasks)):
|
|
# tasks = weight_grad_tasks[i]
|
|
# param = None
|
|
# for j in range(len(tasks)):
|
|
# total_input, grad_output, weight, func = tasks[j]
|
|
# if param is None:
|
|
# param = weight
|
|
# assert param is weight
|
|
# assert not (weight is output_layer_weight)
|
|
# func(total_input, grad_output, weight.main_grad)
|
|
# tasks[j] = None # release memory
|
|
# weight_params.append(param)
|
|
# if get_args().overlap_grad_reduce:
|
|
# # All-reduce param grad here
|
|
# handles += model.async_reduce_grad(param)
|
|
# weight_grad_tasks[i] = None # release memory
|
|
|
|
# # timers('wait_all_reduce', log_level=1).start(barrier=False)
|
|
# for handle in embedding_handles:
|
|
# if handle is not None:
|
|
# handle.wait()
|
|
# # timers('wait_all_reduce').stop()
|