[gemini] fix tensor storage cleaning in state dict collection (#4396)

pull/4424/head^2
Baizhou Zhang 2023-08-10 15:36:46 +08:00 committed by GitHub
parent 458ae331ad
commit 6ccecc0c69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 0 additions and 6 deletions

View File

@ -1,6 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import gc
import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
@ -468,11 +467,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
shard_size)
# Clean gathered states
for state_shard in gathered_state_shards:
del state_shard[0]
gc.collect()
# Reshape tensors
if is_collector:
for state_name, state_tensor in collected_states.items():