mirror of https://github.com/hpcaitech/ColossalAI
[bug] shard param during initializing the ShardedModelV2 (#381)
parent
8c18eb0998
commit
272ebfb57d
|
@ -139,7 +139,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
if self.convert_fp16:
|
||||
param.data = param.data.to(torch.half)
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(torch.half).to(target_device)
|
||||
param.grad = param.grad.to(torch.half)
|
||||
|
||||
# move torch parameters to the target device
|
||||
param.data = param.data.to(target_device)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from ast import Try
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
@ -54,7 +55,7 @@ class ShardedModelV2(nn.Module):
|
|||
# In case user didn't use ZeroInitContext
|
||||
for param in self.module.parameters():
|
||||
if not hasattr(param, 'col_attr'):
|
||||
param.col_attr = ShardedParamV2(param, process_group)
|
||||
param.col_attr = ShardedParamV2(param, process_group, rm_torch_payload=True)
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.data])
|
||||
|
||||
|
|
Loading…
Reference in New Issue