From 64dc96ddfa84511ba07d1929481e93e864735409 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 18 Jan 2024 07:51:11 +0100 Subject: [PATCH] Update. --- fridge | 88 +++++++++++++++++++++++++++++ grid.py | 165 +++++++++++++++++++++++++++++++++++++++++++++++++++---- main.py | 37 ++++++++----- mygpt.py | 120 +++++++++++++--------------------------- tasks.py | 6 +- 5 files changed, 307 insertions(+), 109 deletions(-) diff --git a/fridge b/fridge index f87c1df..d09e92d 100644 --- a/fridge +++ b/fridge @@ -204,3 +204,91 @@ def insert_flash_back(rec_V, V, rec_K, K, t0, t1, CL, proba): + dropout_head * (1 - epsilon - G.detach()) - dropout_tail * G.detach() ) + +###################################################################### + +2024 Jan 18 07:39:29 (from mygpt.py) + +class Calibrator: + def __init__(self, w=None, b=None): + self.w = w + self.b = b + self.s, self.s_sq, self.n = 0, 0, 0 + self.mean, self.std = 0, 0 + + def update(self, X): + X = X.detach() + self.s += X.sum(dim=0) + self.s_sq += X.pow(2).sum(dim=0) + self.n += X.size(0) + + def moments(self): + mean = self.s / self.n + std = (self.s_sq / self.n - mean * mean).sqrt() + return mean, std + + def normalize(self): + mean, std = self.moments() + if self.b is not None: + self.b.sub_(mean) + if self.w is not None: + self.w.div_(std) + result = mean - self.mean, std - self.std + self.mean, self.std = mean, std + self.s, self.s_sq, self.n = 0, 0, 0 + return result + + + +###################################################################### + +2024 Jan 18 07:39:34 (from mygpt.py) + + # self.calibrator_G = Calibrator() + # self.calibrator_rec_V = Calibrator() + # self.calibrator_rec_K = Calibrator() + + +###################################################################### + +2024 Jan 18 07:39:37 (from mygpt.py) + + # self.calibrator_G.update(G.reshape(-1, G.size(-1))) + + +###################################################################### + +2024 Jan 18 07:39:42 (from mygpt.py) + + # self.calibrator_rec_V.update( + # next_V.permute(0, 1, 3, 2).reshape(-1, next_V.size(2)) + # ) + # self.calibrator_rec_K.update( + # next_K.permute(0, 1, 3, 2).reshape(-1, next_K.size(2)) + # ) + + +###################################################################### + +2024 Jan 18 07:47:12 (from mygpt.py) + + ###################################################################### + # Roll the gating indexes + + # warnings.warn("rotating barrel", RuntimeWarning) + + # r_barrel = torch.arange(R, device=G.device)[None, None, :, None] + # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] + # r_barrel = (r_barrel + (t_barrel + t0) // L) % R + # G = G.gather(dim=2, index=r_barrel.expand_as(G)) + + +###################################################################### + +2024 Jan 18 07:47:25 (from mygpt.py) + + # warnings.warn("harmonic recurrence", RuntimeWarning) + # har = torch.arange(t0, t1, device = G.device).float() + 1 + # A = har / (har + 1) + # G = G / har + diff --git a/grid.py b/grid.py index 268f4ee..f9f1557 100755 --- a/grid.py +++ b/grid.py @@ -9,10 +9,6 @@ import math import torch, torchvision import torch.nn.functional as F -name_shapes = ["A", "B", "C", "D", "E", "F"] - -name_colors = ["red", "yellow", "blue", "green", "white", "purple"] - ###################################################################### @@ -23,20 +19,160 @@ class GridFactory: max_nb_items=4, max_nb_transformations=3, nb_questions=4, + nb_shapes=6, + nb_colors=6, ): assert size % 2 == 0 self.size = size self.max_nb_items = max_nb_items self.max_nb_transformations = max_nb_transformations self.nb_questions = nb_questions + self.name_shapes = [chr(ord("A") + k) for k in range(nb_shapes)] + self.name_colors = [ + "red", + "yellow", + "blue", + "green", + "white", + "black", + "maroon", + "dark_red", + "brown", + "firebrick", + "crimson", + "tomato", + "coral", + "indian_red", + "light_coral", + "dark_salmon", + "salmon", + "light_salmon", + "orange_red", + "dark_orange", + "orange", + "gold", + "dark_golden_rod", + "golden_rod", + "pale_golden_rod", + "dark_khaki", + "khaki", + "olive", + "yellow_green", + "dark_olive_green", + "olive_drab", + "lawn_green", + "chartreuse", + "green_yellow", + "dark_green", + "forest_green", + "lime", + "lime_green", + "light_green", + "pale_green", + "dark_sea_green", + "medium_spring_green", + "spring_green", + "sea_green", + "medium_aqua_marine", + "medium_sea_green", + "light_sea_green", + "dark_slate_gray", + "teal", + "dark_cyan", + "aqua", + "cyan", + "light_cyan", + "dark_turquoise", + "turquoise", + "medium_turquoise", + "pale_turquoise", + "aqua_marine", + "powder_blue", + "cadet_blue", + "steel_blue", + "corn_flower_blue", + "deep_sky_blue", + "dodger_blue", + "light_blue", + "sky_blue", + "light_sky_blue", + "midnight_blue", + "navy", + "dark_blue", + "medium_blue", + "royal_blue", + "blue_violet", + "indigo", + "dark_slate_blue", + "slate_blue", + "medium_slate_blue", + "medium_purple", + "dark_magenta", + "dark_violet", + "dark_orchid", + "medium_orchid", + "purple", + "thistle", + "plum", + "violet", + "magenta", + "orchid", + "medium_violet_red", + "pale_violet_red", + "deep_pink", + "hot_pink", + "light_pink", + "pink", + "antique_white", + "beige", + "bisque", + "blanched_almond", + "wheat", + "corn_silk", + "lemon_chiffon", + "light_golden_rod_yellow", + "light_yellow", + "saddle_brown", + "sienna", + "chocolate", + "peru", + "sandy_brown", + "burly_wood", + "tan", + "rosy_brown", + "moccasin", + "navajo_white", + "peach_puff", + "misty_rose", + "lavender_blush", + "linen", + "old_lace", + "papaya_whip", + "sea_shell", + "mint_cream", + "slate_gray", + "light_slate_gray", + "light_steel_blue", + "lavender", + "floral_white", + "alice_blue", + "ghost_white", + "honeydew", + "ivory", + "azure", + "snow", + "silver", + "gainsboro", + "white_smoke", + ][:nb_colors] def generate_scene(self): nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 col = torch.full((self.size * self.size,), -1) shp = torch.full((self.size * self.size,), -1) - a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items] - col[:nb_items] = a % len(name_colors) - shp[:nb_items] = a // len(name_colors) + a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items] + col[:nb_items] = a % len(self.name_colors) + shp[:nb_items] = a // len(self.name_colors) i = torch.randperm(self.size * self.size) col = col[i] shp = shp[i] @@ -76,12 +212,15 @@ class GridFactory: # for i in range(self.size): # for j in range(self.size): # if col[i,j] >= 0: - # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}") + # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}") for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="") + print( + f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}", + end="", + ) elif j == 0: print(" +", end="") else: @@ -103,7 +242,7 @@ class GridFactory: for i in range(self.size): for j in range(self.size): if col[i, j] >= 0: - n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}" + n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}" properties += [f"a {n} at {i} {j}"] return properties @@ -116,7 +255,9 @@ class GridFactory: for i1 in range(self.size): for j1 in range(self.size): if col[i1, j1] >= 0: - n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}" + n1 = ( + f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}" + ) properties += [f"there is a {n1}"] if i1 < self.size // 2: properties += [f"a {n1} is in the top half"] @@ -129,7 +270,7 @@ class GridFactory: for i2 in range(self.size): for j2 in range(self.size): if col[i2, j2] >= 0: - n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}" + n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}" if i1 > i2: properties += [f"a {n1} is below a {n2}"] if i1 < i2: diff --git a/main.py b/main.py index 04e5652..79841f3 100755 --- a/main.py +++ b/main.py @@ -133,6 +133,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False) parser.add_argument("--grid_size", type=int, default=6) +parser.add_argument("--grid_nb_colors", type=int, default=6) + +parser.add_argument("--grid_nb_shapes", type=int, default=6) + ############################## # picoclvr options @@ -701,6 +705,8 @@ elif args.task == "grid": nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, size=args.grid_size, + nb_shapes=args.grid_nb_shapes, + nb_colors=args.grid_nb_colors, logger=log_string, device=device_data, ) @@ -835,21 +841,22 @@ if args.max_percents_of_test_in_train >= 0: ############################## -for input in task.batches(split="train", desc="calibrate"): - input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x - -for n, m in model.named_modules(): - for a in dir(m): - x = getattr(m, a) - if isinstance(x, mygpt.Calibrator): - print(f"####### ${n} | ${a} ########################") - mean, std = x.moments() - print("mean\n", mean, "\n") - print("std\n", std, "\n") - print(f"############################################\n\n") - -exit(0) +if "calibrate" in sup_args: + for input in task.batches(split="train", desc="calibrate"): + input = input.to(device) + output = model(mygpt.BracketedSequence(input)).x + + for n, m in model.named_modules(): + for a in dir(m): + x = getattr(m, a) + if isinstance(x, mygpt.Calibrator): + print(f"####### ${n} | ${a} ########################") + mean, std = x.moments() + print("mean\n", mean, "\n") + print("std\n", std, "\n") + print(f"############################################\n\n") + + exit(0) ############################## diff --git a/mygpt.py b/mygpt.py index aded796..a27b99e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -126,7 +126,6 @@ class AddPositionalEncoding(nn.Module): import pscan - # X is /.../xTxD A is /.../xT Y_init is /.../xD @@ -147,6 +146,18 @@ def pscan_dim(A, X, Y_init, dim=-2): return Y +def pscan_rgrad(grad_Y, A, X, Y_init, dim=-2, eps=1e-2): + with torch.no_grad(): + s_A, s_X = 0, 0 + for t in range(X.size(dim) - 1, 0, -1): + delta = (grad_Y[t] - s_A) / A[t].grad + s_A += A[t].grad * delta + A[t].grad = delta + delta = (grad_Y[t] - s_X) / X[t].grad + s_X += X[t].grad * delta + X[t].grad = delta + + def pscan_shape(A, X, Y_init): s = X.size() A = A.reshape(-1, s[-2]) @@ -464,36 +475,6 @@ def moving_window(x, dim, win_dim, win_size): ############################## -class Calibrator: - def __init__(self, w=None, b=None): - self.w = w - self.b = b - self.s, self.s_sq, self.n = 0, 0, 0 - self.mean, self.std = 0, 0 - - def update(self, X): - X = X.detach() - self.s += X.sum(dim=0) - self.s_sq += X.pow(2).sum(dim=0) - self.n += X.size(0) - - def moments(self): - mean = self.s / self.n - std = (self.s_sq / self.n - mean * mean).sqrt() - return mean, std - - def normalize(self): - mean, std = self.moments() - if self.b is not None: - self.b.sub_(mean) - if self.w is not None: - self.w.div_(std) - result = mean - self.mean, std - self.std - self.mean, self.std = mean, std - self.s, self.s_sq, self.n = 0, 0, 0 - return result - - class Caterpillar(nn.Module): def __init__( self, @@ -561,10 +542,6 @@ class Caterpillar(nn.Module): dim_v, ) - self.calibrator_G = Calibrator() - self.calibrator_rec_V = Calibrator() - self.calibrator_rec_K = Calibrator() - def reset_inner_loss(self): self.acc_attention = 0 self.acc_nb = 0 @@ -620,8 +597,6 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() - self.calibrator_G.update(G.reshape(-1, G.size(-1))) - # warnings.warn("softmax gating", RuntimeWarning) # G = ( @@ -646,64 +621,47 @@ class Caterpillar(nn.Module): G = alpha * (1 - kill) - ###################################################################### - # Clip the gating to avoid values greater than 1 when several - # heads hit the same row + def recurrence(G, V, K): + # Clip the gating to avoid values greater than 1 when several + # heads hit the same row - G = G / G.sum(1, keepdim=True).clamp(min=1) + G = G / G.sum(1, keepdim=True).clamp(min=1) - ###################################################################### - # Roll the gating indexes - - # warnings.warn("rotating barrel", RuntimeWarning) + # We prepare the arguments for the parallel scan - # r_barrel = torch.arange(R, device=G.device)[None, None, :, None] - # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :] - # r_barrel = (r_barrel + (t_barrel + t0) // L) % R - # G = G.gather(dim=2, index=r_barrel.expand_as(G)) + A = 1 - G.sum(1) - # We prepare the arguments for the parallel scan + gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) + gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) - A = 1 - G.sum(1) + # We start from cached values, which matters in inference - # warnings.warn("harmonic recurrence", RuntimeWarning) - # har = torch.arange(t0, t1, device = G.device).float() + 1 - # A = har / (har + 1) - # G = G / har + init_rec_V = self.rec_V[:, :, t0 - L : t0] + init_rec_K = self.rec_K[:, :, t0 - L : t0] - gated_V = torch.einsum("nhrt,nhtd->nrtd", G, V) - gated_K = torch.einsum("nhrt,nhtd->nrtd", G, K) + # Associative scan - # We start from cached values, which matters in inference + # Here there is a trick: Since the stack at position t is + # computed by updating that at position t-L, the parallel + # scan operates with a period of L. To do so we split the + # sequence indexing in two axes, the second of size L, and + # run the parallel scan using the first as the sequence index. - init_rec_V = self.rec_V[:, :, t0 - L : t0] - init_rec_K = self.rec_K[:, :, t0 - L : t0] - - ################################################################# - # Associative scan + A = A.unflatten(2, (-1, L)) + gated_V = gated_V.unflatten(2, (-1, L)) + gated_K = gated_K.unflatten(2, (-1, L)) - # Here there is a trick: Since the stack at position t is - # computed by updating that at position t-L, the parallel - # scan operates with a period of L. To do so we split the - # sequence indexing in two axes, the second of size L, and - # run the parallel scan using the first as the sequence index. + next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) + next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) - A = A.unflatten(2, (-1, L)) - gated_V = gated_V.unflatten(2, (-1, L)) - gated_K = gated_K.unflatten(2, (-1, L)) + next_V = next_V.flatten(2, 3) + next_K = next_K.flatten(2, 3) - next_V = pscan_dim(A, gated_V, init_rec_V, dim=2) - next_K = pscan_dim(A, gated_K, init_rec_K, dim=2) + return next_V, next_K - next_V = next_V.flatten(2, 3) - next_K = next_K.flatten(2, 3) + ################################################################# - self.calibrator_rec_V.update( - next_V.permute(0, 1, 3, 2).reshape(-1, next_V.size(2)) - ) - self.calibrator_rec_K.update( - next_K.permute(0, 1, 3, 2).reshape(-1, next_K.size(2)) - ) + next_V, next_K = recurrence(G, V, K) self.rec_V[:, :, t0:t1] = next_V self.rec_K[:, :, t0:t1] = next_K diff --git a/tasks.py b/tasks.py index 4777a11..727b196 100755 --- a/tasks.py +++ b/tasks.py @@ -1473,6 +1473,8 @@ class Grid(Task): nb_test_samples, batch_size, size, + nb_shapes, + nb_colors, logger=None, device=torch.device("cpu"), ): @@ -1480,7 +1482,9 @@ class Grid(Task): self.device = device self.batch_size = batch_size - self.grid_factory = grid.GridFactory(size=size) + self.grid_factory = grid.GridFactory( + size=size, nb_shapes=nb_shapes, nb_colors=nb_colors + ) if logger is not None: logger( -- 2.39.5