ver217
a1a7899cae
[hotfix] fix zero init ctx numel ( #1128 )
2 years ago
Jiarui Fang
49832b2344
[refactory] add nn.parallel module ( #1068 )
3 years ago
ver217
7cfd6c827e
[zero] add load_state_dict for sharded model ( #894 )
...
* add load_state_dict for sharded model
* fix bug
* fix bug
* fix ckpt dtype and device
* support load state dict in zero init ctx
* fix bugs
3 years ago
ver217
0f7ed8c192
fix _post_init_method of zero init ctx ( #847 )
3 years ago
HELSON
e5ea3fdeef
[gemini] add GeminiMemoryManger ( #832 )
...
* refactor StatefulTensor, tensor utilities
* add unitest for GeminiMemoryManager
3 years ago
Jiarui Fang
eb1b89908c
[refactor] moving InsertPostInitMethodToModuleSubClasses to utils. ( #824 )
3 years ago
ver217
dd92b90a68
[DO NOT MERGE] [zero] init fp16 params directly in ZeroInitContext ( #808 )
...
* init fp16 param directly
* polish code
3 years ago
Jiarui Fang
e761ad2cd7
Revert "[zero] add ZeroTensorShardStrategy ( #793 )" ( #806 )
3 years ago
HELSON
88759e289e
[zero] add ZeroTensorShardStrategy ( #793 )
3 years ago
HELSON
22c4b88d56
[zero] refactor ShardedParamV2 for convenience ( #742 )
3 years ago
HELSON
a9b8300d54
[zero] improve adaptability for not-shard parameters ( #708 )
...
* adapt post grad hooks for not-shard parameters
* adapt optimizer for not-shard parameters
* offload gradients for not-replicated parameters
3 years ago
HELSON
ee112fe1da
[zero] adapt zero hooks for unsharded module ( #699 )
3 years ago
HELSON
d7ecaf362b
[zero] fix init bugs in zero context ( #686 )
...
* adapt model weight initialization for methods in Pytorch nn.init
3 years ago
Jiarui Fang
036404ca8a
Revert "[zero] polish init context ( #645 )" ( #657 )
3 years ago
Jiarui Fang
67b4928244
[zero] polish init context ( #645 )
3 years ago
HELSON
055fbf5be6
[zero] adapt zero for unsharded paramters (Optimizer part) ( #601 )
3 years ago
HELSON
e6d50ec107
[zero] adapt zero for unsharded parameters ( #561 )
...
* support existing sharded and unsharded parameters in zero
* add unitest for moe-zero model init
* polish moe gradient handler
3 years ago
Jiarui Fang
7675366fce
[polish] rename col_attr -> colo_attr ( #558 )
3 years ago
HELSON
8c90d4df54
[zero] add zero context manager to change config during initialization ( #546 )
3 years ago
ver217
1f90a3b129
[zero] polish ZeroInitContext ( #540 )
3 years ago
HELSON
a30e2b4c24
[zero] adapt for no-leaf module in zero ( #535 )
...
only process module's own parameters in Zero context
add zero hooks for all modules that contrain parameters
gather parameters only belonging to module itself
3 years ago
Jiarui Fang
705f56107c
[zero] refactor model data tracing ( #537 )
3 years ago
Jiarui Fang
8d8c5407c0
[zero] refactor model data tracing ( #522 )
3 years ago
Jiarui Fang
4d322b79da
[refactor] remove old zero code ( #517 )
3 years ago
Jiarui Fang
920c5889a7
[zero] add colo move inline ( #521 )
3 years ago
Jiarui Fang
0bebda6ea5
[zero] fix init device bug in zero init context unittest ( #516 )
3 years ago
Jiarui Fang
7ef3507ace
[zero] show model data cuda memory usage after zero context init. ( #515 )
3 years ago
ver217
a2e61d61d4
[zero] zero init ctx enable rm_torch_payload_on_the_fly ( #512 )
...
* enable rm_torch_payload_on_the_fly
* polish docstr
3 years ago
Jiarui Fang
b334822163
[zero] polish sharded param name ( #484 )
...
* [zero] polish sharded param name
* polish code
* polish
* polish code
* polish
* polsih
* polish
3 years ago
ver217
3cb3fc275e
zero init ctx receives a dp process group ( #471 )
3 years ago
ver217
642846d6f9
update sharded optim and fix zero init ctx ( #457 )
3 years ago
Jiarui Fang
e2e9f82588
Revert "[zero] update sharded optim and fix zero init ctx" ( #456 )
...
* Revert "polish code"
This reverts commit 8cf7ff08cf
.
* Revert "rename variables"
This reverts commit e99af94ab8
.
* Revert "remove surplus imports"
This reverts commit 46add4a5c5
.
* Revert "update sharded optim and fix zero init ctx"
This reverts commit 57567ee768
.
3 years ago
ver217
57567ee768
update sharded optim and fix zero init ctx
3 years ago
ver217
9506a8beb2
use double buffer to handle grad
3 years ago
Jiarui Fang
56bb412e72
[polish] use GLOBAL_MODEL_DATA_TRACER ( #417 )
3 years ago
Jiarui Fang
21dc54e019
[zero] memtracer to record cuda memory usage of model data and overall system ( #395 )
3 years ago
Jiarui Fang
272ebfb57d
[bug] shard param during initializing the ShardedModelV2 ( #381 )
3 years ago
Jiarui Fang
6b6002962a
[zero] zero init context collect numel of model ( #375 )
3 years ago
Jiarui Fang
44e4891f57
[zero] able to place params on cpu after zero init context ( #365 )
...
* place params on cpu after zero init context
* polish code
3 years ago
Jiarui Fang
ea2872073f
[zero] global model data memory tracer ( #360 )
3 years ago
ver217
1388671699
[zero] Update sharded model v2 using sharded param v2 ( #323 )
3 years ago
Jiarui Fang
11bddb6e55
[zero] update zero context init with the updated test utils ( #327 )
3 years ago
Jiarui Fang
de0468c7a8
[zero] zero init context ( #321 )
...
* add zero init context
* add more flags for zero init context
fix bug of repeated converting param to ShardedParamV2
* polish code
3 years ago