+class EMA:
+ def __init__(self, model, decay):
+ self.model = model
+ self.decay = decay
+ if self.decay < 0: return
+ self.ema = { }
+ with torch.no_grad():
+ for p in model.parameters():
+ self.ema[p] = p.clone()
+
+ def step(self):
+ if self.decay < 0: return
+ with torch.no_grad():
+ for p in self.model.parameters():
+ self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p)
+
+ def copy(self):
+ if self.decay < 0: return
+ with torch.no_grad():
+ for p in self.model.parameters():
+ p.copy_(self.ema[p])
+
+######################################################################
+
+class ConvNet(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ ks, nc = 5, 64
+
+ self.core = nn.Sequential(
+ nn.Conv2d(in_channels, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, out_channels, ks, padding = ks//2),
+ )
+
+ def forward(self, x):
+ return self.core(x)
+
+######################################################################
+# Data
+
+try:
+ train_input = samplers[args.data](args.nb_samples).to(device)
+except KeyError:
+ print(f'unknown data {args.data}')
+ exit(1)
+
+train_mean, train_std = train_input.mean(), train_input.std()
+
+######################################################################
+# Model
+
+if train_input.dim() == 2:
+ nh = 64
+
+ model = nn.Sequential(
+ nn.Linear(train_input.size(1) + 1, nh),
+ nn.ReLU(),
+ nn.Linear(nh, nh),
+ nn.ReLU(),
+ nn.Linear(nh, nh),
+ nn.ReLU(),
+ nn.Linear(nh, train_input.size(1)),
+ )
+
+elif train_input.dim() == 4:
+
+ model = ConvNet(train_input.size(1) + 1, train_input.size(1))