+# Returns a triplet a, b, c where a are the standard MNIST images, c
+# the classes, and b is a Nx2 tensor, with for every n:
+#
+# b[n, 0] ~ Uniform(0, 10)
+# b[n, 1] ~ b[n, 0] + Uniform(0, 0.5) + c[n]
+
+def create_image_values_pairs(train = False):
+ ua, ub = [], []
+
+ if train:
+ input, target = train_input, train_target
+ else:
+ input, target = test_input, test_target
+
+ m = torch.zeros(used_MNIST_classes.max() + 1, dtype = torch.uint8, device = target.device)
+ m[used_MNIST_classes] = 1
+ m = m[target]
+ used_indices = torch.arange(input.size(0), device = target.device).masked_select(m)
+
+ input = input[used_indices].contiguous()
+ target = target[used_indices].contiguous()
+
+ a = input
+ c = target
+
+ b = a.new(a.size(0), 2)
+ b[:, 0].uniform_(0.0, 10.0)
+ b[:, 1].uniform_(0.0, 0.5)
+
+ if args.independent:
+ b[:, 1] += b[:, 0] + \
+ used_MNIST_classes[torch.randint(len(used_MNIST_classes), target.size())]
+ else:
+ b[:, 1] += b[:, 0] + target.float()
+
+ return a, b, c
+
+######################################################################
+
+def create_sequences_pairs(train = False):
+ nb, length = 10000, 1024
+ noise_level = 2e-2
+
+ ha = torch.randint(args.nb_classes, (nb, ), device = device) + 1
+ if args.independent:
+ hb = torch.randint(args.nb_classes, (nb, ), device = device)
+ else:
+ hb = ha
+
+ pos = torch.empty(nb, device = device).uniform_(0.0, 0.9)
+ a = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
+ a = a - pos.view(nb, 1)
+ a = (a >= 0).float() * torch.exp(-a * math.log(2) / 0.1)
+ a = a * ha.float().view(-1, 1).expand_as(a) / (1 + args.nb_classes)
+ noise = a.new(a.size()).normal_(0, noise_level)
+ a = a + noise
+
+ pos = torch.empty(nb, device = device).uniform_(0.0, 0.5)
+ b1 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
+ b1 = b1 - pos.view(nb, 1)
+ b1 = (b1 >= 0).float() * torch.exp(-b1 * math.log(2) / 0.1) * 0.25
+ pos = pos + hb.float() / (args.nb_classes + 1) * 0.5
+ # pos += pos.new(hb.size()).uniform_(0.0, 0.01)
+ b2 = torch.linspace(0, 1, length, device = device).view(1, -1).expand(nb, -1)
+ b2 = b2 - pos.view(nb, 1)
+ b2 = (b2 >= 0).float() * torch.exp(-b2 * math.log(2) / 0.1) * 0.25
+
+ b = b1 + b2
+ noise = b.new(b.size()).normal_(0, noise_level)
+ b = b + noise
+
+ # a = (a - a.mean()) / a.std()
+ # b = (b - b.mean()) / b.std()
+
+ return a, b, ha
+
+######################################################################
+
+class NetForImagePair(nn.Module):