Update.
[pytorch.git] / ideal_rnn.py
1 #!/usr/bin/env python
2
3 import torch
4
5 ######################################################################
6
7
8 def single_test(D, N, fun_f, fun_g, nb_max_sequences=10000):
9     n_star = torch.randint(N, (1,)).item()
10     r = torch.zeros(D)
11     for k in range(nb_max_sequences):
12         X = torch.randn(N, D)
13         y = X[n_star] + torch.randn(D) * 0.5
14         for n in range(X.size(0)):
15             r = fun_f(N, k, n, X[n], r)
16         r, n_star_hat = fun_g(N, k, y, r)
17         if n_star_hat is not None:
18             return k + 1, n_star_hat == n_star
19     return -1, False
20
21
22 def multi_test(fun_f, fun_g):
23     result = {False: [], True: []}
24     N = 100
25     D = 25
26     for u in range(100):
27         nb_realizations, correctness = single_test(D, N, fun_f, fun_g)
28         result[correctness].append(nb_realizations)
29
30     return torch.tensor(result[False]), torch.tensor(result[True])
31
32
33 ######################################################################
34
35 d_best_id = 0
36 d_best_mean = 1
37 d_current_id = 2
38 d_current_sum = 3
39 d_current_sum_sq = 4
40 d_current_nb = 5
41 d_content = 6
42
43 # N is the sequence length
44 # k the index of the realization
45 # n the index of the X in the current realization
46 # x is X^k_n
47 # r is R^k
48
49
50 def fun_f(N, k, n, x, r):
51     if k == 0 and n == 0:
52         r[d_best_mean] = 1e9
53         r[d_current_id] = 0
54         r[d_current_sum] = 0.0
55         r[d_current_sum_sq] = 0.0
56         r[d_current_nb] = 0
57
58     if n == r[d_current_id]:
59         r[d_content:] = x[d_content:]
60
61     return r
62
63
64 def fun_g(N, k, y, r):
65     current_mean = r[d_current_sum] / r[d_current_nb]
66     current_std = (
67         (r[d_current_sum_sq] / r[d_current_nb] - current_mean**2).sqrt().item()
68     )
69
70     if (
71         r[d_current_nb] > 1
72         and current_std / r[d_current_nb].sqrt() < (current_mean - r[d_best_mean]).abs()
73     ):
74         if current_mean <= r[d_best_mean]:
75             r[d_best_id] = r[d_current_id]
76             r[d_best_mean] = current_mean
77
78         r[d_current_nb] = 0
79         r[d_current_sum] = 0
80         r[d_current_sum_sq] = 0
81         r[d_current_id] += 1
82
83         if r[d_current_id] == N:
84             return r, r[d_best_id].long().item()
85
86     norm = (y[d_content:] - r[d_content:]).norm()
87     r[d_current_nb] += 1
88     r[d_current_sum] += norm
89     r[d_current_sum_sq] += norm**2
90
91     return r, None
92
93
94 ######################################################################
95
96 r_failure, r_succes = multi_test(fun_f, fun_g)
97
98 n_failures = r_failure.size(0)
99 n_successes = r_succes.size(0)
100
101 print(
102     f"ERRORS_RATE {n_failures/(n_failures+n_successes)} ({n_failures}/{n_failures+n_successes})"
103 )
104 print(f"K {r_succes.float().mean()} (+/- {r_succes.float().std()})")
105
106 ######################################################################