From bbe5b7ddb723696fb5388be950af252cb95eb5fb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 17 Mar 2024 22:42:46 +0100 Subject: [PATCH] Update. --- ideal_rnn.py | 106 +++++++++++++++++++++++++++++++++++++++++++++++++++ tiny_vae.py | 6 +-- 2 files changed, 109 insertions(+), 3 deletions(-) create mode 100755 ideal_rnn.py diff --git a/ideal_rnn.py b/ideal_rnn.py new file mode 100755 index 0000000..16d6059 --- /dev/null +++ b/ideal_rnn.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python + +import torch + +###################################################################### + + +def single_test(D, N, fun_f, fun_g, nb_max_sequences=10000): + n_star = torch.randint(N, (1,)).item() + r = torch.zeros(D) + for k in range(nb_max_sequences): + X = torch.randn(N, D) + y = X[n_star] + torch.randn(D) * 0.5 + for n in range(X.size(0)): + r = fun_f(N, k, n, X[n], r) + r, n_star_hat = fun_g(N, k, y, r) + if n_star_hat is not None: + return k + 1, n_star_hat == n_star + return -1, False + + +def multi_test(fun_f, fun_g): + result = {False: [], True: []} + N = 100 + D = 25 + for u in range(100): + nb_realizations, correctness = single_test(D, N, fun_f, fun_g) + result[correctness].append(nb_realizations) + + return torch.tensor(result[False]), torch.tensor(result[True]) + + +###################################################################### + +d_best_id = 0 +d_best_mean = 1 +d_current_id = 2 +d_current_sum = 3 +d_current_sum_sq = 4 +d_current_nb = 5 +d_content = 6 + +# N is the sequence length +# k the index of the realization +# n the index of the X in the current realization +# x is X^k_n +# r is R^k + + +def fun_f(N, k, n, x, r): + if k == 0 and n == 0: + r[d_best_mean] = 1e9 + r[d_current_id] = 0 + r[d_current_sum] = 0.0 + r[d_current_sum_sq] = 0.0 + r[d_current_nb] = 0 + + if n == r[d_current_id]: + r[d_content:] = x[d_content:] + + return r + + +def fun_g(N, k, y, r): + current_mean = r[d_current_sum] / r[d_current_nb] + current_std = ( + (r[d_current_sum_sq] / r[d_current_nb] - current_mean**2).sqrt().item() + ) + + if ( + r[d_current_nb] > 1 + and current_std / r[d_current_nb].sqrt() < (current_mean - r[d_best_mean]).abs() + ): + if current_mean <= r[d_best_mean]: + r[d_best_id] = r[d_current_id] + r[d_best_mean] = current_mean + + r[d_current_nb] = 0 + r[d_current_sum] = 0 + r[d_current_sum_sq] = 0 + r[d_current_id] += 1 + + if r[d_current_id] == N: + return r, r[d_best_id].long().item() + + norm = (y[d_content:] - r[d_content:]).norm() + r[d_current_nb] += 1 + r[d_current_sum] += norm + r[d_current_sum_sq] += norm**2 + + return r, None + + +###################################################################### + +r_failure, r_succes = multi_test(fun_f, fun_g) + +n_failures = r_failure.size(0) +n_successes = r_succes.size(0) + +print( + f"ERRORS_RATE {n_failures/(n_failures+n_successes)} ({n_failures}/{n_failures+n_successes})" +) +print(f"K {r_succes.float().mean()} (+/- {r_succes.float().std()})") + +###################################################################### diff --git a/tiny_vae.py b/tiny_vae.py index fa09831..4d11c7f 100755 --- a/tiny_vae.py +++ b/tiny_vae.py @@ -175,12 +175,12 @@ def save_images(model, prefix=""): def save_image(x, filename): x = x * train_std + train_mu x = x.clamp(min=0, max=255) / 255 - torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8) + torchvision.utils.save_image(1 - x, filename, nrow=12, pad_value=1.0) log_string(f"wrote {filename}") # Save a bunch of train images - x = train_input[:256] + x = train_input[:36] save_image(x, f"{prefix}train_input.png") # Save the same images after encoding / decoding @@ -194,7 +194,7 @@ def save_images(model, prefix=""): # Save a bunch of test images - x = test_input[:256] + x = test_input[:36] save_image(x, f"{prefix}input.png") # Save the same images after encoding / decoding -- 2.39.5