Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 21:14:15 +0000 (23:14 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 21:14:15 +0000 (23:14 +0200)
grids.py
main.py

index fb31c7d..0613043 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -134,20 +134,20 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations):
 
 
 class Grids(problem.Problem):
-    grid_gray = 64
-    thickness = 1
-    background_gray = 255
-    dots = False
+    grid_gray = 64
+    thickness = 1
+    background_gray = 255
+    dots = False
 
     # grid_gray=240
     # thickness=1
     # background_gray=240
     # dots = False
 
-    grid_gray = 200
-    thickness = 0
-    background_gray = 240
-    dots = True
+    grid_gray = 200
+    thickness = 0
+    background_gray = 240
+    dots = True
 
     named_colors = [
         ("white", [background_gray, background_gray, background_gray]),
@@ -1835,7 +1835,7 @@ if __name__ == "__main__":
         print(t.__name__)
         w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
 
-        w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size())
+        w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size())
 
         grids.save_quizzes_as_image(
             "/tmp",
diff --git a/main.py b/main.py
index 5493b7d..52505de 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -250,7 +250,7 @@ assert args.nb_test_samples % args.batch_size == 0
 ######################################################################
 
 
-def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
+def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
     if c_quizzes is None:
         quizzes = problem.generate_w_quizzes(nb_samples)
         nb_w_quizzes = quizzes.size(0)
@@ -486,7 +486,7 @@ def ae_generate(model, nb, local_device=main_device):
 
 
 def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
-    quizzes = quiz_set(
+    quizzes = generate_quiz_set(
         args.nb_train_samples if train else args.nb_test_samples,
         c_quizzes,
         args.c_quiz_multiplier,
@@ -559,7 +559,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
     # Save some original world quizzes and the full prediction (the four grids)
 
-    quizzes = quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device)
+    quizzes = generate_quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device)
     problem.save_quizzes_as_image(
         args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
     )
@@ -570,7 +570,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
     # Save some images of the prediction results
 
-    quizzes = quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+    quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
     imt_set = samples_for_prediction_imt(quizzes.to(local_device))
     result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
     masks = imt_set[:, 1].to("cpu")