mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix handle name; rm useless comments;
parent
5aee4261a6
commit
fafe049b83
|
@ -107,7 +107,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
self.local_send_backward_buffer = []
|
||||
|
||||
# wait pp buffer
|
||||
self.send_handles = []
|
||||
self.wait_handles = []
|
||||
|
||||
def assert_buffer_empty(self):
|
||||
# assert buffer is empty at end
|
||||
|
@ -129,7 +129,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
assert len(self.recv_backward_buffer[1]) == 0
|
||||
assert len(self.local_send_forward_buffer) == 0
|
||||
assert len(self.local_send_backward_buffer) == 0
|
||||
# assert len(self.send_handles) == 0
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
@ -891,7 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
wait_handle = communication_func(scheduled_node.chunk)
|
||||
self.send_handles.append(wait_handle)
|
||||
self.wait_handles.append(wait_handle)
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
|
@ -915,7 +914,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
for h in self.send_handles:
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import queue
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class WeightGradStore:
|
||||
|
||||
|
@ -32,52 +30,3 @@ class WeightGradStore:
|
|||
weight.grad = grad_weight
|
||||
else:
|
||||
raise Exception("Pop empty queue.")
|
||||
|
||||
@classmethod
|
||||
def clear(cls, stage_manager: PipelineStageManager, chunk=0):
|
||||
pass
|
||||
# print(f"stage {stage_manager.stage} len_chunk_0 {cls.weight_grad_queue[0].qsize()} len_chunk_1 {cls.weight_grad_queue[1].qsize()}")
|
||||
# while cls.weight_grad_queue[chunk].qsize() > 0:
|
||||
# stored_grads = cls.weight_grad_queue[chunk].get()
|
||||
# for total_input, grad_output, weight, func in stored_grads:
|
||||
# if weight.grad is not None:
|
||||
# func(total_input, grad_output, weight.grad)
|
||||
# # for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
||||
# else:
|
||||
# grad_weight = func(total_input, grad_output)
|
||||
# weight.grad = grad_weight
|
||||
|
||||
# weight_grad_tasks = []
|
||||
# while cls.weight_grad_queue[chunk].qsize() > 0:
|
||||
# stored_grads = cls.weight_grad_queue[chunk].get()
|
||||
# if len(weight_grad_tasks) == 0:
|
||||
# for _ in stored_grads:
|
||||
# weight_grad_tasks.append([])
|
||||
# else:
|
||||
# assert len(weight_grad_tasks) == len(stored_grads)
|
||||
# for i, task in enumerate(stored_grads):
|
||||
# weight_grad_tasks[i].append(task)
|
||||
|
||||
# if stage_manager.is_last_stage(ignore_chunk=True) and chunk == 1:
|
||||
# assert len(weight_grad_tasks) > 0
|
||||
# output_layer_grads = weight_grad_tasks[0]
|
||||
# for j in range(len(output_layer_grads)):
|
||||
# total_input, grad_output, weight, func = output_layer_grads[j]
|
||||
# if output_layer_weight is None:
|
||||
# output_layer_weight = weight
|
||||
# assert output_layer_weight is weight
|
||||
# func(total_input, grad_output, weight.grad)
|
||||
# output_layer_grads[j] = None # release memory
|
||||
# weight_grad_tasks = weight_grad_tasks[1:]
|
||||
|
||||
# for i in range(len(weight_grad_tasks)):
|
||||
# tasks = weight_grad_tasks[i]
|
||||
# param = None
|
||||
# for j in range(len(tasks)):
|
||||
# total_input, grad_output, weight, func = tasks[j]
|
||||
# if param is None:
|
||||
# param = weight
|
||||
# assert param is weight
|
||||
# func(total_input, grad_output, weight.grad)
|
||||
# tasks[j] = None # release memory
|
||||
# weight_grad_tasks[i] = None # release memory
|
||||
|
|
|
@ -60,10 +60,7 @@ class LlamaPolicy(Policy):
|
|||
else:
|
||||
norm_cls = RMSNorm
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
|
@ -96,7 +93,6 @@ class LlamaPolicy(Policy):
|
|||
target_key=attn_cls,
|
||||
)
|
||||
|
||||
# if self.pipeline_stage_manager is not None:
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
|
@ -410,20 +406,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
# if self.pipeline_stage_manager.use_zbv:
|
||||
# return [
|
||||
# {
|
||||
# 0: llama_model.embed_tokens.weight,
|
||||
# 0: self.model.lm_head.weight,
|
||||
# }
|
||||
# ]
|
||||
# else:
|
||||
# return [
|
||||
# {
|
||||
# 0: llama_model.embed_tokens.weight,
|
||||
# self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
# }
|
||||
# ]
|
||||
return []
|
||||
|
||||
|
||||
|
|
|
@ -904,7 +904,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
torch_output.backward()
|
||||
torch_output_sum += torch_output.detach()
|
||||
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||
# avg dp grads follows zero optimizer
|
||||
for p in torch_model.parameters():
|
||||
if p.grad is not None:
|
||||
|
@ -912,7 +911,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
||||
clear_layout_converter()
|
||||
|
@ -1064,7 +1062,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
torch_output.backward()
|
||||
torch_output_sum += torch_output.detach()
|
||||
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||
# avg dp grads follows zero optimizer
|
||||
for p in torch_model.parameters():
|
||||
if p.grad is not None:
|
||||
|
@ -1072,7 +1069,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
||||
clear_layout_converter()
|
||||
|
|
Loading…
Reference in New Issue