+def generate_one_batch(s):
+ problem_number, batch_size, cuda, random_seed = s
+ svrt.seed(random_seed)
+ target = torch.LongTensor(batch_size).bernoulli_(0.5)
+ input = svrt.generate_vignettes(problem_number, target)
+ input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
+ if cuda:
+ input = input.cuda()
+ target = target.cuda()
+ return [ input, target ]
+