[fix] fix handle name; rm useless comments;

pull/6083/head
duanjunwen 2024-10-29 03:24:15 +00:00
parent 5aee4261a6
commit fafe049b83
4 changed files with 4 additions and 78 deletions

View File

@ -107,7 +107,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.local_send_backward_buffer = []
# wait pp buffer
self.send_handles = []
self.wait_handles = []
def assert_buffer_empty(self):
# assert buffer is empty at end
@ -129,7 +129,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
assert len(self.recv_backward_buffer[1]) == 0
assert len(self.local_send_forward_buffer) == 0
assert len(self.local_send_backward_buffer) == 0
# assert len(self.send_handles) == 0
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@ -891,7 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# communication
communication_func = self.communication_map[scheduled_node.type]
wait_handle = communication_func(scheduled_node.chunk)
self.send_handles.append(wait_handle)
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
@ -915,7 +914,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
for h in self.send_handles:
for h in self.wait_handles:
for hh in h:
hh.wait()

View File

@ -1,7 +1,5 @@
import queue
from colossalai.pipeline.stage_manager import PipelineStageManager
class WeightGradStore:
@ -32,52 +30,3 @@ class WeightGradStore:
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")
@classmethod
def clear(cls, stage_manager: PipelineStageManager, chunk=0):
pass
# print(f"stage {stage_manager.stage} len_chunk_0 {cls.weight_grad_queue[0].qsize()} len_chunk_1 {cls.weight_grad_queue[1].qsize()}")
# while cls.weight_grad_queue[chunk].qsize() > 0:
# stored_grads = cls.weight_grad_queue[chunk].get()
# for total_input, grad_output, weight, func in stored_grads:
# if weight.grad is not None:
# func(total_input, grad_output, weight.grad)
# # for first bwd; weight.grad is None, assign grad_weight to weight.grad
# else:
# grad_weight = func(total_input, grad_output)
# weight.grad = grad_weight
# weight_grad_tasks = []
# while cls.weight_grad_queue[chunk].qsize() > 0:
# stored_grads = cls.weight_grad_queue[chunk].get()
# if len(weight_grad_tasks) == 0:
# for _ in stored_grads:
# weight_grad_tasks.append([])
# else:
# assert len(weight_grad_tasks) == len(stored_grads)
# for i, task in enumerate(stored_grads):
# weight_grad_tasks[i].append(task)
# if stage_manager.is_last_stage(ignore_chunk=True) and chunk == 1:
# assert len(weight_grad_tasks) > 0
# output_layer_grads = weight_grad_tasks[0]
# for j in range(len(output_layer_grads)):
# total_input, grad_output, weight, func = output_layer_grads[j]
# if output_layer_weight is None:
# output_layer_weight = weight
# assert output_layer_weight is weight
# func(total_input, grad_output, weight.grad)
# output_layer_grads[j] = None # release memory
# weight_grad_tasks = weight_grad_tasks[1:]
# for i in range(len(weight_grad_tasks)):
# tasks = weight_grad_tasks[i]
# param = None
# for j in range(len(tasks)):
# total_input, grad_output, weight, func = tasks[j]
# if param is None:
# param = weight
# assert param is weight
# func(total_input, grad_output, weight.grad)
# tasks[j] = None # release memory
# weight_grad_tasks[i] = None # release memory

View File

@ -60,10 +60,7 @@ class LlamaPolicy(Policy):
else:
norm_cls = RMSNorm
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
@ -96,7 +93,6 @@ class LlamaPolicy(Policy):
target_key=attn_cls,
)
# if self.pipeline_stage_manager is not None:
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
@ -410,20 +406,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
# if self.pipeline_stage_manager.use_zbv:
# return [
# {
# 0: llama_model.embed_tokens.weight,
# 0: self.model.lm_head.weight,
# }
# ]
# else:
# return [
# {
# 0: llama_model.embed_tokens.weight,
# self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
# }
# ]
return []

View File

@ -904,7 +904,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
@ -912,7 +911,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
torch_optimizer.step()
torch_optimizer.zero_grad()
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter()
@ -1064,7 +1062,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
@ -1072,7 +1069,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
torch_optimizer.step()
torch_optimizer.zero_grad()
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
clear_layout_converter()