+parser.add_argument('--seed',
+ type = int, default = 0,
+ help = 'Random seed, < 0 is no seeding')
+
+parser.add_argument('--nb_epochs',
+ type = int, default = 100,
+ help = 'How many epochs')
+
+parser.add_argument('--batch_size',
+ type = int, default = 25,
+ help = 'Batch size')
+
+parser.add_argument('--nb_samples',
+ type = int, default = 25000,
+ help = 'Number of training examples')
+
+parser.add_argument('--learning_rate',
+ type = float, default = 1e-3,
+ help = 'Learning rate')
+
+parser.add_argument('--ema_decay',
+ type = float, default = 0.9999,
+ help = 'EMA decay, < 0 is no EMA')
+
+data_list = ', '.join( [ str(k) for k in samplers ])
+
+parser.add_argument('--data',
+ type = str, default = 'gaussian_mixture',
+ help = f'Toy data-set to use: {data_list}')
+
+args = parser.parse_args()
+
+if args.seed >= 0:
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = False
+ # torch.use_deterministic_algorithms(True)
+ torch.manual_seed(args.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed)
+
+######################################################################
+
+class EMA:
+ def __init__(self, model, decay):
+ self.model = model
+ self.decay = decay
+ if self.decay < 0: return
+ self.ema = { }
+ with torch.no_grad():
+ for p in model.parameters():
+ self.ema[p] = p.clone()
+
+ def step(self):
+ if self.decay < 0: return
+ with torch.no_grad():
+ for p in self.model.parameters():
+ self.ema[p].copy_(self.decay * self.ema[p] + (1 - self.decay) * p)
+
+ def copy(self):
+ if self.decay < 0: return
+ with torch.no_grad():
+ for p in self.model.parameters():
+ p.copy_(self.ema[p])
+