--- /dev/null
+#!/usr/bin/env python
+
+import torch
+
+######################################################################
+
+
+def single_test(D, N, fun_f, fun_g, nb_max_sequences=10000):
+ n_star = torch.randint(N, (1,)).item()
+ r = torch.zeros(D)
+ for k in range(nb_max_sequences):
+ X = torch.randn(N, D)
+ y = X[n_star] + torch.randn(D) * 0.5
+ for n in range(X.size(0)):
+ r = fun_f(N, k, n, X[n], r)
+ r, n_star_hat = fun_g(N, k, y, r)
+ if n_star_hat is not None:
+ return k + 1, n_star_hat == n_star
+ return -1, False
+
+
+def multi_test(fun_f, fun_g):
+ result = {False: [], True: []}
+ N = 100
+ D = 25
+ for u in range(100):
+ nb_realizations, correctness = single_test(D, N, fun_f, fun_g)
+ result[correctness].append(nb_realizations)
+
+ return torch.tensor(result[False]), torch.tensor(result[True])
+
+
+######################################################################
+
+d_best_id = 0
+d_best_mean = 1
+d_current_id = 2
+d_current_sum = 3
+d_current_sum_sq = 4
+d_current_nb = 5
+d_content = 6
+
+# N is the sequence length
+# k the index of the realization
+# n the index of the X in the current realization
+# x is X^k_n
+# r is R^k
+
+
+def fun_f(N, k, n, x, r):
+ if k == 0 and n == 0:
+ r[d_best_mean] = 1e9
+ r[d_current_id] = 0
+ r[d_current_sum] = 0.0
+ r[d_current_sum_sq] = 0.0
+ r[d_current_nb] = 0
+
+ if n == r[d_current_id]:
+ r[d_content:] = x[d_content:]
+
+ return r
+
+
+def fun_g(N, k, y, r):
+ current_mean = r[d_current_sum] / r[d_current_nb]
+ current_std = (
+ (r[d_current_sum_sq] / r[d_current_nb] - current_mean**2).sqrt().item()
+ )
+
+ if (
+ r[d_current_nb] > 1
+ and current_std / r[d_current_nb].sqrt() < (current_mean - r[d_best_mean]).abs()
+ ):
+ if current_mean <= r[d_best_mean]:
+ r[d_best_id] = r[d_current_id]
+ r[d_best_mean] = current_mean
+
+ r[d_current_nb] = 0
+ r[d_current_sum] = 0
+ r[d_current_sum_sq] = 0
+ r[d_current_id] += 1
+
+ if r[d_current_id] == N:
+ return r, r[d_best_id].long().item()
+
+ norm = (y[d_content:] - r[d_content:]).norm()
+ r[d_current_nb] += 1
+ r[d_current_sum] += norm
+ r[d_current_sum_sq] += norm**2
+
+ return r, None
+
+
+######################################################################
+
+r_failure, r_succes = multi_test(fun_f, fun_g)
+
+n_failures = r_failure.size(0)
+n_successes = r_succes.size(0)
+
+print(
+ f"ERRORS_RATE {n_failures/(n_failures+n_successes)} ({n_failures}/{n_failures+n_successes})"
+)
+print(f"K {r_succes.float().mean()} (+/- {r_succes.float().std()})")
+
+######################################################################
def save_image(x, filename):
x = x * train_std + train_mu
x = x.clamp(min=0, max=255) / 255
- torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
+ torchvision.utils.save_image(1 - x, filename, nrow=12, pad_value=1.0)
log_string(f"wrote {filename}")
# Save a bunch of train images
- x = train_input[:256]
+ x = train_input[:36]
save_image(x, f"{prefix}train_input.png")
# Save the same images after encoding / decoding
# Save a bunch of test images
- x = test_input[:256]
+ x = test_input[:36]
save_image(x, f"{prefix}input.png")
# Save the same images after encoding / decoding