From: François Fleuret Date: Sun, 11 Aug 2024 12:33:19 +0000 (+0200) Subject: Merge branch 'dev' X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=HEAD;hp=00f7b3d445af8bb57376faabbf74eadc145faf1f;p=culture.git Merge branch 'dev' --- diff --git a/grids.py b/grids.py index 47e5861..0564f3b 100755 --- a/grids.py +++ b/grids.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, tqdm, os, warnings +import math, sys, tqdm, os, warnings, cairo import torch, torchvision @@ -14,9 +14,125 @@ from torch.nn import functional as F ###################################################################### + +def text_img(height, width, text): + pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8) + + surface = cairo.ImageSurface.create_for_data( + pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0) + ) + + ctx = cairo.Context(surface) + ctx.set_source_rgb(0, 0, 0) + ctx.set_font_size(16) + ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) + y = None + for line in text.split("\n"): + xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line) + if y is None: + y = height * 1.5 + x = height * 0.5 + + ctx.move_to(x, y) + ctx.show_text(line) + y += height * 1.5 + + ctx.stroke() + + return pixel_map.permute(2, 0, 1)[None, :3].contiguous() + + +###################################################################### + import problem +def grow_islands(nb, height, width, nb_seeds, nb_iterations): + w = torch.empty(5, 1, 3, 3) + + w[0, 0] = torch.tensor( + [ + [1.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 1.0], + ] + ) + + w[1, 0] = torch.tensor( + [ + [-1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + w[2, 0] = torch.tensor( + [ + [0.0, 1.0, -1.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0], + ] + ) + + w[3, 0] = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, -1.0], + ] + ) + + w[4, 0] = torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [-1.0, 1.0, 0.0], + ] + ) + + Z = torch.zeros(nb, height, width) + U = Z.flatten(1) + + for _ in range(nb_seeds): + M = F.conv2d(Z[:, None, :, :], w, padding=1) + M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1) + M = ((M[:, 0] == 0) & (Z == 0)).long() + Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None] + M = M * torch.rand(M.size()) + M = M.flatten(1) + M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1)) + U += M * Q + + for _ in range(nb_iterations): + M = F.conv2d(Z[:, None, :, :], w, padding=1) + M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1) + M = ((M[:, 1] >= 0) & (Z == 0)).long() + Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None] + M = M * torch.rand(M.size()) + M = M.flatten(1) + M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1)) + U = Z.flatten(1) + U += M * Q + + M = Z.clone() + Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2)) + + while True: + W = Z.clone() + Z = F.max_pool2d(Z, 3, 1, 1) * M + if Z.equal(W): + break + + Z = Z.long() + U = Z.flatten(1) + V = F.one_hot(U).max(dim=1).values + W = V.cumsum(dim=1) - V + N = torch.arange(Z.size(0))[:, None, None].expand_as(Z) + Z = W[N, Z] + + return Z + + class Grids(problem.Problem): named_colors = [ ("white", [255, 255, 255]), @@ -32,155 +148,295 @@ class Grids(problem.Problem): ("gray", [128, 128, 128]), ] - def __init__(self, device=torch.device("cpu")): + def check_structure(self, quizzes, struct): + S = self.height * self.width + + return ( + (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]]) + & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]]) + & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]]) + & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]]) + ).all() + + def get_structure(self, quizzes): + S = self.height * self.width + struct = tuple( + self.tok2l[n.item()] + for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0] + ) + self.check_structure(quizzes, struct) + return struct + + def inject_noise(self, quizzes, noise, struct, mask): + assert self.check_structure(quizzes, struct=struct) + S = self.height * self.width + + mask = torch.tensor(mask, device=quizzes.device) + mask = mask[None, :, None].expand(1, 4, S + 1).clone() + mask[:, :, 0] = 0 + mask = mask.reshape(1, -1).expand_as(quizzes) + mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long() + random = torch.randint(self.nb_colors, mask.size()) + quizzes = mask * random + (1 - mask) * quizzes + + return quizzes + + # What a mess + def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")): + if torch.is_tensor(quizzes): + return self.reconfigure([quizzes], struct=struct)[0] + + S = self.height * self.width + result = [x.new(x.size()) for x in quizzes] + + struct_from = self.get_structure(quizzes[0][:1]) + i = self.indices_select(quizzes[0], struct_from) + + sf = dict((l, n) for n, l in enumerate(struct_from)) + + for q in range(4): + k = sf[struct[q]] + for x, y in zip(quizzes, result): + l = x.size(1) // 4 + y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l] + + j = i == False + + if j.any(): + for z, y in zip( + self.reconfigure([x[j] for x in quizzes], struct=struct), result + ): + y[j] = z + + return result + + def trivial(self, quizzes): + S = self.height * self.width + assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B")) + a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:] + return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min( + dim=1 + ).values + + def make_quiz_mask( + self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1) + ): + assert self.check_structure(quizzes, struct) + + ar_mask = quizzes.new_zeros(quizzes.size()) + + S = self.height * self.width + a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:] + a[:, 0, :] = mask[0] + a[:, 1, :] = mask[1] + a[:, 2, :] = mask[2] + a[:, 3, :] = mask[3] + + return ar_mask + + def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")): + S = self.height * self.width + q = quizzes.reshape(quizzes.size(0), 4, S + 1) + return ( + (q[:, 0, 0] == self.l2tok[struct[0]]) + & (q[:, 1, 0] == self.l2tok[struct[1]]) + & (q[:, 2, 0] == self.l2tok[struct[2]]) + & (q[:, 3, 0] == self.l2tok[struct[3]]) + ) + + def __init__( + self, + max_nb_cached_chunks=None, + chunk_size=None, + nb_threads=-1, + tasks=None, + ): self.colors = torch.tensor([c for _, c in self.named_colors]) + + self.nb_colors = len(self.colors) + self.token_A = self.nb_colors + self.token_f_A = self.token_A + 1 + self.token_B = self.token_f_A + 1 + self.token_f_B = self.token_B + 1 + + self.nb_rec_max = 5 + self.rfree = torch.tensor([]) + + self.l2tok = { + "A": self.token_A, + "f_A": self.token_f_A, + "B": self.token_B, + "f_B": self.token_f_B, + } + + self.tok2l = { + self.token_A: "A", + self.token_f_A: "f_A", + self.token_B: "B", + self.token_f_B: "f_B", + } + self.height = 10 self.width = 10 - self.device = device + self.seq_len = 4 * (1 + self.height * self.width) + self.nb_token_values = self.token_f_B + 1 + + self.cache_rec_coo = {} + + all_tasks = [ + self.task_replace_color, + self.task_translate, + self.task_grow, + self.task_half_fill, + self.task_frame, + self.task_detect, + self.task_scale, + self.task_symbols, + self.task_corners, + self.task_contact, + self.task_path, + self.task_fill, + ############################################ hard ones + self.task_isometry, + self.task_trajectory, + self.task_bounce, + # self.task_count, # NOT REVERSIBLE + # self.task_islands, # TOO MESSY + ] + + if tasks is None: + self.all_tasks = all_tasks + else: + self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")] + + super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) ###################################################################### - def frame2img(self, x, scale=15): - x = x.reshape(x.size(0), self.height, -1) - m = torch.logical_and(x >= 0, x < self.nb_token_values()).long() - x = self.colors[x * m].permute(0, 3, 1, 2) - s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) - x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) + def grid2img(self, x, scale=15): + m = torch.logical_and(x >= 0, x < self.nb_colors).long() + y = self.colors[x * m].permute(0, 3, 1, 2) + s = y.shape + y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 - x[:, :, torch.arange(0, x.size(2), scale), :] = 0 - x = x[:, :, 1:, 1:] + y[:, :, :, torch.arange(0, y.size(3), scale)] = 64 + y[:, :, torch.arange(0, y.size(2), scale), :] = 64 for n in range(m.size(0)): for i in range(m.size(1)): for j in range(m.size(2)): if m[n, i, j] == 0: - for k in range(2, scale - 2): - for l in [0, 1]: - x[n, :, i * scale + k, j * scale + k - l] = 0 - x[ - n, :, i * scale + scale - 1 - k, j * scale + k - l - ] = 0 + for k in range(3, scale - 2): + y[n, :, i * scale + k, j * scale + k] = 0 + y[n, :, i * scale + k, j * scale + scale - k] = 0 + + y = y[:, :, 1:, 1:] + + return y - return x + def add_frame(self, img, colors, thickness): + result = img.new( + img.size(0), + img.size(1), + img.size(2) + 2 * thickness, + img.size(3) + 2 * thickness, + ) + + result[...] = colors[:, :, None, None] + result[:, :, thickness:-thickness, thickness:-thickness] = img - def save_image( + return result + + def save_quizzes_as_image( self, result_dir, filename, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, + quizzes, + predicted_parts=None, + correct_parts=None, + comments=None, + comment_height=48, nrow=4, margin=8, ): - S = self.height * self.width - As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width) - f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view( - -1, self.height, self.width - ) - Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width) - prompts = torch.cat([As, f_As, Bs], dim=2) - answers = answers.reshape(answers.size(0), self.height, self.width) + quizzes = quizzes.to("cpu") - if predicted_prompts is None: - predicted_prompts = 255 + to_reconfigure = [quizzes] + if predicted_parts is not None: + to_reconfigure.append(predicted_parts) + if correct_parts is not None: + to_reconfigure.append(correct_parts) - if predicted_answers is None: - predicted_answers = 255 - - def add_frame(x, c, margin, bottom=False): - if bottom: - h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0 - else: - h, w, di, dj = ( - x.size(2) + 2 * margin, - x.size(3) + 2 * margin, - margin, - margin, - ) + to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B")) - y = x.new_full((x.size(0), x.size(1), h, w), 0) + quizzes = to_reconfigure.pop(0) + if predicted_parts is not None: + predicted_parts = to_reconfigure.pop(0) + if correct_parts is not None: + correct_parts = to_reconfigure.pop(0) - if type(c) is int: - y[...] = c - else: - c = c.long()[:, None] - c = ( - (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) - * torch.tensor([64, 64, 64], device=c.device) - + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) - + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) - + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) - ) - y[...] = c[:, :, None, None] - - y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x - - return y + S = self.height * self.width - img_prompts = torch.cat( - [ - add_frame( - add_frame(self.frame2img(x), c=0, margin=1), - c=predicted_prompts, - margin=margin, - ) - for x in prompts.to("cpu").split(split_size=self.width, dim=2) - ], - dim=3, + A, f_A, B, f_B = ( + quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:] + .reshape(quizzes.size(0), 4, self.height, self.width) + .permute(1, 0, 2, 3) ) - h = img_prompts.size(2) - img_answers = add_frame( - add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1), - c=predicted_answers, - margin=margin, + frame, white, gray, green, red = torch.tensor( + [[64, 64, 64], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]], + device=quizzes.device, ) - separator_size = 2 * margin - - separator = img_prompts.new_full( - ( - img_prompts.size(0), - img_prompts.size(1), - img_prompts.size(2), - separator_size, - ), - 255, - ) + img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1) + img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1) + img_B = self.add_frame(self.grid2img(B), frame[None, :], thickness=1) + img_f_B = self.add_frame(self.grid2img(f_B), frame[None, :], thickness=1) + + # predicted_parts Nx4 + # correct_parts Nx4 + + if predicted_parts is None: + colors = white[None, None, :].expand(-1, 4, -1) + else: + predicted_parts = predicted_parts.to("cpu") + if correct_parts is None: + colors = ( + predicted_parts[:, :, None] * gray[None, None, :] + + (1 - predicted_parts[:, :, None]) * white[None, None, :] + ) + else: + correct_parts = correct_parts.to("cpu") + colors = ( + predicted_parts[:, :, None] + * ( + (correct_parts[:, :, None] == 1).long() * green[None, None, :] + + (correct_parts[:, :, None] == 0).long() * gray[None, None, :] + + (correct_parts[:, :, None] == -1).long() * red[None, None, :] + ) + + (1 - predicted_parts[:, :, None]) * white[None, None, :] + ) - marker = img_prompts.new_full( - ( - img_prompts.size(0), - img_prompts.size(1), - img_prompts.size(2), - separator_size, - ), - 255, - ) + img_A = self.add_frame(img_A, colors[:, 0], thickness=8) + img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8) + img_B = self.add_frame(img_B, colors[:, 2], thickness=8) + img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8) - # marker[:, :, 0] = 0 - # marker[:, :, h - 1] = 0 + img_A = self.add_frame(img_A, white[None, :], thickness=2) + img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2) + img_B = self.add_frame(img_B, white[None, :], thickness=2) + img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2) - for k in range(1, 2 * separator_size - 8): - i = k - (separator_size - 4) - j = separator_size - 5 - abs(i) - marker[:, :, h // 2 - 1 + i, 2 + j] = 0 - marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 + img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3) - img = torch.cat( - [ - img_prompts, - marker, - img_answers, - ], - dim=3, - ) + if comments is not None: + comment_img = [text_img(comment_height, img.size(3), t) for t in comments] + comment_img = torch.cat(comment_img, dim=0) + img = torch.cat([img, comment_img], dim=2) image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image( img.float() / 255.0, image_name, @@ -191,134 +447,243 @@ class Grids(problem.Problem): ###################################################################### - def nb_token_values(self): - return len(self.colors) - # @torch.compile - def rec_coo_(self, nb_rec, min_height=3, min_width=3): - # @torch.compile - def overlap(ia, ja, ib, jb): - return ( - ia[1] >= ib[0] and ia[0] <= ib[1] and ja[1] >= jb[0] and ja[0] <= jb[1] - ) + def rec_coo( + self, + nb_rec, + min_height=3, + min_width=3, + surface_max=None, + prevent_overlap=False, + ): + if surface_max is None: + surface_max = self.height * self.width // 2 - if nb_rec == 3: + signature = (nb_rec, min_height, min_width, surface_max) + + try: + return self.cache_rec_coo[signature].pop() + except IndexError: + pass + except KeyError: + pass + + N = 10000 + while True: while True: - i = torch.randint(self.height + 1, (nb_rec, 2)).sort(dim=1).values - j = torch.randint(self.width + 1, (nb_rec, 2)).sort(dim=1).values - if ( - not ( - overlap(i[0], j[0], i[1], j[1]) - or overlap(i[0], j[0], i[2], j[2]) - or overlap(i[1], j[1], i[2], j[2]) - ) - and (i[:, 1] - i[:, 0]).min() >= min_height - and (j[:, 1] - j[:, 0]).min() >= min_width - ): + i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values + j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values + i[:, 1] += 1 + j[:, 1] += 1 + big_enough = ( + (i[:, 1] >= i[:, 0] + min_height) + & (j[:, 1] >= j[:, 0] + min_height) + & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max) + ) + + i, j = i[big_enough], j[big_enough] + + n = i.size(0) - i.size(0) % nb_rec + + if n > 0: break - return ( - (i[0, 0], j[0, 0], i[0, 1], j[0, 1]), - (i[1, 0], j[1, 0], i[1, 1], j[1, 1]), - (i[2, 0], j[2, 0], i[2, 1], j[2, 1]), - ) - # That's quite a tensorial spaghetti mess to sample - # non-overlapping rectangles quickly, but made the generation of - # 100k samples go from 1h50 with a lame pure python code to 3min30s - # with this one. - # @torch.compile - def rec_coo(self, nb_rec, min_height=3, min_width=3): - nb_trials = 200 + i = i[:n].reshape(n // nb_rec, nb_rec, -1) + j = j[:n].reshape(n // nb_rec, nb_rec, -1) + + if prevent_overlap: + can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum( + dim=-1 + ) <= self.height * self.width + i, j = i[can_fit], j[can_fit] + if nb_rec == 2: + A_i1, A_i2, A_j1, A_j2 = ( + i[:, 0, 0], + i[:, 0, 1], + j[:, 0, 0], + j[:, 0, 1], + ) + B_i1, B_i2, B_j1, B_j2 = ( + i[:, 1, 0], + i[:, 1, 1], + j[:, 1, 0], + j[:, 1, 1], + ) + no_overlap = ( + (A_i1 >= B_i2) + | (A_i2 <= B_i1) + | (A_j1 >= B_j2) + | (A_j2 <= B_j1) + ) + i, j = (i[no_overlap], j[no_overlap]) + elif nb_rec == 3: + A_i1, A_i2, A_j1, A_j2 = ( + i[:, 0, 0], + i[:, 0, 1], + j[:, 0, 0], + j[:, 0, 1], + ) + B_i1, B_i2, B_j1, B_j2 = ( + i[:, 1, 0], + i[:, 1, 1], + j[:, 1, 0], + j[:, 1, 1], + ) + C_i1, C_i2, C_j1, C_j2 = ( + i[:, 2, 0], + i[:, 2, 1], + j[:, 2, 0], + j[:, 2, 1], + ) + no_overlap = ( + ( + (A_i1 >= B_i2) + | (A_i2 <= B_i1) + | (A_j1 >= B_j2) + | (A_j2 <= B_j1) + ) + & ( + (A_i1 >= C_i2) + | (A_i2 <= C_i1) + | (A_j1 >= C_j2) + | (A_j2 <= C_j1) + ) + & ( + (B_i1 >= C_i2) + | (B_i2 <= C_i1) + | (B_j1 >= C_j2) + | (B_j2 <= C_j1) + ) + ) + i, j = (i[no_overlap], j[no_overlap]) + else: + assert nb_rec == 1 - while True: - v = ( + if i.size(0) > 1: + break + + self.cache_rec_coo[signature] = [ + [ ( - torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device) - .sort(dim=-1) - .indices - < 2 + i[n, k, 0].item(), + j[n, k, 0].item(), + i[n, k, 1].item(), + j[n, k, 1].item(), ) - .long() - .cumsum(dim=1) - == 1 - ).long() + for k in range(nb_rec) + ] + for n in range(i.size(0)) + ] + + return self.cache_rec_coo[signature].pop() - h = ( + ###################################################################### + + def contact_matrices(self, rn, ri, rj, rz): + n = torch.arange(self.nb_rec_max) + return ( + ( ( - torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device) - .sort(dim=-1) - .indices - < 2 + ( + (ri[:, :, None, 0] == ri[:, None, :, 1] + 1) + | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0]) + ) + & (rj[:, :, None, 0] <= rj[:, None, :, 1]) + & (rj[:, :, None, 1] >= rj[:, None, :, 0]) ) - .long() - .cumsum(dim=1) - == 1 - ).long() + | ( + ( + (rj[:, :, None, 0] == rj[:, None, :, 1] + 1) + | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0]) + ) + & (ri[:, :, None, 0] <= ri[:, None, :, 1]) + & (ri[:, :, None, 1] >= ri[:, None, :, 0]) + ) + ) + # & (rz[:, :, None] == rz[:, None, :]) + & (n[None, :, None] < rn[:, None, None]) + & (n[None, None, :] < n[None, :, None]) + ) - i = torch.logical_and( - v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width + def sample_rworld_states(self, N=1000): + while True: + ri = ( + torch.randint(self.height - 2, (N, self.nb_rec_max, 2)) + .sort(dim=2) + .values + ) + ri[:, :, 1] += 2 + rj = ( + torch.randint(self.width - 2, (N, self.nb_rec_max, 2)) + .sort(dim=2) + .values + ) + rj[:, :, 1] += 2 + rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2 + rz = torch.randint(2, (N, self.nb_rec_max)) + rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1 + n = torch.arange(self.nb_rec_max) + nb_collisions = ( + ( + (ri[:, :, None, 0] <= ri[:, None, :, 1]) + & (ri[:, :, None, 1] >= ri[:, None, :, 0]) + & (rj[:, :, None, 0] <= rj[:, None, :, 1]) + & (rj[:, :, None, 1] >= rj[:, None, :, 0]) + & (rz[:, :, None] == rz[:, None, :]) + & (n[None, :, None] < rn[:, None, None]) + & (n[None, None, :] < n[None, :, None]) + ) + .long() + .flatten(1) + .sum(dim=1) ) - v, h = v[i], h[i] - v = v[: v.size(0) - v.size(0) % nb_rec] - h = h[: h.size(0) - h.size(0) % nb_rec] - v = v.reshape(v.size(0) // nb_rec, nb_rec, -1) - h = h.reshape(h.size(0) // nb_rec, nb_rec, -1) + no_collision = nb_collisions == 0 - r = v[:, :, :, None] * h[:, :, None, :] + if no_collision.any(): + print(no_collision.long().sum() / N) + self.rn = rn[no_collision] + self.ri = ri[no_collision] + self.rj = rj[no_collision] + self.rz = rz[no_collision] + self.rc = rc[no_collision] - valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1 + nb_contact = ( + self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1) + ) - v = v[valid] - h = h[valid] + self.rcontact = nb_contact > 0 + self.rfree = torch.full((self.rn.size(0),), True) - if v.size(0) > 0: break - av = torch.arange(v.size(2), device=self.device)[None, :] - ah = torch.arange(h.size(2), device=self.device)[None, :] + def get_recworld_state(self): + if not self.rfree.any(): + self.sample_rworld_states() + k = torch.arange(self.rn.size(0))[self.rfree] + k = k[torch.randint(k.size(0), (1,))].item() + self.rfree[k] = False + return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k] - return [ - (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1) - for i1, j1, i2, j2 in zip( - v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values, - h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values, - (v[0] * av).max(dim=-1).values, - (h[0] * ah).max(dim=-1).values, - ) - ] + def draw_state(self, X, rn, ri, rj, rz, rc): + for n in sorted(list(range(rn)), key=lambda n: rz[n].item()): + X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n] - # @torch.compile - def rec_coo_(self, x, n, min_height=3, min_width=3): - collision = x.new(x.size()) - while True: - collision[...] = 0 - result = [] - for _ in range(n): - while True: - i1, i2 = torch.randint(x.size(0), (2,)) - if i1 + min_height <= i2: - break - while True: - j1, j2 = torch.randint(x.size(1), (2,)) - if j1 + min_width <= j2: - break - collision[i1:i2, j1:j2] += 1 - if collision.max() > 1: - break - result.append((i1, j1, i2, j2)) - if collision.max() == 1: - break - return result + def task_recworld_immobile(self, A, f_A, B, f_B): + for X, f_X in [(A, f_A), (B, f_B)]: + rn, ri, rj, rz, rc = self.get_recworld_state() + self.draw_state(X, rn, ri, rj, rz, rc) + ri += 1 + self.draw_state(f_X, rn, ri, rj, rz, rc) ###################################################################### # @torch.compile def task_replace_color(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] @@ -326,12 +691,16 @@ class Grids(problem.Problem): # @torch.compile def task_translate(self, A, f_A, B, f_B): - di, dj = torch.randint(3, (2,)) - 1 + while True: + di, dj = torch.randint(3, (2,)) - 1 + if di.abs() + dj.abs() > 0: + break + nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) i1, j1, i2, j2 = r[nb_rec - 1] if ( i1 + di >= 0 @@ -353,11 +722,11 @@ class Grids(problem.Problem): def task_grow(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 - direction = torch.randint(2, (1,)) + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 + direction = torch.randint(2, (1,)).item() for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) i1, j1, i2, j2 = r[nb_rec - 1] if i1 + 3 < i2 and j1 + 3 < j2: break @@ -376,13 +745,13 @@ class Grids(problem.Problem): f_X[i1:i2, j1:j2] = c[n] # @torch.compile - def task_color_grow(self, A, f_A, B, f_B): + def task_half_fill(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1 - direction = torch.randint(4, (1,)) + c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1 + direction = torch.randint(4, (1,)).item() for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[2 * n] @@ -420,27 +789,34 @@ class Grids(problem.Problem): # @torch.compile def task_frame(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] - f_X[i1:i2, j1:j2] = c[n] if n == nb_rec - 1: - f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0 + f_X[i1:i2, j1] = c[n] + f_X[i1:i2, j2 - 1] = c[n] + f_X[i1, j1:j2] = c[n] + f_X[i2 - 1, j1:j2] = c[n] + else: + f_X[i1:i2, j1:j2] = c[n] # @torch.compile def task_detect(self, A, f_A, B, f_B): nb_rec = 3 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] if n < nb_rec - 1: - f_X[i1, j1] = c[-1] + for k in range(2): + f_X[i1 + k, j1] = c[-1] + f_X[i1, j1 + k] = c[-1] # @torch.compile def contact(self, X, i, j, q): @@ -478,37 +854,66 @@ class Grids(problem.Problem): return no, nq, nq_diag - # @torch.compile - def task_count(self, A, f_A, B, f_B): - N = (torch.randint(4, (1,)) + 2).item() - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + def REMOVED_task_count(self, A, f_A, B, f_B): + while True: + error = False + + N = 3 + c = torch.zeros(N + 2, dtype=torch.int64) + c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + if not hasattr(self, "cache_count") or len(self.cache_count) == 0: + self.cache_count = list( + grow_islands( + 1000, + self.height, + self.width, + nb_seeds=self.height * self.width // 8, + nb_iterations=self.height * self.width // 5, + ) + ) - for X, f_X in [(A, f_A), (B, f_B)]: - nb = torch.zeros(N, dtype=torch.int64) - q = torch.randint(N, (self.height * self.width,)) - k = torch.randperm(self.height * self.width) - for p in range(self.height * self.width): - i, j = k[p] % self.height, k[p] // self.height - no, nq, nq_diag = self.contact(X, i, j, c[q[p]]) - if no == 0 and nq_diag == 0: - if nq == 0: - if nb[q[p]] < self.width: - X[i, j] = c[q[p]] - nb[q[p]] += 1 - if nq == 1: - X[i, j] = c[q[p]] - - for n in range(N): - for j in range(nb[n]): - f_X[n, j] = c[n] + X[...] = self.cache_count.pop() + + # k = (X.max() + 1 + (c.size(0) - 1)).item() + # V = torch.arange(k) // (c.size(0) - 1) + # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % ( + # c.size(0) - 1 + # ) + 1 + + V = torch.randint(N, (X.max() + 1,)) + 1 + V[0] = 0 + NB = F.one_hot(c[V]).sum(dim=0) + X[...] = c[V[X]] + f_X[...] = X + + if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3: + m = NB[c[:-1]].max() + if (NB[c[:-1]] == m).long().sum() == 1: + for e in range(1, N + 1): + if NB[c[e]] == m: + a = (f_X == c[e]).long() + f_X[...] = (1 - a) * f_X + a * c[-1] + else: + error = True + break + + if not error: + break + + assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3 # @torch.compile def task_trajectory(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:2] + 1 + c = torch.randperm(self.nb_colors - 1)[:2] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: di, dj = torch.randint(7, (2,)) - 3 - i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) if ( abs(di) + abs(dj) > 0 and i + 2 * di >= 0 @@ -532,7 +937,7 @@ class Grids(problem.Problem): # @torch.compile def task_bounce(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:3] + 1 + c = torch.randperm(self.nb_colors - 1)[:3] + 1 for X, f_X in [(A, f_A), (B, f_B)]: # @torch.compile def free(i, j): @@ -549,8 +954,9 @@ class Grids(problem.Problem): X[...] = 0 for _ in range((self.height * self.width) // 10): - i, j = torch.randint(self.height, (1,)), torch.randint( - self.width, (1,) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), ) X[i, j] = c[0] f_X[i, j] = c[0] @@ -560,7 +966,10 @@ class Grids(problem.Problem): if abs(di) + abs(dj) == 1: break - i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) X[i, j] = c[1] f_X[i, j] = c[1] @@ -584,6 +993,7 @@ class Grids(problem.Problem): f_X[i, j] = c[2] if l <= 1: X[i, j] = c[2] + f_X[i, j] = c[1] if l >= self.width: break @@ -596,33 +1006,39 @@ class Grids(problem.Problem): # @torch.compile def task_scale(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:2] + 1 + c = torch.randperm(self.nb_colors - 1)[:2] + 1 - i, j = torch.randint(self.height // 2, (1,)), torch.randint( - self.width // 2, (1,) + i, j = ( + torch.randint(self.height // 2, (1,)).item(), + torch.randint(self.width // 2, (1,)).item(), ) for X, f_X in [(A, f_A), (B, f_B)]: for _ in range(3): while True: - i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( - self.width // 2 + 1, (1,) + i1, j1 = ( + torch.randint(self.height // 2 + 1, (1,)).item(), + torch.randint(self.width // 2 + 1, (1,)).item(), ) - i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( - self.width // 2 + 1, (1,) + i2, j2 = ( + torch.randint(self.height // 2 + 1, (1,)).item(), + torch.randint(self.width // 2 + 1, (1,)).item(), ) if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3: break X[i + i1 : i + i2, j + j1 : j + j2] = c[0] f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0] - X[i, j] = c[1] - f_X[0:2, 0:2] = c[1] + for k in range(2): + X[i + k, j] = c[1] + X[i, j + k] = c[1] + f_X[i + k, j] = c[1] + f_X[i, j + k] = c[1] # @torch.compile def task_symbols(self, A, f_A, B, f_B): nb_rec = 4 - c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 delta = 3 for X, f_X in [(A, f_A), (B, f_B)]: while True: @@ -634,34 +1050,45 @@ class Grids(problem.Problem): if d.min() > delta: break - for k in range(1, nb_rec): - X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k] - ai, aj = i.float().mean(), j.float().mean() - q = torch.randint(3, (1,)) + 1 - - X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0] - X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0] - X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0] - X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0] + q = torch.randint(3, (1,)).item() + 1 assert i[q] != ai and j[q] != aj - X[ + for Z in [X, f_X]: + for k in range(0, nb_rec): + Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k] + # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0] + # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0] + # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0] + # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0] + + # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] + + f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q] + # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] + + ii, jj = ( i[0] + delta // 2 + (i[q] - ai).sign().long(), j[0] + delta // 2 + (j[q] - aj).sign().long(), - ] = c[nb_rec] + ) + + X[ii, jj] = c[nb_rec] + X[i[0] + delta // 2, jj] = c[nb_rec] + X[ii, j[0] + delta // 2] = c[nb_rec] - f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] + f_X[ii, jj] = c[nb_rec] + f_X[i[0] + delta // 2, jj] = c[nb_rec] + f_X[ii, j[0] + delta // 2] = c[nb_rec] # @torch.compile - def task_ortho(self, A, f_A, B, f_B): + def task_isometry(self, A, f_A, B, f_B): nb_rec = 3 di, dj = torch.randint(3, (2,)) - 1 o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]]) m = torch.eye(2) - for _ in range(torch.randint(4, (1,))): + for _ in range(torch.randint(4, (1,)).item()): m = m @ o if torch.rand(1) < 0.5: m[0, :] = -m[0, :] @@ -673,7 +1100,7 @@ class Grids(problem.Problem): X[...] = 0 f_X[...] = 0 - c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 for r in range(nb_rec): while True: @@ -710,9 +1137,88 @@ class Grids(problem.Problem): ): break + def compute_distance(self, walls, goal_i, goal_j): + max_length = walls.numel() + dist = torch.full_like(walls, max_length) + + dist[goal_i, goal_j] = 0 + pred_dist = torch.empty_like(dist) + + while True: + pred_dist.copy_(dist) + dist[1:-1, 1:-1] = ( + torch.cat( + ( + dist[None, 1:-1, 1:-1], + dist[None, 1:-1, 0:-2], + dist[None, 2:, 1:-1], + dist[None, 1:-1, 2:], + dist[None, 0:-2, 1:-1], + ), + 0, + ).min(dim=0)[0] + + 1 + ) + + dist = walls * max_length + (1 - walls) * dist + + if dist.equal(pred_dist): + return dist * (1 - walls) + # @torch.compile - def task_islands(self, A, f_A, B, f_B): - pass + def REMOVED_task_distance(self, A, f_A, B, f_B): + c = torch.randperm(self.nb_colors - 1)[:3] + 1 + dist0 = torch.empty(self.height + 2, self.width + 2) + dist1 = torch.empty(self.height + 2, self.width + 2) + for X, f_X in [(A, f_A), (B, f_B)]: + nb_rec = torch.randint(3, (1,)).item() + 1 + while True: + r = self.rec_coo(nb_rec, prevent_overlap=True) + X[...] = 0 + f_X[...] = 0 + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[0] + f_X[i1:i2, j1:j2] = c[0] + while True: + i0, j0 = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + if X[i0, j0] == 0: + break + while True: + i1, j1 = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + if X[i1, j1] == 0: + break + dist1[...] = 1 + dist1[1:-1, 1:-1] = (X != 0).long() + dist1[...] = self.compute_distance(dist1, i1 + 1, j1 + 1) + if ( + dist1[i0 + 1, j0 + 1] >= 1 + and dist1[i0 + 1, j0 + 1] < self.height * 4 + ): + break + + dist0[...] = 1 + dist0[1:-1, 1:-1] = (X != 0).long() + dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1) + + dist0 = dist0[1:-1, 1:-1] + dist1 = dist1[1:-1, 1:-1] + + D = dist1[i0, j0] + for d in range(1, D): + M = (dist0 == d) & (dist1 == D - d) + f_X[...] = (1 - M) * f_X + M * c[1] + + X[i0, j0] = c[2] + f_X[i0, j0] = c[2] + X[i1, j1] = c[2] + f_X[i1, j1] = c[2] # for X, f_X in [(A, f_A), (B, f_B)]: # n = torch.arange(self.height * self.width).reshape(self.height, self.width) @@ -722,80 +1228,528 @@ class Grids(problem.Problem): # i,j=q%self.height,q//self.height # if - ###################################################################### + # @torch.compile + def TOO_HARD_task_puzzle(self, A, f_A, B, f_B): + S = 4 + i0, j0 = (self.height - S) // 2, (self.width - S) // 2 + c = torch.randperm(self.nb_colors - 1)[:4] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + f_X[...] = 0 + h = list(torch.randperm(c.size(0))) + n = torch.zeros(c.max() + 1) + for _ in range(2): + k = torch.randperm(S * S) + for q in k: + i, j = q % S + i0, q // S + j0 + if f_X[i, j] == 0: + r, s, t, u = ( + f_X[i - 1, j], + f_X[i, j - 1], + f_X[i + 1, j], + f_X[i, j + 1], + ) + r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)] + if r > 0 and n[r] < 6: + n[r] += 1 + f_X[i, j] = r + elif s > 0 and n[s] < 6: + n[s] += 1 + f_X[i, j] = s + elif t > 0 and n[t] < 6: + n[t] += 1 + f_X[i, j] = t + elif u > 0 and n[u] < 6: + n[u] += 1 + f_X[i, j] = u + else: + if len(h) > 0: + d = c[h.pop()] + n[d] += 1 + f_X[i, j] = d + + if n.sum() == S * S: + break - def all_tasks(self): - return [ - self.task_replace_color, - self.task_translate, - self.task_grow, - self.task_color_grow, - self.task_frame, - self.task_detect, - self.task_count, - self.task_trajectory, - self.task_bounce, - self.task_scale, - self.task_symbols, - self.task_ortho, - # self.task_islands, - ] + k = 0 + for d in range(4): + while True: + ii, jj = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + e = 0 + for i in range(S): + for j in range(S): + if ( + ii + i >= self.height + or jj + j >= self.width + or ( + f_X[i + i0, j + j0] == c[d] + and X[ii + i, jj + j] > 0 + ) + ): + e = 1 + if e == 0: + break + for i in range(S): + for j in range(S): + if f_X[i + i0, j + j0] == c[d]: + X[ii + i, jj + j] = c[d] + + def TOO_MESSY_task_islands(self, A, f_A, B, f_B): + c = torch.randperm(self.nb_colors - 1)[:2] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0: + self.cache_islands = list( + grow_islands( + 1000, + self.height, + self.width, + nb_seeds=self.height * self.width // 20, + nb_iterations=self.height * self.width // 2, + ) + ) + + A = self.cache_islands.pop() + + while True: + i, j = ( + torch.randint(self.height // 2, (1,)).item(), + torch.randint(self.width // 2, (1,)).item(), + ) + if A[i, j] > 0: + break + + X[...] = (A > 0) * c[0] + f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0] + f_X[i, j] = X[i, j] + X[i, j] = c[1] - def trivial_prompts_and_answers(self, prompts, answers): + # @torch.compile + def TOO_HARD_task_stack(self, A, f_A, B, f_B): + N = 5 + c = torch.randperm(self.nb_colors - 1)[:N] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + i1, j1, i2, j2 = ( + self.height // 2 - 1, + self.width // 2 - 1, + self.height // 2 + 1, + self.width // 2 + 1, + ) + op = torch.tensor((0, 1, 2, 3) * 4) + op = op[torch.randperm(op.size(0))[:9]] + for q in range(op.size(0)): + u = 3 * (q // 3) + v = 3 * (q % 3) + d = c[torch.randint(N, (1,)).item()] + # X[u+1,v+1]=d + if op[q] == 0: # right + X[u : u + 3, v + 2] = d + elif op[q] == 1: # let + X[u : u + 3, v] = d + elif op[q] == 2: # bottom + X[u + 2, v : v + 3] = d + elif op[q] == 3: # top + X[u, v : v + 3] = d + + if q == 0: + f_X[i1:i2, j1:j2] = d + elif op[q] == 0: # right + f_X[i1:i2, j2] = d + j2 += 1 + elif op[q] == 1: # let + j1 -= 1 + f_X[i1:i2, j1] = d + elif op[q] == 2: # bottom + f_X[i2, j1:j2] = d + i2 += 1 + elif op[q] == 3: # top + i1 -= 1 + f_X[i1, j1:j2] = d + + def randint(self, *m): + m = torch.tensor(m) + return (torch.rand(m.size()) * m).long() + + def TOO_HARD_task_matrices(self, A, f_A, B, f_B): + N = 6 + c = torch.randperm(self.nb_colors - 1)[:N] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + M1 = torch.randint(2, (5, 5)) + M2 = torch.randint(2, (5, 5)) + P = M1 @ M2 + for i in range(5): + for j in range(5): + X[i, j] = c[M1[i, j]] + X[i, j + 5] = c[M2[i, j]] + f_X[i, j] = c[M1[i, j]] + f_X[i, j + 5] = c[M2[i, j]] + f_X[i + 5, j + 5] = c[P[i, j]] + + def TOO_HARD_task_compute(self, A, f_A, B, f_B): + N = 6 + c = torch.randperm(self.nb_colors - 1)[:N] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + v = torch.randint((self.width - 1) // 2, (N,)) + 1 + chain = torch.randperm(N) + eq = [] + for i in range(chain.size(0) - 1): + i1, i2 = chain[i], chain[i + 1] + v1, v2 = v[i1], v[i2] + k = torch.arange(self.width // 2) + 1 + d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1 + d = d[torch.randint(d.size(0), (1,)).item()] + w1, w2 = d + eq.append((c[i1], w1, c[i2], w2)) + + ii = torch.randperm(self.height - 2)[: len(eq)] + + for k, x in enumerate(eq): + i = ii[k] + c1, w1, c2, w2 = x + s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item() + X[i, s : s + w1] = c1 + X[i, s + w1 : s + w1 + w2] = c2 + f_X[i, s : s + w1] = c1 + f_X[i, s + w1 : s + w1 + w2] = c2 + + i1, i2 = torch.randperm(N)[:2] + v1, v2 = v[i1], v[i2] + k = torch.arange(self.width // 2) + 1 + d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1 + d = d[torch.randint(d.size(0), (1,)).item()] + w1, w2 = d + c1, c2 = c[i1], c[i2] + s = 0 # torch.randint(self.width - (w1 + w2) + 1, (1,)).item() + i = self.height - 1 + X[i, s : s + w1] = c1 + X[i, s + w1 : s + w1 + 1] = c2 + f_X[i, s : s + w1] = c1 + f_X[i, s + w1 : s + w1 + w2] = c2 + + # @torch.compile + # [ai1,ai2] [bi1,bi2] + def task_contact(self, A, f_A, B, f_B): + def rec_dist(a, b): + ai1, aj1, ai2, aj2 = a + bi1, bj1, bi2, bj2 = b + v = max(ai1 - bi2, bi1 - ai2) + h = max(aj1 - bj2, bj1 - aj2) + return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0)) + + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + r = self.rec_coo(nb_rec, prevent_overlap=True) + d = [rec_dist(r[0], r[k]) for k in range(nb_rec)] + if min(d[1:]) == 0: + break + + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + if d[n] == 0: + f_X[i1, j1:j2] = c[0] + f_X[i2 - 1, j1:j2] = c[0] + f_X[i1:i2, j1] = c[0] + f_X[i1:i2, j2 - 1] = c[0] + + # @torch.compile + # [ai1,ai2] [bi1,bi2] + def task_corners(self, A, f_A, B, f_B): + polarity = torch.randint(2, (1,)).item() + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + r = self.rec_coo(nb_rec, prevent_overlap=True) + + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + for k in range(2): + if polarity == 0: + X[i1 + k, j1] = c[n] + X[i2 - 1 - k, j2 - 1] = c[n] + X[i1, j1 + k] = c[n] + X[i2 - 1, j2 - 1 - k] = c[n] + else: + X[i1 + k, j2 - 1] = c[n] + X[i2 - 1 - k, j1] = c[n] + X[i1, j2 - 1 - k] = c[n] + X[i2 - 1, j1 + k] = c[n] + f_X[i1:i2, j1:j2] = c[n] + + def compdist(self, X, i, j): + dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width) + d = dd[1:-1, 1:-1] + m = (X > 0).long() + d[i, j] = 0 + e = d.clone() + while True: + e[...] = d + d[...] = ( + d.min(dd[:-2, 1:-1] + 1) + .min(dd[2:, 1:-1] + 1) + .min(dd[1:-1, :-2] + 1) + .min(dd[1:-1, 2:] + 1) + ) + d[...] = (1 - m) * d + m * self.height * self.width + if e.equal(d): + break + + return d + + # @torch.compile + def task_path(self, A, f_A, B, f_B): + nb_rec = 2 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + X[...] = 0 + f_X[...] = 0 + + r = self.rec_coo(nb_rec, prevent_overlap=True) + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + + i1, i2 = torch.randint(self.height, (2,)) + j1, j2 = torch.randint(self.width, (2,)) + if ( + abs(i1 - i2) + abs(j1 - j2) > 2 + and X[i1, j1] == 0 + and X[i2, j2] == 0 + ): + d2 = self.compdist(X, i2, j2) + d = self.compdist(X, i1, j1) + + if d2[i1, j1] < 2 * self.width: + break + + m = ((d + d2) == d[i2, j2]).long() + f_X[...] = m * c[-1] + (1 - m) * f_X + + X[i1, j1] = c[-2] + X[i2, j2] = c[-2] + f_X[i1, j1] = c[-2] + f_X[i2, j2] = c[-2] + + # @torch.compile + def task_fill(self, A, f_A, B, f_B): + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + accept_full = torch.rand(1) < 0.5 + + while True: + X[...] = 0 + f_X[...] = 0 + + r = self.rec_coo(nb_rec, prevent_overlap=True) + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + + while True: + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + if X[i, j] == 0: + break + + d = self.compdist(X, i, j) + m = (d < self.height * self.width).long() + X[i, j] = c[-1] + f_X[...] = m * c[-1] + (1 - m) * f_X + f_X[i, j] = 0 + + if accept_full or (d * (X == 0)).max() == self.height * self.width: + break + + def TOO_HARD_task_addition(self, A, f_A, B, f_B): + c = torch.randperm(self.nb_colors - 1)[:4] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() + N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item() + S = N1 + N2 + for j in range(self.width): + r1 = (N1 // (2**j)) % 2 + X[0, -j - 1] = c[r1] + f_X[0, -j - 1] = c[r1] + r2 = (N2 // (2**j)) % 2 + X[1, -j - 1] = c[r2] + f_X[1, -j - 1] = c[r2] + rs = (S // (2**j)) % 2 + f_X[2, -j - 1] = c[2 + rs] + + def task_science_implicit(self, A, f_A, B, f_B): + nb_rec = 5 + c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + i1, i2 = torch.randint(self.height, (2,)).sort().values + if i1 >= 1 and i2 < self.height and i1 + 3 < i2: + break + + while True: + j1, j2 = torch.randint(self.width, (2,)).sort().values + if j1 >= 1 and j2 < self.width and j1 + 3 < j2: + break + + f_X[i1:i2, j1:j2] = c[0] + + # --------------------- + + while True: + ii1, ii2 = torch.randint(self.height, (2,)).sort().values + if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2: + break + jj = torch.randint(j1, (1,)) + X[ii1:ii2, jj:j1] = c[1] + f_X[ii1:ii2, jj:j1] = c[1] + + while True: + ii1, ii2 = torch.randint(self.height, (2,)).sort().values + if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2: + break + jj = torch.randint(self.width - j2, (1,)) + j2 + 1 + X[ii1:ii2, j2:jj] = c[2] + f_X[ii1:ii2, j2:jj] = c[2] + + # --------------------- + + while True: + jj1, jj2 = torch.randint(self.width, (2,)).sort().values + if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2: + break + ii = torch.randint(i1, (1,)) + X[ii:i1, jj1:jj2] = c[3] + f_X[ii:i1, jj1:jj2] = c[3] + + while True: + jj1, jj2 = torch.randint(self.width, (2,)).sort().values + if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2: + break + ii = torch.randint(self.height - i2, (1,)) + i2 + 1 + X[i2:ii, jj1:jj2] = c[4] + f_X[i2:ii, jj1:jj2] = c[4] + + def task_science_dot(self, A, f_A, B, f_B): + nb_rec = 3 + c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + X[...] = 0 + f_X[...] = 0 + r = self.rec_coo(nb_rec, prevent_overlap=True) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + q = 0 + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[n] + f_X[i1:i2, j1:j2] = c[n] + if i >= i1 and i < i2: + q += 1 + f_X[i, j1:j2] = c[-1] + if j >= j1 and j < j2: + q += 1 + f_X[i1:i2, j] = c[-1] + X[i, j] = c[-1] + f_X[i, j] = c[-1] + if q >= 2: + break + + def collide(self, s, r, rs): + i, j = r + for i2, j2 in rs: + if abs(i - i2) < s and abs(j - j2) < s: + return True + return False + + def task_science_tag(self, A, f_A, B, f_B): + c = torch.randperm(self.nb_colors - 1)[:4] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + rs = [] + while len(rs) < 4: + i, j = ( + torch.randint(self.height - 3, (1,)).item(), + torch.randint(self.width - 3, (1,)).item(), + ) + if not self.collide(s=3, r=(i, j), rs=rs): + rs.append((i, j)) + + for k in range(len(rs)): + i, j = rs[k] + q = min(k, 2) + X[i, j : j + 3] = c[q] + X[i + 2, j : j + 3] = c[q] + X[i : i + 3, j] = c[q] + X[i : i + 3, j + 2] = c[q] + + f_X[i, j : j + 3] = c[q] + f_X[i + 2, j : j + 3] = c[q] + f_X[i : i + 3, j] = c[q] + f_X[i : i + 3, j + 2] = c[q] + if q == 2: + f_X[i + 1, j + 1] = c[-1] + + # end_tasks + + ###################################################################### + + def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")): S = self.height * self.width - Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S] - f_Bs = answers - return (Bs == f_Bs).long().min(dim=-1).values > 0 + quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64) + quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]] + quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]] + quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]] + quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]] - def generate_prompts_and_answers( - self, nb, tasks=None, progress_bar=False, device="cpu" - ): - if tasks is None: - tasks = self.all_tasks() + return quizzes + def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False): S = self.height * self.width - prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64) - answers = torch.zeros(nb, S, dtype=torch.int64) - bunch = zip(prompts, answers) + if tasks is None: + tasks = self.all_tasks + + quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B")) if progress_bar: - bunch = tqdm.tqdm( - bunch, + quizzes = tqdm.tqdm( + quizzes, dynamic_ncols=True, - desc="world generation", - total=prompts.size(0), + desc="world quizzes generation", + total=quizzes.size(0), ) - for prompt, answer in bunch: - A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width) - f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width) - B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width) - f_B = answer.view(self.height, self.width) - task = tasks[torch.randint(len(tasks), (1,))] + for quiz in quizzes: + q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width) + q[...] = 0 + A, f_A, B, f_B = q + task = tasks[torch.randint(len(tasks), (1,)).item()] task(A, f_A, B, f_B) - return prompts.flatten(1), answers.flatten(1) + return quizzes - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - nrow=4, - ): - self.save_image( - result_dir, - filename_prefix + ".png", - prompts, - answers, - predicted_prompts, - predicted_answers, - nrow, - ) + def save_some_examples(self, result_dir, prefix=""): + nb, nrow = 128, 4 + for t in self.all_tasks: + print(t.__name__) + quizzes = self.generate_w_quizzes_(nb, tasks=[t]) + self.save_quizzes_as_image( + result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow + ) ###################################################################### @@ -803,37 +1757,87 @@ class Grids(problem.Problem): if __name__ == "__main__": import time + # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4) + grids = Grids() + # nb = 5 + # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill]) + # print(quizzes) + # print(grids.get_structure(quizzes)) + # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B")) + # print("DEBUG2", quizzes) + # print(grids.get_structure(quizzes)) + # print(quizzes) + + # i = torch.rand(quizzes.size(0)) < 0.5 + + # quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A")) + + # j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A")) + + # print( + # i.equal(j), + # grids.get_structure(quizzes[j]), + # grids.get_structure(quizzes[j == False]), + # ) + + # exit(0) + # nb = 1000 # grids = problem.MultiThreadProblem( # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1 # ) # time.sleep(10) # start_time = time.perf_counter() - # prompts, answers = grids.generate_prompts_and_answers(nb) + # prompts, answers = grids.generate_w_quizzes(nb) # delay = time.perf_counter() - start_time # print(f"{prompts.size(0)/delay:02f} seq/s") # exit(0) - if True: - nb = 72 + # if True: + nb, nrow = 128, 4 + # nb, nrow = 8, 2 - for t in grids.all_tasks(): - # for t in [grids.task_ortho]: - print(t.__name__) - prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) - grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) + # for t in grids.all_tasks: - exit(0) + for t in [grids.task_recworld_immobile]: + print(t.__name__) + w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) + grids.save_quizzes_as_image( + "/tmp", + t.__name__ + ".png", + w_quizzes, + comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))], + ) - nb = 500 + exit(0) - for t in grids.all_tasks(): + nb = 1000 + + for t in [ + # grids.task_bounce, + # grids.task_contact, + # grids.task_corners, + # grids.task_detect, + # grids.task_fill, + # grids.task_frame, + # grids.task_grow, + # grids.task_half_fill, + # grids.task_isometry, + # grids.task_path, + # grids.task_replace_color, + # grids.task_scale, + grids.task_symbols, + # grids.task_trajectory, + # grids.task_translate, + ]: + # for t in [grids.task_path]: start_time = time.perf_counter() - prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) + w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) delay = time.perf_counter() - start_time - print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s") + print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s") + grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128]) exit(0) @@ -841,9 +1845,9 @@ if __name__ == "__main__": predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - grids.save_quizzes( + grids.save_quizzes_as_image( "/tmp", - "test", + "test.png", prompts[:nb], answers[:nb], # You can add a bool to put a frame around the predicted parts diff --git a/main.py b/main.py index 9c3d7f1..40772c2 100755 --- a/main.py +++ b/main.py @@ -3,6 +3,9 @@ # Any copyright is dedicated to the Public Domain. # https://creativecommons.org/publicdomain/zero/1.0/ +# > A > f(A) > B ; > f(B) +# < f(B) ; < B < f(A) < A + # Written by Francois Fleuret import math, sys, argparse, time, tqdm, os, datetime, warnings @@ -15,22 +18,16 @@ import ffutils import mygpt import sky, grids, quiz_machine -from problem import MultiThreadProblem -# world quizzes vs. culture quizzes +from quiz_machine import one_batch_masked_inplace_autoregression -###################################################################### +import threading, subprocess -if torch.cuda.is_available(): - device = torch.device("cuda") - torch.backends.cuda.matmul.allow_tf32 = True -else: - device = torch.device("cpu") +import torch.multiprocessing as mp ###################################################################### parser = argparse.ArgumentParser( - description="An implementation of GPT with cache.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) @@ -40,9 +37,13 @@ parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) -parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) +parser.add_argument("--resume", action="store_true", default=False) + +parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1) + +parser.add_argument("--log_command", type=str, default=None) -######################################## +# ---------------------------------- parser.add_argument("--nb_epochs", type=int, default=10000) @@ -50,14 +51,21 @@ parser.add_argument("--batch_size", type=int, default=None) parser.add_argument("--physical_batch_size", type=int, default=None) +parser.add_argument("--inference_batch_size", type=int, default=None) + parser.add_argument("--nb_train_samples", type=int, default=None) parser.add_argument("--nb_test_samples", type=int, default=None) +parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) + +parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) + parser.add_argument("--learning_rate", type=float, default=5e-4) -######################################## +parser.add_argument("--schedule_free", action="store_true", default=False) +# ---------------------------------- parser.add_argument("--model", type=str, default=None) parser.add_argument("--dim_model", type=int, default=None) @@ -72,30 +80,57 @@ parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.1) -######################################## - +# ---------------------------------- parser.add_argument("--deterministic_synthesis", action="store_true", default=False) parser.add_argument("--problem", type=str, default="grids") -parser.add_argument("--multi_thread_problem", action="store_true", default=False) +parser.add_argument("--nb_threads", type=int, default=1) + +parser.add_argument("--gpus", type=str, default="all") + +# ---------------------------------- parser.add_argument("--nb_gpts", type=int, default=5) -parser.add_argument("--min_to_validate", type=int, default=None) +parser.add_argument("--max_fail_to_validate", type=int, default=3) + +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) -parser.add_argument("--max_to_validate", type=int, default=None) +parser.add_argument("--proba_understands", type=float, default=0.95) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) +parser.add_argument("--proba_not_understands", type=float, default=0.1) -parser.add_argument("--generation_temperature", type=float, default=2.0) +parser.add_argument("--temperature_hot", type=float, default=1.5) -parser.add_argument("--deterministic_validation", action="store_true", default=False) +parser.add_argument("--temperature_cold", type=float, default=1) -parser.add_argument("--bidirectional_validation", action="store_true", default=False) +parser.add_argument("--prompt_noise", type=float, default=0.05) parser.add_argument("--dirty_debug", action="store_true", default=False) +parser.add_argument("--test", type=str, default=None) + +###################################################################### + +grids_tasks = ", ".join( + [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks] +) + +parser.add_argument( + "--grids_world_tasks", + type=str, + default="replace_color,translate,grow,frame", + help="A comma-separated subset of: " + grids_tasks + ".", +) + +parser.add_argument( + "--grids_science_tasks", + type=str, + default=None, + help="A comma-separated subset of: " + grids_tasks + ", or None.", +) + ###################################################################### parser.add_argument("--sky_height", type=int, default=6) @@ -112,22 +147,25 @@ parser.add_argument("--sky_speed", type=int, default=3) args = parser.parse_args() -if args.min_to_validate is None: - args.min_to_validate = args.nb_gpts - 1 - -if args.max_to_validate is None: - args.max_to_validate = args.nb_gpts - 1 - if args.result_dir is None: args.result_dir = f"results_culture" +assert not args.grids_science_tasks or ( + len( + set(args.grids_world_tasks.split(",")) + & set(args.grids_science_tasks.split(",")) + ) + == 0 +), "World and science tasks have to be disjoint" + ###################################################################### default_args = { "model": "37M", - "batch_size": 100, - "nb_train_samples": 100000, - "nb_test_samples": 10000, + "batch_size": 25, + "inference_batch_size": 50, + "nb_train_samples": 40000, + "nb_test_samples": 1000, } for k, v in default_args.items(): @@ -183,11 +221,15 @@ else: ###################################################################### -try: - os.mkdir(args.result_dir) -except FileExistsError: - print(f"result directory {args.result_dir} already exists") - exit(1) +if args.resume: + assert os.path.isdir(args.result_dir) + +else: + try: + os.mkdir(args.result_dir) + except FileExistsError: + print(f"result directory {args.result_dir} already exists") + exit(1) log_file = open(os.path.join(args.result_dir, args.log_filename), "a") @@ -213,6 +255,18 @@ def log_string(s): sys.stdout.flush() +###################################################################### +# Create a time-stamped archive of the source code + +with open("this_run.sh", "w") as f: + f.write(f"{' '.join(sys.argv)}\n") + +now = time.strftime("%Y%m%d-%H%M%S", time.localtime()) + +os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh") + +###################################################################### + log_string(f"argv {' '.join(sys.argv)}") for n in vars(args): @@ -221,6 +275,19 @@ for n in vars(args): ###################################################################### +if args.gpus == "all": + gpus_idx = range(torch.cuda.device_count()) +else: + gpus_idx = [int(k) for k in args.gpus.split(",")] + +gpus = [torch.device(f"cuda:{n}") for n in gpus_idx] + +if torch.cuda.is_available(): + main_device = gpus[0] +else: + assert len(gpus) == 0 + main_device = torch.device("cpu") + if args.dirty_debug: args.nb_train_samples = 2500 args.nb_test_samples = 100 @@ -240,31 +307,52 @@ if args.problem == "sky": nb_birds=args.sky_nb_birds, nb_iterations=args.sky_nb_iterations, speed=args.sky_speed, + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, ) - back_accuracy = False + elif args.problem == "grids": - problem = grids.Grids(device=device) - back_accuracy = True + problem = grids.Grids( + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, + tasks=args.grids_world_tasks, + ) + + if args.grids_science_tasks is None: + science_w_quizzes = None + else: + science_problem = grids.Grids( + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, + tasks=args.grids_science_tasks, + ) + science_w_quizzes = science_problem.generate_w_quizzes(100) + + if not args.resume: + science_problem.save_some_examples(args.result_dir, "science_") + + else: raise ValueError -if args.multi_thread_problem: - problem = MultiThreadProblem(problem, args.nb_train_samples, chunk_size=1000) +if not args.resume: + problem.save_some_examples(args.result_dir) quiz_machine = quiz_machine.QuizMachine( problem=problem, - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - back_accuracy=back_accuracy, - batch_size=args.physical_batch_size, + batch_size=args.inference_batch_size, result_dir=args.result_dir, + prompt_noise=args.prompt_noise, logger=log_string, - device=device, + device=main_device, ) ###################################################################### -log_string(f"device {device}") +log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}") vocabulary_size = quiz_machine.vocabulary_size() @@ -272,230 +360,674 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### -# Compute the entropy of the training tokens -token_count = 0 -for input in quiz_machine.batches(split="train", desc="train-entropy"): - token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum( - (0, 1) - ) -token_probas = token_count / token_count.sum() -entropy = -torch.xlogy(token_probas, token_probas).sum() -train_set_perplexity = math.exp(entropy) +def optimizer_to(optim, device): + for param in optim.state.values(): + # Not sure there are any global tensors in the state dict + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) + ###################################################################### -# A bit of paranoia never hurts - -if args.max_percents_of_test_in_train >= 0: - - def subsets_as_tuples(batches, cs): - s = set() - for batch in batches: - for x in batch: - s.add(tuple([v.item() for v in x])) - if len(s) == cs: - yield s - s = set() - yield s - - nb_test, nb_in_train = 0, 0 - for test_subset in subsets_as_tuples( - quiz_machine.batches(split="test", desc="test-check"), 25000 - ): - in_train = set() - for train_subset in subsets_as_tuples( - quiz_machine.batches(split="train", desc="train-check"), 25000 + + +def run_tests(model, quiz_machine, local_device=main_device): + with torch.autograd.no_grad(): + model.to(local_device).eval() + if args.schedule_free: + model.optimizer.eval() + + nb_test_samples, acc_test_loss = 0, 0.0 + nb_samples_accumulated = 0 + + full_input, full_mask_loss = quiz_machine.data_input(model, split="test") + src = zip( + full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) + ) + + for input, mask_loss in tqdm.tqdm( + src, + dynamic_ncols=True, + desc="test", + total=full_input.size(0) // args.batch_size, ): - in_train.update(test_subset.intersection(train_subset)) - nb_in_train += len(in_train) - nb_test += len(test_subset) + input = input.to(local_device) + mask_loss = mask_loss.to(local_device) + targets = input - log_string( - f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set" - ) + output = model(mygpt.BracketedSequence(input)).x + loss_per_token = F.cross_entropy( + output.transpose(1, 2), targets, reduction="none" + ) + loss = (loss_per_token * mask_loss).mean() + acc_test_loss += loss.item() * input.size(0) + nb_test_samples += input.size(0) - assert ( - nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 - ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) -############################## + log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") + model.main_test_accuracy = quiz_machine.produce_results( + n_epoch=n_epoch, + model=model, + input=full_input[:2000], + result_dir=args.result_dir, + ) -def one_epoch(model, quiz_machine): - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - model.train() +###################################################################### + + +def one_epoch(model, quiz_machine, local_device=main_device): + model.to(local_device).train() + optimizer_to(model.optimizer, local_device) + + if args.schedule_free: + model.optimizer.train() nb_train_samples, acc_train_loss = 0, 0.0 - for input in quiz_machine.batches(split="train"): - input = input.to(device) + hard_w_quizzes = [] + + full_input, full_mask_loss = quiz_machine.data_input(model, split="train") + src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) + + for input, mask_loss in tqdm.tqdm( + src, + dynamic_ncols=True, + desc="training", + total=full_input.size(0) // args.batch_size, + ): + input = input.to(local_device) + mask_loss = mask_loss.to(local_device) if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() + model.optimizer.zero_grad() + + targets = input output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) + loss_per_token = F.cross_entropy( + output.transpose(1, 2), targets, reduction="none" + ) + loss = (loss_per_token * mask_loss).mean() + model.loss acc_train_loss += loss.item() * input.size(0) + loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) + nb_train_samples += input.size(0) loss.backward() if nb_train_samples % args.batch_size == 0: - optimizer.step() + model.optimizer.step() train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - log_string(f"train_perplexity {n_epoch} {train_perplexity}") + log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}") + + run_tests(model, quiz_machine) + + # threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values + # threshold = threshold[threshold.size(0) // 2] + + # model.hard_w_quizzes = torch.cat( + # [x[l >= threshold] for x, l in hard_w_quizzes], dim=0 + # ) + + model.to(main_device) + optimizer_to(model.optimizer, main_device) ###################################################################### -def run_tests(model, quiz_machine, deterministic_synthesis): - with torch.autograd.no_grad(): - model.eval() +def model_transformer_hot(model): + model.temperature = args.temperature_hot + # model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2)) - nb_test_samples, acc_test_loss = 0, 0.0 - nb_samples_accumulated = 0 - for input in quiz_machine.batches(split="test"): - input = input.to(device) +def model_transformer_cold(model): + model.temperature = args.temperature_cold + # pass - bs = model(mygpt.BracketedSequence(input)) - output = bs.x - loss = F.cross_entropy(output.transpose(1, 2), input) +c_quizzes_procedure = [ + (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), +] - acc_test_loss += loss.item() * input.size(0) +###################################################################### - nb_test_samples += input.size(0) - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) +def save_additional_results(model, models, science_w_quizzes): + # Save generated quizzes with the successive steps - log_string(f"test_perplexity {n_epoch} {test_perplexity}") + recorder = [] - model.main_test_accuracy = quiz_machine.produce_results( - n_epoch=n_epoch, + c_quizzes = quiz_machine.generate_c_quizzes( + 64, + model_for_generation=model, + procedure=c_quizzes_procedure, + recorder=recorder, + ) + + # This is nb_quizzes x nb_models + + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + probas = seq_logproba.exp() + + comments = [] + + for l in seq_logproba: + comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) + + ## + + c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1) + predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1) + nb_steps = c_quizzes.size(1) + c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1)) + predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1)) + + # We have comments only for the final quiz, not the successive + # steps, so we have to add nb_steps-1 empty comments + + steps_comments = [] + for c in comments: + steps_comments += [""] * (nb_steps - 1) + [c] + + filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png" + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=c_quizzes, + predicted_parts=predicted_parts, + comments=steps_comments, + nrow=nb_steps * 2, # two quiz per row + ) + + log_string(f"wrote {filename}") + + ###################################################################### + + if science_w_quizzes is not None: + struct = ("A", "f_A", "B", "f_B") + mask = (0, 0, 0, 1) + result, correct = quiz_machine.predict( model=model, - result_dir=args.result_dir, - deterministic_synthesis=deterministic_synthesis, + quizzes=science_w_quizzes.to(main_device), + struct=struct, + mask=mask, + ) + + predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand( + correct.size(0), -1 + ) + correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long() + + nb_correct = (correct == 1).long().sum() + nb_total = (correct != 0).long().sum() + + log_string( + f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" + ) + + i = correct == 1 + j = correct != 1 + + result = torch.cat([result[i], result[j]], dim=0) + correct = torch.cat([correct[i], correct[j]], dim=0) + correct_parts = predicted_parts * correct[:, None] + + result = result[:128] + predicted_parts = predicted_parts[:128] + correct_parts = correct_parts[:128] + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"culture_science_{n_epoch:04d}_{model.id:02d}.png", + quizzes=result, + predicted_parts=predicted_parts, + correct_parts=correct_parts, ) ###################################################################### -def valid_c_quizzes(recorded, criteria): - result = [q[criteria(c)] for q, c in recorded] - return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([]) +def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): + nb_to_validate = nb_for_train + nb_for_test + nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate) + nb_validated = 0 + + recorded_validated = [] + + start_time = time.perf_counter() + + nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64) + + while nb_validated_per_model.sum() < nb_to_validate: + # We use the model that has generated the fewest quizzes to + # balance the number of quizzes per model overall + + # model_for_generation = sorted( + # models, key=lambda m: nb_validated_per_model[m.id] + # )[0] + + model_for_generation = models[torch.randint(len(models), (1,)).item()] + + # We generate quizzes with a procedure that injects some + # structured noise + + c_quizzes = quiz_machine.generate_c_quizzes( + nb_to_generate_per_iteration, + model_for_generation=model, + procedure=c_quizzes_procedure, + ) + + # We discard the trivial ones, according to a criterion + # specific to the world quizzes (e.g. B=f(B)) + + to_keep = quiz_machine.problem.trivial(c_quizzes) == False + + c_quizzes = c_quizzes[to_keep] + + # This is nb_quizzes x nb_models + + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + probas = seq_logproba.exp() + + nb_succeed = (probas >= args.proba_understands).long().sum(dim=1) + nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1) + + to_keep = ( + (nb_succeed + nb_fail == probas.size(1)) + & (nb_fail >= 1) + & (nb_fail <= args.max_fail_to_validate) + ) + + c_quizzes = c_quizzes[to_keep] + + if c_quizzes.size(0) > 0: + nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0) + recorded_validated.append(c_quizzes) + nb_validated = c_quizzes.size(0) + else: + nb_validated = 0 + + total_nb_validated = nb_validated_per_model.sum().item() + + duration = time.perf_counter() - start_time + + if total_nb_validated > 0: + if total_nb_validated < nb_to_validate: + d = ( + (nb_to_validate - total_nb_validated) + * duration + / total_nb_validated + ) + e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime( + "%a %H:%M" + ) + else: + e = "now!" + else: + e = "???" + + log_string( + f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)" + ) + + validated_quizzes = torch.cat(recorded_validated, dim=0) + + ###################################################################### + # store the new c_quizzes which have been validated + + v_train = validated_quizzes[:nb_for_train] + quiz_machine.store_c_quizzes(v_train, for_train=True) + + v_test = validated_quizzes[nb_for_train:nb_to_validate] + quiz_machine.store_c_quizzes(v_test, for_train=False) + + ###################################################################### + # save images + + vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] + + if vq.size(0) > 0: + seq_logproba = quiz_machine.models_logprobas( + models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + probas = seq_logproba.exp() + + comments = [] + + for l in seq_logproba: + comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l])) + + filename = f"culture_c_quiz_{n_epoch:04d}.png" + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, filename, vq, comments=comments + ) ###################################################################### +# The generator is very similar to a "solving GPT" except that it +# deals with quizzes prologued with one token per solving GPT that +# indicates if the said model solves it or not. +# +# There are three levels of solving 0->proba<=proba_not_understands, +# 2->proba>=proba_understands and 1 otherwise. -def create_c_quizzes( - models, - quiz_machine, - nb_for_train=1000, - nb_for_test=100, -): - quizzes_and_nb_correct_records = [] - nb_to_create = nb_for_train + nb_for_test +def generate_c_quizzes_with_generator(generator, quiz_machine, nb): + generator.to(main_device) + + struct = ("A", "f_A", "B", "f_B") - # ------------------------------------------------------------ + c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct) + ar_mask = quiz_machine.make_quiz_mask(c_quizzes, struct, (1, 1, 1, 1)) - standard_validity = lambda nb_correct: torch.logical_and( - nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate + i = F.one_hot( + torch.randint(args.nb_gpts, (c_quizzes.size(0),)), + num_classes=args.nb_gpts, ) - file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") + prologs_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i) + prologs_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prologs_c_quizzes.size(1)) - with open(file_name, "w") as logp_file: - while ( - valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0) - < nb_to_create - ): - # Select a model at random to generate the new quizzes + prologued_c_quizzes = torch.cat([prologs_c_quizzes, c_quizzes], dim=1).to( + main_device + ) + prologued_ar_mask = torch.cat([prologs_ar_mask, ar_mask], dim=1).to(main_device) - model_for_generation = models[torch.randint(len(models), (1,))] + seq_logproba = torch.zeros( + prologued_c_quizzes.size(0), device=prologued_c_quizzes.device + ) - c_quizzes = quiz_machine.generate_quizzes( - nb_to_create, - model_for_generation=model_for_generation, - temperature=args.generation_temperature, - ) + generator.temperature = args.temperature_hot + + with torch.autograd.no_grad(): + t = generator.training + generator.eval() + + one_batch_masked_inplace_autoregression( + generator, + prologued_c_quizzes, + prologued_ar_mask, + seq_logproba, + deterministic_synthesis=False, + ) + + generator.train(t) + + generator.reset_transformations() + + prologued_c_quizzes = ( + prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long() + ) + + c_quizzes = prologued_c_quizzes[:, prologs_c_quizzes.size(1) :] + + return c_quizzes.to("cpu"), prologs_c_quizzes.to("cpu") - c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] + +def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0): + samples = [] + + for _ in range(args.nb_train_samples // args.batch_size): + while sum([x.size(0) for x in samples]) < args.batch_size: + # Generate a bunch of quizzes + + if torch.rand(1).item() <= fraction_w_quizzes: + # Either we start with the world quizzes + c_quizzes = quiz_machine.problem.generate_w_quizzes( + args.batch_size, progress_bar=False + ) + else: + # Or we use the generator itself to generate them + c_quizzes, _ = generate_c_quizzes_with_generator( + generator, quiz_machine, args.batch_size + ) + + # We remove the trivial ones + to_keep = quiz_machine.problem.trivial(c_quizzes) == False + c_quizzes = c_quizzes[to_keep] + + # If there are remaining ones, we compute the true prolog + # that indicates how the GPTs solve it if c_quizzes.size(0) > 0: - nb_correct, seq_logproba = quiz_machine.compute_correctness( + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, + ("A", "f_A", "B", "f_B"), + (0, 0, 0, 1), + (0, 0, 1, 0), + ) + quiz_machine.models_logprobas( models, - bidirectional_validation=args.bidirectional_validation, - deterministic_validation=args.deterministic_validation, + c_quizzes, + ("f_A", "A", "f_B", "B"), + (0, 0, 0, 1), + (0, 0, 1, 0), ) - for n, l in zip(nb_correct, seq_logproba): - s = " ".join([str(x.item()) for x in l]) - logp_file.write(f"{n} {s}\n") + probas = seq_logproba.exp() + + u0 = probas <= args.proba_not_understands + u2 = probas >= args.proba_understands + u1 = (u0 | u2) == False + + prologs = ( + (u0.long() * token_prolog_0) + + (u1.long() * token_prolog_1) + + (u2.long() * token_prolog_2) + ) + + prologued_c_quizzes = torch.cat([prologs, c_quizzes], dim=1) + + # nb_u2 = u2.long().sum(dim=1) + # nb_u0 = u0.long().sum(dim=1) + # prologued_c_quizzes = prologued_c_quizzes[(nb_u2 >= 1) & (nb_u0 >= 1)] + + if prologued_c_quizzes.size(0) > 0: + samples.append(prologued_c_quizzes) + + # Now we yield a batch + + x = torch.cat(samples, dim=0) + samples = [x[args.batch_size :]] + + yield x[: args.batch_size] + + +def one_generator_epoch( + generator, quiz_machine, models, fraction_w_quizzes, local_device=main_device +): + model.to(local_device).train() + + optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate) + + nb_train_samples, acc_train_loss = 0, 0.0 + + src = batches_for_generator( + generator=generator, + quiz_machine=quiz_machine, + models=models, + fraction_w_quizzes=fraction_w_quizzes, + ) + + for input in tqdm.tqdm( + src, + dynamic_ncols=True, + desc="training", + total=args.nb_train_samples // args.batch_size, + ): + input = input.to(local_device) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + targets = input - if args.dirty_debug: - nb_correct = torch.randint( - len(models) + 1, nb_correct.size(), device=c_quizzes.device + output = generator(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), targets) + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} generator - {train_perplexity}") + + generator.to(main_device) + + +###################################################################### + + +def train_complexifier(model_gen, model_pred1, model_pred2): + samples = [] + perf = [] + + optimizer = torch.optim.Adam(model_gen.parameters(), lr=args.learning_rate) + + nb_train_samples, acc_train_loss = 0, 0.0 + + for n_epoch in range(args.nb_epochs): + for b in range(args.nb_train_samples // args.batch_size): + while sum([x.size(0) for x in samples]) < args.batch_size: + c_quizzes = quiz_machine.generate_c_quizzes( + args.inference_batch_size, + model_for_generation=model_gen, + procedure=c_quizzes_procedure, + ) + to_keep = quiz_machine.problem.trivial(c_quizzes) == False + c_quizzes = c_quizzes[to_keep] + if c_quizzes.size(0) > 0: + seq_logproba = quiz_machine.models_logprobas( + [model_pred1, model_pred2], + c_quizzes, + ("A", "f_A", "B", "f_B"), + (0, 0, 0, 1), + ) + quiz_machine.models_logprobas( + [model_pred1, model_pred2], + c_quizzes, + ("f_A", "A", "f_B", "B"), + (0, 0, 0, 1), ) + probas = seq_logproba.exp() + to_keep = (probas[:, model_pred1.id] >= args.proba_understands) & ( + probas[:, model_pred2.id] <= args.proba_not_understands + ) + log_string( + f"generating {to_keep.long().sum()} / {c_quizzes.size(0)}" + ) + c_quizzes = c_quizzes[to_keep] + if c_quizzes.size(0): + samples.append(c_quizzes) + + log_string(f"full batch {sum([x.size(0) for x in samples])}") - quizzes_and_nb_correct_records.append((c_quizzes, nb_correct)) + x = torch.cat(samples, dim=0) - nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) - nv = " ".join([str(x.item()) for x in nv]) + input = x[: args.batch_size] + samples = [x[args.batch_size :]] - nb_validated = valid_c_quizzes( - quizzes_and_nb_correct_records, standard_validity - ).size(0) + # ------------------- - log_string( - f"keep c_quizzes model {model_for_generation.id} kept {nv} nb_accumulated {nb_validated} / {nb_to_create}" + seq_logproba = quiz_machine.models_logprobas( + [model_pred1, model_pred2], + input, + ("A", "f_A", "B", "f_B"), + (0, 0, 0, 1), + ) + quiz_machine.models_logprobas( + [model_pred1, model_pred2], + input, + ("f_A", "A", "f_B", "B"), + (0, 0, 0, 1), ) - # store the new c_quizzes which have been validated + comments = [] - new_c_quizzes = valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity) + for l in seq_logproba: + comments.append( + f"proba {l[model_pred1.id].exp().item():.02f} {l[model_pred2.id].exp().item():.02f}" + ) - quiz_machine.reverse_random_half_in_place(new_c_quizzes) + filename = f"batch_{n_epoch:04d}_{b:04d}.png" + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, filename, input, comments=comments + ) + log_string(f"wrote {filename}") - quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) - quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) + # ------------------------ - # save a bunch of images to investigate what quizzes with a - # certain nb of correct predictions look like + input = input.to(main_device) - for n in range(len(models) + 1): - s = ( - "_validated" - if n >= args.min_to_validate and n <= args.max_to_validate - else "" - ) + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() - q = valid_c_quizzes( - quizzes_and_nb_correct_records, criteria=lambda nb_correct: nb_correct == n - )[:72] + output = model_gen(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) - quiz_machine.reverse_random_half_in_place(q) + loss.backward() - if q.size(0) > 0: - quiz_machine.save_quizzes( - args.result_dir, f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", q - ) + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}") ###################################################################### models = [] + +def compute_causal_attzero(t_q, t_k): + return t_q < t_k + + +if args.schedule_free: + import schedulefree + for k in range(args.nb_gpts): + log_string(f"creating model {k} and its w_quizzes") + model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -503,26 +1035,102 @@ for k in range(args.nb_gpts): dim_hidden=args.dim_hidden, nb_heads=args.nb_heads, nb_blocks=args.nb_blocks, - causal=True, + compute_attzero=compute_causal_attzero, dropout=args.dropout, - ).to(device) + ).to(main_device) - model.main_test_accuracy = 0.0 model.id = k + if args.schedule_free: + model.optimizer = schedulefree.AdamWScheduleFree( + model.parameters(), lr=args.learning_rate + ) + else: + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + model.main_test_accuracy = 0.0 + + model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes( + args.nb_train_samples + ) + + model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) + models.append(model) +###################################################################### + +if args.test == "quant": + nb_bits = 8 + for model in models: + model.trunk.insert( + 12, + mygpt.CacheWrapper( + mygpt.RandomBypass( + nn.Sequential( + nn.Linear(args.dim_model, nb_bits), + mygpt.BSQ(nb_bits), + nn.Linear(nb_bits, args.dim_model), + ), + 0.1, + ) + ), + ) + + print(model) + exit(0) + + +###################################################################### + +current_epoch = 0 + +if args.resume: + for model in models: + filename = f"gpt_{model.id:03d}.pth" + + try: + d = torch.load(os.path.join(args.result_dir, filename)) + model.load_state_dict(d["state_dict"]) + model.optimizer.load_state_dict(d["optimizer_state_dict"]) + model.main_test_accuracy = d["main_test_accuracy"] + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + + try: + filename = "c_quizzes.pth" + quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename)) + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + + try: + filename = "state.pth" + state = torch.load(os.path.join(args.result_dir, filename)) + log_string(f"successfully loaded {filename}") + current_epoch = state["current_epoch"] + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + +###################################################################### nb_parameters = sum(p.numel() for p in models[0].parameters()) log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -nb_new_c_quizzes_for_train = args.nb_train_samples // 50 -nb_new_c_quizzes_for_test = args.nb_test_samples // 50 +if args.nb_new_c_quizzes_for_train is None: + args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100 + +if args.nb_new_c_quizzes_for_test is None: + args.nb_new_c_quizzes_for_test = args.nb_test_samples // 100 log_string( - f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}" + f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}" ) ###################################################################### @@ -530,57 +1138,239 @@ log_string( if args.dirty_debug: args.accuracy_to_make_c_quizzes = 0.0 args.nb_gpts = 2 - nb_new_c_quizzes_for_train = 100 - nb_new_c_quizzes_for_test = 10 + args.nb_new_c_quizzes_for_train = 100 + args.nb_new_c_quizzes_for_test = 10 ###################################################################### -for n_epoch in range(args.nb_epochs): - log_string(f"--- epoch {n_epoch} ----------------------------------------") +if args.test == "tsne": + model = models[0] - cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) - log_string(f"current_test_accuracies {cta}") + quizzes = [] + labels = [] + nb_samples_per_task = 1000 - ################################################## - # Select, improve, and eval the worst model + for n, t in enumerate(args.grids_world_tasks.split(",")): + quizzes.append( + quiz_machine.problem.generate_w_quizzes(nb_samples_per_task, [t]) + ) + labels.append(torch.full((quizzes[-1].size(0),), n)) - weakest_model = min(models, key=lambda m: float(m.main_test_accuracy)) + quizzes = torch.cat(quizzes, dim=0) + labels = torch.cat(labels, dim=0) - log_string( - f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}" - ) + with torch.autograd.no_grad(): + model.eval().to(main_device) + record = [] + for input, targets in zip( + quizzes.split(args.batch_size), labels.split(args.batch_size) + ): + input = input.to(main_device) + bs = mygpt.BracketedSequence(input) + bs = mygpt.BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + bs = model.embedding(bs) + bs = model.trunk[args.nb_blocks // 2](bs) + record.append((bs.x.to("cpu"), targets)) - one_epoch(weakest_model, quiz_machine) + x = torch.cat([x for x, y in record], dim=0).flatten(1) + y = torch.cat([y for x, y in record], dim=0) - log_string( - f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}" - ) + print(f"{x.size()=} {y.size()=}") + # torch.save((x,y), "/tmp/embed.pth") + # exit(0) - run_tests(weakest_model, quiz_machine, deterministic_synthesis=False) + from sklearn.manifold import TSNE - log_string( - f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}" - ) + x_np = x.numpy() + z_np = TSNE(n_components=2, perplexity=50).fit_transform(x_np) + z = torch.from_numpy(z_np) - ################################################## - # Replace a fraction of the w_quizzes with fresh ones + print(f"{z.size()=}") + + with open("/tmp/result.dat", "w") as f: + for k in range(z.size(0)): + f.write(f"{y[k]} {z[k,0]} {z[k,1]}\n") + + exit(0) + +###################################################################### + +if args.test == "generator": + token_prolog_0 = vocabulary_size + 0 + token_prolog_1 = vocabulary_size + 1 + token_prolog_2 = vocabulary_size + 2 + generator_vocabulary_size = vocabulary_size + 3 + + generator = mygpt.MyGPT( + vocabulary_size=generator_vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + compute_attzero=compute_causal_attzero, + dropout=args.dropout, + ).to(main_device) + + generator.main_test_accuracy = 0.0 + + filename = f"generator.pth" + + try: + d = torch.load(os.path.join(args.result_dir, filename)) + generator.load_state_dict(d[0]) + generator.main_test_accuracy = d[1] + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + + for n_epoch in range(args.nb_epochs): + one_generator_epoch( + generator, + quiz_machine=quiz_machine, + models=models, + fraction_w_quizzes=1 if n_epoch < 25 else 0.5, + local_device=main_device, + ) + + filename = f"generator.pth" + torch.save( + (generator.state_dict(), generator.main_test_accuracy), + os.path.join(args.result_dir, filename), + ) + log_string(f"wrote {filename}") + + c_quizzes, prologs = generate_c_quizzes_with_generator( + generator, quiz_machine, args.batch_size + ) + + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + probas = seq_logproba.exp() + + u0 = probas <= args.proba_not_understands + u2 = probas >= args.proba_understands + u1 = (u0 | u2) == False + + predicted_prologs = ( + (u0.long() * token_prolog_0) + + (u1.long() * token_prolog_1) + + (u2.long() * token_prolog_2) + ) + + comments = [] + + nb_errors = (predicted_prologs != prologs).long().sum() + nb_total = prologs.numel() + + log_string(f"generator_error {nb_errors} / {nb_total}") + + def readable(prologs): + return (prologs == token_prolog_1) + 2 * (prologs == token_prolog_2) + + for aa, ee, ff in zip(probas, readable(predicted_prologs), readable(prologs)): + sa = "prolog " + " ".join( + [f"{e.item()}/{f.item()}" for e, f in zip(ee, ff)] + ) + sp = "proba " + " ".join([f"{p.item():.02f}" for p in aa]) + comments.append(sa + "\n" + sp) + + filename = f"generator_batch_{n_epoch:04d}.png" + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, filename, c_quizzes, comments=comments + ) + log_string(f"wrote {filename}") + + exit(0) + +###################################################################### + +for n_epoch in range(current_epoch, args.nb_epochs): + state = {"current_epoch": n_epoch} + filename = "state.pth" + torch.save(state, os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") + + log_string(f"--- epoch {n_epoch} ----------------------------------------") - quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts) + cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) + log_string(f"current_test_accuracies {cta}") ################################################## # If all the models are good enough, generate new quizzes and # re-compute the test errors if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: - create_c_quizzes( + record_new_c_quizzes( models, quiz_machine, - nb_for_train=nb_new_c_quizzes_for_train, - nb_for_test=nb_new_c_quizzes_for_test, + nb_for_train=args.nb_new_c_quizzes_for_train, + nb_for_test=args.nb_new_c_quizzes_for_test, ) + filename = "c_quizzes.pth" + quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") + + # Force one epoch of training for model in models: - run_tests(model, quiz_machine, deterministic_synthesis=False) + model.main_test_accuracy = 0.0 + + ################################################## + # Select, improve, and eval the worst model(s) + + ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy)) + + weakest_models = ranked_models[: len(gpus)] + + threads = [] + + for gpu, model in zip(gpus, weakest_models): + log_string(f"training model {model.id}") + + t = threading.Thread( + target=one_epoch, daemon=True, args=(model, quiz_machine, gpu) + ) + + threads.append(t) + + t.start() + + for t in threads: + t.join() + + # Save the models to disk + + for model in weakest_models: + filename = f"gpt_{model.id:03d}.pth" + torch.save( + { + "state_dict": model.state_dict(), + "optimizer_state_dict": model.optimizer.state_dict(), + "main_test_accuracy": model.main_test_accuracy, + }, + os.path.join(args.result_dir, filename), + ) + log_string(f"wrote {filename}") + + for model in weakest_models: + save_additional_results(model, models, science_w_quizzes) + + ###################################################################### + + # Renew the training samples + + for model in weakest_models: + quiz_machine.renew_train_w_quizzes(model=model) + if args.log_command is not None: + s = args.log_command.split() + s.insert(1, args.result_dir) + subprocess.run(s) ###################################################################### diff --git a/mygpt.py b/mygpt.py index d0fda7e..041d28c 100755 --- a/mygpt.py +++ b/mygpt.py @@ -19,6 +19,45 @@ from torch.nn import functional as F ###################################################################### + +class BSQ(nn.Module): + def __init__(self, L): + super().__init__() + self.L = L + + def forward(self, input, indexes=False): + norm = input.pow(2).sum(dim=2, keepdim=True).sqrt() + u = input / norm + + if indexes: + return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1) + + hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1) + if self.training: + self.loss += u.mean(dim=0).tanh().pow(2).mean() + return hat_u + u - u.detach() + else: + return hat_u + + +class RandomBypass(nn.Module): + def __init__(self, m, p): + super().__init__() + self.m = m + self.p = p + + def forward(self, x): + y = self.m(x) + + if self.training: + u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None] + return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size()) + else: + return y + + +###################################################################### + # A BracketedSequence is a BxTx... tensor with a first and a nb time # steps to compute. @@ -114,6 +153,30 @@ class AddPositionalEncoding(nn.Module): ############################## +class EncoderHead(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.fc = nn.Linear(dim_in, dim_out) + + def forward(self, bs): + z = self.fc(bs.x).mean(dim=1) + return z, bs.x.shape + + +class DecoderBottom(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.fc = nn.Linear(dim_in, dim_out) + + def forward(self, z_shape): + z, shape = z_shape + y = self.fc(z)[:, None, :].expand(shape) + return BracketedSequence(y) + + +############################## + + class QKVAttention(nn.Module): def __init__( self, @@ -121,7 +184,7 @@ class QKVAttention(nn.Module): dim_qk, dim_v, nb_heads=1, - causal=False, + compute_attzero=None, attention_dropout=0.0, ): super().__init__() @@ -129,7 +192,7 @@ class QKVAttention(nn.Module): def randw(*d): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) - self.causal = causal + self.compute_attzero = compute_attzero self.attention_dropout = attention_dropout self.record_attention = False @@ -141,10 +204,6 @@ class QKVAttention(nn.Module): def forward(self, bs_q): x_q = bs_q.x - assert ( - self.causal or bs_q.complete() - ), "Partial evaluation is only possible for causal models" - if bs_q.first == 0: self.cache_k = x_q.new_zeros( x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1) @@ -169,12 +228,12 @@ class QKVAttention(nn.Module): "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] ) / math.sqrt(self.w_q.size(1)) - if self.causal: + if self.compute_attzero is not None: if bs_q.first == 0: - self.cache_attzero = ( - torch.arange(x_q.size(1), device=q.device)[None, None, :, None] - < torch.arange(x_q.size(1), device=q.device)[None, None, None, :] - ) + self.cache_attzero = self.compute_attzero( + torch.arange(x_q.size(1), device=q.device)[:, None], + torch.arange(x_q.size(1), device=q.device)[None, :], + )[None, None, :, :] a = a.masked_fill( self.cache_attzero[ :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb @@ -202,22 +261,19 @@ class QKVAttention(nn.Module): class NoiseInjector(nn.Module): - def __init__(self): + def __init__(self, identifier=None): super().__init__() self.noise_std = 0.0 + self.identifier = identifier def forward(self, x): if self.noise_std > 0: - x = x + torch.randn(x.size(), device=x.device) * self.noise_std + x = x * ( + 1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long() + ) return x -def set_noise_injection(model, noise_std): - for m in model.modules(): - if isinstance(m, NoiseInjector): - m.noise_std = noise_std - - ############################## @@ -230,7 +286,8 @@ class MyGPT(nn.Module): dim_hidden, nb_heads, nb_blocks, - causal=False, + compute_attzero=None, + autoencoder_dim=-1, dropout=0.0, len_max=1e5, ): @@ -238,6 +295,8 @@ class MyGPT(nn.Module): assert dim_model % nb_heads == 0 + self.temperature = 1.0 + self.embedding = nn.Sequential( CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), AddPositionalEncoding(len_max), @@ -250,21 +309,21 @@ class MyGPT(nn.Module): WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), - NoiseInjector(), + NoiseInjector(identifier=("attention", b)), ), QKVAttention( dim_in=dim_model, dim_qk=dim_keys, dim_v=dim_model // nb_heads, nb_heads=nb_heads, - causal=causal, + compute_attzero=compute_attzero, attention_dropout=dropout, ), ), WithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), - NoiseInjector(), + NoiseInjector(identifier=("ffw", b)), nn.Linear(in_features=dim_model, out_features=dim_hidden), nn.ReLU(), nn.Linear(in_features=dim_hidden, out_features=dim_model), @@ -279,6 +338,26 @@ class MyGPT(nn.Module): nn.Linear(in_features=dim_model, out_features=vocabulary_size) ) + # ------------------------------------------------------- + if autoencoder_dim > 0: + self.encoder = nn.Sequential( + *( + trunk_blocks[: nb_blocks // 2] + + [EncoderHead(dim_model, autoencoder_dim)] + ) + ) + + self.decoder = nn.Sequential( + *( + [ + DecoderBottom(autoencoder_dim, dim_model), + AddPositionalEncoding(len_max), + ] + + trunk_blocks[nb_blocks // 2 :] + ) + ) + # ------------------------------------------------------- + with torch.no_grad(): for m in self.modules(): if isinstance(m, nn.Embedding): @@ -288,13 +367,59 @@ class MyGPT(nn.Module): m.weight.fill_(1.0) def forward(self, bs): - # print(f"GENERATE {bs.first} {bs.first+bs.nb}") + for m in self.modules(): + m.loss = 0 + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) bs = self.trunk(bs) bs = self.readout(bs) + bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature + + for m in self.modules(): + self.loss += m.loss + + return bs + + def encode(self, bs): + bs = self.embedding(bs) + z = self.encoder(bs) + return z + + def decode(self, z_shape): + bs = self.decoder(z_shape) + bs = self.readout(bs) return bs + def partial_forward(self, bs, start_layer=None, end_layer=None): + if start_layer is None: + # print(f"GENERATE {bs.first} {bs.first+bs.nb}") + bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + bs = self.embedding(bs) + if end_layer is not None: + return self.trunk[:end_layer](bs) + else: + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + else: + bs = self.trunk[start_layer:](bs) + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + + def reset_transformations(self): + self.temperature = 1.0 + for m in self.modules(): + if isinstance(m, NoiseInjector): + m.noise_std = 0.0 + + def set_noise_injection(self, noise_std, identifier=None): + for m in self.modules(): + if isinstance(m, NoiseInjector): + if identifier is None or identifier == m.identifier: + m.noise_std = noise_std + def record_attention(self, v=True): for m in self.modules(): if isinstance(m, QKVAttention): @@ -324,7 +449,6 @@ if __name__ == "__main__": nb_heads=2, nb_blocks=2, dropout=0.1, - causal=True, ) model.eval() diff --git a/problem.py b/problem.py index a49634d..9bee5b2 100755 --- a/problem.py +++ b/problem.py @@ -9,99 +9,90 @@ import threading, queue, torch, tqdm class Problem: - def nb_token_values(self): - pass - - def trivial_prompts_and_answers(self, prompts, answers): - pass - - # returns two tensors nb x D and nb x D' - def generate_prompts_and_answers(self, nb): - pass - - # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - pass - - -class MultiThreadProblem: - def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1): - self.problem = problem - self.chunk_size = chunk_size - self.queue = queue.Queue(maxsize=max_nb_cached_chunks) - for _ in range(nb_threads): - threading.Thread(target=self.fill_cache, daemon=True).start() - self.rest = None - - def nb_token_values(self): - return self.problem.nb_token_values() + def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1): + if nb_threads > 0: + self.chunk_size = chunk_size + self.queue = queue.Queue(maxsize=max_nb_cached_chunks) + for _ in range(nb_threads): + threading.Thread(target=self.fill_cache, daemon=True).start() + self.rest = None + else: + self.queue = None - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - self.problem.save_quizzes( - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ) + def nb_cached_quizzes(self): + if self.queue is None: + return None + else: + return self.queue.qsize() * self.chunk_size def fill_cache(self): while True: - prompts, answers = self.problem.generate_prompts_and_answers( - self.chunk_size - ) + quizzes = self.generate_w_quizzes_(self.chunk_size) + self.queue.put(quizzes.to("cpu"), block=True) - self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) + def generate_w_quizzes(self, nb, progress_bar=True): + if self.queue is None: + return self.generate_w_quizzes_(nb) - def trivial_prompts_and_answers(self, prompts, answers): - return self.problem.trivial_prompts_and_answers(prompts, answers) - - def generate_prompts_and_answers(self, nb): if self.rest is not None: - prompts, answers = rest + quizzes = rest else: - prompts, answers = [], [] + quizzes = [] self.rest = None - n = sum([p.size(0) for p in prompts]) - - with tqdm.tqdm( - total=nb, - dynamic_ncols=True, - desc="world generation", - ) as pbar: + n = sum([q.size(0) for q in quizzes]) + + if progress_bar: + with tqdm.tqdm( + total=nb, + dynamic_ncols=True, + desc="world generation", + ) as pbar: + while n < nb: + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) + pbar.update(q.size(0)) + else: while n < nb: - p, s = self.queue.get(block=True) - prompts.append(p) - answers.append(s) - n += p.size(0) - pbar.update(p.size(0)) + q = self.queue.get(block=True) + quizzes.append(q) + n += q.size(0) - prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0) - assert n == prompts.size(0) + quizzes = torch.cat(quizzes, dim=0) + assert n == quizzes.size(0) k = n - nb if k > 0: - rest = (prompts[-k:], answers[-k:]) - prompts, answers = prompts[:-k], answers[:-k] + rest = quizzes[-k:] + quizzes = quizzes[:-k] + + return quizzes + + ###################################################################### + + def trivial_prompts_and_answers(self, prompts, answers): + pass + + # The one to implement, returns two tensors nb x D and nb x D' + def generate_w_quizzes_(self, nb): + pass + + # save a file to vizualize quizzes, you can save a txt or png file + def save_quiz_illustrations( + self, + result_dir, + filename_prefix, + prompts, + answers, + predicted_prompts=None, + predicted_answers=None, + ): + pass + + def save_some_examples(self, result_dir): + pass - return prompts, answers + ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index f0fb408..92da03d 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -15,6 +15,8 @@ from torch.nn import functional as F import mygpt from mygpt import BracketedSequence +import threading + ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -27,9 +29,11 @@ def one_batch_masked_inplace_autoregression( input, ar_mask, seq_logproba, - temperature, - deterministic_synthesis, + deterministic_synthesis=False, ): + if input.size(0) == 0: + return + to_generate = (ar_mask.sum(0) > 0).nonzero() if to_generate.min() > 0: @@ -41,8 +45,6 @@ def one_batch_masked_inplace_autoregression( logits = output[:, s] - logits = (logits / temperature).log_softmax(dim=-1) - if deterministic_synthesis: t_next = logits.argmax(-1) else: @@ -56,486 +58,373 @@ def one_batch_masked_inplace_autoregression( input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] -def masked_inplace_autoregression( - model, - batch_size, - input, - ar_mask, - seq_logproba, - temperature, - deterministic_synthesis, - forbidden_tokens=None, - logit_biases=None, - progress_bar_desc=None, - device=torch.device("cpu"), -): - assert input.size() == ar_mask.size() - - batches = zip( - input.split(batch_size), - ar_mask.split(batch_size), - seq_logproba.split(batch_size), - ) - - if progress_bar_desc is not None: - batches = tqdm.tqdm( - batches, - dynamic_ncols=True, - desc=progress_bar_desc, - total=(input.size(0) + batch_size - 1) // batch_size, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, ar_mask, seq_logproba in batches: - one_batch_masked_inplace_autoregression( - model=model, - input=input, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=deterministic_synthesis, - ) - - model.train(t) - - ###################################################################### class QuizMachine: - def indices_forward_and_backward(self, quizzes): - i_forward = quizzes[:, 0] == self.token_forward - j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward - i_backward = quizzes[:, 0] == self.token_backward - j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward - assert torch.logical_or( - torch.logical_and(i_forward, j_forward), - torch.logical_and(i_backward, j_backward), - ).all() - return i_forward, i_backward - - def non_trivial(self, quizzes): - quizzes = quizzes.clone() - n_forward = quizzes[quizzes[:, 0] == self.token_forward] - n_backward = quizzes[:, 0] == self.token_backward - backward = quizzes[n_backward] - quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - return torch.logical_not( - self.problem.trivial_prompts_and_answers( - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], - ) - ) - - def reverse_time(self, quizzes): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) - - forward_to_backward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len], - quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1], - quizzes[:, 1 : 1 + self.prompt_len], - ], - dim=1, - ) - - forward_to_backward[:, 0] = self.token_backward - forward_to_backward[:, 1 + self.answer_len] = self.token_backward - - backward_to_forward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.answer_len :], - quizzes[:, 1 + self.answer_len : 2 + self.answer_len], - quizzes[:, 1 : 1 + self.answer_len], - ], - dim=1, - ) - - backward_to_forward[:, 0] = self.token_forward - backward_to_forward[:, 1 + self.prompt_len] = self.token_forward - - m = i_forward.long()[:, None] - - return m * forward_to_backward + (1 - m) * backward_to_forward - - def reverse_random_half_in_place(self, quizzes): - i = torch.rand(quizzes.size(0)) < 0.5 - if i.any(): - quizzes[i] = self.reverse_time(quizzes[i]) - - def make_ar_mask(self, quizzes, first=False): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) - - t = torch.arange(quizzes.size(1), device=quizzes.device) - - if first: - m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long() - m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long() - else: - m_forward = (t >= 2 + self.prompt_len).long() - m_backward = (t >= 2 + self.answer_len).long() - - m = i_forward.long()[:, None] - - return m * m_forward + (1 - m) * m_backward - - def generate_token_sequences(self, nb): - prompts, answers = self.problem.generate_prompts_and_answers(nb) - - if self.prompt_len is None: - self.prompt_len = prompts.size(1) - - if self.answer_len is None: - self.answer_len = answers.size(1) - - assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len - - result = [] - - for prompt, answer in zip(prompts, answers): - a = [ - torch.tensor([self.token_forward]), - prompt, - torch.tensor([self.token_forward]), - answer, - ] - - result.append(torch.cat(a, dim=0)[None, :]) - - return torch.cat(result, dim=0) - def __init__( self, problem, - nb_train_samples, - nb_test_samples, - back_accuracy, batch_size, result_dir, + prompt_noise, logger, device=torch.device("cpu"), ): super().__init__() - v = problem.nb_token_values() - self.token_forward = v - self.token_backward = v + 1 - self.nb_token_values = v + 2 - self.problem = problem - self.back_accuracy = back_accuracy self.batch_size = batch_size self.device = device self.logger = logger self.prompt_len = None self.answer_len = None - - self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - self.reverse_random_half_in_place(self.train_w_quizzes) - self.train_w_quizzes = self.train_w_quizzes.to(device) - - self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - self.reverse_random_half_in_place(self.test_w_quizzes) - self.test_w_quizzes = self.test_w_quizzes.to(device) - + self.prompt_noise = prompt_noise + + # struct, mask_generate, mask_noise, mask_loss + self.train_structures = [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), + # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), + # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + ] + + self.test_structures = self.train_structures + + self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] self.test_c_quizzes = [] - if result_dir is not None: - self.save_quizzes( - result_dir, - "culture_w_quizzes", - self.train_w_quizzes[:72], - ) + def vocabulary_size(self): + return self.problem.nb_token_values - def save_quizzes( + ###################################################################### + + def autoregression( self, - result_dir, - filename_prefix, - quizzes, - mistakes=None, + model, + input, + ar_mask, + seq_logproba=None, + progress_bar_desc=None, ): - quizzes = quizzes.clone() - n_forward = quizzes[quizzes[:, 0] == self.token_forward] - n_backward = quizzes[:, 0] == self.token_backward - backward = quizzes[n_backward] - assert n_forward.size(0) + backward.size(0) == quizzes.size(0) - quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) - - predicted_prompts = n_backward.long() - predicted_answers = 1 - predicted_prompts - if mistakes is not None: - # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes - predicted_answers *= mistakes - else: - # 0/2 ~ not-to-predict / to predict - predicted_prompts *= 2 - predicted_answers *= 2 + assert input.size() == ar_mask.size() - self.problem.save_quizzes( - result_dir, - filename_prefix, - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], - predicted_prompts, - predicted_answers, + if seq_logproba is None: + seq_logproba = torch.empty(input.size(0), device=self.device) + + batches = zip( + input.split(self.batch_size), + ar_mask.split(self.batch_size), + seq_logproba.split(self.batch_size), ) - def batches(self, split="train", desc=None): - assert split in {"train", "test"} - if split == "train": - w_quizzes = self.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = self.test_w_quizzes - c_quizzes = self.test_c_quizzes + if progress_bar_desc is not None: + batches = tqdm.tqdm( + batches, + dynamic_ncols=True, + desc=progress_bar_desc, + total=(input.size(0) + self.batch_size - 1) // self.batch_size, + ) - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] + with torch.autograd.no_grad(): + t = model.training + model.eval() - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] + for input, ar_mask, seq_logproba in batches: + one_batch_masked_inplace_autoregression( + model=model, + input=input, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + deterministic_synthesis=False, + ) - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) + model.train(t) - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 + ###################################################################### - # Shuffle - input = input[torch.randperm(input.size(0))] + def data_input(self, model, split="train"): + assert split in {"train", "test"} - if desc is None: - desc = f"epoch-{split}" - for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=desc - ): - yield batch + with self.LOCK_C_QUIZZES: + if split == "train": + w_quizzes = model.train_w_quizzes + c_quizzes = self.train_c_quizzes + else: + w_quizzes = model.test_w_quizzes + c_quizzes = self.test_c_quizzes - def vocabulary_size(self): - return self.nb_token_values + if len(c_quizzes) > 0: + c_quizzes = torch.cat(c_quizzes, dim=0) - def produce_results( - self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 - ): - def compute_accuracy(input, log_prefix=None): - ar_mask = self.make_ar_mask(input) - result = input.clone() * (1 - ar_mask) - seq_logproba = torch.empty(input.size(0), device=self.device) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: + i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] + c_quizzes = c_quizzes[i] - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba, - temperature=1.0, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc=None, - device=self.device, - ) + i = torch.randperm(w_quizzes.size(0))[ + : w_quizzes.size(0) - c_quizzes.size(0) + ] + w_quizzes = w_quizzes[i] - correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device) + quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) + from_w = torch.arange( + quizzes.size(0), device=quizzes.device + ) < w_quizzes.size(0) - n_forward = input[:, 0] == self.token_forward - n_backward = input[:, 0] == self.token_backward + else: + quizzes = w_quizzes.clone() + from_w = torch.full((quizzes.size(0),), True, device=quizzes.device) - correct[n_forward] = ( - (input[n_forward] == result[n_forward]).long().min(dim=1).values - ) + i = torch.randperm(quizzes.size(0), device=quizzes.device) + quizzes, from_w = quizzes[i], from_w[i] - if self.back_accuracy and n_backward.any(): - # accuracy of B->A*->B*=B instead of B->A*=A - back_input = self.reverse_time(result[n_backward]) - back_input[:, 2 + self.prompt_len :] = input[ - n_backward, 1 : 1 + self.answer_len - ] - _, correct[n_backward] = compute_accuracy(back_input) + self.randomize_configuations_inplace( + quizzes, structs=[s for s, _, _, _ in self.train_structures] + ) - if log_prefix is not None: - forward_nb_correct = correct[n_forward].sum() - forward_nb_total = correct[n_forward].size(0) - backward_nb_correct = correct[n_backward].sum() - backward_nb_total = correct[n_backward].size(0) + quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) - self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" - ) + if self.prompt_noise > 0.0: + for struct, _, mask_noise, mask_loss in self.train_structures: + i = self.problem.indices_select(quizzes=quizzes, struct=struct) + if i.any(): + quizzes[i] = self.problem.inject_noise( + quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise + ) + quiz_mask_loss[i] = self.make_quiz_mask( + quizzes=quizzes[i], struct=struct, mask=mask_loss + ) - self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" - ) + return quizzes, quiz_mask_loss - return result, correct + ###################################################################### - compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") + def make_quiz_mask(self, quizzes, struct, mask): + assert struct in [s for s, _, _, _ in self.train_structures] + return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask) - test_result, test_correct = compute_accuracy( - self.test_w_quizzes[:nmax], log_prefix="test" - ) + ###################################################################### - main_test_accuracy = test_correct.sum() / test_correct.size(0) - self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") + def predict(self, model, quizzes, struct, mask): + ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask) + result = quizzes * (1 - ar_mask) - ############################## + seq_logproba = torch.empty(quizzes.size(0), device=self.device) - self.save_quizzes( - result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}", - quizzes=test_result[:72], - mistakes=test_correct[:72] * 2 - 1, + self.autoregression( + model=model, + input=result, + ar_mask=ar_mask, + seq_logproba=seq_logproba, + progress_bar_desc="accuracy", ) - return main_test_accuracy + correct = (result == quizzes).min(dim=1).values.long() - def renew_w_quizzes(self, nb, for_train=True): - input = self.train_w_quizzes if for_train else self.test_w_quizzes - nb = min(nb, input.size(0)) - input[:-nb] = input[nb:].clone() - fresh_w_quizzes = self.generate_token_sequences(nb) - self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to(self.device) + return result, correct - def store_c_quizzes(self, new_c_quizzes, for_train=True): - if for_train: - self.train_c_quizzes.append(new_c_quizzes) - else: - self.test_c_quizzes.append(new_c_quizzes) + ###################################################################### - def compute_correctness( - self, - c_quizzes, - models_for_validation, - bidirectional_validation=False, - deterministic_validation=True, - ): - if bidirectional_validation: - backward_c_quizzes = self.forward_to_backward(c_quizzes) + def produce_results(self, n_epoch, model, input, result_dir): + input = input.to(self.device) + result = input.new(input.size()) + correct = input.new(input.size(0)) + predicted_parts = input.new(input.size(0), 4) - seq_logproba = torch.zeros( - c_quizzes.size(0), - max([m.id for m in models_for_validation]) + 1, - device=self.device, + nb = 0 + + # We consider all the configurations that we train for + for struct, mask_generate, _, _ in self.test_structures: + i = self.problem.indices_select(quizzes=input, struct=struct) + nb += i.long().sum() + result[i], correct[i] = self.predict( + model=model, quizzes=input[i], struct=struct, mask=mask_generate + ) + predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[ + None, : + ] + solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 + correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long() + + assert nb == input.size(0) + + nb_correct = (correct == 1).long().sum() + nb_total = (correct != 0).long().sum() + self.logger( + f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" ) - nb_correct = 0 + main_test_accuracy = nb_correct / nb_total - seq_logproba[...] = 0.0 + ############################## - for model in models_for_validation: - result = c_quizzes.clone() - - ar_mask = self.make_ar_mask(result) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - temperature=1.0, - deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving c_quizzes", - device=self.device, - ) + correct_parts = predicted_parts * correct[:, None] - correct = (c_quizzes == result).long().min(dim=-1).values + result = result[:128] + predicted_parts = predicted_parts[:128] + correct_parts = correct_parts[:128] - if bidirectional_validation: - backward_result = backward_c_quizzes.clone() + self.problem.save_quizzes_as_image( + result_dir, + f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png", + quizzes=result, + predicted_parts=predicted_parts, + correct_parts=correct_parts, + ) - ar_mask = self.make_ar_mask(backward_result) + return main_test_accuracy - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=backward_result, - ar_mask=ar_mask, - seq_logproba=seq_logproba[:, model.id], - temperature=1.0, - deterministic_synthesis=deterministic_validation, - # progress_bar_desc="solving backward c_quizzes", - device=self.device, - ) + ###################################################################### + + def randomize_configuations_inplace(self, quizzes, structs): + r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device) + for c in range(len(structs)): + quizzes[r == c] = self.problem.reconfigure( + quizzes[r == c], struct=structs[c] + ) + + ###################################################################### - backward_correct = ( - (backward_c_quizzes == backward_result).long().min(dim=-1).values + def renew_train_w_quizzes(self, model): + if hasattr(model, "hard_w_quizzes"): + hard_w_quizzes = self.problem.reconfigure( + model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B") + ) + self.logger( + f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}" + ) + if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0): + nb_to_generate = 0 + model.train_w_quizzes[...] = hard_w_quizzes[ + torch.randperm(hard_w_quizzes.size(0))[ + model.train_w_quizzes.size(0) + ] + ] + else: + nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0) + model.train_w_quizzes[...] = torch.cat( + [ + hard_w_quizzes, + self.problem.generate_w_quizzes(nb_to_generate), + ], + dim=0, ) + else: + nb_to_generate = 0 + model.train_w_quizzes[...] = self.problem.generate_w_quizzes( + model.train_w_quizzes.size(0) + ) + + ###################################################################### - correct *= backward_correct + def store_c_quizzes(self, new_c_quizzes, for_train=True): + with self.LOCK_C_QUIZZES: + if for_train: + self.train_c_quizzes.append(new_c_quizzes.to("cpu")) + else: + self.test_c_quizzes.append(new_c_quizzes.to("cpu")) + + def save_c_quizzes(self, filename): + torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) - # endif + def load_c_quizzes(self, filename): + self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) - nb_correct += correct + ###################################################################### - return nb_correct, seq_logproba + def models_logprobas( + self, + models_for_validation, + c_quizzes, + struct, + mask_loss, + mask_noise=None, + device=None, + ): + if device is None: + device = self.device - ############################################################### + c_quizzes = self.problem.reconfigure(c_quizzes, struct) - def generate_quizzes(self, nb, model_for_generation, temperature=1.0): - c_quizzes = torch.empty( - nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 + seq_logproba = torch.zeros( + c_quizzes.size(0), + max([m.id for m in models_for_validation]) + 1, + device=device, ) + # if self.prompt_noise > 0.0 and mask_noise is not None: + # c_quizzes = self.problem.inject_noise( + # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise + # ) + + for model in models_for_validation: + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), + seq_logproba.split(self.batch_size), + ): + input = input.to(device) + quiz_mask_loss = self.make_quiz_mask( + input, struct=struct, mask=mask_loss + ) + output = model(mygpt.BracketedSequence(input)).x + l[:, model.id] = ( + -F.cross_entropy( + output.transpose(1, 2), input, reduction="none" + ) + * quiz_mask_loss + ).sum(dim=1) + + model.train(t) + + return seq_logproba.to("cpu") + + ###################################################################### + + def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None): seq_logproba = torch.zeros(nb, device=self.device) - # First, we generate the answer at high temperature + c_quizzes = None - c_quizzes[:, 0] = self.token_backward - c_quizzes[:, 1 + self.answer_len] = self.token_backward + for s, m, mt in procedure: + if c_quizzes is None: + c_quizzes = self.problem.create_empty_quizzes(nb, s) + c_quizzes = c_quizzes.to(self.device) + elif s != pred_s: + c_quizzes = self.problem.reconfigure(c_quizzes, s) + pred_s = s - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes, first=True), - seq_logproba=seq_logproba, - temperature=temperature, - deterministic_synthesis=False, - device=self.device, - ) + if mt is not None: + mt(model_for_generation) - # Then, we generate the prompt at low temperature + self.autoregression( + model=model_for_generation, + input=c_quizzes, + ar_mask=self.make_quiz_mask(c_quizzes, s, m), + seq_logproba=seq_logproba, + ) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, - temperature=1 / temperature, - deterministic_synthesis=False, - device=self.device, - ) + model_for_generation.reset_transformations() - # Then we return the quizz, and re-generate the response, now - # at low temperature + if recorder is not None: + x = c_quizzes.clone() + t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1) + recorder.append( + self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B")) + ) - c_quizzes = self.reverse_time(c_quizzes) + c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) - masked_inplace_autoregression( - model=model_for_generation, - batch_size=self.batch_size, - input=c_quizzes, - ar_mask=self.make_ar_mask(c_quizzes), - seq_logproba=seq_logproba, - temperature=1 / temperature, - deterministic_synthesis=False, - device=self.device, - ) + return c_quizzes.to("cpu") - return c_quizzes + ###################################################################### diff --git a/sky.py b/sky.py index ed440d3..cc5bd4f 100755 --- a/sky.py +++ b/sky.py @@ -50,7 +50,11 @@ class Sky(problem.Problem): speed=2, nb_iterations=2, avoid_collision=True, + max_nb_cached_chunks=None, + chunk_size=None, + nb_threads=-1, ): + super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) self.height = height self.width = width self.nb_birds = nb_birds @@ -296,7 +300,7 @@ class Sky(problem.Problem): return prompts, answers - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -327,7 +331,7 @@ if __name__ == "__main__": predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 - sky.save_quizzes( + sky.save_quiz_illustrations( "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers )