+class EMA:
+ def __init__(self, model, decay):
+ self.model = model
+ self.decay = decay
+ self.mem = { }
+ with torch.no_grad():
+ for p in model.parameters():
+ self.mem[p] = p.clone()
+
+ def step(self):
+ with torch.no_grad():
+ for p in self.model.parameters():
+ self.mem[p].copy_(self.decay * self.mem[p] + (1 - self.decay) * p)
+
+ def copy_to_model(self):
+ with torch.no_grad():
+ for p in self.model.parameters():
+ p.copy_(self.mem[p])
+
+######################################################################
+
+# Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
+# additional dimension / channel
+
+class TimeAppender(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, u):
+ x, t = u
+ if not torch.is_tensor(t):
+ t = x.new_full((x.size(0),), t)
+ t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1])
+ return torch.cat((x, t), 1)
+
+class ConvNet(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ ks, nc = 5, 64
+
+ self.core = nn.Sequential(
+ TimeAppender(),
+ nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, out_channels, ks, padding = ks//2),
+ )
+
+ def forward(self, u):
+ return self.core(u)
+
+######################################################################
+# Data
+
+try:
+ train_input = samplers[args.data](args.nb_samples).to(device)
+except KeyError:
+ print(f'unknown data {args.data}')
+ exit(1)
+
+train_mean, train_std = train_input.mean(), train_input.std()
+
+######################################################################
+# Model
+
+if train_input.dim() == 2:
+ nh = 256
+
+ model = nn.Sequential(
+ TimeAppender(),
+ nn.Linear(train_input.size(1) + 1, nh),
+ nn.ReLU(),
+ nn.Linear(nh, nh),
+ nn.ReLU(),
+ nn.Linear(nh, nh),
+ nn.ReLU(),
+ nn.Linear(nh, train_input.size(1)),
+ )
+
+elif train_input.dim() == 4:
+
+ model = ConvNet(train_input.size(1), train_input.size(1))
+
+model.to(device)