mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix handle name; rm useless comments;
parent
5aee4261a6
commit
fafe049b83
|
@ -107,7 +107,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
self.local_send_backward_buffer = []
|
self.local_send_backward_buffer = []
|
||||||
|
|
||||||
# wait pp buffer
|
# wait pp buffer
|
||||||
self.send_handles = []
|
self.wait_handles = []
|
||||||
|
|
||||||
def assert_buffer_empty(self):
|
def assert_buffer_empty(self):
|
||||||
# assert buffer is empty at end
|
# assert buffer is empty at end
|
||||||
|
@ -129,7 +129,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
assert len(self.recv_backward_buffer[1]) == 0
|
assert len(self.recv_backward_buffer[1]) == 0
|
||||||
assert len(self.local_send_forward_buffer) == 0
|
assert len(self.local_send_forward_buffer) == 0
|
||||||
assert len(self.local_send_backward_buffer) == 0
|
assert len(self.local_send_backward_buffer) == 0
|
||||||
# assert len(self.send_handles) == 0
|
|
||||||
|
|
||||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||||
"""Load a batch from data iterator.
|
"""Load a batch from data iterator.
|
||||||
|
@ -891,7 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_map[scheduled_node.type]
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
wait_handle = communication_func(scheduled_node.chunk)
|
wait_handle = communication_func(scheduled_node.chunk)
|
||||||
self.send_handles.append(wait_handle)
|
self.wait_handles.append(wait_handle)
|
||||||
elif scheduled_node.type == "F":
|
elif scheduled_node.type == "F":
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
|
@ -915,7 +914,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
for h in self.send_handles:
|
for h in self.wait_handles:
|
||||||
for hh in h:
|
for hh in h:
|
||||||
hh.wait()
|
hh.wait()
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import queue
|
import queue
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
||||||
|
|
||||||
|
|
||||||
class WeightGradStore:
|
class WeightGradStore:
|
||||||
|
|
||||||
|
@ -32,52 +30,3 @@ class WeightGradStore:
|
||||||
weight.grad = grad_weight
|
weight.grad = grad_weight
|
||||||
else:
|
else:
|
||||||
raise Exception("Pop empty queue.")
|
raise Exception("Pop empty queue.")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def clear(cls, stage_manager: PipelineStageManager, chunk=0):
|
|
||||||
pass
|
|
||||||
# print(f"stage {stage_manager.stage} len_chunk_0 {cls.weight_grad_queue[0].qsize()} len_chunk_1 {cls.weight_grad_queue[1].qsize()}")
|
|
||||||
# while cls.weight_grad_queue[chunk].qsize() > 0:
|
|
||||||
# stored_grads = cls.weight_grad_queue[chunk].get()
|
|
||||||
# for total_input, grad_output, weight, func in stored_grads:
|
|
||||||
# if weight.grad is not None:
|
|
||||||
# func(total_input, grad_output, weight.grad)
|
|
||||||
# # for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
|
||||||
# else:
|
|
||||||
# grad_weight = func(total_input, grad_output)
|
|
||||||
# weight.grad = grad_weight
|
|
||||||
|
|
||||||
# weight_grad_tasks = []
|
|
||||||
# while cls.weight_grad_queue[chunk].qsize() > 0:
|
|
||||||
# stored_grads = cls.weight_grad_queue[chunk].get()
|
|
||||||
# if len(weight_grad_tasks) == 0:
|
|
||||||
# for _ in stored_grads:
|
|
||||||
# weight_grad_tasks.append([])
|
|
||||||
# else:
|
|
||||||
# assert len(weight_grad_tasks) == len(stored_grads)
|
|
||||||
# for i, task in enumerate(stored_grads):
|
|
||||||
# weight_grad_tasks[i].append(task)
|
|
||||||
|
|
||||||
# if stage_manager.is_last_stage(ignore_chunk=True) and chunk == 1:
|
|
||||||
# assert len(weight_grad_tasks) > 0
|
|
||||||
# output_layer_grads = weight_grad_tasks[0]
|
|
||||||
# for j in range(len(output_layer_grads)):
|
|
||||||
# total_input, grad_output, weight, func = output_layer_grads[j]
|
|
||||||
# if output_layer_weight is None:
|
|
||||||
# output_layer_weight = weight
|
|
||||||
# assert output_layer_weight is weight
|
|
||||||
# func(total_input, grad_output, weight.grad)
|
|
||||||
# output_layer_grads[j] = None # release memory
|
|
||||||
# weight_grad_tasks = weight_grad_tasks[1:]
|
|
||||||
|
|
||||||
# for i in range(len(weight_grad_tasks)):
|
|
||||||
# tasks = weight_grad_tasks[i]
|
|
||||||
# param = None
|
|
||||||
# for j in range(len(tasks)):
|
|
||||||
# total_input, grad_output, weight, func = tasks[j]
|
|
||||||
# if param is None:
|
|
||||||
# param = weight
|
|
||||||
# assert param is weight
|
|
||||||
# func(total_input, grad_output, weight.grad)
|
|
||||||
# tasks[j] = None # release memory
|
|
||||||
# weight_grad_tasks[i] = None # release memory
|
|
||||||
|
|
|
@ -60,10 +60,7 @@ class LlamaPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = RMSNorm
|
norm_cls = RMSNorm
|
||||||
|
|
||||||
if self.pipeline_stage_manager:
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
|
||||||
else:
|
|
||||||
use_zbv = False
|
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
|
@ -96,7 +93,6 @@ class LlamaPolicy(Policy):
|
||||||
target_key=attn_cls,
|
target_key=attn_cls,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if self.pipeline_stage_manager is not None:
|
|
||||||
if self.pipeline_stage_manager is None:
|
if self.pipeline_stage_manager is None:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
|
@ -410,20 +406,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
# if self.pipeline_stage_manager.use_zbv:
|
|
||||||
# return [
|
|
||||||
# {
|
|
||||||
# 0: llama_model.embed_tokens.weight,
|
|
||||||
# 0: self.model.lm_head.weight,
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# else:
|
|
||||||
# return [
|
|
||||||
# {
|
|
||||||
# 0: llama_model.embed_tokens.weight,
|
|
||||||
# self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -904,7 +904,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
torch_output.backward()
|
torch_output.backward()
|
||||||
torch_output_sum += torch_output.detach()
|
torch_output_sum += torch_output.detach()
|
||||||
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
|
||||||
# avg dp grads follows zero optimizer
|
# avg dp grads follows zero optimizer
|
||||||
for p in torch_model.parameters():
|
for p in torch_model.parameters():
|
||||||
if p.grad is not None:
|
if p.grad is not None:
|
||||||
|
@ -912,7 +911,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
|
@ -1064,7 +1062,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
torch_output.backward()
|
torch_output.backward()
|
||||||
torch_output_sum += torch_output.detach()
|
torch_output_sum += torch_output.detach()
|
||||||
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
|
||||||
# avg dp grads follows zero optimizer
|
# avg dp grads follows zero optimizer
|
||||||
for p in torch_model.parameters():
|
for p in torch_model.parameters():
|
||||||
if p.grad is not None:
|
if p.grad is not None:
|
||||||
|
@ -1072,7 +1069,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
|
|
Loading…
Reference in New Issue