Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 4 Oct 2024 20:45:42 +0000 (22:45 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 4 Oct 2024 20:45:42 +0000 (22:45 +0200)
attae.py
main.py
world.py

index c04c5d3..2b231de 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -92,10 +92,42 @@ class MHAttention(nn.Module):
 ######################################################################
 
 
+def create_trunk(dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout=0.0):
+    trunk_blocks = []
+
+    for b in range(nb_blocks):
+        trunk_blocks += [
+            WithResidual(
+                nn.LayerNorm((dim_model,)),
+                MHAttention(
+                    dim_model=dim_model,
+                    dim_qk=dim_keys,
+                    dim_v=dim_model // nb_heads,
+                    nb_heads=nb_heads,
+                    attention=vanilla_attention,
+                    attention_dropout=dropout,
+                ),
+            ),
+            WithResidual(
+                nn.LayerNorm((dim_model,)),
+                nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                nn.ReLU(),
+                nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                nn.Dropout(dropout),
+            ),
+        ]
+
+    return nn.Sequential(*trunk_blocks)
+
+
+######################################################################
+
+
 class AttentionAE(nn.Module):
     def __init__(
         self,
-        vocabulary_size,
+        vocabulary_size_in,
+        vocabulary_size_out,
         dim_model,
         dim_keys,
         dim_hidden,
@@ -109,39 +141,24 @@ class AttentionAE(nn.Module):
         assert dim_model % nb_heads == 0
 
         self.embedding = nn.Sequential(
-            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Embedding(vocabulary_size_in, dim_model),
             nn.Dropout(dropout),
         )
 
         self.positional_encoding = VaswaniPositionalEncoding(len_max)
 
-        trunk_blocks = []
-
-        for b in range(nb_blocks):
-            trunk_blocks += [
-                WithResidual(
-                    nn.LayerNorm((dim_model,)),
-                    MHAttention(
-                        dim_model=dim_model,
-                        dim_qk=dim_keys,
-                        dim_v=dim_model // nb_heads,
-                        nb_heads=nb_heads,
-                        attention=vanilla_attention,
-                        attention_dropout=dropout,
-                    ),
-                ),
-                WithResidual(
-                    nn.LayerNorm((dim_model,)),
-                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                    nn.ReLU(),
-                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                    nn.Dropout(dropout),
-                ),
-            ]
-
-        self.trunk = nn.Sequential(*trunk_blocks)
+        self.trunk = create_trunk(
+            dim_model=dim_model,
+            dim_keys=dim_keys,
+            dim_hidden=dim_hidden,
+            nb_heads=nb_heads,
+            nb_blocks=nb_blocks,
+            dropout=dropout,
+        )
 
-        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+        self.readout = nn.Linear(
+            in_features=dim_model, out_features=vocabulary_size_out
+        )
 
         with torch.no_grad():
             for m in self.modules():
@@ -271,7 +288,8 @@ class FunctionalAttentionAE(nn.Module):
 
 if __name__ == "__main__":
     model = FunctionalAttentionAE(
-        vocabulary_size=100,
+        vocabulary_size_in=100,
+        vocabulary_size_out=100,
         dim_model=16,
         dim_keys=64,
         dim_hidden=32,
diff --git a/main.py b/main.py
index d699bc6..f6cf450 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -19,7 +19,7 @@ import threading, subprocess
 
 torch.set_float32_matmul_precision("high")
 
-torch.set_default_dtype(torch.bfloat16)
+torch.set_default_dtype(torch.bfloat16)
 
 ######################################################################
 
@@ -234,18 +234,17 @@ for n in vars(args):
 
 ######################################################################
 
-if args.gpus == "all":
-    gpus_idx = range(torch.cuda.device_count())
+if args.gpus == "none" or not torch.cuda.is_available():
+    gpus = [torch.device("cpu")]
 else:
-    gpus_idx = [int(k) for k in args.gpus.split(",")]
+    if args.gpus == "all":
+        gpus_idx = range(torch.cuda.device_count())
+    else:
+        gpus_idx = [int(k) for k in args.gpus.split(",")]
 
-gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
+    gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
 
-if torch.cuda.is_available():
-    main_device = gpus[0]
-else:
-    assert len(gpus) == 0
-    main_device = torch.device("cpu")
+main_device = gpus[0]
 
 if args.train_batch_size is None:
     args.train_batch_size = args.batch_size
@@ -318,6 +317,8 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
 
 ######################################################################
 
+# IMT stands for image/mask/target
+
 
 def add_hints_imt(imt_set, proba_hints):
     """Set every component of the mask to zero with probability proba,
@@ -356,7 +357,7 @@ def add_input_noise_imt(imt_set, proba_input_noise):
 # Prediction
 
 
-def samples_for_prediction_imt(input):
+def make_imt_samples_for_prediction(input):
     nb = input.size(0)
     masks = input.new_zeros(input.size())
     u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
@@ -386,7 +387,7 @@ def ae_predict(model, imt_set, local_device=main_device):
         imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
 
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits = model(imt[:, 0] * 2 + imt[:, 1])
+            logits = model(imt[:, 0] + imt[:, 1] * vocabulary_size)
         dist = torch.distributions.categorical.Categorical(logits=logits)
         result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample()
         record.append(result)
@@ -421,7 +422,7 @@ def predict_the_four_grids(
 ######################################################################
 
 
-def samples_for_generation_imt(input):
+def make_imt_samples_for_generation(input):
     nb = input.size(0)
     probs_iterations = 0.1 ** torch.linspace(
         0, 1, args.diffusion_nb_iterations, device=input.device
@@ -480,7 +481,7 @@ def ae_generate(model, nb, local_device=main_device):
 
         for input, masks, changed in src:
             with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-                logits = model(input * 2 + masks)
+                logits = model(input + masks * vocabulary_size)
             dist = torch.distributions.categorical.Categorical(logits=logits)
             output = dist.sample()
             r = prioritized_rand(input != output)
@@ -510,7 +511,7 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
     q_p, q_g = quizzes.to(local_device).chunk(2)
 
     # Half of the samples are used to train the prediction.
-    b_p = samples_for_prediction_imt(q_p)
+    b_p = make_imt_samples_for_prediction(q_p)
     # We inject noise in all to avoid drift of the culture toward
     # "finding waldo" type of complexity
     b_p = add_input_noise_imt(b_p, args.proba_input_noise)
@@ -521,7 +522,7 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
 
     # The other half are denoising examples to train the generative
     # process.
-    b_g = samples_for_generation_imt(q_g)
+    b_g = make_imt_samples_for_generation(q_g)
 
     imt_set = torch.cat([b_p, b_g])
     imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
@@ -550,7 +551,7 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
             model.optimizer.zero_grad()
 
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits = model(input * 2 + masks)
+            logits = model(input + masks * vocabulary_size)
 
         loss_per_token = F.cross_entropy(
             logits.transpose(1, 2), targets, reduction="none"
@@ -575,7 +576,7 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de
     # Save some images of the prediction results
 
     quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
-    imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+    imt_set = make_imt_samples_for_prediction(quizzes.to(local_device))
     result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
     masks = imt_set[:, 1].to("cpu")
 
@@ -622,7 +623,7 @@ def one_complete_epoch(
         # c_quizzes=test_c_quizzes,
         c_quiz_multiplier=args.c_quiz_multiplier,
     )
-    imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+    imt_set = make_imt_samples_for_prediction(quizzes.to(local_device))
     result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
     correct = (quizzes == result).min(dim=1).values.long()
 
@@ -855,7 +856,8 @@ def new_model(id=-1):
         raise ValueError(f"Unknown model type {args.model_type}")
 
     model = model_constructor(
-        vocabulary_size=vocabulary_size * 2,
+        vocabulary_size_in=vocabulary_size * 2,
+        vocabulary_size_out=vocabulary_size,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,
         dim_hidden=args.dim_hidden,
@@ -910,6 +912,29 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
+if args.test == "aebn":
+    model = new_model()
+
+    # model.trunk = (
+    # model.trunk[: len(model.trunk) // 2] + model.trunk[len(model.trunk) // 2 :]
+    # )
+
+    model.id = 0
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    model.test_accuracy = 0.0
+    model.nb_epochs = 0
+
+    for n_epoch in range(args.nb_epochs):
+        one_complete_epoch(
+            model,
+            n_epoch,
+            train_c_quizzes=None,
+            test_c_quizzes=None,
+            local_device=main_device,
+        )
+
+######################################################################
+
 train_c_quizzes, test_c_quizzes = None, None
 
 models = []
index 3ab6944..cd34cf8 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -560,31 +560,39 @@ class Grids(problem.Problem):
                 result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow, delta=True
             )
 
-    def detect_rectangles(self, q1, q2):
+    def oracle(self, q1, q2):
         c = torch.arange(self.nb_colors)
-        I = torch.arange(self.height)[None, :, None]
-        J = torch.arange(self.width)[None, :, None]
+        all_i = torch.arange(self.height)[None, :, None]
+        all_j = torch.arange(self.width)[None, :, None]
 
         def corners(q):
             q = q.reshape(-1, self.height, self.width)
             a = (q[:, :, :, None] == c[None, None, None, :]).long()
             mi = a.max(dim=2).values
-            i = mi * I
+            i = mi * all_i
             i1 = (i + (1 - mi) * q.size(1)).min(dim=1).values
             i2 = (i + (1 - mi) * (-1)).max(dim=1).values + 1
             mj = a.max(dim=1).values
-            j = mj * J
+            j = mj * all_j
             j1 = (j + (1 - mj) * q.size(2)).min(dim=1).values
             j2 = (j + (1 - mj) * (-1)).max(dim=1).values + 1
             m = (
-                ((I > i1[:, None, :]) & (I < i2[:, None, :] - 1))[:, :, None, :]
-                & ((J > j1[:, None, :]) & (J < j2[:, None, :] - 1))[:, None, :, :]
+                ((all_i > i1[:, None, :]) & (all_i < i2[:, None, :] - 1))[:, :, None, :]
+                & ((all_j > j1[:, None, :]) & (all_j < j2[:, None, :] - 1))[
+                    :, None, :, :
+                ]
             ).long()
             f = ((a * m).long().sum(dim=(1, 2)) > 0).long()
+            i1[:, 0], i2[:, 0], j1[:, 0], j2[:, 0] = self.height, 0, self.width, 0
             return i1, i2, j1, j2, f
 
+        # Coordinates and frame-shape per grid per color
+        #
+        # NxC
+        #
         q1_i1, q1_i2, q1_j1, q1_j2, q1_f = corners(q1)
         q2_i1, q2_i2, q2_j1, q2_j2, q2_f = corners(q2)
+
         u1, u2 = 0, 0
 
         for _ in range(10):
@@ -592,54 +600,48 @@ class Grids(problem.Problem):
             r2 = q.new_zeros(q1.size(0), self.height, self.width)
 
             m1 = (
-                ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :]))[:, :, None, :]
-                & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[:, None, :, :]
+                ((all_i >= q1_i1[:, None, :]) & (all_i < q1_i2[:, None, :]))[
+                    :, :, None, :
+                ]
+                & ((all_j >= q1_j1[:, None, :]) & (all_j < q1_j2[:, None, :]))[
+                    :, None, :, :
+                ]
             ).long()
 
             f1 = (
-                (
-                    ((I == q1_i1[:, None, :]) | (I == q1_i2[:, None, :] - 1))[
-                        :, :, None, :
-                    ]
-                    & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[
-                        :, None, :, :
-                    ]
-                )
-                | (
-                    ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :] - 1))[
+                m1
+                * (
+                    ((all_i == q1_i1[:, None, :]) | (all_i == q1_i2[:, None, :] - 1))[
                         :, :, None, :
                     ]
-                    & ((J == q1_j1[:, None, :]) | (J == q1_j2[:, None, :] - 1))[
+                    | ((all_j == q1_j1[:, None, :]) | (all_j == q1_j2[:, None, :] - 1))[
                         :, None, :, :
                     ]
-                )
-            ).long()
+                ).long()
+            )
 
             r2 = q.new_zeros(q2.size(0), self.height, self.width)
 
             m2 = (
-                ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :]))[:, :, None, :]
-                & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[:, None, :, :]
+                ((all_i >= q2_i1[:, None, :]) & (all_i < q2_i2[:, None, :]))[
+                    :, :, None, :
+                ]
+                & ((all_j >= q2_j1[:, None, :]) & (all_j < q2_j2[:, None, :]))[
+                    :, None, :, :
+                ]
             ).long()
 
             f2 = (
-                (
-                    ((I == q2_i1[:, None, :]) | (I == q2_i2[:, None, :] - 1))[
-                        :, :, None, :
-                    ]
-                    & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[
-                        :, None, :, :
-                    ]
-                )
-                | (
-                    ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :] - 1))[
+                m2
+                * (
+                    ((all_i == q2_i1[:, None, :]) | (all_i == q2_i2[:, None, :] - 1))[
                         :, :, None, :
                     ]
-                    & ((J == q2_j1[:, None, :]) | (J == q2_j2[:, None, :] - 1))[
+                    | ((all_j == q2_j1[:, None, :]) | (all_j == q2_j2[:, None, :] - 1))[
                         :, None, :, :
                     ]
-                )
-            ).long()
+                ).long()
+            )
 
             for c in torch.randperm(self.nb_colors - 1) + 1:
                 r1[...] = q1_f[:, None, None, c] * (
@@ -661,17 +663,119 @@ class Grids(problem.Problem):
             u1 = (1 - match) * u1 + match * r1
             u2 = (1 - match) * u2 + match * r2
 
-        return u1.flatten(1), u2.flatten(1)
+        ok = (u1.flatten(1) == q1).min(dim=1).values & (u2.flatten(1) == q2).min(
+            dim=1
+        ).values
 
-        # o = F.one_hot(q * (1 - m)).sum(dim=1)
-        # print(o)
-        # print(o.sort(dim=1, descending=True))
-        # c = N x nb_col x 4
+        # q1_i1, q1_i2, q1_j1, q1_j2, q1_f = corners(q1)
+        # q2_i1, q2_i2, q2_j1, q2_j2, q2_f = corners(q2)
+        # NxC
+
+        q1i1, q1i2, q1j1, q1j2, q1f = (
+            q1_i1[:, :, None],
+            q1_i2[:, :, None],
+            q1_j1[:, :, None],
+            q1_j2[:, :, None],
+            q1_f[:, :, None],
+        )
+        q2i1, q2i2, q2j1, q2j2, q2f = (
+            q2_i1[:, None, :],
+            q2_i2[:, None, :],
+            q2_j1[:, None, :],
+            q2_j2[:, None, :],
+            q2_f[:, None, :],
+        )
+
+        match = (
+            (q1i1 < q1i2)
+            & (q2i1 < q2i2)
+            & (q1i1 == q2i1)
+            & (q1i2 == q2i2)
+            & (q1j1 == q2j1)
+            & (q1j2 == q2j2)
+        ).long()
+        translate = (
+            ((q1i1 - q2i1).abs() <= 1)
+            & (q1i1 - q2i1 == q1i2 - q2i2)
+            & ((q1j1 - q2j1).abs() <= 1)
+            & (q1j1 - q2j1 == q1j2 - q2j2)
+            & ((q1i1 - q2i1).abs() + (q1j1 - q2j1).abs() > 0)
+        ).long()
+        grow = (
+            (
+                (q2i1 == q1i1 - 1)
+                & (q2i2 == q1i2 + 1)
+                & (q2j1 == q1j1 - 1)
+                & (q2j2 == q1j2 + 1)
+            )
+            | (
+                (q2i1 == q1i1 + 1)
+                & (q2i2 == q1i2 - 1)
+                & (q2j1 == q1j1 + 1)
+                & (q2j2 == q1j2 - 1)
+            )
+        ).long()
+
+        nb_same_color_not_frame = torch.einsum("ncc->n", match * q2f)
+        nb_change_color = torch.einsum("ncd->n", match * q2f) - nb_same_color_not_frame
+        nb_frame = torch.einsum("ncc->n", match * (1 - q2f))
+        nb_translate = torch.einsum("ncc->n", translate)
+        nb_translate_change_color = torch.einsum("ncd->n", translate) - nb_translate
+        nb_grow = torch.einsum("ncc->n", grow)
+        nb_grow_change_color = torch.einsum("ncd->n", grow) - nb_grow
+
+        print("-------------------------")
+        print("nb_same_color_not_frame", nb_same_color_not_frame)
+        print("nb_change_color", nb_change_color)
+        print("nb_frame", nb_frame)
+        print("nb_translate", nb_translate)
+        print("nb_translate_change_color", nb_translate_change_color)
+        print("nb_grow", nb_grow)
+        print("nb_grow_change_color", nb_grow_change_color)
+
+        # ok = ok & ( <= 1) & (translate.sum(dim=(1,2) == 3)
+
+        # print("match", match, "\n\n")
+
+        # print("translate", translate, "\n\n")
+
+        # print("grow", grow, "\n\n")
+
+        return u1.flatten(1), u2.flatten(1)
 
 
 ######################################################################
 
+def recenv(a):
+    s_row = a.sum(dim=2, keepdim=True)
+    c_row = s_row.cumsum(dim=1)
+    s_col = a.sum(dim=1, keepdim=True)
+    c_col = s_col.cumsum(dim=2)
+    env_row = ((c_row > 0) & ((c_row < c_row[:, -1:, :]) | (s_row > 0))).long()
+    env_col = ((c_col > 0) & ((c_col < c_col[:, :, -1:]) | (s_col > 0))).long()
+    return env_row * env_col
+
+def valid(q1, q2, m1=None, m2=None):
+    if m1 is None:
+        m1=q1.new_zeros(m1.size())
+    if m2 is None:
+        m2=q2.new_zeros(m2.size())
+
+def valid_exact_match(q1, q2, m1, m2, c1, c2):
+    # q1, q2, m1, m2 are NxHxW
+    q1,m1=F.hone_hot(q1,numclasses=self.nb_colors),m1[:,:,:,None]
+    q2,m2=F.hone_hot(q2,numclasses=self.nb_colors),m2[:,:,:,None]
+    a1 = (1 - m1) * q1
+    b1 = (1 - m1) * (1-q1)
+    a2 = (1 - m2) * q2
+    b2 = (1 - m2) * (1-q2)
+    rec_a1 = recenv(a1)
+    rec_a2 = recenv(a2)
+    rec = recenv(1-(1-rec_a1)*(1-rec_a2))
+    ok = rec_a1 * rec * (1-m1) 
+
 if __name__ == "__main__":
+
     import time
 
     grids = Grids()
@@ -679,6 +783,23 @@ if __name__ == "__main__":
     nb, nrow = 64, 4
     nb_rows = 12
 
+    q = grids.generate_w_quizzes_(
+        1,
+        tasks=[
+            grids.task_replace_color,
+            # grids.task_translate,
+            # grids.task_grow,
+            # grids.task_frame,
+        ],
+    )
+
+    q = q.reshape(q.size(0), 4, q.size(1)//4)
+
+    print(q)
+    print(valid_exact_match(q[:,0], q[:,1]
+
+    exit(0)
+
     # c_quizzes = torch.load("/home/fleuret/state.pth")["train_c_quizzes"]
     # c_quizzes = c_quizzes[torch.randperm(c_quizzes.size(0))[: nrow * nb_rows]]
 
@@ -694,7 +815,7 @@ if __name__ == "__main__":
     # )
 
     w_quizzes = grids.generate_w_quizzes_(
-        16,
+        1,
         tasks=[
             grids.task_replace_color,
             grids.task_translate,
@@ -705,8 +826,8 @@ if __name__ == "__main__":
 
     q = w_quizzes.reshape(-1, 4, w_quizzes.size(1) // 4)
     r = q.new_zeros(q.size())
-    r[:, 0], r[:, 1] = grids.detect_rectangles(q[:, 0], q[:, 1])
-    r[:, 2], r[:, 3] = grids.detect_rectangles(q[:, 2], q[:, 3])
+    r[:, 0], r[:, 1] = grids.oracle(q[:, 0], q[:, 1])
+    r[:, 2], r[:, 3] = grids.oracle(q[:, 2], q[:, 3])
 
     grids.save_quizzes_as_image(
         "/tmp",