Commit Graph

45 Commits (fee2af861078181e18a059800ba09a4edfb311f0)

Author SHA1 Message Date
YH 80aed29cd3
[zero] Refactor ZeroContextConfig class using dataclass (#3186) 2023-03-21 12:36:47 +08:00
HELSON f7f2248771
[moe] fix MoE bugs (#1628)
* remove forced FP32 modules

* correct no_shard-contexts' positions
2022-09-22 13:56:30 +08:00
ver217 a1a7899cae
[hotfix] fix zero init ctx numel (#1128) 2022-06-16 17:17:27 +08:00
Jiarui Fang 49832b2344
[refactory] add nn.parallel module (#1068) 2022-06-06 15:34:41 +08:00
ver217 7cfd6c827e
[zero] add load_state_dict for sharded model (#894)
* add load_state_dict for sharded model

* fix bug

* fix bug

* fix ckpt dtype and device

* support load state dict in zero init ctx

* fix bugs
2022-05-27 10:25:08 +08:00
ver217 0f7ed8c192
fix _post_init_method of zero init ctx (#847) 2022-04-24 14:16:50 +08:00
HELSON e5ea3fdeef
[gemini] add GeminiMemoryManger (#832)
* refactor StatefulTensor, tensor utilities

* add unitest for GeminiMemoryManager
2022-04-24 13:08:48 +08:00
Jiarui Fang eb1b89908c
[refactor] moving InsertPostInitMethodToModuleSubClasses to utils. (#824) 2022-04-21 16:03:18 +08:00
ver217 dd92b90a68
[DO NOT MERGE] [zero] init fp16 params directly in ZeroInitContext (#808)
* init fp16 param directly

* polish code
2022-04-19 16:16:48 +08:00
Jiarui Fang e761ad2cd7
Revert "[zero] add ZeroTensorShardStrategy (#793)" (#806) 2022-04-19 14:40:02 +08:00
HELSON 88759e289e
[zero] add ZeroTensorShardStrategy (#793) 2022-04-19 14:32:45 +08:00
HELSON 22c4b88d56
[zero] refactor ShardedParamV2 for convenience (#742) 2022-04-13 14:54:26 +08:00
HELSON a9b8300d54
[zero] improve adaptability for not-shard parameters (#708)
* adapt post grad hooks for not-shard parameters
* adapt optimizer for not-shard parameters
* offload gradients for not-replicated parameters
2022-04-11 13:38:51 +08:00
HELSON ee112fe1da
[zero] adapt zero hooks for unsharded module (#699) 2022-04-08 20:23:26 +08:00
HELSON d7ecaf362b
[zero] fix init bugs in zero context (#686)
* adapt model weight initialization for methods in Pytorch nn.init
2022-04-07 17:38:45 +08:00
Jiarui Fang 036404ca8a
Revert "[zero] polish init context (#645)" (#657) 2022-04-02 18:30:06 +08:00
Jiarui Fang 67b4928244
[zero] polish init context (#645) 2022-04-02 15:52:04 +08:00
HELSON 055fbf5be6
[zero] adapt zero for unsharded paramters (Optimizer part) (#601) 2022-04-01 20:10:47 +08:00
HELSON e6d50ec107
[zero] adapt zero for unsharded parameters (#561)
* support existing sharded and unsharded parameters in zero

* add unitest for moe-zero model init

* polish moe gradient handler
2022-03-31 18:34:11 +08:00
Jiarui Fang 7675366fce
[polish] rename col_attr -> colo_attr (#558) 2022-03-31 12:25:45 +08:00
HELSON 8c90d4df54
[zero] add zero context manager to change config during initialization (#546) 2022-03-29 17:57:59 +08:00
ver217 1f90a3b129
[zero] polish ZeroInitContext (#540) 2022-03-29 09:09:04 +08:00
HELSON a30e2b4c24
[zero] adapt for no-leaf module in zero (#535)
only process module's own parameters in Zero context

add zero hooks for all modules that contrain parameters

gather parameters only belonging to module itself
2022-03-28 17:42:18 +08:00
Jiarui Fang 705f56107c
[zero] refactor model data tracing (#537) 2022-03-28 16:38:18 +08:00
Jiarui Fang 8d8c5407c0
[zero] refactor model data tracing (#522) 2022-03-25 18:03:32 +08:00
Jiarui Fang 4d322b79da
[refactor] remove old zero code (#517) 2022-03-25 14:54:39 +08:00
Jiarui Fang 920c5889a7
[zero] add colo move inline (#521) 2022-03-25 14:02:55 +08:00
Jiarui Fang 0bebda6ea5
[zero] fix init device bug in zero init context unittest (#516) 2022-03-25 12:24:18 +08:00
Jiarui Fang 7ef3507ace
[zero] show model data cuda memory usage after zero context init. (#515) 2022-03-25 11:23:35 +08:00
ver217 a2e61d61d4
[zero] zero init ctx enable rm_torch_payload_on_the_fly (#512)
* enable rm_torch_payload_on_the_fly

* polish docstr
2022-03-24 23:44:00 +08:00
Jiarui Fang b334822163
[zero] polish sharded param name (#484)
* [zero] polish sharded param name

* polish code

* polish

* polish code

* polish

* polsih

* polish
2022-03-22 14:36:16 +08:00
ver217 3cb3fc275e
zero init ctx receives a dp process group (#471) 2022-03-21 11:18:55 +08:00
ver217 642846d6f9
update sharded optim and fix zero init ctx (#457) 2022-03-18 15:44:47 +08:00
Jiarui Fang e2e9f82588
Revert "[zero] update sharded optim and fix zero init ctx" (#456)
* Revert "polish code"

This reverts commit 8cf7ff08cf.

* Revert "rename variables"

This reverts commit e99af94ab8.

* Revert "remove surplus imports"

This reverts commit 46add4a5c5.

* Revert "update sharded optim and fix zero init ctx"

This reverts commit 57567ee768.
2022-03-18 15:22:43 +08:00
ver217 57567ee768 update sharded optim and fix zero init ctx 2022-03-18 14:25:25 +08:00
ver217 9506a8beb2 use double buffer to handle grad 2022-03-16 14:24:09 +08:00
Jiarui Fang 56bb412e72
[polish] use GLOBAL_MODEL_DATA_TRACER (#417) 2022-03-15 11:29:46 +08:00
Jiarui Fang 21dc54e019
[zero] memtracer to record cuda memory usage of model data and overall system (#395) 2022-03-14 22:05:30 +08:00
Jiarui Fang 272ebfb57d [bug] shard param during initializing the ShardedModelV2 (#381) 2022-03-11 15:50:28 +08:00
Jiarui Fang 6b6002962a [zero] zero init context collect numel of model (#375) 2022-03-11 15:50:28 +08:00
Jiarui Fang 44e4891f57 [zero] able to place params on cpu after zero init context (#365)
* place params on cpu after zero init context

* polish code
2022-03-11 15:50:28 +08:00
Jiarui Fang ea2872073f [zero] global model data memory tracer (#360) 2022-03-11 15:50:28 +08:00
ver217 1388671699 [zero] Update sharded model v2 using sharded param v2 (#323) 2022-03-11 15:50:28 +08:00
Jiarui Fang 11bddb6e55 [zero] update zero context init with the updated test utils (#327) 2022-03-11 15:50:28 +08:00
Jiarui Fang de0468c7a8 [zero] zero init context (#321)
* add zero init context

* add more flags for zero init context
fix bug of repeated converting param to ShardedParamV2

* polish code
2022-03-11 15:50:28 +08:00