[pipeline]add pipeline policy and bert forward (#4130)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt
pull/4445/head
Jianghai 1 year ago committed by Hongxin Liu
parent 5c897ddb94
commit c552cefa93

@ -21,7 +21,7 @@ def check_stage_manager():
1: [0, 1], 1: [0, 1],
2: [2, 3], 2: [2, 3],
3: [2, 3], 3: [2, 3],
} }
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank() rank = dist.get_rank()

Loading…
Cancel
Save