Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 16 Sep 2024 21:25:33 +0000 (23:25 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 16 Sep 2024 21:25:33 +0000 (23:25 +0200)
main.py

diff --git a/main.py b/main.py
index d21c54b..899a099 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -640,22 +640,22 @@ def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device):
 
 def batch_prediction(input, proba_hints=0.0):
     nb = input.size(0)
-    mask_generate = input.new_zeros(input.size())
-    u = F.one_hot(torch.randint(4, (nb,), device=mask_generate.device), num_classes=4)
-    mask_generate.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+    mask = input.new_zeros(input.size())
+    u = F.one_hot(torch.randint(4, (nb,), device=mask.device), num_classes=4)
+    mask.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
 
     if proba_hints > 0:
-        h = torch.rand(input.size(), device=input.device) * mask_generate
+        h = torch.rand(input.size(), device=input.device) * mask
         mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
         v = torch.rand(nb, device=input.device)[:, None]
         mask_hints = mask_hints * (v < proba_hints).long()
-        mask_generate = (1 - mask_hints) * mask_generate
+        mask = (1 - mask_hints) * mask
 
     # noise = quiz_machine.problem.pure_noise(nb, input.device)
     targets = input
-    input = (1 - mask_generate) * targets  # + mask_generate * noise
+    input = (1 - mask) * targets  # + mask * noise
 
-    return input, targets, mask_generate
+    return input, targets, mask
 
 
 def predict(model, input, targets, mask, local_device=main_device):
@@ -704,10 +704,10 @@ def batch_generation(input):
 
     targets = input
     input = (1 - mask_erased) * input + mask_erased * noise
-    mask_generate = input.new_full(input.size(), 1)
-    mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0
+    mask = input.new_full(input.size(), 1)
+    mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0
 
-    return input, targets, mask_generate
+    return input, targets, mask
 
 
 def prioritized_rand(low):
@@ -721,20 +721,19 @@ def prioritized_rand(low):
 
 def generate(model, nb, local_device=main_device):
     input = quiz_machine.problem.pure_noise(nb, local_device)
-    mask_generate = input.new_full(input.size(), 1)
-    mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0
+    mask = input.new_full(input.size(), 1)
+    mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0
 
     changed = True
-    for it in range(self.diffusion_nb_iterations):
+    for it in range(args.diffusion_nb_iterations):
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits = model(NTC_channel_cat(input, mask_generate))
+            logits = model(NTC_channel_cat(input, mask))
         dist = torch.distributions.categorical.Categorical(logits=logits)
         output = dist.sample()
 
-        r = self.prioritized_rand(input != output)
-        mask_changes = (r <= self.proba_corruption).long()
+        r = prioritized_rand(input != output)
+        mask_changes = (r <= args.diffusion_proba_corruption).long() * mask
         update = (1 - mask_changes) * input + mask_changes * output
-
         if update.equal(input):
             break
         else:
@@ -803,9 +802,13 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
 
 
 def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
+    # train
+
     one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True)
     one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False)
 
+    # predict
+
     quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
     input, targets, mask = batch_prediction(quizzes.to(local_device))
     result = predict(model, input, targets, mask).to("cpu")
@@ -825,6 +828,15 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
     model.test_accuracy = correct.sum() / quizzes.size(0)
 
+    # generate
+
+    result = generate(model, 25).to("cpu")
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir,
+        f"culture_generation_{n_epoch}_{model.id}.png",
+        quizzes=result[:128],
+    )
+
 
 ######################################################################