import torch import torch.nn as nn from torch.fx import GraphModule from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.testing import clear_cache_before_run class Conv1D(nn.Module): def __init__(self, nf, nx): super().__init__() self.nf = nf w = torch.empty(nx, nf) nn.init.normal_(w, std=0.02) self.weight = nn.Parameter(w) self.bias = nn.Parameter(torch.zeros(nf)) def forward(self, x): size_out = x.shape[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = x.view(size_out) return x @clear_cache_before_run() def test_coloproxy(): tracer = ColoTracer() model = Conv1D(3, 3) input_sample = {"x": torch.rand(3, 3).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() node = list(gm.graph.nodes)[0] proxy = ColoProxy(node=node, tracer=tracer) proxy.meta_data = torch.empty(4, 2, device="meta") assert len(proxy) == 4 assert proxy.shape[0] == 4 and proxy.shape[1] == 2 assert proxy.dim() == 2 assert proxy.dtype == torch.float32 assert proxy.size(0) == 4 if __name__ == "__main__": test_coloproxy()