Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 08:35:55 +0000 (10:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 21 Jun 2024 08:35:55 +0000 (10:35 +0200)
main.py
world.py

diff --git a/main.py b/main.py
index d92c4a5..18b19db 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -859,7 +859,7 @@ def one_epoch(model, task, learning_rate):
 ######################################################################
 
 
-def run_tests(model, task):
+def run_tests(model, task, deterministic_synthesis):
     with torch.autograd.no_grad():
         model.eval()
 
@@ -883,7 +883,7 @@ def run_tests(model, task):
             model=model,
             result_dir=args.result_dir,
             logger=log_string,
-            deterministic_synthesis=args.deterministic_synthesis,
+            deterministic_synthesis=deterministic_synthesis,
         )
 
         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
@@ -897,7 +897,9 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     one_epoch(model, task, learning_rate)
 
-    run_tests(model, task)
+    run_tests(model, task, deterministic_synthesis=True)
+
+    # --------------------------------------------
 
     time_current_result = datetime.datetime.now()
     if time_pred_result is not None:
@@ -906,6 +908,8 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
         )
     time_pred_result = time_current_result
 
+    # --------------------------------------------
+
     checkpoint = {
         "nb_epochs_finished": n_epoch + 1,
         "model_state": model.state_dict(),
index ac201e7..97c7b1d 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -34,7 +34,7 @@ def generate(
     nb,
     height,
     width,
-    max_nb_obj=len(colors) - 2,
+    max_nb_obj=colors.size(0) - 2,
     nb_iterations=2,
 ):
     f_start = torch.zeros(nb, height, width, dtype=torch.int64)
@@ -43,7 +43,7 @@ def generate(
 
     for n in range(nb):
         nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1
-        for c in range(nb_fish):
+        for c in torch.randperm(colors.size(0) - 2)[:nb_fish].sort().values:
             i, j = (
                 torch.randint(height - 2, (1,))[0] + 1,
                 torch.randint(width - 2, (1,))[0] + 1,
@@ -117,7 +117,7 @@ if __name__ == "__main__":
 
     height, width = 6, 8
     start_time = time.perf_counter()
-    seq = generate(nb=64, height=height, width=width)
+    seq = generate(nb=64, height=height, width=width, max_nb_obj=3)
     delay = time.perf_counter() - start_time
     print(f"{seq.size(0)/delay:02f} samples/s")