3 import torch, torchvision
6 ######################################################################
10 n = sum(p.numel() for p in model.parameters())
11 big = next(model.parameters()).new(n) # Get same device and dtype
13 for p in model.parameters():
14 tmp = p.new(0).set_(p)
15 p.set_(big.storage(), k, p.size()).copy_(tmp)
18 ######################################################################
20 model = nn.Sequential(
21 nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)
25 for p in model.parameters():
26 print(p.size(), p.storage().size())
31 for p in model.parameters():
32 print(p.size(), p.storage().size())
34 ######################################################################
38 input = torch.rand(100, 2)
39 targets = torch.rand(100, 2)
40 optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2)
45 loss = mse(output, targets)