5 ######################################################################
8 def single_test(D, N, fun_f, fun_g, nb_max_sequences=10000):
9 n_star = torch.randint(N, (1,)).item()
11 for k in range(nb_max_sequences):
13 y = X[n_star] + torch.randn(D) * 0.5
14 for n in range(X.size(0)):
15 r = fun_f(N, k, n, X[n], r)
16 r, n_star_hat = fun_g(N, k, y, r)
17 if n_star_hat is not None:
18 return k + 1, n_star_hat == n_star
22 def multi_test(fun_f, fun_g):
23 result = {False: [], True: []}
27 nb_realizations, correctness = single_test(D, N, fun_f, fun_g)
28 result[correctness].append(nb_realizations)
30 return torch.tensor(result[False]), torch.tensor(result[True])
33 ######################################################################
43 # N is the sequence length
44 # k the index of the realization
45 # n the index of the X in the current realization
50 def fun_f(N, k, n, x, r):
54 r[d_current_sum] = 0.0
55 r[d_current_sum_sq] = 0.0
58 if n == r[d_current_id]:
59 r[d_content:] = x[d_content:]
64 def fun_g(N, k, y, r):
65 current_mean = r[d_current_sum] / r[d_current_nb]
67 (r[d_current_sum_sq] / r[d_current_nb] - current_mean**2).sqrt().item()
72 and current_std / r[d_current_nb].sqrt() < (current_mean - r[d_best_mean]).abs()
74 if current_mean <= r[d_best_mean]:
75 r[d_best_id] = r[d_current_id]
76 r[d_best_mean] = current_mean
80 r[d_current_sum_sq] = 0
83 if r[d_current_id] == N:
84 return r, r[d_best_id].long().item()
86 norm = (y[d_content:] - r[d_content:]).norm()
88 r[d_current_sum] += norm
89 r[d_current_sum_sq] += norm**2
94 ######################################################################
96 r_failure, r_succes = multi_test(fun_f, fun_g)
98 n_failures = r_failure.size(0)
99 n_successes = r_succes.size(0)
102 f"ERRORS_RATE {n_failures/(n_failures+n_successes)} ({n_failures}/{n_failures+n_successes})"
104 print(f"K {r_succes.float().mean()} (+/- {r_succes.float().std()})")
106 ######################################################################