[bug] shard param during initializing the ShardedModelV2 (#381)

pull/394/head
Jiarui Fang 2022-03-10 19:28:03 +08:00 committed by Frank Lee
parent 8c18eb0998
commit 272ebfb57d
2 changed files with 3 additions and 2 deletions

View File

@ -139,7 +139,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.convert_fp16:
param.data = param.data.to(torch.half)
if param.grad is not None:
param.grad = param.grad.to(torch.half).to(target_device)
param.grad = param.grad.to(torch.half)
# move torch parameters to the target device
param.data = param.data.to(target_device)

View File

@ -1,3 +1,4 @@
from ast import Try
import functools
from collections import OrderedDict
from typing import Any, Optional
@ -54,7 +55,7 @@ class ShardedModelV2(nn.Module):
# In case user didn't use ZeroInitContext
for param in self.module.parameters():
if not hasattr(param, 'col_attr'):
param.col_attr = ShardedParamV2(param, process_group)
param.col_attr = ShardedParamV2(param, process_group, rm_torch_payload=True)
if self.shard_param:
self.shard_strategy.shard([param.col_attr.data])