+class EMA:
+ def __init__(self, model, decay):
+ self.model = model
+ self.decay = decay
+ self.mem = {}
+ with torch.no_grad():
+ for p in model.parameters():
+ self.mem[p] = p.clone()
+
+ def step(self):
+ with torch.no_grad():
+ for p in self.model.parameters():
+ self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
+
+ def copy_to_model(self):
+ with torch.no_grad():
+ for p in self.model.parameters():
+ p.copy_(self.mem[p])
+
+
+######################################################################
+
+# Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
+# additional dimension / channel
+
+
+class TimeAppender(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, u):
+ x, t = u
+ if not torch.is_tensor(t):
+ t = x.new_full((x.size(0),), t)
+ t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:, :1])
+ return torch.cat((x, t), 1)
+
+
+class ConvNet(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ ks, nc = 5, 64
+
+ self.core = nn.Sequential(
+ TimeAppender(),
+ nn.Conv2d(in_channels + 1, nc, ks, padding=ks // 2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding=ks // 2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding=ks // 2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding=ks // 2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding=ks // 2),
+ nn.ReLU(),
+ nn.Conv2d(nc, out_channels, ks, padding=ks // 2),
+ )
+
+ def forward(self, u):
+ return self.core(u)
+
+
+######################################################################
+# Data
+
+try:
+ train_input = samplers[args.data](args.nb_samples).to(device)
+except KeyError:
+ print(f"unknown data {args.data}")
+ exit(1)
+
+train_mean, train_std = train_input.mean(), train_input.std()
+
+######################################################################
+# Model
+
+if train_input.dim() == 2:
+ nh = 256
+
+ model = nn.Sequential(
+ TimeAppender(),
+ nn.Linear(train_input.size(1) + 1, nh),
+ nn.ReLU(),
+ nn.Linear(nh, nh),
+ nn.ReLU(),
+ nn.Linear(nh, nh),
+ nn.ReLU(),
+ nn.Linear(nh, train_input.size(1)),
+ )
+
+elif train_input.dim() == 4:
+ model = ConvNet(train_input.size(1), train_input.size(1))
+
+model.to(device)
+
+print(f"nb_parameters {sum([ p.numel() for p in model.parameters() ])}")
+
+######################################################################
+# Generate
+
+
+def generate(size, T, alpha, alpha_bar, sigma, model, train_mean, train_std):
+ with torch.no_grad():
+ x = torch.randn(size, device=device)
+
+ for t in range(T - 1, -1, -1):
+ output = model((x, t / (T - 1) - 0.5))
+ z = torch.zeros_like(x) if t == 0 else torch.randn_like(x)
+ x = (
+ 1
+ / torch.sqrt(alpha[t])
+ * (x - (1 - alpha[t]) / torch.sqrt(1 - alpha_bar[t]) * output)
+ + sigma[t] * z
+ )
+
+ x = x * train_std + train_mean
+
+ return x
+
+
+######################################################################
+# Train