From bcf1be7eeec30ff2633126c56120b5389bf1fde1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jun 2024 15:22:05 +0200 Subject: [PATCH] Update. --- do_all.sh | 22 --- graph.py | 185 -------------------- main.py | 2 + problems.py | 492 ---------------------------------------------------- tasks.py | 7 +- 5 files changed, 4 insertions(+), 704 deletions(-) delete mode 100755 do_all.sh delete mode 100755 graph.py delete mode 100755 problems.py diff --git a/do_all.sh b/do_all.sh deleted file mode 100755 index c5d16fc..0000000 --- a/do_all.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -################################################################## -# START_IP_HEADER # -# # -# Written by Francois Fleuret # -# Contact for comments & bug reports # -# # -# END_IP_HEADER # -################################################################## - -# set -e -# set -o pipefail - -#prefix="--nb_train_samples=1000 --nb_test_samples=100 --batch_size=25 --nb_epochs=2 --max_percents_of_test_in_train=-1 --model=17K" -prefix="--nb_epochs=25" - -for task in byheart learnop guessop twotargets addition picoclvr maze snake stack expr rpl -do - [[ ! -d results_${task} ]] && ./main.py ${prefix} --task=${task} -done - diff --git a/graph.py b/graph.py deleted file mode 100755 index 07e376a..0000000 --- a/graph.py +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env python - -import math - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -import cairo - - -###################################################################### - - -def save_attention_image( - # image to save - filename, - tokens_input, - tokens_output, - # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1 - attention_matrices, - # do not draw links with a lesser attention - min_link_attention=0, - # draw only the strongest links necessary so that their summed - # attention is above min_total_attention - min_total_attention=None, - # draw only the top k links - k_top=None, - # the purely graphical settings - curved=True, - pixel_scale=8, - token_gap=15, - layer_gap=25, - y_eps=0.5, - padding=10, -): - if k_top is not None: - am = [] - for m in attention_matrices: - am.append(m * (m.sort(dim=-1, descending=True).indices < k_top)) - attention_matrices = am - - if min_total_attention is not None: - am = [] - for m in attention_matrices: - s = m.sort(dim=-1) - m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long() - b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m) - am.append(m * b) - - surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) - - ctx = cairo.Context(surface) - ctx.scale(pixel_scale, pixel_scale) - - ctx.set_source_rgb(0.0, 0.0, 0.0) - ctx.set_font_size(4.0) - # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL) - - x, y = 0, 0 - - ctx.set_line_width(0.25) - for d in range(len(attention_matrices)): - at = attention_matrices[d].to("cpu") - ni = torch.arange(at.size(0))[:, None].expand_as(at) - nj = torch.arange(at.size(1))[None, :].expand_as(at) - at = at.flatten() - o = at.sort().indices - at = at[o] - ni = ni.flatten()[o] - nj = nj.flatten()[o] - for i, j, a in zip(ni, nj, at): - if a > 0 and a >= min_link_attention: - c = 1 - a.item() - ctx.set_source_rgb(c, c, c) - ax, ay = j * token_gap, y - y_eps - ctx.move_to(ax, ay) - dx, dy = i * token_gap, y - layer_gap + y_eps - if curved: - bx, by = ax, ay - layer_gap * 0.5 - cx, cy = dx, dy + layer_gap * 0.5 - ctx.curve_to(bx, by, cx, cy, dx, dy) - else: - ctx.line_to(dx, dy) - ctx.stroke() - y -= layer_gap - - for d in range(0, len(attention_matrices) + 1): - n = ( - attention_matrices[0].size(-1) - if d == 0 - else attention_matrices[d - 1].size(-2) - ) - for n in range(n): - xc, yc = n * token_gap, -d * layer_gap - ctx.set_source_rgb(1.0, 1.0, 1.0) - ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) - ctx.fill() - ctx.set_source_rgb(0.0, 0.0, 0.0) - ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi) - ctx.fill() - - ctx.set_source_rgb(0.0, 0.0, 0.0) - - for k, t in enumerate(tokens_input): - s = str(t) - ( - x_bearing, - y_bearing, - width_t, - height_t, - x_advance, - y_advance, - ) = ctx.text_extents(s) - ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5) - ctx.show_text(s) - - for k, t in enumerate(tokens_output): - s = str(t) - ( - x_bearing, - y_bearing, - width_t, - height_t, - x_advance, - y_advance, - ) = ctx.text_extents(s) - ctx.move_to( - k * token_gap - width_t / 2, - -token_gap / 5 - len(attention_matrices) * layer_gap, - ) - ctx.show_text(s) - - x, y, width, height = surface.ink_extents() - x -= padding - y -= padding - width += 2 * padding - height += 2 * padding - pdf_surface = cairo.PDFSurface(filename, width, height) - ctx_pdf = cairo.Context(pdf_surface) - ctx_pdf.set_source_surface(surface, -x, -y) - ctx_pdf.paint() - pdf_surface.finish() - - -###################################################################### - -if __name__ == "__main__": - import mygpt - - tokens_output = ["", "-", 3, 4, ""] - tokens_input = [""] + tokens_output[:-1] - - vocabulary_size = 3 - x = torch.randint(vocabulary_size, (1, len(tokens_input))) - - model = mygpt.MyGPT( - vocabulary_size=vocabulary_size, - dim_model=4, - dim_keys=2, - dim_hidden=2, - nb_heads=2, - nb_blocks=5, - dropout=0.1, - causal=True, - ) - - model.eval() - model.record_attention() - - y1 = model(mygpt.BracketedSequence(x)).x - - attention_matrices = [m[0, 0] for m in model.retrieve_attention()] - - # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]] - - save_attention_image( - "attention.pdf", - tokens_input, - tokens_output, - attention_matrices, - # k_top=2, - min_total_attention=0.9, - ) diff --git a/main.py b/main.py index e058822..549e7ea 100755 --- a/main.py +++ b/main.py @@ -568,6 +568,8 @@ def create_quizzes( other_models=other_models, ) + print(nb_correct) + to_keep = new_quizzes[nb_correct == len(other_models) - 1] log_string(f"keep {to_keep.size(0)} quizzes") kept.append(to_keep) diff --git a/problems.py b/problems.py deleted file mode 100755 index 446e1a1..0000000 --- a/problems.py +++ /dev/null @@ -1,492 +0,0 @@ -#!/usr/bin/env python - -import math - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -###################################################################### - - -class Problem: - def generate_sequences(self, nb): - pass - - def seq2str(self, seq): - return "[NOT IMPLEMENTED]" - - def compute_nb_correct(self, input, ar_mask, result): - nb_total = ar_mask.sum().item() - nb_correct = ((result == input).long() * ar_mask).sum().item() - return nb_total, nb_correct - - -#################### - - -class ProblemDegradation(Problem): - def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False): - assert value_max // nb_state_tokens >= 2 - self.nb_state_tokens = nb_state_tokens - self.nb_time_steps = nb_time_steps - self.value_max = value_max - self.hard = hard - - def generate_sequences(self, nb): - x = ( - torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0 - ).long() * self.value_max - seq = [x] - - for t in range(self.nb_time_steps - 1): - v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long() - u = (v.max(dim=-1, keepdim=True).values == v).long() - n = ( - (u * x) - .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size())) - .sum(dim=-1, keepdim=True) - ) - m = 1 + ((n - 1) * torch.rand(n.size())).long() - x = ( - x - + m * u.roll(shifts=-1, dims=-1) - - n * u - + (n - m) * u.roll(shifts=1, dims=-1) - ) - seq.append(x) - - if self.hard: - seq.reverse() - - seq = torch.cat(seq, dim=1) - return seq, seq.new_full(seq.size(), 1, dtype=torch.int64) - - def compute_nb_correct(self, input, ar_mask, result): - nb_total = result.size(0) - nb_correct = 0 - e = result.new_zeros(self.nb_state_tokens) - - for seq in result: - states = list(seq.split(self.nb_state_tokens)) - if self.hard: - states.reverse() - - d = states[0] - j = d.sort(descending=True).indices[0] - e.zero_() - e[j] = self.value_max - if (d - e).abs().sum() == 0: - nb_errors = 0 - for k in range(len(states) - 1): - d = states[k + 1] - states[k] - j = d.sort(descending=False).indices[0] - if ( - d[j] == 0 - or d[j] > self.value_max // 4 - or d[(j + 1) % e.size(0)] <= 0 - or d[(j + 1) % e.size(0)] >= -d[j] - ): - nb_errors += 1 - else: - e.zero_() - e[j] = d[j] - e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)] - e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j] - if (d - e).abs().sum() > 0: - nb_errors += 1 - if nb_errors == 0: - nb_correct += 1 - - return nb_total, nb_correct - - def seq2str(self, seq): - return " | ".join( - [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)] - ) - - -#################### - - -class ProblemMemory(Problem): - def __init__(self, len_total=25): - self.len_total = len_total - self.max_len_pattern = 5 - self.nb_noise_tokens = 10 - self.start_pattern_token = 0 - self.end_pattern_token = 1 - self.start_result_token = 2 - self.end_result_token = 3 - self.token_string = "[]<>" + "".join( - [chr(ord("a") + k) for k in range(self.nb_noise_tokens)] - ) - - def generate_sequences(self, nb): - sequences = ( - torch.randint(self.nb_noise_tokens, (nb, self.len_total)) - + self.end_result_token - + 1 - ) - len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1 - pattern_positions = torch.randint( - self.len_total - (5 + 2 * self.max_len_pattern), (nb,) - ) - k = self.len_total - (3 + self.max_len_pattern) - for i in range(nb): - l = len_patterns[i] - j = pattern_positions[i] - sequences[i, j] = self.start_pattern_token - sequences[i, j + l + 2] = self.end_pattern_token - sequences[i, k] = self.start_result_token - sequences[i, k + l + 2] = self.end_result_token - sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l] - - j = torch.arange(self.len_total)[None, :] - ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long() - - return sequences, ar_mask - - def seq2str(self, seq): - return "".join(self.token_string[x.item()] for x in seq) - - -class ProblemTwoTargets(Problem): - def __init__(self, len_total=10, len_targets=3): - assert len_targets >= 3 - assert len_total >= 3 * len_targets - 1 - self.len_total = len_total - self.len_targets = len_targets - - def generate_sequences(self, nb): - k = torch.arange(self.len_total)[None, :] - s = torch.randint(10, (nb, self.len_total)) - l = torch.rand(nb, self.len_total) - l = l * (k <= self.len_total - self.len_targets).long() - k1 = l.argmax(dim=1, keepdim=True) - m = (k != k1).long() * (k != k1 + self.len_targets - 1).long() - s = s * m + 10 * (1 - m) - l = l * ( - 1 - - (k + self.len_targets - 1 >= k1).long() - * (k < k1 + self.len_targets).long() - ) - k2 = l.argmax(dim=1, keepdim=True) - m = (k != k2).long() * (k != k2 + self.len_targets - 1).long() - s = s * m + 11 * (1 - m) - a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :]) - a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :]) - sequences = torch.cat( - ( - s, - torch.full((nb, 1), 12), - a1, - torch.full((nb, 1), 12), - a2, - torch.full((nb, 1), 12), - ), - 1, - ) - ar_mask = (sequences == 12).long() - ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) - return sequences, ar_mask - - def seq2str(self, seq): - return "".join("0123456789-+|"[x.item()] for x in seq) - - -#################### - - -class ProblemByHeart(Problem): - def __init__(self, nb_sentences=100, len_prompt=8, len_result=8, separation=1): - self.seq = torch.randint( - 10, (nb_sentences, len_prompt + separation + len_result) - ) - self.seq[:, len_prompt : len_prompt + separation] = 10 - - def generate_sequences(self, nb): - sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] - ar_mask = (sequences == 10).long() - ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) - return sequences, ar_mask - - def seq2str(self, seq): - return "".join("0123456789|"[x.item()] for x in seq) - - -#################### - - -class ProblemLearnOperator(Problem): - def __init__(self, nb_operators=100, len_source=6, len_result=9): - self.len_source = len_source - self.len_result = len_result - self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1 - self.operators = F.one_hot( - torch.rand(nb_operators, len_result, len_source).argmax(-1), - num_classes=len_source, - ) - - def generate_sequences(self, nb): - nb_operators = torch.randint(self.operators.size(0), (nb,)) - operators = self.operators[nb_operators] - nb_operators = ( - nb_operators[:, None] - // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1) - ) % 10 - marker1 = torch.full((nb, 1), 10) - source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] - marker2 = torch.full((nb, 1), 11) - result = operators.bmm(source[:, :, None]).squeeze(-1) - sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1) - ar_mask = (sequences == 11).long() - ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) - return sequences, ar_mask - - def seq2str(self, seq): - return "".join("0123456789|>"[x.item()] for x in seq) - - -#################### - - -class ProblemGuessOperator(Problem): - def __init__(self, len_source=5, len_result=8): - self.len_source = len_source - self.len_result = len_result - - def generate_sequences(self, nb): - operators = F.one_hot( - torch.rand(nb, self.len_result, self.len_source).argmax(-1), - num_classes=self.len_source, - ) - source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] - marker1 = torch.full((nb, 1), 10) - result1 = operators.bmm(source1[:, :, None]).squeeze(-1) - marker2 = torch.full((nb, 1), 11) - source2 = torch.randint(10, (nb, self.len_source)) - marker3 = torch.full((nb, 1), 12) - result2 = operators.bmm(source2[:, :, None]).squeeze(-1) - - sequences = torch.cat( - (source1, marker1, result1, marker2, source2, marker3, result2), 1 - ) - ar_mask = (sequences == 12).long() - ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) - return sequences, ar_mask - - def seq2str(self, seq): - return "".join("0123456789>|~"[x.item()] for x in seq) - - -#################### - - -class ProblemAddition(Problem): - def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False): - self.nb_digits = nb_digits - self.zero_padded = zero_padded - self.inverted_result = inverted_result - self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")]) - self.id2char = dict([(n, c) for c, n in self.char2id.items()]) - - def tensorize(self, strings): - len_max = max([len(x) for x in strings]) - return torch.cat( - [ - torch.tensor( - [ - [self.char2id[c] for c in s + "$" * (len_max - len(s))] - for s in strings - ] - ) - ], - 0, - ) - - def generate_sequences(self, nb): - sequences = [] - for k in range(nb): - a, b = torch.randint(10**self.nb_digits, (2,)) - c = a + b - a, b, c = str(a.item()), str(b.item()), str(c.item()) - if self.zero_padded: - a = "0" * (self.nb_digits - len(a)) + a - b = "0" * (self.nb_digits - len(b)) + b - c = "0" * (self.nb_digits + 1 - len(c)) + c - if self.inverted_result: - c = c[::-1] - sequences.append(f"{a}+{b}={c}$") - - sequences = self.tensorize(sequences) - ar_mask = (sequences == self.char2id["="]).long() - ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) - return sequences, ar_mask - - def seq2str(self, seq): - return "".join(self.id2char[x.item()] for x in seq) - - -#################### - - -class ProblemMixing(Problem): - def __init__( - self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True - ): - self.height = height - self.width = width - self.nb_time_steps = nb_time_steps - self.hard = hard - self.random_start = random_start - - def start_random(self, nb): - y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1) - - if self.random_start: - i = ( - torch.arange(self.height) - .reshape(1, -1, 1) - .expand(nb, self.height, self.width) - ) - j = ( - torch.arange(self.width) - .reshape(1, 1, -1) - .expand(nb, self.height, self.width) - ) - - ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1) - rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1) - - m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1) - - y = y * m + self.height * self.width * (1 - m) - - y = y.reshape(nb, self.height, self.width) - - return y - - def start_error(self, x): - if self.random_start: - i = ( - torch.arange(self.height, device=x.device) - .reshape(1, -1, 1) - .expand_as(x) - ) - j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x) - - ri = ( - (x == self.height * self.width) - .long() - .sum(dim=-1) - .argmax(-1) - .view(-1, 1, 1) - ) - rj = ( - (x == self.height * self.width) - .long() - .sum(dim=-2) - .argmax(-1) - .view(-1, 1, 1) - ) - - m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1) - else: - m = 1 - - x = x.flatten(1) - u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1) - - d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1) - - return d - - def moves(self, x): - y = ( - x[:, None, :, :] - .expand(-1, self.height * 2 + self.width * 2, -1, -1) - .clone() - ) - k = 0 - - for i in range(self.height): - y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1) - k += 1 - y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1) - k += 1 - - for j in range(self.width): - y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1) - k += 1 - y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1) - k += 1 - - return y - - def generate_sequences(self, nb): - x = self.start_random(nb) - - seq = [x.flatten(1)] - - for t in range(self.nb_time_steps - 1): - y = self.moves(x) - x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))] - seq.append(x.flatten(1)) - - if self.hard: - seq.reverse() - - seq = torch.cat(seq, dim=1) - return seq, seq.new_full(seq.size(), 1, dtype=torch.int64) - - def compute_nb_correct(self, input, ar_mask, result): - a = [ - x.reshape(result.size(0), self.height, self.width) - for x in result.split(self.height * self.width, dim=1) - ] - if self.hard: - a.reverse() - - x = a[0] - - d = self.start_error(x) - - for t in range(self.nb_time_steps - 1): - x0, x = a[t], a[t + 1] - y = self.moves(x0) - d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values - - nb_total, nb_correct = result.size(0), (d == 0).long().sum().item() - - return nb_total, nb_correct - - def seq2str(self, seq): - return " | ".join( - [ - " ".join( - [ - "-".join( - [ - f"{x:02d}" if x < self.height * self.width else "**" - for x in s - ] - ) - for s in r.split(self.width) - ] - ) - for r in seq.split(self.height * self.width) - ] - ) - - -#################### - -if __name__ == "__main__": - p = ProblemMixing(height=3, width=3, random_start=False) - - s, m = p.generate_sequences(10000) - for x in s[:5]: - print(p.seq2str(x)) - print(p.compute_nb_correct(None, None, s)) diff --git a/tasks.py b/tasks.py index 8680ba1..f6d34a8 100755 --- a/tasks.py +++ b/tasks.py @@ -14,9 +14,6 @@ from torch.nn import functional as F from mygpt import BracketedSequence -# from graph import save_attention_image -save_attention_image = None - ###################################################################### @@ -252,7 +249,7 @@ class World(Task): new_quizzes, ar_mask, deterministic_synthesis=False, - progress_bar_desc="new quizzes", + progress_bar_desc="creating quizzes", device=self.device, ) @@ -290,7 +287,7 @@ class World(Task): inverted_result, ar_mask, deterministic_synthesis=True, - progress_bar_desc="solving reverse quizzes", + progress_bar_desc="solving reversed quizzes", device=self.device, ) -- 2.39.5