Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 13:22:05 +0000 (15:22 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 13:22:05 +0000 (15:22 +0200)
do_all.sh [deleted file]
graph.py [deleted file]
main.py
problems.py [deleted file]
tasks.py

diff --git a/do_all.sh b/do_all.sh
deleted file mode 100755 (executable)
index c5d16fc..0000000
--- a/do_all.sh
+++ /dev/null
@@ -1,22 +0,0 @@
-#!/bin/bash
-
-##################################################################
-# START_IP_HEADER                                                #
-#                                                                #
-# Written by Francois Fleuret                                    #
-# Contact <francois.fleuret@unige.ch> for comments & bug reports #
-#                                                                #
-# END_IP_HEADER                                                  #
-##################################################################
-
-# set -e
-# set -o pipefail
-
-#prefix="--nb_train_samples=1000 --nb_test_samples=100 --batch_size=25 --nb_epochs=2 --max_percents_of_test_in_train=-1 --model=17K"
-prefix="--nb_epochs=25"
-
-for task in byheart learnop guessop twotargets addition picoclvr maze snake stack expr rpl
-do
-    [[ ! -d results_${task} ]] && ./main.py ${prefix} --task=${task}
-done
-
diff --git a/graph.py b/graph.py
deleted file mode 100755 (executable)
index 07e376a..0000000
--- a/graph.py
+++ /dev/null
@@ -1,185 +0,0 @@
-#!/usr/bin/env python
-
-import math
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-import cairo
-
-
-######################################################################
-
-
-def save_attention_image(
-    # image to save
-    filename,
-    tokens_input,
-    tokens_output,
-    # list of 2d tensors T2xT1, T3xT2, ..., TkxTk-1
-    attention_matrices,
-    # do not draw links with a lesser attention
-    min_link_attention=0,
-    # draw only the strongest links necessary so that their summed
-    # attention is above min_total_attention
-    min_total_attention=None,
-    # draw only the top k links
-    k_top=None,
-    # the purely graphical settings
-    curved=True,
-    pixel_scale=8,
-    token_gap=15,
-    layer_gap=25,
-    y_eps=0.5,
-    padding=10,
-):
-    if k_top is not None:
-        am = []
-        for m in attention_matrices:
-            am.append(m * (m.sort(dim=-1, descending=True).indices < k_top))
-        attention_matrices = am
-
-    if min_total_attention is not None:
-        am = []
-        for m in attention_matrices:
-            s = m.sort(dim=-1)
-            m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long()
-            b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m)
-            am.append(m * b)
-
-    surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None)
-
-    ctx = cairo.Context(surface)
-    ctx.scale(pixel_scale, pixel_scale)
-
-    ctx.set_source_rgb(0.0, 0.0, 0.0)
-    ctx.set_font_size(4.0)
-    # ctx.select_font_face("Arial", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
-
-    x, y = 0, 0
-
-    ctx.set_line_width(0.25)
-    for d in range(len(attention_matrices)):
-        at = attention_matrices[d].to("cpu")
-        ni = torch.arange(at.size(0))[:, None].expand_as(at)
-        nj = torch.arange(at.size(1))[None, :].expand_as(at)
-        at = at.flatten()
-        o = at.sort().indices
-        at = at[o]
-        ni = ni.flatten()[o]
-        nj = nj.flatten()[o]
-        for i, j, a in zip(ni, nj, at):
-            if a > 0 and a >= min_link_attention:
-                c = 1 - a.item()
-                ctx.set_source_rgb(c, c, c)
-                ax, ay = j * token_gap, y - y_eps
-                ctx.move_to(ax, ay)
-                dx, dy = i * token_gap, y - layer_gap + y_eps
-                if curved:
-                    bx, by = ax, ay - layer_gap * 0.5
-                    cx, cy = dx, dy + layer_gap * 0.5
-                    ctx.curve_to(bx, by, cx, cy, dx, dy)
-                else:
-                    ctx.line_to(dx, dy)
-                ctx.stroke()
-        y -= layer_gap
-
-    for d in range(0, len(attention_matrices) + 1):
-        n = (
-            attention_matrices[0].size(-1)
-            if d == 0
-            else attention_matrices[d - 1].size(-2)
-        )
-        for n in range(n):
-            xc, yc = n * token_gap, -d * layer_gap
-            ctx.set_source_rgb(1.0, 1.0, 1.0)
-            ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
-            ctx.fill()
-            ctx.set_source_rgb(0.0, 0.0, 0.0)
-            ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
-            ctx.fill()
-
-    ctx.set_source_rgb(0.0, 0.0, 0.0)
-
-    for k, t in enumerate(tokens_input):
-        s = str(t)
-        (
-            x_bearing,
-            y_bearing,
-            width_t,
-            height_t,
-            x_advance,
-            y_advance,
-        ) = ctx.text_extents(s)
-        ctx.move_to(k * token_gap - width_t / 2, 2 * token_gap / 5)
-        ctx.show_text(s)
-
-    for k, t in enumerate(tokens_output):
-        s = str(t)
-        (
-            x_bearing,
-            y_bearing,
-            width_t,
-            height_t,
-            x_advance,
-            y_advance,
-        ) = ctx.text_extents(s)
-        ctx.move_to(
-            k * token_gap - width_t / 2,
-            -token_gap / 5 - len(attention_matrices) * layer_gap,
-        )
-        ctx.show_text(s)
-
-    x, y, width, height = surface.ink_extents()
-    x -= padding
-    y -= padding
-    width += 2 * padding
-    height += 2 * padding
-    pdf_surface = cairo.PDFSurface(filename, width, height)
-    ctx_pdf = cairo.Context(pdf_surface)
-    ctx_pdf.set_source_surface(surface, -x, -y)
-    ctx_pdf.paint()
-    pdf_surface.finish()
-
-
-######################################################################
-
-if __name__ == "__main__":
-    import mygpt
-
-    tokens_output = ["<wat>", "-", 3, 4, "<end>"]
-    tokens_input = [""] + tokens_output[:-1]
-
-    vocabulary_size = 3
-    x = torch.randint(vocabulary_size, (1, len(tokens_input)))
-
-    model = mygpt.MyGPT(
-        vocabulary_size=vocabulary_size,
-        dim_model=4,
-        dim_keys=2,
-        dim_hidden=2,
-        nb_heads=2,
-        nb_blocks=5,
-        dropout=0.1,
-        causal=True,
-    )
-
-    model.eval()
-    model.record_attention()
-
-    y1 = model(mygpt.BracketedSequence(x)).x
-
-    attention_matrices = [m[0, 0] for m in model.retrieve_attention()]
-
-    # attention_matrices = [torch.rand(*s) for s in [ (4,5),(3,4),(8,3),(5,8) ]]
-
-    save_attention_image(
-        "attention.pdf",
-        tokens_input,
-        tokens_output,
-        attention_matrices,
-        # k_top=2,
-        min_total_attention=0.9,
-    )
diff --git a/main.py b/main.py
index e058822..549e7ea 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -568,6 +568,8 @@ def create_quizzes(
             other_models=other_models,
         )
 
+        print(nb_correct)
+
         to_keep = new_quizzes[nb_correct == len(other_models) - 1]
         log_string(f"keep {to_keep.size(0)} quizzes")
         kept.append(to_keep)
diff --git a/problems.py b/problems.py
deleted file mode 100755 (executable)
index 446e1a1..0000000
+++ /dev/null
@@ -1,492 +0,0 @@
-#!/usr/bin/env python
-
-import math
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-
-class Problem:
-    def generate_sequences(self, nb):
-        pass
-
-    def seq2str(self, seq):
-        return "[NOT IMPLEMENTED]"
-
-    def compute_nb_correct(self, input, ar_mask, result):
-        nb_total = ar_mask.sum().item()
-        nb_correct = ((result == input).long() * ar_mask).sum().item()
-        return nb_total, nb_correct
-
-
-####################
-
-
-class ProblemDegradation(Problem):
-    def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
-        assert value_max // nb_state_tokens >= 2
-        self.nb_state_tokens = nb_state_tokens
-        self.nb_time_steps = nb_time_steps
-        self.value_max = value_max
-        self.hard = hard
-
-    def generate_sequences(self, nb):
-        x = (
-            torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0
-        ).long() * self.value_max
-        seq = [x]
-
-        for t in range(self.nb_time_steps - 1):
-            v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long()
-            u = (v.max(dim=-1, keepdim=True).values == v).long()
-            n = (
-                (u * x)
-                .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size()))
-                .sum(dim=-1, keepdim=True)
-            )
-            m = 1 + ((n - 1) * torch.rand(n.size())).long()
-            x = (
-                x
-                + m * u.roll(shifts=-1, dims=-1)
-                - n * u
-                + (n - m) * u.roll(shifts=1, dims=-1)
-            )
-            seq.append(x)
-
-        if self.hard:
-            seq.reverse()
-
-        seq = torch.cat(seq, dim=1)
-        return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
-
-    def compute_nb_correct(self, input, ar_mask, result):
-        nb_total = result.size(0)
-        nb_correct = 0
-        e = result.new_zeros(self.nb_state_tokens)
-
-        for seq in result:
-            states = list(seq.split(self.nb_state_tokens))
-            if self.hard:
-                states.reverse()
-
-            d = states[0]
-            j = d.sort(descending=True).indices[0]
-            e.zero_()
-            e[j] = self.value_max
-            if (d - e).abs().sum() == 0:
-                nb_errors = 0
-                for k in range(len(states) - 1):
-                    d = states[k + 1] - states[k]
-                    j = d.sort(descending=False).indices[0]
-                    if (
-                        d[j] == 0
-                        or d[j] > self.value_max // 4
-                        or d[(j + 1) % e.size(0)] <= 0
-                        or d[(j + 1) % e.size(0)] >= -d[j]
-                    ):
-                        nb_errors += 1
-                    else:
-                        e.zero_()
-                        e[j] = d[j]
-                        e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)]
-                        e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j]
-                        if (d - e).abs().sum() > 0:
-                            nb_errors += 1
-                if nb_errors == 0:
-                    nb_correct += 1
-
-        return nb_total, nb_correct
-
-    def seq2str(self, seq):
-        return " | ".join(
-            [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)]
-        )
-
-
-####################
-
-
-class ProblemMemory(Problem):
-    def __init__(self, len_total=25):
-        self.len_total = len_total
-        self.max_len_pattern = 5
-        self.nb_noise_tokens = 10
-        self.start_pattern_token = 0
-        self.end_pattern_token = 1
-        self.start_result_token = 2
-        self.end_result_token = 3
-        self.token_string = "[]<>" + "".join(
-            [chr(ord("a") + k) for k in range(self.nb_noise_tokens)]
-        )
-
-    def generate_sequences(self, nb):
-        sequences = (
-            torch.randint(self.nb_noise_tokens, (nb, self.len_total))
-            + self.end_result_token
-            + 1
-        )
-        len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1
-        pattern_positions = torch.randint(
-            self.len_total - (5 + 2 * self.max_len_pattern), (nb,)
-        )
-        k = self.len_total - (3 + self.max_len_pattern)
-        for i in range(nb):
-            l = len_patterns[i]
-            j = pattern_positions[i]
-            sequences[i, j] = self.start_pattern_token
-            sequences[i, j + l + 2] = self.end_pattern_token
-            sequences[i, k] = self.start_result_token
-            sequences[i, k + l + 2] = self.end_result_token
-            sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l]
-
-        j = torch.arange(self.len_total)[None, :]
-        ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long()
-
-        return sequences, ar_mask
-
-    def seq2str(self, seq):
-        return "".join(self.token_string[x.item()] for x in seq)
-
-
-class ProblemTwoTargets(Problem):
-    def __init__(self, len_total=10, len_targets=3):
-        assert len_targets >= 3
-        assert len_total >= 3 * len_targets - 1
-        self.len_total = len_total
-        self.len_targets = len_targets
-
-    def generate_sequences(self, nb):
-        k = torch.arange(self.len_total)[None, :]
-        s = torch.randint(10, (nb, self.len_total))
-        l = torch.rand(nb, self.len_total)
-        l = l * (k <= self.len_total - self.len_targets).long()
-        k1 = l.argmax(dim=1, keepdim=True)
-        m = (k != k1).long() * (k != k1 + self.len_targets - 1).long()
-        s = s * m + 10 * (1 - m)
-        l = l * (
-            1
-            - (k + self.len_targets - 1 >= k1).long()
-            * (k < k1 + self.len_targets).long()
-        )
-        k2 = l.argmax(dim=1, keepdim=True)
-        m = (k != k2).long() * (k != k2 + self.len_targets - 1).long()
-        s = s * m + 11 * (1 - m)
-        a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
-        a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
-        sequences = torch.cat(
-            (
-                s,
-                torch.full((nb, 1), 12),
-                a1,
-                torch.full((nb, 1), 12),
-                a2,
-                torch.full((nb, 1), 12),
-            ),
-            1,
-        )
-        ar_mask = (sequences == 12).long()
-        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
-        return sequences, ar_mask
-
-    def seq2str(self, seq):
-        return "".join("0123456789-+|"[x.item()] for x in seq)
-
-
-####################
-
-
-class ProblemByHeart(Problem):
-    def __init__(self, nb_sentences=100, len_prompt=8, len_result=8, separation=1):
-        self.seq = torch.randint(
-            10, (nb_sentences, len_prompt + separation + len_result)
-        )
-        self.seq[:, len_prompt : len_prompt + separation] = 10
-
-    def generate_sequences(self, nb):
-        sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
-        ar_mask = (sequences == 10).long()
-        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
-        return sequences, ar_mask
-
-    def seq2str(self, seq):
-        return "".join("0123456789|"[x.item()] for x in seq)
-
-
-####################
-
-
-class ProblemLearnOperator(Problem):
-    def __init__(self, nb_operators=100, len_source=6, len_result=9):
-        self.len_source = len_source
-        self.len_result = len_result
-        self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
-        self.operators = F.one_hot(
-            torch.rand(nb_operators, len_result, len_source).argmax(-1),
-            num_classes=len_source,
-        )
-
-    def generate_sequences(self, nb):
-        nb_operators = torch.randint(self.operators.size(0), (nb,))
-        operators = self.operators[nb_operators]
-        nb_operators = (
-            nb_operators[:, None]
-            // 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
-        ) % 10
-        marker1 = torch.full((nb, 1), 10)
-        source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
-        marker2 = torch.full((nb, 1), 11)
-        result = operators.bmm(source[:, :, None]).squeeze(-1)
-        sequences = torch.cat((nb_operators, marker1, source, marker2, result), 1)
-        ar_mask = (sequences == 11).long()
-        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
-        return sequences, ar_mask
-
-    def seq2str(self, seq):
-        return "".join("0123456789|>"[x.item()] for x in seq)
-
-
-####################
-
-
-class ProblemGuessOperator(Problem):
-    def __init__(self, len_source=5, len_result=8):
-        self.len_source = len_source
-        self.len_result = len_result
-
-    def generate_sequences(self, nb):
-        operators = F.one_hot(
-            torch.rand(nb, self.len_result, self.len_source).argmax(-1),
-            num_classes=self.len_source,
-        )
-        source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
-        marker1 = torch.full((nb, 1), 10)
-        result1 = operators.bmm(source1[:, :, None]).squeeze(-1)
-        marker2 = torch.full((nb, 1), 11)
-        source2 = torch.randint(10, (nb, self.len_source))
-        marker3 = torch.full((nb, 1), 12)
-        result2 = operators.bmm(source2[:, :, None]).squeeze(-1)
-
-        sequences = torch.cat(
-            (source1, marker1, result1, marker2, source2, marker3, result2), 1
-        )
-        ar_mask = (sequences == 12).long()
-        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
-        return sequences, ar_mask
-
-    def seq2str(self, seq):
-        return "".join("0123456789>|~"[x.item()] for x in seq)
-
-
-####################
-
-
-class ProblemAddition(Problem):
-    def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
-        self.nb_digits = nb_digits
-        self.zero_padded = zero_padded
-        self.inverted_result = inverted_result
-        self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
-        self.id2char = dict([(n, c) for c, n in self.char2id.items()])
-
-    def tensorize(self, strings):
-        len_max = max([len(x) for x in strings])
-        return torch.cat(
-            [
-                torch.tensor(
-                    [
-                        [self.char2id[c] for c in s + "$" * (len_max - len(s))]
-                        for s in strings
-                    ]
-                )
-            ],
-            0,
-        )
-
-    def generate_sequences(self, nb):
-        sequences = []
-        for k in range(nb):
-            a, b = torch.randint(10**self.nb_digits, (2,))
-            c = a + b
-            a, b, c = str(a.item()), str(b.item()), str(c.item())
-            if self.zero_padded:
-                a = "0" * (self.nb_digits - len(a)) + a
-                b = "0" * (self.nb_digits - len(b)) + b
-                c = "0" * (self.nb_digits + 1 - len(c)) + c
-            if self.inverted_result:
-                c = c[::-1]
-            sequences.append(f"{a}+{b}={c}$")
-
-        sequences = self.tensorize(sequences)
-        ar_mask = (sequences == self.char2id["="]).long()
-        ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
-        return sequences, ar_mask
-
-    def seq2str(self, seq):
-        return "".join(self.id2char[x.item()] for x in seq)
-
-
-####################
-
-
-class ProblemMixing(Problem):
-    def __init__(
-        self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True
-    ):
-        self.height = height
-        self.width = width
-        self.nb_time_steps = nb_time_steps
-        self.hard = hard
-        self.random_start = random_start
-
-    def start_random(self, nb):
-        y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
-
-        if self.random_start:
-            i = (
-                torch.arange(self.height)
-                .reshape(1, -1, 1)
-                .expand(nb, self.height, self.width)
-            )
-            j = (
-                torch.arange(self.width)
-                .reshape(1, 1, -1)
-                .expand(nb, self.height, self.width)
-            )
-
-            ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
-            rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
-
-            m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
-
-            y = y * m + self.height * self.width * (1 - m)
-
-        y = y.reshape(nb, self.height, self.width)
-
-        return y
-
-    def start_error(self, x):
-        if self.random_start:
-            i = (
-                torch.arange(self.height, device=x.device)
-                .reshape(1, -1, 1)
-                .expand_as(x)
-            )
-            j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x)
-
-            ri = (
-                (x == self.height * self.width)
-                .long()
-                .sum(dim=-1)
-                .argmax(-1)
-                .view(-1, 1, 1)
-            )
-            rj = (
-                (x == self.height * self.width)
-                .long()
-                .sum(dim=-2)
-                .argmax(-1)
-                .view(-1, 1, 1)
-            )
-
-            m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
-        else:
-            m = 1
-
-        x = x.flatten(1)
-        u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1)
-
-        d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
-
-        return d
-
-    def moves(self, x):
-        y = (
-            x[:, None, :, :]
-            .expand(-1, self.height * 2 + self.width * 2, -1, -1)
-            .clone()
-        )
-        k = 0
-
-        for i in range(self.height):
-            y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
-            k += 1
-            y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
-            k += 1
-
-        for j in range(self.width):
-            y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
-            k += 1
-            y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
-            k += 1
-
-        return y
-
-    def generate_sequences(self, nb):
-        x = self.start_random(nb)
-
-        seq = [x.flatten(1)]
-
-        for t in range(self.nb_time_steps - 1):
-            y = self.moves(x)
-            x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
-            seq.append(x.flatten(1))
-
-        if self.hard:
-            seq.reverse()
-
-        seq = torch.cat(seq, dim=1)
-        return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
-
-    def compute_nb_correct(self, input, ar_mask, result):
-        a = [
-            x.reshape(result.size(0), self.height, self.width)
-            for x in result.split(self.height * self.width, dim=1)
-        ]
-        if self.hard:
-            a.reverse()
-
-        x = a[0]
-
-        d = self.start_error(x)
-
-        for t in range(self.nb_time_steps - 1):
-            x0, x = a[t], a[t + 1]
-            y = self.moves(x0)
-            d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
-
-        nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
-
-        return nb_total, nb_correct
-
-    def seq2str(self, seq):
-        return " | ".join(
-            [
-                " ".join(
-                    [
-                        "-".join(
-                            [
-                                f"{x:02d}" if x < self.height * self.width else "**"
-                                for x in s
-                            ]
-                        )
-                        for s in r.split(self.width)
-                    ]
-                )
-                for r in seq.split(self.height * self.width)
-            ]
-        )
-
-
-####################
-
-if __name__ == "__main__":
-    p = ProblemMixing(height=3, width=3, random_start=False)
-
-    s, m = p.generate_sequences(10000)
-    for x in s[:5]:
-        print(p.seq2str(x))
-    print(p.compute_nb_correct(None, None, s))
index 8680ba1..f6d34a8 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -14,9 +14,6 @@ from torch.nn import functional as F
 
 from mygpt import BracketedSequence
 
-# from graph import save_attention_image
-save_attention_image = None
-
 ######################################################################
 
 
@@ -252,7 +249,7 @@ class World(Task):
             new_quizzes,
             ar_mask,
             deterministic_synthesis=False,
-            progress_bar_desc="new quizzes",
+            progress_bar_desc="creating quizzes",
             device=self.device,
         )
 
@@ -290,7 +287,7 @@ class World(Task):
                 inverted_result,
                 ar_mask,
                 deterministic_synthesis=True,
-                progress_bar_desc="solving reverse quizzes",
+                progress_bar_desc="solving reversed quizzes",
                 device=self.device,
             )