Merge pull request #4856 from KKZ20/test/model_support_for_low_level_zero

[test] remove the redundant code of model output transformation in torchrec
pull/4887/head
ppt0011 1 year ago committed by GitHub
commit ad23460cf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,12 +44,6 @@ We've tested compatibility on some famous models, following models may not be su
- `timm.models.convit_base` - `timm.models.convit_base`
- dlrm and deepfm models in `torchrec` - dlrm and deepfm models in `torchrec`
- `diffusers.VQModel`
- `transformers.AlbertModel`
- `transformers.AlbertForPreTraining`
- `transformers.BertModel`
- `transformers.BertForPreTraining`
- `transformers.GPT2DoubleHeadsModel`
Compatibility problems will be fixed in the future. Compatibility problems will be fixed in the future.

@ -42,12 +42,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
- `timm.models.convit_base` - `timm.models.convit_base`
- dlrm and deepfm models in `torchrec` - dlrm and deepfm models in `torchrec`
- `diffusers.VQModel`
- `transformers.AlbertModel`
- `transformers.AlbertForPreTraining`
- `transformers.BertModel`
- `transformers.BertForPreTraining`
- `transformers.GPT2DoubleHeadsModel`
兼容性问题将在未来修复。 兼容性问题将在未来修复。

@ -53,16 +53,6 @@ def output_transform_fn(x):
return dict(output=x) return dict(output=x)
def output_transform_fn(x):
if isinstance(x, KeyedTensor):
output = dict()
for key in x.keys():
output[key] = x[key]
return output
else:
return dict(output=x)
def get_ebc(): def get_ebc():
# EmbeddingBagCollection # EmbeddingBagCollection
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])

Loading…
Cancel
Save