+ def copy_to_model(self):
+ with torch.no_grad():
+ for p in self.model.parameters():
+ p.copy_(self.mem[p])
+
+######################################################################
+
+class ConvNet(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ ks, nc = 5, 64
+
+ self.core = nn.Sequential(
+ nn.Conv2d(in_channels, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, out_channels, ks, padding = ks//2),
+ )
+
+ def forward(self, x):
+ return self.core(x)
+
+######################################################################
+# Data
+
+try:
+ train_input = samplers[args.data](args.nb_samples).to(device)
+except KeyError:
+ print(f'unknown data {args.data}')
+ exit(1)
+
+train_mean, train_std = train_input.mean(), train_input.std()
+
+######################################################################
+# Model
+
+if train_input.dim() == 2:
+ nh = 256
+
+ model = nn.Sequential(
+ nn.Linear(train_input.size(1) + 1, nh),
+ nn.ReLU(),
+ nn.Linear(nh, nh),
+ nn.ReLU(),
+ nn.Linear(nh, nh),
+ nn.ReLU(),
+ nn.Linear(nh, train_input.size(1)),
+ )
+
+elif train_input.dim() == 4:
+
+ model = ConvNet(train_input.size(1) + 1, train_input.size(1))
+
+model.to(device)
+
+print(f'nb_parameters {sum([ p.numel() for p in model.parameters() ])}')
+
+######################################################################
+# Generate
+
+def generate(size, alpha, alpha_bar, sigma, model, train_mean, train_std):
+
+ with torch.no_grad():
+
+ x = torch.randn(size, device = device)
+
+ for t in range(T-1, -1, -1):
+ z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
+ input = torch.cat((x, torch.full_like(x[:,:1], t / (T - 1) - 0.5)), 1)
+ x = 1/torch.sqrt(alpha[t]) \
+ * (x - (1-alpha[t]) / torch.sqrt(1-alpha_bar[t]) * model(input)) \
+ + sigma[t] * z
+
+ x = x * train_std + train_mean
+
+ return x
+
+######################################################################
+# Train